Skip to content

Quantized ops in QuantizationSimModel causing extreme slowdown #3976

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
Eric-Lee opened this issue Apr 11, 2025 · 2 comments
Open

Quantized ops in QuantizationSimModel causing extreme slowdown #3976

Eric-Lee opened this issue Apr 11, 2025 · 2 comments
Assignees

Comments

@Eric-Lee
Copy link

Hi,

AIMET version 1.35.0
Torch 2.1.2+cu121

I am evaluating the use of AIMET to quantize Llama3.2 3B. I use prepare_model (pro) to prepare the model for quantization, create a QuantizationSimModel (v2) and then compute encodings with post_training_tf quantscheme with channelwise encodings. After computing encodings, I am using the sim model in a token generation loop to test the output quality of the quantization. The problem is the forward pass is extremely slow compared to the model without quantization ops. For instance, on a single A100 in FP32 the forward pass of the quant sim model is about 20x slower than forward pass of QuantizationSimModel.get_original_model(model) (which also seems slow but thats a different issue). A 20x slowdown from the quantization ops seems unusual. In both cases, GPU utilization is at 100%

So my question is, is this kind of overhead expected? If not, any guidelines on how to properly configure the QuantizationSimModel model for performance? I am aware that the main purpose of the sim is for quantization and QAT and not model inference but I would still like to use it in my workflow to check the output quality.

@quic-kyunggeu quic-kyunggeu self-assigned this Apr 28, 2025
@quic-kyunggeu
Copy link
Contributor

quic-kyunggeu commented Apr 30, 2025

Hi @Eric-Lee , thanks for filing this issue.
This has been a very hot topic recently, and we are well aware of this issue. (20x is maybe not super uncommon in token generation tasks according to our observation)

Overall, LLM token generation tasks are the least favorable workload for AIMET -- it's heavily memory-bound, and AIMET is adding quantize-dequantize operations everywhere which are also memory-heavy.

In aimet 2.5 planned 5/5, we made some major optimization for LLM token generation.

1. Fold param quantizers

About 50% of the slowdown is coming from on-the-fly weight quantization.

(This QDQ accounts for 50% of the slowdown in token generation)
       |
       V
W --> QDQ ---+
x --> QDQ ---+---- Linear --- QDQ ----> y

So we recently introduced a new API, called sim.fold_param_quantizers, which will basically do the following:

    W_qdq ---+
x --> QDQ ---+---- Linear --- QDQ ----> y

This way, you can speed up your token generation about by ~2x -- which is still way slower than fp16 models, but a huge improvement still.

Expected usage

>>> sim = QuantizationSimModel(...)
>>> type(sim.model[0].weight)
<class 'torch.nn.parameter.Parameter'>
>>> sim.model[0]
QuantizedLinear(
  in_features=10, out_features=10, bias=True
   (param_quantizers): ModuleDict(
   (weight): QuantizeDequantize(shape=(), qmin=-128, qmax=127, symmetric=True)
   (bias): None
 )
)
>>> sim.fold_param_quantizers()
>>> type(sim.model[0].weight)
<class 'aimet_torch.v2.quantization.tensor.DequantizedTensor'>
>>> sim.model[0]
QuantizedLinear(
  in_features=10, out_features=10, bias=True
  (param_quantizers): ModuleDict(
    (weight): None
    (bias): None
  )
)

For more information, see ab7b19d.

2. Using fused quantize-dequantize

In aimet 2.5, we sped up quantize-dequantize kernel itself using PyTorch built-in kernels.
You don't need to make any code changes to enjoy this speed up; this is always enabled by default.
With this optimization, you can enjoy addtional 10-30% speedup beyond param quantizer folding.

For more information, see a444995

@quic-kyunggeu
Copy link
Contributor

With these two techniques, you'll get 2-2.5x speedup in token generation tasks.
With these optimizations applied, if your LLM token generation is about 3-6x slower than fp16 in aimet >= 2.5, you're in the normal range. If not, please let us know.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants