From 893f02692d5d5f0fdd9134501e2ba08536a39458 Mon Sep 17 00:00:00 2001 From: Scott Roy Date: Tue, 22 Apr 2025 17:15:19 -0700 Subject: [PATCH] init --- torchao/quantization/utils.py | 277 +++++++++++++++++++++++++++++++++- 1 file changed, 276 insertions(+), 1 deletion(-) diff --git a/torchao/quantization/utils.py b/torchao/quantization/utils.py index f5bdfa9193..545664bea6 100644 --- a/torchao/quantization/utils.py +++ b/torchao/quantization/utils.py @@ -4,7 +4,7 @@ # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. import importlib.util -from typing import Dict, List, Optional +from typing import Dict, List, Optional, Tuple import torch from torch.utils._python_dispatch import TorchDispatchMode @@ -12,6 +12,13 @@ from torchao.kernel import ( int_scaled_matmul, ) +from torchao.quantization.granularity import ( + Granularity, + PerAxis, + PerGroup, + PerTensor, + PerToken, +) from torchao.quantization.quant_primitives import ( MappingType, ZeroPointDomain, @@ -141,6 +148,24 @@ def _get_per_token_block_size(x: torch.Tensor) -> List[int]: return block_size +def _get_block_size_from_granularity( + x: torch.Tensor, granularity: Granularity +) -> Tuple[int, ...]: + if isinstance(granularity, PerTensor): + block_size = x.shape + elif isinstance(granularity, PerToken): + block_size = _get_per_token_block_size(x) + elif isinstance(granularity, PerGroup): + block_size = list(x.shape) + block_size[-1] = granularity.group_size + elif isinstance(granularity, PerAxis): + block_size = list(x.shape) + block_size[granularity.axis] = 1 + else: + raise ValueError(f"Unsupported granularity: {granularity}") + return tuple(block_size) + + # taken from # https://github.com/mit-han-lab/smoothquant/blob/2f87951dacfb9238d8d657f52ae83a82a3c9ba0c/smoothquant/fake_quant.py#L26 # and slightly modified @@ -631,6 +656,256 @@ def per_token_dynamic_quant( return dq +from torchao.utils import _register_custom_op + +quant_lib = torch.library.Library("quant", "FRAGMENT") +register_custom_op = _register_custom_op(quant_lib) + +_GRANULARITY_STR_AND_KWARG_KEYS = [ + (PerTensor, "PER_TENSOR", []), + (PerAxis, "PER_AXIS", ["axis"]), + (PerGroup, "PER_GROUP", ["group_size"]), + (PerToken, "PER_TOKEN", []), +] + + +def _get_str_and_kwargs_for_granularity(granularity: Granularity): + for cls, cls_str, cls_kwarg_keys in _GRANULARITY_STR_AND_KWARG_KEYS: + if isinstance(granularity, cls): + return cls_str, {k: getattr(granularity, k) for k in cls_kwarg_keys} + raise ValueError(f"Unsupported granularity: {granularity}") + + +def _get_granularity_from_str_and_kwargs(granularity_str, kwargs): + for cls, cls_str, cls_kwarg_keys in _GRANULARITY_STR_AND_KWARG_KEYS: + if cls_str == granularity_str: + return cls(**{k: kwargs[k] for k in cls_kwarg_keys}) + raise ValueError(f"Unsupported granularity: {granularity_str}") + + +def qdq_dynamic_affine_quantization( + input: torch.Tensor, + mapping_type: MappingType, + granularity: Granularity, + quant_min: int, + quant_max: int, + zero_point_domain: ZeroPointDomain, + *, + scale_dtype: Optional[torch.dtype] = None, + zero_point_dtype: Optional[torch.dtype] = None, + eps: Optional[float] = None, +) -> torch.Tensor: + granularity_str, kwargs = _get_str_and_kwargs_for_granularity(granularity) + return _qdq_dynamic_affine_quantization( + input, + mapping_type.name, + granularity_str, + quant_min, + quant_max, + zero_point_domain.name, + scale_dtype=scale_dtype, + zero_point_dtype=zero_point_dtype, + eps=eps, + **kwargs, + ) + + +@register_custom_op +def _qdq_dynamic_affine_quantization( + input: torch.Tensor, + mapping_type: str, + granularity: str, + quant_min: int, + quant_max: int, + zero_point_domain: str, + *, + scale_dtype: Optional[torch.dtype] = None, + zero_point_dtype: Optional[torch.dtype] = None, + eps: Optional[float] = None, + # granularity kwargs + axis: Optional[int] = None, + group_size: Optional[int] = None, +) -> torch.Tensor: + granularity = _get_granularity_from_str_and_kwargs( + granularity, {"axis": axis, "group_size": group_size} + ) + block_size = _get_block_size_from_granularity(input, granularity) + quant_dtype = torch.int8 + assert ( + quant_max <= torch.iinfo(quant_dtype).max + ), f"quant_max {quant_max} is too large for {quant_dtype}" + assert ( + quant_min >= torch.iinfo(quant_dtype).min + ), f"quant_min {quant_min} is too small for {quant_dtype}" + output_dtype = input.dtype + + mapping_type = getattr(MappingType, mapping_type.upper()) + zero_point_domain = getattr(ZeroPointDomain, zero_point_domain.upper()) + preserve_zero = (mapping_type == MappingType.SYMMETRIC) or ( + zero_point_domain == ZeroPointDomain.NONE + ) + scale, zero_point = choose_qparams_affine( + input, + mapping_type, + block_size, + quant_dtype, + quant_min, + quant_max, + eps, + scale_dtype, + zero_point_dtype, + preserve_zero, + zero_point_domain, + ) + q = quantize_affine( + input, + block_size, + scale, + zero_point, + quant_dtype, + quant_min, + quant_max, + zero_point_domain, + ) + dq = dequantize_affine( + q, + block_size, + scale, + zero_point, + quant_dtype, + quant_min, + quant_max, + zero_point_domain, + output_dtype=output_dtype, + ) + return dq + + +def qdq_weight_only_affine_quantized_embedding( + indices: torch.Tensor, + w_int_data: torch.Tensor, + w_quant_min: int, + w_quant_max: int, + w_scale: torch.Tensor, + w_zero_point: Optional[torch.Tensor], + granularity: Granularity, + output_dtype: torch.dtype = torch.float32, +) -> torch.Tensor: + granularity_str, kwargs = _get_str_and_kwargs_for_granularity(granularity) + return _qdq_weight_only_affine_quantized_embedding( + indices, + w_int_data, + w_quant_min, + w_quant_max, + w_scale, + w_zero_point, + granularity_str, + output_dtype, + **kwargs, + ) + + +@register_custom_op +def _qdq_weight_only_affine_quantized_embedding( + indices: torch.Tensor, + w_int_data: torch.Tensor, + w_quant_min: int, + w_quant_max: int, + w_scale: torch.Tensor, + w_zero_point: Optional[torch.Tensor], + granularity: str, + output_dtype: torch.dtype, + *, + # granularity kwargs + axis: Optional[int] = None, + group_size: Optional[int] = None, +) -> torch.Tensor: + granularity = _get_granularity_from_str_and_kwargs( + granularity, {"axis": axis, "group_size": group_size} + ) + block_size = _get_block_size_from_granularity(w_int_data, granularity) + input_dtype = w_int_data.dtype + w_zero_point_domain = ZeroPointDomain.NONE + if w_zero_point is not None: + assert w_zero_point.dtype in [torch.int8, torch.int32, torch.int64] + w_zero_point_domain = ZeroPointDomain.INT + w_dq = dequantize_affine( + w_int_data, + block_size, + w_scale, + w_zero_point, + input_dtype, + w_quant_min, + w_quant_max, + w_zero_point_domain, + output_dtype=output_dtype, + ) + return torch.nn.functional.embedding(indices, w_dq) + + +def qdq_weight_only_affine_quantized_linear( + x: torch.Tensor, + w_int_data: torch.Tensor, + w_quant_min: int, + w_quant_max: int, + w_scale: torch.Tensor, + w_zero_point: Optional[torch.Tensor], + granularity: Granularity, + bias: Optional[torch.Tensor] = None, +) -> torch.Tensor: + granularity_str, kwargs = _get_str_and_kwargs_for_granularity(granularity) + return _qdq_weight_only_affine_quantized_linear( + x, + w_int_data, + w_quant_min, + w_quant_max, + w_scale, + w_zero_point, + granularity_str, + bias, + **kwargs, + ) + + +@register_custom_op +def _qdq_weight_only_affine_quantized_linear( + x: torch.Tensor, + w_int_data: torch.Tensor, + w_quant_min: int, + w_quant_max: int, + w_scale: torch.Tensor, + w_zero_point: Optional[torch.Tensor], + granularity: str, + bias: Optional[torch.Tensor], + *, + # granularity kwargs + axis: Optional[int] = None, + group_size: Optional[int] = None, +) -> torch.Tensor: + granularity = _get_granularity_from_str_and_kwargs( + granularity, {"axis": axis, "group_size": group_size} + ) + block_size = _get_block_size_from_granularity(w_int_data, granularity) + output_dtype = x.dtype + input_dtype = w_int_data.dtype + w_zero_point_domain = ZeroPointDomain.NONE + if w_zero_point is not None: + assert w_zero_point.dtype in [torch.int8, torch.int32, torch.int64] + w_zero_point_domain = ZeroPointDomain.INT + w_dq = dequantize_affine( + w_int_data, + block_size, + w_scale, + w_zero_point, + input_dtype, + w_quant_min, + w_quant_max, + w_zero_point_domain, + output_dtype=output_dtype, + ) + return torch.nn.functional.linear(x, w_dq, bias) + + def recommended_inductor_config_setter(): """ Set inductor config to use the following optimizations which have been showed to improve performance for quantized models: