-
Notifications
You must be signed in to change notification settings - Fork 403
Quantized ops in QuantizationSimModel causing extreme slowdown #3976
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Comments
Hi @Eric-Lee , thanks for filing this issue. Overall, LLM token generation tasks are the least favorable workload for AIMET -- it's heavily memory-bound, and AIMET is adding quantize-dequantize operations everywhere which are also memory-heavy. In aimet 2.5 planned 5/5, we made some major optimization for LLM token generation. 1. Fold param quantizersAbout 50% of the slowdown is coming from on-the-fly weight quantization.
So we recently introduced a new API, called
This way, you can speed up your token generation about by ~2x -- which is still way slower than fp16 models, but a huge improvement still. Expected usage>>> sim = QuantizationSimModel(...)
>>> type(sim.model[0].weight)
<class 'torch.nn.parameter.Parameter'>
>>> sim.model[0]
QuantizedLinear(
in_features=10, out_features=10, bias=True
(param_quantizers): ModuleDict(
(weight): QuantizeDequantize(shape=(), qmin=-128, qmax=127, symmetric=True)
(bias): None
)
)
>>> sim.fold_param_quantizers()
>>> type(sim.model[0].weight)
<class 'aimet_torch.v2.quantization.tensor.DequantizedTensor'>
>>> sim.model[0]
QuantizedLinear(
in_features=10, out_features=10, bias=True
(param_quantizers): ModuleDict(
(weight): None
(bias): None
)
) For more information, see ab7b19d. 2. Using fused quantize-dequantizeIn aimet 2.5, we sped up quantize-dequantize kernel itself using PyTorch built-in kernels. For more information, see a444995 |
With these two techniques, you'll get 2-2.5x speedup in token generation tasks. |
Hi,
AIMET version 1.35.0
Torch 2.1.2+cu121
I am evaluating the use of AIMET to quantize Llama3.2 3B. I use prepare_model (pro) to prepare the model for quantization, create a
QuantizationSimModel
(v2) and then compute encodings withpost_training_tf
quantscheme with channelwise encodings. After computing encodings, I am using the sim model in a token generation loop to test the output quality of the quantization. The problem is the forward pass is extremely slow compared to the model without quantization ops. For instance, on a single A100 in FP32 the forward pass of the quant sim model is about 20x slower than forward pass ofQuantizationSimModel.get_original_model(model)
(which also seems slow but thats a different issue). A 20x slowdown from the quantization ops seems unusual. In both cases, GPU utilization is at 100%So my question is, is this kind of overhead expected? If not, any guidelines on how to properly configure the
QuantizationSimModel
model for performance? I am aware that the main purpose of the sim is for quantization and QAT and not model inference but I would still like to use it in my workflow to check the output quality.The text was updated successfully, but these errors were encountered: