diff --git a/README.md b/README.md index 8c06112451..f490ff0249 100644 --- a/README.md +++ b/README.md @@ -188,11 +188,15 @@ We also support deployment to edge devices through ExecuTorch, for more detail, Post-training quantization can result in a fast and compact model, but may also lead to accuracy degradation. We recommend exploring Quantization-Aware Training (QAT) to overcome this limitation, especially for lower bit-width dtypes such as int4. In collaboration with [TorchTune](https://github.com/pytorch/torchtune/blob/main/recipes/quantization.md#quantization-aware-training-qat), we've developed a QAT recipe that demonstrates significant accuracy improvements over traditional PTQ, recovering **96% of the accuracy degradation on hellaswag and 68% of the perplexity degradation on wikitext** for Llama3 compared to post-training quantization (PTQ). For more details, please refer to the [QAT README](torchao/quantization/qat/README.md) and the [original blog](https://pytorch.org/blog/quantization-aware-training/): ```python -from torchao.quantization import quantize_, Int8DynamicActivationInt4WeightConfig +import torch +from torchao.quantization import quantize_, Int8DynamicActivationIntxWeightConfig, PerGroup from torchao.quantization.qat import QATConfig # prepare -base_config = Int8DynamicActivationInt4WeightConfig(group_size=32) +base_config = Int8DynamicActivationIntxWeightConfig( + weight_dtype=torch.int4, + weight_granularity=PerGroup(32), +) quantize_(my_model, QATConfig(base_config, step="prepare")) # train model (not shown) @@ -266,15 +270,6 @@ The best example we have combining the composability of lower bit dtype with com Our framework makes it straightforward to add tensor parallel support to your custom quantized tensor subclass. Check out our [tensor parallel tutorial](tutorials/developer_api_guide/tensor_parallel.py) to see how a quantized tensor subclass can be extended to support column and row-wise tensor sharding while maintaining compatibility with `torch.compile`. -### Custom Kernels - -We've added support for authoring and releasing [custom ops](./torchao/csrc/) that do not graph break with `torch.compile()`. We have a few examples you can follow - -1. [fp6](torchao/dtypes/floatx/README.md) for 2x faster inference over fp16 with an easy to use API `quantize_(model, FPXWeightOnlyConfig(3, 2))` -2. [2:4 Sparse Marlin GEMM](https://github.com/pytorch/ao/pull/733) 2x speedups for FP16xINT4 kernels even at batch sizes up to 256 -3. [int4 tinygemm unpacker](https://github.com/pytorch/ao/pull/415) which makes it easier to switch quantized backends for inference - -If you believe there's other CUDA kernels we should be taking a closer look at please leave a comment on [this issue](https://github.com/pytorch/ao/issues/697) or feel free to contribute directly to the repo. --> ## 🔗 Integrations diff --git a/test/dtypes/test_affine_quantized_float.py b/test/dtypes/test_affine_quantized_float.py index 54820cb5b3..7f276860b2 100644 --- a/test/dtypes/test_affine_quantized_float.py +++ b/test/dtypes/test_affine_quantized_float.py @@ -274,9 +274,7 @@ def test_fp8_weight_dimension_warning(self): model = ToyLinearModel(10, 25).to(_DEVICE) # 10x25 and 25x10 weights # Set up logging capture - with self.assertLogs( - "torchao.quantization.quant_api", level="INFO" - ) as log_context: + with self.assertLogs("torchao.quantization.utils", level="INFO") as log_context: quantize_( model, Float8DynamicActivationFloat8WeightConfig( diff --git a/test/quantization/test_quant_api.py b/test/quantization/test_quant_api.py index 3f9bc4d345..afe99c4ed0 100644 --- a/test/quantization/test_quant_api.py +++ b/test/quantization/test_quant_api.py @@ -834,9 +834,7 @@ def test_config_deprecation(self): self.assertTrue(len(_warnings) == 1) found_deprecated = False for w in _warnings: - if "will be moving to prototype in a future release" in str( - w.message - ): + if "will be deleted in a future release" in str(w.message): found_deprecated = True self.assertTrue( found_deprecated, f"did not find deprecated warning for {cls}" diff --git a/torchao/prototype/quantization/quant_api.py b/torchao/prototype/quantization/quant_api.py new file mode 100644 index 0000000000..851a3597ac --- /dev/null +++ b/torchao/prototype/quantization/quant_api.py @@ -0,0 +1,555 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# Copyright 2024-2025 Arm Limited and affiliates. +# All rights reserved. +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +""" +Quantization workflow APIs moved from `torch/quantization/quant_api.py` +to prototype. +""" + +import logging +import types +import warnings +from dataclasses import dataclass +from typing import Optional, Tuple, Union + +import torch + +import torchao +from torchao.core.config import AOBaseConfig +from torchao.dtypes import ( + CutlassInt4PackedLayout, + Float8Layout, + Int8DynamicActInt4WeightCPULayout, + MarlinQQQLayout, + PlainLayout, + UintxLayout, + to_affine_quantized_floatx, + to_affine_quantized_intx, + to_marlinqqq_quantized_intx, +) +from torchao.dtypes.utils import Layout +from torchao.float8.config import e4m3_dtype +from torchao.float8.float8_linear import Float8Linear +from torchao.float8.inference import ( + Float8MMConfig, + FP8Granularity, + _normalize_granularity, +) +from torchao.quantization.granularity import ( + PerTensor, +) +from torchao.quantization.linear_activation_quantized_tensor import ( + to_linear_activation_quantized, +) +from torchao.quantization.quant_primitives import ( + _DTYPE_TO_QVALUE_BOUNDS, + MappingType, + ZeroPointDomain, +) +from torchao.quantization.transform_module import ( + register_quantize_module_handler, +) +from torchao.quantization.utils import ( + _fp8_mm_compat, + _linear_extra_repr, + get_block_size, +) +from torchao.quantization.weight_tensor_linear_activation_quantization import ( + to_weight_tensor_with_linear_activation_quantization_metadata, +) +from torchao.utils import ( + is_MI300, + is_sm_at_least_89, +) + +logger = logging.getLogger(__name__) + + +@dataclass +class Int8DynamicActivationInt4WeightConfig(AOBaseConfig): + """Configuration for applying int8 dynamic per token asymmetric activation quantization and int4 per group weight symmetric quantization to linear + This is used to produce a model for executorch backend, but currently executorch did not + support lowering for the quantized model from this flow yet + + Args: + `group_size`: parameter for quantization, controls the granularity of quantization, smaller + size is more fine grained + `layout`: layout type for quantized weight tensor, only supports `MarlinQQQLayout()` and `CutlassInt4PackedLayout()` for now + `mapping_type`: quantization type for weight, controls the weight quantization is symmetric or asymmetric + `act_mapping_type`: quantization type for activation, controls the activation quantization is symmetric or asymmetric + `set_inductor_config`: if True, adjusts `torchinductor` settings to recommended values. + """ + + group_size: int = 32 + layout: Layout = PlainLayout() + mapping_type: MappingType = MappingType.SYMMETRIC + act_mapping_type: MappingType = MappingType.ASYMMETRIC + set_inductor_config: bool = True + + def __post_init__(self): + torch._C._log_api_usage_once( + "torchao.quantization.Int8DynamicActivationInt4WeightConfig" + ) + warnings.warn( + "`Int8DynamicActivationInt4WeightConfig` will be deleted in a future release of torchao. Please see https://github.com/pytorch/ao/issues/2752 for more details." + ) + + +@register_quantize_module_handler(Int8DynamicActivationInt4WeightConfig) +def _int8_dynamic_activation_int4_weight_transform( + module: torch.nn.Module, + config: Int8DynamicActivationInt4WeightConfig, + *, + custom_scale: Optional[torch.Tensor] = None, + custom_zero_point: Optional[torch.Tensor] = None, +): + group_size = config.group_size + layout = config.layout + mapping_type = config.mapping_type + act_mapping_type = config.act_mapping_type + if config.set_inductor_config: + torchao.quantization.utils.recommended_inductor_config_setter() + + weight = module.weight + + if group_size is None or group_size == -1: + group_size = weight.shape[-1] + if weight.shape[-1] % group_size != 0: + return module + + # weight settings + block_size = (1, group_size) + target_dtype = torch.int8 + quant_min = -8 + quant_max = 7 + + # avoid circular import + from torchao.quantization.quant_api import ( + _int4_symm_cutlass_quant, + _int8_asymm_per_token_quant, + _int8_symm_cutlass_quant, + _int8_symm_per_token_quant, + _uint8_asymm_per_token_quant, + ) + + # input settings + if act_mapping_type == MappingType.ASYMMETRIC: + if isinstance(layout, Int8DynamicActInt4WeightCPULayout): + input_quant_func = _uint8_asymm_per_token_quant + else: + input_quant_func = _int8_asymm_per_token_quant + elif act_mapping_type == MappingType.SYMMETRIC: + if isinstance(layout, MarlinQQQLayout): + input_quant_func = _int8_symm_per_token_quant + elif isinstance(layout, CutlassInt4PackedLayout): + input_quant_func = _int8_symm_cutlass_quant + else: + input_quant_func = _int8_symm_per_token_quant + else: + assert False, f"Unsupported activation mapping type: {act_mapping_type}" + + if isinstance(layout, MarlinQQQLayout): + weight = to_marlinqqq_quantized_intx( + weight, block_size, quant_min, quant_max, _layout=layout + ) + elif isinstance(layout, CutlassInt4PackedLayout): + weight = _int4_symm_cutlass_quant(weight) + elif isinstance(layout, Int8DynamicActInt4WeightCPULayout): + weight = to_affine_quantized_intx( + weight, + mapping_type, + block_size, + target_dtype=torch.uint8, + quant_min=0, + quant_max=15, + _layout=layout, + custom_scale=custom_scale, + custom_zero_point=custom_zero_point, + ) + else: + weight = to_affine_quantized_intx( + weight, + mapping_type, + block_size, + target_dtype, + quant_min, + quant_max, + _layout=layout, + custom_scale=custom_scale, + custom_zero_point=custom_zero_point, + ) + weight = to_linear_activation_quantized(weight, input_quant_func) + module.weight = torch.nn.Parameter(weight, requires_grad=False) + module.extra_repr = types.MethodType(_linear_extra_repr, module) + return module + + +@dataclass +class Int4DynamicActivationInt4WeightConfig(AOBaseConfig): + """Applies int4 dynamic per token symmetric activation quantization and int4 per row weight symmetric quantization to linear + + Args: + `layout`: layout type for quantized weight tensor, only supports `MarlinQQQLayout()` and `CutlassInt4PackedLayout()` for now + `mapping_type`: quantization type for weight, controls the weight quantization is symmetric or asymmetric + `act_mapping_type`: quantization type for activation, controls the activation quantization is symmetric or asymmetric + `set_inductor_config`: if True, adjusts `torchinductor` settings to recommended values. + """ + + layout: Layout = CutlassInt4PackedLayout() + mapping_type: MappingType = MappingType.SYMMETRIC + act_mapping_type: MappingType = MappingType.SYMMETRIC + set_inductor_config: bool = True + + def __post_init__(self): + torch._C._log_api_usage_once( + "torchao.quantization.Int4DynamicActivationInt4WeightConfig" + ) + warnings.warn( + "`Int4DynamicActivationInt4WeightConfig` will be deleted in a future release of torchao. Please see https://github.com/pytorch/ao/issues/2752 for more details." + ) + + +@register_quantize_module_handler(Int4DynamicActivationInt4WeightConfig) +def _int4_dynamic_activation_int4_weight_transform( + module: torch.nn.Module, config: Int4DynamicActivationInt4WeightConfig +) -> torch.nn.Module: + weight = module.weight + layout = config.layout + mapping_type = config.mapping_type + act_mapping_type = config.act_mapping_type + if config.set_inductor_config: + torchao.quantization.utils.recommended_inductor_config_setter() + + if not isinstance(layout, CutlassInt4PackedLayout): + raise NotImplementedError( + f"Only CutlassInt4PackedLayout layout is supported. Received {layout}." + ) + if mapping_type != MappingType.SYMMETRIC: + raise NotImplementedError("Only mapping_type=SYMMETRIC is supported.") + if act_mapping_type != MappingType.SYMMETRIC: + raise NotImplementedError("Only act_mapping_type=SYMMETRIC is supported.") + + # avoid circular import + from torchao.quantization.quant_api import _int4_symm_cutlass_quant + + weight = _int4_symm_cutlass_quant(weight) + weight = to_linear_activation_quantized( + weight, + _int4_symm_cutlass_quant, + ) + module.weight = torch.nn.Parameter(weight, requires_grad=False) + module.extra_repr = types.MethodType(_linear_extra_repr, module) + return module + + +@dataclass +class GemliteUIntXWeightOnlyConfig(AOBaseConfig): + """ + applies weight only 4 or 8 bit integer quantization and utilizes the gemlite triton kernel and its associated weight packing format. + This only works for fp16 models. 8 bit quantization is symmetric, 4 bit quantization is asymmetric. + + Args: + `group_size`: parameter for quantization, controls the granularity of quantization, smaller + size is more fine grained + `bit_width`: bit width of the quantized weight. + `packing_bitwidth`: bit width of the packed weight, should be 8 or 32. Can have performance impacts depending on hardware. + `mode`: if set to "dynamic", activations are quantized at runtime; default is "weight_only" (weight-only quantization). + `set_inductor_config`: if True, adjusts `torchinductor` settings to recommended values. + """ + + group_size: Optional[int] = 128 + bit_width: int = 4 + packing_bitwidth: Optional[int] = None + mode: Optional[str] = "weight_only" + set_inductor_config: bool = True + + def __post_init__(self): + torch._C._log_api_usage_once( + "torchao.quantization.GemliteUIntXWeightOnlyConfig" + ) + warnings.warn( + "`GemliteUIntXWeightOnlyConfig` will be deleted in a future release of torchao. Please see https://github.com/pytorch/ao/issues/2752 for more details." + ) + + +@register_quantize_module_handler(GemliteUIntXWeightOnlyConfig) +def _gemlite_uintx_weight_only_transform( + module: torch.nn.Module, config: GemliteUIntXWeightOnlyConfig +): + group_size = config.group_size + bit_width = config.bit_width + packing_bitwidth = config.packing_bitwidth + mode = config.mode + if config.set_inductor_config: + torchao.quantization.utils.recommended_inductor_config_setter() + + weight = module.weight + + from torchao.prototype.dtypes.uintx.gemlite_layout import get_gemlite_aqt_kwargs + + use_hqq = True if bit_width == 4 else False + new_weight = to_affine_quantized_intx( + weight, + **get_gemlite_aqt_kwargs( + weight, + group_size=group_size, + bit_width=bit_width, + packing_bitwidth=packing_bitwidth, + mode=mode, + use_hqq=use_hqq, + ), + ) + module.weight = torch.nn.Parameter(new_weight, requires_grad=False) + module.extra_repr = types.MethodType(_linear_extra_repr, module) + return module + + +@dataclass +class Float8StaticActivationFloat8WeightConfig(AOBaseConfig): + """ + Configuration for applying float8 static symmetric quantization to + + Args: + scale (torch.Tensor): The scale tensor for activation quantization. + activation_dtype (torch.dtype): The target data type for activation quantization. Default is torch.float8_e4m + weight_dtype (torch.dtype): The target data type for weight quantization. Default is torch.float8_e4m + mm_config (Float8MMConfig): Configuration for the matrix multiplication. Default uses fast accumulation. + set_inductor_config (bool): if True, adjusts `torchinductor` settings to recommended values. + """ + + scale: torch.Tensor + activation_dtype: torch.dtype = e4m3_dtype + weight_dtype: torch.dtype = e4m3_dtype + granularity: Optional[ + Union[FP8Granularity, Tuple[FP8Granularity, FP8Granularity]] + ] = None + mm_config: Optional[Float8MMConfig] = Float8MMConfig(use_fast_accum=True) + set_inductor_config: bool = True + + def __post_init__(self): + torch._C._log_api_usage_once( + "torchao.quantization.Float8StaticActivationFloat8WeightConfig" + ) + warnings.warn( + "`Float8StaticActivationFloat8WeightConfig` will be deleted in a future release of torchao. Please see https://github.com/pytorch/ao/issues/2752 for more details." + ) + + +@register_quantize_module_handler(Float8StaticActivationFloat8WeightConfig) +def _float8_static_activation_float8_weight_transform( + module: torch.nn.Module, config: Float8StaticActivationFloat8WeightConfig +): + assert is_sm_at_least_89() or is_MI300(), ( + "Float8 static activation quantization is only supported on CUDA 8.9 and above" + ) + + if isinstance(module, Float8Linear): + # avoid circular import + from torchao.quantization.quant_api import _unwrap_float8_linear + + module = _unwrap_float8_linear(module) + + scale = config.scale + activation_dtype = config.activation_dtype + weight_dtype = config.weight_dtype + granularity = config.granularity + mm_config = config.mm_config + if config.set_inductor_config: + torchao.quantization.utils.recommended_inductor_config_setter() + + weight = module.weight + activation_granularity, weight_granularity = _normalize_granularity(granularity) + assert isinstance(activation_granularity, PerTensor), ( + "Static quantization only supports PerTensor granularity" + ) + + if not _fp8_mm_compat(weight): + # TODO(future PR): this should really throw an exception instead of silently + # not doing what the user asked + return module + block_size = get_block_size(weight.shape, weight_granularity) + quantized_weight = to_affine_quantized_floatx( + input_float=weight, + block_size=block_size, + target_dtype=weight_dtype, + scale_dtype=torch.float32, + _layout=Float8Layout(mm_config=mm_config), + ) + + # prevent circular import + from torchao.quantization.quant_api import _input_activation_quant_func_fp8 + + input_quant_func = _input_activation_quant_func_fp8 + input_quant_kwargs = { + "activation_granularity": activation_granularity, + "activation_dtype": activation_dtype, + } + + quantized_weight = to_weight_tensor_with_linear_activation_quantization_metadata( + quantized_weight, + input_quant_func, + scale=scale, + zero_point=None, + quant_kwargs=input_quant_kwargs, + ) + + module.weight = torch.nn.Parameter(quantized_weight, requires_grad=False) + module.extra_repr = types.MethodType(_linear_extra_repr, module) + return module + + +@dataclass +class UIntXWeightOnlyConfig(AOBaseConfig): + """ + Configuration for applying uintx weight-only asymmetric per-group quantization to linear layers, using uintx quantization where + x is the number of bits specified by `dtype` + + Args: + `dtype`: torch.uint1 to torch.uint7 sub byte dtypes + `group_size`: parameter for quantization, controls the granularity of quantization, smaller + size is more fine grained, defaults to 64 + `pack_dim`: the dimension we use for packing, defaults to -1 + `use_hqq`: whether to use hqq algorithm or the default algorithm to quantize the weight + `set_inductor_config`: if True, adjusts `torchinductor` settings to recommended values. + """ + + dtype: torch.dtype + group_size: int = 64 + pack_dim: int = -1 + use_hqq: bool = False + set_inductor_config: bool = True + + def __post_init__(self): + torch._C._log_api_usage_once("torchao.quantization.UIntXWeightOnlyConfig") + warnings.warn( + "`UIntXWeightOnlyConfig` will be deleted in a future release of torchao. Please see https://github.com/pytorch/ao/issues/2752 for more details." + ) + + +@register_quantize_module_handler(UIntXWeightOnlyConfig) +def _uintx_weight_only_transform( + module: torch.nn.Module, config: UIntXWeightOnlyConfig +): + dtype = config.dtype + group_size = config.group_size + pack_dim = config.pack_dim + use_hqq = config.use_hqq + if config.set_inductor_config: + torchao.quantization.utils.recommended_inductor_config_setter() + + weight = module.weight + + SUPPORTED_DTYPES = { + torch.uint1, + torch.uint2, + torch.uint3, + torch.uint4, + torch.uint5, + torch.uint6, + torch.uint7, + torch.uint8, + } + assert dtype in SUPPORTED_DTYPES, f"Unsupported dtype for hqq: {dtype}" + + mapping_type = MappingType.ASYMMETRIC + block_size = (1, group_size) + + if use_hqq: + if dtype == torch.uint4: + logger.warning( + "Recommended to use `Int4WeightOnlyConfig(group_size, use_hqq=True, version=1)` for the best performance" + ) + quant_min, quant_max = _DTYPE_TO_QVALUE_BOUNDS[dtype] + dtype = torch.uint8 + eps = None + zero_point_dtype = None + zero_point_domain = ZeroPointDomain.FLOAT + preserve_zero = False + _layout = PlainLayout() + else: + quant_min, quant_max = None, None + eps = torch.finfo(torch.float32).eps + zero_point_dtype = torch.int32 + zero_point_domain = ZeroPointDomain.INT + preserve_zero = True + _layout = UintxLayout(dtype=dtype, pack_dim=pack_dim) + + new_weight = to_affine_quantized_intx( + weight, + mapping_type, + block_size, + dtype, + quant_min=quant_min, + quant_max=quant_max, + eps=eps, + zero_point_dtype=zero_point_dtype, + zero_point_domain=zero_point_domain, + preserve_zero=preserve_zero, + _layout=_layout, + use_hqq=use_hqq, + ) + module.weight = torch.nn.Parameter(new_weight, requires_grad=False) + module.extra_repr = types.MethodType(_linear_extra_repr, module) + return module + + +@dataclass +class FPXWeightOnlyConfig(AOBaseConfig): + """Sub-byte floating point dtypes defined by `ebits`: exponent bits and `mbits`: mantissa bits + e.g. fp6_e3_m2, fp6_e2_m3, ... + The packing format and kernels are from the fp6-llm paper: https://arxiv.org/abs/2401.14112 + github repo: https://github.com/usyd-fsalab/fp6_llm, now renamed to quant-llm + For more details for packing please see: :class:`~torchao.dtypes.fpx.FpxTensorCoreAQTTensorImpl` + + This is experimental, will be merged with `to_affine_quantized_floatx` + in the future + """ + + ebits: int + mbits: int + set_inductor_config: bool = True + + def __post_init__(self): + torch._C._log_api_usage_once("torchao.quantization.FPXWeightOnlyConfig") + warnings.warn( + "`FPXWeightOnlyConfig` will be deleted in a future release of torchao. Please see https://github.com/pytorch/ao/issues/2752 for more details." + ) + + +@register_quantize_module_handler(FPXWeightOnlyConfig) +def _fpx_weight_only_transform( + module: torch.nn.Module, config: FPXWeightOnlyConfig +) -> torch.nn.Module: + ebits = config.ebits + mbits = config.mbits + weight = module.weight + if config.set_inductor_config: + torchao.quantization.utils.recommended_inductor_config_setter() + + if isinstance(module, Float8Linear): + # avoid circular import + from torchao.quantization.quant_api import _unwrap_float8_linear + + module = _unwrap_float8_linear(module) + + from torchao.dtypes import to_affine_quantized_fpx + from torchao.prototype.dtypes.floatx import FloatxTensorCoreLayout + + assert weight.dim() == 2, f"floatx only works for 2-d Tensor, got: {weight.dim()}" + out_dim, in_dim = weight.shape + if (in_dim % 64 != 0) or (out_dim % 256 != 0): + logger.info( + f"Skipping floatx quantization float{ebits + mbits + 1}_{ebits}_{mbits} because " + f"the shape is not compatible with the kernel: in_dim={in_dim}, out_dim={out_dim} " + "expected in_dim % 64 == 0 and out_dim % 256 == 0" + ) + return module + + _layout = FloatxTensorCoreLayout(ebits, mbits) + new_weight = to_affine_quantized_fpx(weight, _layout) + module.weight = torch.nn.Parameter(new_weight, requires_grad=False) + module.extra_repr = types.MethodType(_linear_extra_repr, module) + return module diff --git a/torchao/quantization/README.md b/torchao/quantization/README.md index 0f386ba994..9649fe555d 100644 --- a/torchao/quantization/README.md +++ b/torchao/quantization/README.md @@ -139,44 +139,6 @@ quantize_(model, Float8DynamicActivationFloat8WeightConfig(granularity=PerRow()) Per-row scaling is only supported for bfloat16 weight and activation. This API is only tested on H100. Hardware with CUDA compute capability 8.9 or greater is required. -#### A16W6 Floating Point WeightOnly Quantization - -```python -# for torch 2.4+ -from torchao.quantization import quantize_, FPXWeightOnlyConfig -quantize_(model, FPXWeightOnlyConfig(3, 2)) -``` - -You can find more information [here](../dtypes/floatx/README.md). It should be noted where most other TorchAO apis and benchmarks have focused on applying techniques on top of a bf16 model, performance, fp6 works primarily with the fp16 dtype. - -``` - -KleidiAI Int4 Kernels can be utilized on the Arm platform with PyTorch versions 2.6.0 or later by adjusting the quantization parameters as follows: - -```python -from torchao.quantization.quant_api import ( - Int8DynamicActivationIntxWeightConfig, - quantize_, -) -from torchao.quantization.granularity import PerGroup, PerAxis -from torchao.quantization.quant_primitives import MappingType -from torch.profiler import profile, ProfilerActivity, tensorboard_trace_handler - -my_model = Model() - -quantize_( - my_model, - Int8DynamicActivationIntxWeightConfig( - weight_scale_dtype=torch.float32, - weight_granularity=PerGroup(32), # PerAxis is also supported - weight_mapping_type=MappingType.SYMMETRIC_NO_CLIPPING_ERR, # MappingType.SYMMETRIC can also be used but increases error - layout=layout, - weight_dtype=torch.int4, - intx_packing_format="opaque_aten_kleidiai", - ), -) -``` - #### Workaround with `unwrap_tensor_subclass` for `export`, `AOTI` and `torch.compile` If you are using pytorch 2.6 or before, you need to call `unwrap_tensor_subclass` before `torch.export.export` and `aot_compile`: @@ -222,25 +184,6 @@ Marlin QQQ is an optimized GPU kernel that supports W4A8 mixed precision GEMM. F | | w4a8 | 197.45 | 653.50 | 4.79 | 3.31 | | | w4a8-g128 | 187.62 | 640.32 | 4.82 | 3.41 | -### Gemlite Triton -Int4 and Int8 quantization using the [Gemlite Triton](https://github.com/mobiusml/gemlite) kernels. You can try it out with the `quantize_` api as above alongside the constructor `GemliteUIntXWeightOnlyConfig`. An example can be found in `torchao/_models/llama/generate.py`. - -Note: we test on gemlite 0.4.1, but should be able to use any version after that, we'd recommend to use the latest release to get the most recent performance improvements. - -### UINTx Quantization -We're trying to develop kernels for low bit quantization for intx quantization formats. While the current performance is not ideal, we're hoping to continue to iterate on these kernels to improve their performance. - -| Model | Technique | wikitext-perplexity | Tokens/Second | Memory Bandwidth (GB/s) | Peak Memory (GB) | Model Size (GB) | -| ----------- | ----------------------- | ------------------- | ------------- | ----------------------- | ---------------- | --------------- | -| Llama-2-7B | Base (bfloat16) | 12.212 | 107.38 | 1418.93 | 13.88 | 13.21 | -| | uintx-4-64-hqq | 12.775 | 50.99 | 200.08 | 6.29 | 3.92 | -| | uintx-2-8-hqq | 24.500 | 40.25 | 265.95 | 9.24 | 6.61 | -| Llama-3-8B | Base (bfloat16) | 7.441 | 95.64 | 1435.54 | 16.43 | 15.01 | -| | uintx-4-64-hqq | 8.124 | 47.85 | 213.24 | 11.85 | 4.46 | -| | uintx-2-8-hqq | 39.605 | 34.83 | 261.42 | 14.99 | 7.51 | - -You try can out these apis with the `quantize_` api as above alongside the config `UIntXWeightOnlyConfig`. An example can be found in in `torchao/_models/llama/generate.py`. - ### Int8DynamicActivationIntxWeightConfig Quantization We have kernels that do 8-bit dynamic quantization of activations and uintx groupwise quantization of weights. These kernels are experimental and can only be run on a device with an ARM CPU (e.g., a Mac computers with Apple silicon). The benchmarks below were run on an M1 Mac Pro, with 8 perf cores, and 2 efficiency cores, and 32GB of RAM. In all cases, torch.compile was used. @@ -431,6 +374,62 @@ We've added kv cache quantization and other features in order to enable long con In practice these features alongside int4 weight only quantization allow us to **reduce peak memory by ~55%**, meaning we can Llama3.1-8B inference with a **130k context length with only 18.9 GB of peak memory.** More details can be found [here](../../torchao/_models/llama/README.md#KV-Cache-Quantization-Memory-Efficient-Inference) +#### A16W6 Floating Point WeightOnly Quantization + +```python +# for torch 2.4+ +from torchao.quantization import quantize_, FPXWeightOnlyConfig +quantize_(model, FPXWeightOnlyConfig(3, 2)) +``` + +You can find more information [here](../dtypes/floatx/README.md). It should be noted where most other TorchAO apis and benchmarks have focused on applying techniques on top of a bf16 model, performance, fp6 works primarily with the fp16 dtype. + +``` + +KleidiAI Int4 Kernels can be utilized on the Arm platform with PyTorch versions 2.6.0 or later by adjusting the quantization parameters as follows: + +```python +from torchao.quantization.quant_api import ( + Int8DynamicActivationIntxWeightConfig, + quantize_, +) +from torchao.quantization.granularity import PerGroup, PerAxis +from torchao.quantization.quant_primitives import MappingType +from torch.profiler import profile, ProfilerActivity, tensorboard_trace_handler + +my_model = Model() + +quantize_( + my_model, + Int8DynamicActivationIntxWeightConfig( + weight_scale_dtype=torch.float32, + weight_granularity=PerGroup(32), # PerAxis is also supported + weight_mapping_type=MappingType.SYMMETRIC_NO_CLIPPING_ERR, # MappingType.SYMMETRIC can also be used but increases error + layout=layout, + weight_dtype=torch.int4, + intx_packing_format="opaque_aten_kleidiai", + ), +) +``` + +### Gemlite Triton +Int4 and Int8 quantization using the [Gemlite Triton](https://github.com/mobiusml/gemlite) kernels. You can try it out with the `quantize_` api as above alongside the constructor `GemliteUIntXWeightOnlyConfig`. An example can be found in `torchao/_models/llama/generate.py`. + +Note: we test on gemlite 0.4.1, but should be able to use any version after that, we'd recommend to use the latest release to get the most recent performance improvements. + +### UINTx Quantization +We're trying to develop kernels for low bit quantization for intx quantization formats. While the current performance is not ideal, we're hoping to continue to iterate on these kernels to improve their performance. + +| Model | Technique | wikitext-perplexity | Tokens/Second | Memory Bandwidth (GB/s) | Peak Memory (GB) | Model Size (GB) | +| ----------- | ----------------------- | ------------------- | ------------- | ----------------------- | ---------------- | --------------- | +| Llama-2-7B | Base (bfloat16) | 12.212 | 107.38 | 1418.93 | 13.88 | 13.21 | +| | uintx-4-64-hqq | 12.775 | 50.99 | 200.08 | 6.29 | 3.92 | +| | uintx-2-8-hqq | 24.500 | 40.25 | 265.95 | 9.24 | 6.61 | +| Llama-3-8B | Base (bfloat16) | 7.441 | 95.64 | 1435.54 | 16.43 | 15.01 | +| | uintx-4-64-hqq | 8.124 | 47.85 | 213.24 | 11.85 | 4.46 | +| | uintx-2-8-hqq | 39.605 | 34.83 | 261.42 | 14.99 | 7.51 | + +You try can out these apis with the `quantize_` api as above alongside the config `UIntXWeightOnlyConfig`. An example can be found in in `torchao/_models/llama/generate.py`. diff --git a/torchao/quantization/quant_api.py b/torchao/quantization/quant_api.py index bc6b7ccc8d..a0c7e87765 100644 --- a/torchao/quantization/quant_api.py +++ b/torchao/quantization/quant_api.py @@ -38,19 +38,15 @@ Float8Layout, Int4CPULayout, Int4XPULayout, - Int8DynamicActInt4WeightCPULayout, - MarlinQQQLayout, MarlinSparseLayout, PackedLinearInt8DynamicActivationIntxWeightLayout, PlainLayout, QDQLayout, SemiSparseLayout, TensorCoreTiledLayout, - UintxLayout, to_affine_quantized_floatx, to_affine_quantized_floatx_static, to_affine_quantized_intx, - to_marlinqqq_quantized_intx, ) from torchao.dtypes.uintx.packed_linear_int8_dynamic_activation_intx_weight_layout import ( Target, @@ -66,6 +62,17 @@ _granularity_is_a_1_128_w_128_128, _normalize_granularity, ) + +# for BC, make sure to keep the `noqa: F401` comments to prevent +# ruff from removing "unused imports" +from torchao.prototype.quantization.quant_api import ( + Float8StaticActivationFloat8WeightConfig, # noqa: F401 + FPXWeightOnlyConfig, # noqa: F401 + GemliteUIntXWeightOnlyConfig, # noqa: F401 + Int4DynamicActivationInt4WeightConfig, # noqa: F401 + Int8DynamicActivationInt4WeightConfig, # noqa: F401 + UIntXWeightOnlyConfig, # noqa: F401 +) from torchao.quantization.linear_activation_weight_observed_tensor import ( LinearActivationWeightObservedTensor, ) @@ -94,9 +101,11 @@ _QUANTIZE_CONFIG_HANDLER, register_quantize_module_handler, ) -from torchao.quantization.utils import get_block_size -from torchao.quantization.weight_tensor_linear_activation_quantization import ( - to_weight_tensor_with_linear_activation_quantization_metadata, +from torchao.quantization.utils import ( + _fp8_mm_compat, + _linear_extra_repr, + _quantization_type, + get_block_size, ) from torchao.utils import ( is_MI300, @@ -380,26 +389,6 @@ def convert_to_linear_observer(linear_module: nn.Linear): ) -def _quantization_type(weight: torch.Tensor): - if isinstance(weight, AffineQuantizedTensor): - return f"{weight.__class__.__name__}({weight._quantization_type()})" - - if isinstance(weight, LinearActivationQuantizedTensor): - return f"{weight.__class__.__name__}(activation={weight.input_quant_func}, weight={_quantization_type(weight.original_weight_tensor)})" - - if hasattr(weight, "_quantization_type"): - return f"{weight.__class__.__name__}({weight._quantization_type()})" - - if type(weight) is torch.Tensor or isinstance(weight, torch.nn.Parameter): - return f"Tensor: {type(weight)}" - - return f"not recognized: {type(weight)}" - - -def _linear_extra_repr(self): - return f"in_features={self.weight.shape[1]}, out_features={self.weight.shape[0]}, weight={_quantization_type(self.weight)}" - - def _embedding_extra_repr(self): return f"num_embeddings={self.weight.shape[0]}, embedding_dim={self.weight.shape[1]}, weight={_quantization_type(self.weight)}" @@ -463,7 +452,6 @@ def quantize_( # optimized execution paths or kernels (e.g. int4 tinygemm kernel) # also customizable with arguments # currently options are - # Int8DynamicActivationInt4WeightConfig (for executorch) # Int8DynamicActivationInt8WeightConfig (optimized with int8 mm op and torch.compile) # Int4WeightOnlyConfig (optimized with int4 tinygemm kernel and torch.compile) # Int8WeightOnlyConfig (optimized with int8 mm op and torch.compile @@ -569,116 +557,6 @@ def _int8_symm_per_token_quant(x: torch.Tensor) -> torch.Tensor: ) -@dataclass -class Int8DynamicActivationInt4WeightConfig(AOBaseConfig): - """Configuration for applying int8 dynamic per token asymmetric activation quantization and int4 per group weight symmetric quantization to linear - This is used to produce a model for executorch backend, but currently executorch did not - support lowering for the quantized model from this flow yet - - Args: - `group_size`: parameter for quantization, controls the granularity of quantization, smaller - size is more fine grained - `layout`: layout type for quantized weight tensor, only supports `MarlinQQQLayout()` and `CutlassInt4PackedLayout()` for now - `mapping_type`: quantization type for weight, controls the weight quantization is symmetric or asymmetric - `act_mapping_type`: quantization type for activation, controls the activation quantization is symmetric or asymmetric - `set_inductor_config`: if True, adjusts `torchinductor` settings to recommended values. - """ - - group_size: int = 32 - layout: Layout = PlainLayout() - mapping_type: MappingType = MappingType.SYMMETRIC - act_mapping_type: MappingType = MappingType.ASYMMETRIC - set_inductor_config: bool = True - - def __post_init__(self): - torch._C._log_api_usage_once( - "torchao.quantization.Int8DynamicActivationInt4WeightConfig" - ) - warnings.warn( - "`Int8DynamicActivationInt4WeightConfig` will be moving to prototype in a future release of torchao. Please see https://github.com/pytorch/ao/issues/2752 for more details." - ) - - -@register_quantize_module_handler(Int8DynamicActivationInt4WeightConfig) -def _int8_dynamic_activation_int4_weight_transform( - module: torch.nn.Module, - config: Int8DynamicActivationInt4WeightConfig, - *, - custom_scale: Optional[torch.Tensor] = None, - custom_zero_point: Optional[torch.Tensor] = None, -): - group_size = config.group_size - layout = config.layout - mapping_type = config.mapping_type - act_mapping_type = config.act_mapping_type - if config.set_inductor_config: - torchao.quantization.utils.recommended_inductor_config_setter() - - weight = module.weight - - if group_size is None or group_size == -1: - group_size = weight.shape[-1] - if weight.shape[-1] % group_size != 0: - return module - - # weight settings - block_size = (1, group_size) - target_dtype = torch.int8 - quant_min = -8 - quant_max = 7 - - # input settings - if act_mapping_type == MappingType.ASYMMETRIC: - if isinstance(layout, Int8DynamicActInt4WeightCPULayout): - input_quant_func = _uint8_asymm_per_token_quant - else: - input_quant_func = _int8_asymm_per_token_quant - elif act_mapping_type == MappingType.SYMMETRIC: - if isinstance(layout, MarlinQQQLayout): - input_quant_func = _int8_symm_per_token_quant - elif isinstance(layout, CutlassInt4PackedLayout): - input_quant_func = _int8_symm_cutlass_quant - else: - input_quant_func = _int8_symm_per_token_quant - else: - assert False, f"Unsupported activation mapping type: {act_mapping_type}" - - if isinstance(layout, MarlinQQQLayout): - weight = to_marlinqqq_quantized_intx( - weight, block_size, quant_min, quant_max, _layout=layout - ) - elif isinstance(layout, CutlassInt4PackedLayout): - weight = _int4_symm_cutlass_quant(weight) - elif isinstance(layout, Int8DynamicActInt4WeightCPULayout): - weight = to_affine_quantized_intx( - weight, - mapping_type, - block_size, - target_dtype=torch.uint8, - quant_min=0, - quant_max=15, - _layout=layout, - custom_scale=custom_scale, - custom_zero_point=custom_zero_point, - ) - else: - weight = to_affine_quantized_intx( - weight, - mapping_type, - block_size, - target_dtype, - quant_min, - quant_max, - _layout=layout, - custom_scale=custom_scale, - custom_zero_point=custom_zero_point, - ) - weight = to_linear_activation_quantized(weight, input_quant_func) - module.weight = torch.nn.Parameter(weight, requires_grad=False) - module.extra_repr = types.MethodType(_linear_extra_repr, module) - return module - - @dataclass class Int8DynamicActivationIntxWeightConfig(AOBaseConfig): """ @@ -942,123 +820,6 @@ def _int8_dynamic_activation_intx_weight_transform( return module -@dataclass -class Int4DynamicActivationInt4WeightConfig(AOBaseConfig): - """Applies int4 dynamic per token symmetric activation quantization and int4 per row weight symmetric quantization to linear - - Args: - `layout`: layout type for quantized weight tensor, only supports `MarlinQQQLayout()` and `CutlassInt4PackedLayout()` for now - `mapping_type`: quantization type for weight, controls the weight quantization is symmetric or asymmetric - `act_mapping_type`: quantization type for activation, controls the activation quantization is symmetric or asymmetric - `set_inductor_config`: if True, adjusts `torchinductor` settings to recommended values. - """ - - layout: Layout = CutlassInt4PackedLayout() - mapping_type: MappingType = MappingType.SYMMETRIC - act_mapping_type: MappingType = MappingType.SYMMETRIC - set_inductor_config: bool = True - - def __post_init__(self): - torch._C._log_api_usage_once( - "torchao.quantization.Int4DynamicActivationInt4WeightConfig" - ) - warnings.warn( - "`Int4DynamicActivationInt4WeightConfig` will be moving to prototype in a future release of torchao. Please see https://github.com/pytorch/ao/issues/2752 for more details." - ) - - -@register_quantize_module_handler(Int4DynamicActivationInt4WeightConfig) -def _int4_dynamic_activation_int4_weight_transform( - module: torch.nn.Module, config: Int4DynamicActivationInt4WeightConfig -) -> torch.nn.Module: - weight = module.weight - layout = config.layout - mapping_type = config.mapping_type - act_mapping_type = config.act_mapping_type - if config.set_inductor_config: - torchao.quantization.utils.recommended_inductor_config_setter() - - if not isinstance(layout, CutlassInt4PackedLayout): - raise NotImplementedError( - f"Only CutlassInt4PackedLayout layout is supported. Received {layout}." - ) - if mapping_type != MappingType.SYMMETRIC: - raise NotImplementedError("Only mapping_type=SYMMETRIC is supported.") - if act_mapping_type != MappingType.SYMMETRIC: - raise NotImplementedError("Only act_mapping_type=SYMMETRIC is supported.") - - weight = _int4_symm_cutlass_quant(weight) - weight = to_linear_activation_quantized( - weight, - _int4_symm_cutlass_quant, - ) - module.weight = torch.nn.Parameter(weight, requires_grad=False) - module.extra_repr = types.MethodType(_linear_extra_repr, module) - return module - - -@dataclass -class GemliteUIntXWeightOnlyConfig(AOBaseConfig): - """ - applies weight only 4 or 8 bit integer quantization and utilizes the gemlite triton kernel and its associated weight packing format. - This only works for fp16 models. 8 bit quantization is symmetric, 4 bit quantization is asymmetric. - - Args: - `group_size`: parameter for quantization, controls the granularity of quantization, smaller - size is more fine grained - `bit_width`: bit width of the quantized weight. - `packing_bitwidth`: bit width of the packed weight, should be 8 or 32. Can have performance impacts depending on hardware. - `mode`: if set to "dynamic", activations are quantized at runtime; default is "weight_only" (weight-only quantization). - `set_inductor_config`: if True, adjusts `torchinductor` settings to recommended values. - """ - - group_size: Optional[int] = 128 - bit_width: int = 4 - packing_bitwidth: Optional[int] = None - mode: Optional[str] = "weight_only" - set_inductor_config: bool = True - - def __post_init__(self): - torch._C._log_api_usage_once( - "torchao.quantization.GemliteUIntXWeightOnlyConfig" - ) - warnings.warn( - "`GemliteUIntXWeightOnlyConfig` will be moving to prototype in a future release of torchao. Please see https://github.com/pytorch/ao/issues/2752 for more details." - ) - - -@register_quantize_module_handler(GemliteUIntXWeightOnlyConfig) -def _gemlite_uintx_weight_only_transform( - module: torch.nn.Module, config: GemliteUIntXWeightOnlyConfig -): - group_size = config.group_size - bit_width = config.bit_width - packing_bitwidth = config.packing_bitwidth - mode = config.mode - if config.set_inductor_config: - torchao.quantization.utils.recommended_inductor_config_setter() - - weight = module.weight - - from torchao.prototype.dtypes.uintx.gemlite_layout import get_gemlite_aqt_kwargs - - use_hqq = True if bit_width == 4 else False - new_weight = to_affine_quantized_intx( - weight, - **get_gemlite_aqt_kwargs( - weight, - group_size=group_size, - bit_width=bit_width, - packing_bitwidth=packing_bitwidth, - mode=mode, - use_hqq=use_hqq, - ), - ) - module.weight = torch.nn.Parameter(new_weight, requires_grad=False) - module.extra_repr = types.MethodType(_linear_extra_repr, module) - return module - - @dataclass class Int4WeightOnlyConfig(AOBaseConfig): """ @@ -1814,33 +1575,6 @@ def _input_activation_quant_func_fp8( return activation -def _fp8_mm_compat(weight: torch.Tensor) -> bool: - """ - Check if a weight tensor meets float8 quantization requirements. - - Args: - weight (torch.Tensor): The weight tensor to check - - Returns: - bool: True if the tensor can be quantized to float8, False otherwise - """ - assert weight.dim() in [ - 2, - 3, - ], f"float8 quantization only works for 2/3-D tensors, got {weight.dim()}D tensor" - - out_dim, in_dim = weight.shape[-2:] - is_compatible = (in_dim % 16 == 0) and (out_dim % 16 == 0) - - if not is_compatible: - logger.info( - f"Skipping float8 quantization: weight shape {weight.shape} is not compatible with _scaled_mm. " - f"Both input dimension ({in_dim}) and output dimension ({out_dim}) must be multiples of 16. " - ) - - return is_compatible - - @dataclass class Float8DynamicActivationFloat8WeightConfig(AOBaseConfig): """ @@ -2072,191 +1806,6 @@ def _float8_dynamic_activation_float8_semi_sparse_weight_transform( return module -@dataclass -class Float8StaticActivationFloat8WeightConfig(AOBaseConfig): - """ - Configuration for applying float8 static symmetric quantization to - - Args: - scale (torch.Tensor): The scale tensor for activation quantization. - activation_dtype (torch.dtype): The target data type for activation quantization. Default is torch.float8_e4m - weight_dtype (torch.dtype): The target data type for weight quantization. Default is torch.float8_e4m - mm_config (Float8MMConfig): Configuration for the matrix multiplication. Default uses fast accumulation. - set_inductor_config (bool): if True, adjusts `torchinductor` settings to recommended values. - """ - - scale: torch.Tensor - activation_dtype: torch.dtype = e4m3_dtype - weight_dtype: torch.dtype = e4m3_dtype - granularity: Optional[ - Union[FP8Granularity, Tuple[FP8Granularity, FP8Granularity]] - ] = None - mm_config: Optional[Float8MMConfig] = Float8MMConfig(use_fast_accum=True) - set_inductor_config: bool = True - - def __post_init__(self): - torch._C._log_api_usage_once( - "torchao.quantization.Float8StaticActivationFloat8WeightConfig" - ) - warnings.warn( - "`Float8StaticActivationFloat8WeightConfig` will be moving to prototype in a future release of torchao. Please see https://github.com/pytorch/ao/issues/2752 for more details." - ) - - -@register_quantize_module_handler(Float8StaticActivationFloat8WeightConfig) -def _float8_static_activation_float8_weight_transform( - module: torch.nn.Module, config: Float8StaticActivationFloat8WeightConfig -): - assert is_sm_at_least_89() or is_MI300(), ( - "Float8 static activation quantization is only supported on CUDA 8.9 and above" - ) - - if isinstance(module, Float8Linear): - module = _unwrap_float8_linear(module) - - scale = config.scale - activation_dtype = config.activation_dtype - weight_dtype = config.weight_dtype - granularity = config.granularity - mm_config = config.mm_config - if config.set_inductor_config: - torchao.quantization.utils.recommended_inductor_config_setter() - - weight = module.weight - activation_granularity, weight_granularity = _normalize_granularity(granularity) - assert isinstance(activation_granularity, PerTensor), ( - "Static quantization only supports PerTensor granularity" - ) - - if not _fp8_mm_compat(weight): - # TODO(future PR): this should really throw an exception instead of silently - # not doing what the user asked - return module - block_size = get_block_size(weight.shape, weight_granularity) - quantized_weight = to_affine_quantized_floatx( - input_float=weight, - block_size=block_size, - target_dtype=weight_dtype, - scale_dtype=torch.float32, - _layout=Float8Layout(mm_config=mm_config), - ) - - input_quant_func = _input_activation_quant_func_fp8 - input_quant_kwargs = { - "activation_granularity": activation_granularity, - "activation_dtype": activation_dtype, - } - - quantized_weight = to_weight_tensor_with_linear_activation_quantization_metadata( - quantized_weight, - input_quant_func, - scale=scale, - zero_point=None, - quant_kwargs=input_quant_kwargs, - ) - - module.weight = torch.nn.Parameter(quantized_weight, requires_grad=False) - module.extra_repr = types.MethodType(_linear_extra_repr, module) - return module - - -@dataclass -class UIntXWeightOnlyConfig(AOBaseConfig): - """ - Configuration for applying uintx weight-only asymmetric per-group quantization to linear layers, using uintx quantization where - x is the number of bits specified by `dtype` - - Args: - `dtype`: torch.uint1 to torch.uint7 sub byte dtypes - `group_size`: parameter for quantization, controls the granularity of quantization, smaller - size is more fine grained, defaults to 64 - `pack_dim`: the dimension we use for packing, defaults to -1 - `use_hqq`: whether to use hqq algorithm or the default algorithm to quantize the weight - `set_inductor_config`: if True, adjusts `torchinductor` settings to recommended values. - """ - - dtype: torch.dtype - group_size: int = 64 - pack_dim: int = -1 - use_hqq: bool = False - set_inductor_config: bool = True - - def __post_init__(self): - torch._C._log_api_usage_once("torchao.quantization.UIntXWeightOnlyConfig") - warnings.warn( - "`UIntXWeightOnlyConfig` will be moving to prototype in a future release of torchao. Please see https://github.com/pytorch/ao/issues/2752 for more details." - ) - - -@register_quantize_module_handler(UIntXWeightOnlyConfig) -def _uintx_weight_only_transform( - module: torch.nn.Module, config: UIntXWeightOnlyConfig -): - dtype = config.dtype - group_size = config.group_size - pack_dim = config.pack_dim - use_hqq = config.use_hqq - if config.set_inductor_config: - torchao.quantization.utils.recommended_inductor_config_setter() - - weight = module.weight - - from torchao.quantization.quant_primitives import _DTYPE_TO_QVALUE_BOUNDS - - SUPPORTED_DTYPES = { - torch.uint1, - torch.uint2, - torch.uint3, - torch.uint4, - torch.uint5, - torch.uint6, - torch.uint7, - torch.uint8, - } - assert dtype in SUPPORTED_DTYPES, f"Unsupported dtype for hqq: {dtype}" - - mapping_type = MappingType.ASYMMETRIC - block_size = (1, group_size) - - if use_hqq: - if dtype == torch.uint4: - logger.warning( - "Recommended to use `Int4WeightOnlyConfig(group_size, use_hqq=True, version=1)` for the best performance" - ) - quant_min, quant_max = _DTYPE_TO_QVALUE_BOUNDS[dtype] - dtype = torch.uint8 - eps = None - zero_point_dtype = None - zero_point_domain = ZeroPointDomain.FLOAT - preserve_zero = False - _layout = PlainLayout() - else: - quant_min, quant_max = None, None - eps = torch.finfo(torch.float32).eps - zero_point_dtype = torch.int32 - zero_point_domain = ZeroPointDomain.INT - preserve_zero = True - _layout = UintxLayout(dtype=dtype, pack_dim=pack_dim) - - new_weight = to_affine_quantized_intx( - weight, - mapping_type, - block_size, - dtype, - quant_min=quant_min, - quant_max=quant_max, - eps=eps, - zero_point_dtype=zero_point_dtype, - zero_point_domain=zero_point_domain, - preserve_zero=preserve_zero, - _layout=_layout, - use_hqq=use_hqq, - ) - module.weight = torch.nn.Parameter(new_weight, requires_grad=False) - module.extra_repr = types.MethodType(_linear_extra_repr, module) - return module - - def _adjust_scale_dtype_in_intx_unpacked_tensor( intx_unpacked_tensor: IntxUnpackedToInt8Tensor, hp_tensor: torch.Tensor, @@ -2458,62 +2007,6 @@ def _intx_weight_only_transform( return module -@dataclass -class FPXWeightOnlyConfig(AOBaseConfig): - """Sub-byte floating point dtypes defined by `ebits`: exponent bits and `mbits`: mantissa bits - e.g. fp6_e3_m2, fp6_e2_m3, ... - The packing format and kernels are from the fp6-llm paper: https://arxiv.org/abs/2401.14112 - github repo: https://github.com/usyd-fsalab/fp6_llm, now renamed to quant-llm - For more details for packing please see: :class:`~torchao.dtypes.fpx.FpxTensorCoreAQTTensorImpl` - - This is experimental, will be merged with `to_affine_quantized_floatx` - in the future - """ - - ebits: int - mbits: int - set_inductor_config: bool = True - - def __post_init__(self): - torch._C._log_api_usage_once("torchao.quantization.FPXWeightOnlyConfig") - warnings.warn( - "`FPXWeightOnlyConfig` will be moving to prototype in a future release of torchao. Please see https://github.com/pytorch/ao/issues/2752 for more details." - ) - - -@register_quantize_module_handler(FPXWeightOnlyConfig) -def _fpx_weight_only_transform( - module: torch.nn.Module, config: FPXWeightOnlyConfig -) -> torch.nn.Module: - ebits = config.ebits - mbits = config.mbits - weight = module.weight - if config.set_inductor_config: - torchao.quantization.utils.recommended_inductor_config_setter() - - if isinstance(module, Float8Linear): - module = _unwrap_float8_linear(module) - - from torchao.dtypes import to_affine_quantized_fpx - from torchao.prototype.dtypes.floatx import FloatxTensorCoreLayout - - assert weight.dim() == 2, f"floatx only works for 2-d Tensor, got: {weight.dim()}" - out_dim, in_dim = weight.shape - if (in_dim % 64 != 0) or (out_dim % 256 != 0): - logger.info( - f"Skipping floatx quantization float{ebits + mbits + 1}_{ebits}_{mbits} because " - f"the shape is not compatible with the kernel: in_dim={in_dim}, out_dim={out_dim} " - "expected in_dim % 64 == 0 and out_dim % 256 == 0" - ) - return module - - _layout = FloatxTensorCoreLayout(ebits, mbits) - new_weight = to_affine_quantized_fpx(weight, _layout) - module.weight = torch.nn.Parameter(new_weight, requires_grad=False) - module.extra_repr = types.MethodType(_linear_extra_repr, module) - return module - - @dataclass class FqnToConfig(AOBaseConfig): """Configuration class for applying different quantization configs to modules or parameters based on their fully qualified names (FQNs). diff --git a/torchao/quantization/utils.py b/torchao/quantization/utils.py index 1a0375f3d2..cbc8c1acc4 100644 --- a/torchao/quantization/utils.py +++ b/torchao/quantization/utils.py @@ -3,6 +3,7 @@ # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. +import logging from typing import Dict, List, Optional, Tuple import torch @@ -38,6 +39,9 @@ PerTensor, PerToken, ) +from .linear_activation_quantized_tensor import ( + LinearActivationQuantizedTensor, +) __all__ = [ "compute_error", @@ -58,6 +62,8 @@ "recommended_inductor_config_setter", ] +logger = logging.getLogger(__name__) + # basic SQNR def compute_error(x, y): @@ -735,3 +741,53 @@ def get_block_size( ) return (1,) * (len(input_shape) - 1) + (granularity.group_size,) raise ValueError(f"Unsupported Granularity: {granularity}") + + +def _quantization_type(weight: torch.Tensor): + # prevent circular import + from torchao.dtypes import AffineQuantizedTensor + + if isinstance(weight, AffineQuantizedTensor): + return f"{weight.__class__.__name__}({weight._quantization_type()})" + + if isinstance(weight, LinearActivationQuantizedTensor): + return f"{weight.__class__.__name__}(activation={weight.input_quant_func}, weight={_quantization_type(weight.original_weight_tensor)})" + + if hasattr(weight, "_quantization_type"): + return f"{weight.__class__.__name__}({weight._quantization_type()})" + + if type(weight) is torch.Tensor or isinstance(weight, torch.nn.Parameter): + return f"Tensor: {type(weight)}" + + return f"not recognized: {type(weight)}" + + +def _linear_extra_repr(self): + return f"in_features={self.weight.shape[1]}, out_features={self.weight.shape[0]}, weight={_quantization_type(self.weight)}" + + +def _fp8_mm_compat(weight: torch.Tensor) -> bool: + """ + Check if a weight tensor meets float8 quantization requirements. + + Args: + weight (torch.Tensor): The weight tensor to check + + Returns: + bool: True if the tensor can be quantized to float8, False otherwise + """ + assert weight.dim() in [ + 2, + 3, + ], f"float8 quantization only works for 2/3-D tensors, got {weight.dim()}D tensor" + + out_dim, in_dim = weight.shape[-2:] + is_compatible = (in_dim % 16 == 0) and (out_dim % 16 == 0) + + if not is_compatible: + logger.info( + f"Skipping float8 quantization: weight shape {weight.shape} is not compatible with _scaled_mm. " + f"Both input dimension ({in_dim}) and output dimension ({out_dim}) must be multiples of 16. " + ) + + return is_compatible