diff --git a/test/prototype/test_smoothquant.py b/test/prototype/test_smoothquant.py index faf8231342..6166b80af9 100644 --- a/test/prototype/test_smoothquant.py +++ b/test/prototype/test_smoothquant.py @@ -15,12 +15,12 @@ ) from torchao.prototype.smoothquant.core import SmoothQuantStep from torchao.quantization import quantize_ -from torchao.quantization.linear_activation_scale import ( - WeightTensorWithLinearActivationScaleMetadata, -) +from torchao.quantization.granularity import PerRow, PerTensor from torchao.quantization.quant_api import ( Int8DynamicActivationInt8WeightConfig, + Int8StaticActivationInt8WeightConfig, ) +from torchao.quantization.quantize_.common import SupportsActivationPreScaling from torchao.quantization.utils import ( compute_error as SQNR, ) @@ -83,7 +83,9 @@ def setUpClass(cls): @common_utils.parametrize( "base_config", [ - Int8DynamicActivationInt8WeightConfig(), + Int8DynamicActivationInt8WeightConfig(version=2), + Int8StaticActivationInt8WeightConfig(granularity=PerRow()), + Int8StaticActivationInt8WeightConfig(granularity=PerTensor()), # Note: float8_static_activation_float8_weight is broken after recent PyTorch update. # TODO(#1639): Fix for supporting more API in torchao/quantization/quant_api.py ], @@ -101,7 +103,15 @@ def test_smoothquant_accuracy(self, alpha, base_config, device, input_dtype): # Step 1. Basic quantization basic_model = deepcopy(m) - quantize_(basic_model, base_config) + if isinstance(base_config, Int8StaticActivationInt8WeightConfig): + quantize_( + basic_model, + Int8DynamicActivationInt8WeightConfig( + version=2, granularity=base_config.granularity + ), + ) + else: + quantize_(basic_model, base_config) out_basic = basic_model(*x) loss_base = torch.nn.functional.mse_loss(out_basic, out_ref).item() @@ -119,12 +129,10 @@ def test_smoothquant_accuracy(self, alpha, base_config, device, input_dtype): config.step = SmoothQuantStep.CONVERT quantize_(model, config) - assert isinstance( - model.linear1.weight, WeightTensorWithLinearActivationScaleMetadata - ) - assert isinstance( - model.linear2.weight, WeightTensorWithLinearActivationScaleMetadata - ) + assert isinstance(model.linear1.weight, SupportsActivationPreScaling) + assert isinstance(model.linear2.weight, SupportsActivationPreScaling) + assert model.linear1.weight.act_pre_scale is not None + assert model.linear2.weight.act_pre_scale is not None out_smoothquant = model(*x) loss_smoothquant = torch.nn.functional.mse_loss(out_smoothquant, out_ref).item() @@ -138,7 +146,7 @@ def test_smoothquant_accuracy(self, alpha, base_config, device, input_dtype): @common_utils.parametrize( "base_config", [ - Int8DynamicActivationInt8WeightConfig(), + Int8DynamicActivationInt8WeightConfig(version=2), # TODO: Check more quantization APIs ], ) @@ -177,7 +185,7 @@ def test_observer_insertion(self, base_config): @common_utils.parametrize( "base_config", [ - Int8DynamicActivationInt8WeightConfig(), + Int8DynamicActivationInt8WeightConfig(version=2), # TODO: Check more quantization APIs ], ) diff --git a/test/quantization/quantize_/workflows/int8/test_int8_tensor.py b/test/quantization/quantize_/workflows/int8/test_int8_tensor.py index 80147e68ba..5eaf92f06f 100644 --- a/test/quantization/quantize_/workflows/int8/test_int8_tensor.py +++ b/test/quantization/quantize_/workflows/int8/test_int8_tensor.py @@ -274,17 +274,50 @@ def test_static_activation_per_row_int8_weight(self, granularity, dtype): static_out_compile = model_dynamic_quant(input_tensor) sqnr_static_compile = compute_error(model_out_baseline, static_out_compile) - assert ( - sqnr_static_compile - == sqnr_static_eager - == sqnr_dynamic_compile - == sqnr_dynamic_eager - ), "SQNR should be the same for all quantization methods and eager/compile" + assert sqnr_static_compile == sqnr_static_eager, ( + f"Static SQNR mismatch: compile={sqnr_static_compile} vs eager={sqnr_static_eager}" + ) + assert sqnr_static_eager == sqnr_dynamic_compile, ( + f"Static eager vs dynamic compile SQNR mismatch: {sqnr_static_eager} vs {sqnr_dynamic_compile}" + ) + assert sqnr_dynamic_compile == sqnr_dynamic_eager, ( + f"Dynamic SQNR mismatch: compile={sqnr_dynamic_compile} vs eager={sqnr_dynamic_eager}" + ) # eager numerics should match exactly # for compile, we can't compare dynamic vs static because we may get slightly different qparams when fused torch.testing.assert_close(dynamic_out_eager, static_out_eager) + def test_static_per_feature_act_quant_not_supported(self): + """Test that PerRow(dim != -1) activation quantization raises an error. + + Per-feature activation quantization (PerRow(dim=0)) would require slicing + act_scale when weight is sliced, which is not currently supported. + We explicitly disallow this configuration. + """ + from torchao.quantization.granularity import PerRow as PerRowGranularity + + # Attempting to create a config with PerRow(dim=0) should raise an error + with self.assertRaises(ValueError) as cm: + static_config = Int8StaticActivationInt8WeightConfig( + static_scale=torch.ones(1, 1, device="cuda"), + granularity=PerRowGranularity(dim=0), # This should fail + act_mapping_type=MappingType.SYMMETRIC, + ) + + self.assertIn("PerRow(dim=-1)", str(cm.exception)) + self.assertIn("Per-feature", str(cm.exception).lower()) + + # Verify that PerRow() (default dim=-1) and PerTensor() still work + for granularity in [PerRow(), PerTensor()]: + static_config = Int8StaticActivationInt8WeightConfig( + static_scale=torch.ones(1, 1, device="cuda"), + granularity=granularity, + act_mapping_type=MappingType.SYMMETRIC, + ) + # Should not raise an error + self.assertIsNotNone(static_config) + if __name__ == "__main__": common_utils.run_tests() diff --git a/torchao/prototype/smoothquant/api.py b/torchao/prototype/smoothquant/api.py index 9f78c49fb8..99ec98cccc 100644 --- a/torchao/prototype/smoothquant/api.py +++ b/torchao/prototype/smoothquant/api.py @@ -10,13 +10,15 @@ import torch from torchao.core.config import AOBaseConfig -from torchao.quantization.linear_activation_scale import ( - to_weight_tensor_with_linear_activation_scale_metadata, -) from torchao.quantization.quant_api import ( _QUANTIZE_CONFIG_HANDLER, + Int8StaticActivationInt8WeightConfig, _linear_extra_repr, ) +from torchao.quantization.quantize_.common import SupportsActivationPreScaling +from torchao.quantization.quantize_.workflows.int8.int8_tensor import ( + QuantizeTensorToInt8Kwargs, +) from torchao.quantization.transform_module import ( register_quantize_module_handler, ) @@ -95,8 +97,18 @@ def _smooth_quant_transform( else: raise ValueError(f"Unexpected step: {step}") + if isinstance(base_config, Int8StaticActivationInt8WeightConfig): + quant_kwargs = QuantizeTensorToInt8Kwargs( + granularity=base_config.granularity, + mapping_type=base_config.act_mapping_type, + ) + else: + quant_kwargs = None + # Compute smoothed weight parameters - smoothing_factor = observed_linear.obs.calculate_qparams() + smoothing_factor, activation_scale = observed_linear.obs.calculate_qparams( + weight_quant_kwargs=quant_kwargs + ) weight = observed_linear.weight * smoothing_factor # Create new linear layer @@ -111,15 +123,21 @@ def _smooth_quant_transform( linear.bias = observed_linear.bias # Quantize weights + if isinstance(base_config, Int8StaticActivationInt8WeightConfig): + base_config.static_scale = activation_scale + base_config_handler = _QUANTIZE_CONFIG_HANDLER[type(base_config)] dummy_mod = DummyModule(weight) quant_mod = base_config_handler(dummy_mod, base_config) qw = quant_mod.weight - # Add smoothing factor metadata - qw = to_weight_tensor_with_linear_activation_scale_metadata( - qw, smoothing_factor.to(qw.dtype) + # Add smoothing factor as activation pre-scale + assert isinstance(qw, SupportsActivationPreScaling), ( + "weight must support activation scaling through implementing `SupportsActivationPreScaling`" ) + # Store reciprocal for runtime efficiency: act * act_pre_scale + qw.act_pre_scale = 1.0 / smoothing_factor + linear.weight = torch.nn.Parameter(qw, requires_grad=False) linear.extra_repr = types.MethodType(_linear_extra_repr, linear) diff --git a/torchao/prototype/smoothquant/core.py b/torchao/prototype/smoothquant/core.py index 83f1e78275..9974bf3719 100644 --- a/torchao/prototype/smoothquant/core.py +++ b/torchao/prototype/smoothquant/core.py @@ -9,6 +9,10 @@ import torch import torch.nn.functional as F +from torchao.quantization.quantize_.common import ( + _choose_quant_func_and_quantize_tensor, +) + class SmoothQuantStep(str, Enum): PREPARE = "prepare" @@ -41,13 +45,14 @@ def forward(self, input: torch.Tensor): self.inputs.append(input.to("cpu")) return input - def calculate_qparams(self): + def calculate_qparams(self, weight_quant_kwargs=None): assert self.inputs and len(self.inputs) > 0, ( "calibrate observer first by running model on exemplar data" ) inputs = [inp.to(self.device) for inp in self.inputs] acc = torch.cat(inputs, dim=0) # Reshape if needed: [batch, seq, features] -> [batch*seq, features] + example_input_for_quantization = acc if acc.ndim > 2: acc = acc.view(-1, acc.shape[-1]) @@ -57,12 +62,20 @@ def calculate_qparams(self): # Calculate smoothing factor if self.alpha is None: - return torch.ones_like(x_abs_max) + smoothing_factor = torch.ones_like(x_abs_max) + else: + eps = torch.finfo(torch.float32).eps + smoothing_factor = torch.pow(x_abs_max + eps, self.alpha) / torch.pow( + w_abs_max + eps, 1 - self.alpha + ) - eps = torch.finfo(torch.float32).eps - return torch.pow(x_abs_max + eps, self.alpha) / torch.pow( - w_abs_max + eps, 1 - self.alpha - ) + if weight_quant_kwargs is not None: + quant_smooth_activation = _choose_quant_func_and_quantize_tensor( + example_input_for_quantization / smoothing_factor, weight_quant_kwargs + ) + return smoothing_factor, quant_smooth_activation.scale + else: + return smoothing_factor, None class SmoothQuantObservedLinear(torch.nn.Linear): diff --git a/torchao/quantization/quant_api.py b/torchao/quantization/quant_api.py index e19e35c20c..a0d5387745 100644 --- a/torchao/quantization/quant_api.py +++ b/torchao/quantization/quant_api.py @@ -1651,14 +1651,14 @@ class Int8StaticActivationInt8WeightConfig(AOBaseConfig): Configuration for applying int8 static symmetric quantization to both activation and weight Args: - scale (torch.Tensor): The scale tensor for activation quantization. + static_scale (torch.Tensor): The scale tensor for activation quantization. granularity (Granularity): The granularity of quantization. PerRow() and PerTensor() are supported currently act_mapping_type (MappingType): The mapping type for activation quantization. only SYMMETRIC is supported currently set_inductor_config (bool): if True, adjusts `torchinductor` settings to recommended values. version (int): the version of the config """ - scale: torch.Tensor + static_scale: Optional[torch.Tensor] = None granularity: Granularity = PerRow() act_mapping_type: Optional[MappingType] = MappingType.SYMMETRIC set_inductor_config: bool = True @@ -1669,6 +1669,14 @@ def __post_init__(self): "torchao.quantization.Int8StaticActivationInt8WeightConfig" ) + # Validate activation granularity for static quantization + if isinstance(self.granularity, PerRow) and self.granularity.dim != -1: + raise ValueError( + f"Int8StaticActivationInt8WeightConfig only supports PerRow(dim=-1) " + f"for activation quantization, got PerRow(dim={self.granularity.dim}). " + f"Per-feature activation quantization is not supported due to slicing limitations." + ) + @register_quantize_module_handler(Int8StaticActivationInt8WeightConfig) def _int8_static_activation_int8_weight_transform( @@ -1700,7 +1708,7 @@ def _int8_static_activation_int8_weight_transform( granularity=activation_granularity, mapping_type=config.act_mapping_type, ), - act_scale=config.scale.detach(), + act_scale=config.static_scale.detach(), ) setattr( diff --git a/torchao/quantization/quantize_/workflows/int8/int8_tensor.py b/torchao/quantization/quantize_/workflows/int8/int8_tensor.py index ca16fa6326..953369aa97 100644 --- a/torchao/quantization/quantize_/workflows/int8/int8_tensor.py +++ b/torchao/quantization/quantize_/workflows/int8/int8_tensor.py @@ -3,7 +3,6 @@ # # This source code is licensed under the BSD 3-Clause license found in the # LICENSE file in the root directory of this source tree. - from dataclasses import dataclass from typing import List, Optional @@ -60,7 +59,7 @@ class Int8Tensor(TorchAOBaseTensor): """ tensor_data_names = ["qdata", "scale"] - optional_tensor_data_names = ["act_scale"] + optional_tensor_data_names = ["act_scale", "act_pre_scale"] tensor_attribute_names = ["block_size", "dtype"] optional_tensor_attribute_names = [ "act_quant_kwargs", @@ -73,6 +72,7 @@ def __new__( block_size: List[int], dtype: torch.dtype, act_scale=None, + act_pre_scale: Optional[torch.Tensor] = None, act_quant_kwargs: Optional[QuantizeTensorToInt8Kwargs] = None, ): kwargs = { @@ -89,6 +89,7 @@ def __init__( block_size: List[int], dtype: torch.dtype, act_scale=None, + act_pre_scale: Optional[torch.Tensor] = None, act_quant_kwargs: Optional[QuantizeTensorToInt8Kwargs] = None, ): super().__init__() @@ -98,6 +99,7 @@ def __init__( # don't set dtype because this gets done in __new__ self.act_quant_kwargs = act_quant_kwargs self.act_scale = act_scale + self.act_pre_scale = act_pre_scale def __repr__(self): return ( @@ -106,6 +108,7 @@ def __repr__(self): f"qdata={self.qdata}, " f"scale={self.scale}, " f"act_scale={self.act_scale}, " + f"act_pre_scale={self.act_scale}, " f"block_size={self.block_size}, " f"shape={self.shape}, " f"device={self.device}, " @@ -121,6 +124,7 @@ def from_hp( scale: Optional[torch.Tensor] = None, act_quant_kwargs: Optional[QuantizeTensorToInt8Kwargs] = None, act_scale: Optional[torch.Tensor] = None, + act_pre_scale: Optional[torch.Tensor] = None, ): """Create Int8Tensor from high-precision tensor""" block_size = get_block_size(hp_tensor.shape, granularity) @@ -161,6 +165,7 @@ def from_hp( block_size, hp_tensor.dtype, act_scale=act_scale, + act_pre_scale=act_pre_scale, act_quant_kwargs=act_quant_kwargs, ) @@ -198,13 +203,18 @@ def _(func, types, args, kwargs): output_dtype = activation_tensor.dtype + # Apply activation pre-scaling if present (for AWQ, SmoothQuant, etc.) + if weight_tensor.act_pre_scale is not None: + activation_tensor = activation_tensor * weight_tensor.act_pre_scale + if weight_tensor.act_quant_kwargs is not None: + # for int8 dynamic + static quantization path + activation_tensor = _choose_quant_func_and_quantize_tensor( activation_tensor, weight_tensor.act_quant_kwargs, scale=weight_tensor.act_scale, ) - # Dynamic activation quantization path # 1. do the matrix form of dot(X_i, W_j) # @@ -292,6 +302,8 @@ def _(func, types, args, kwargs): block_size, self.dtype, act_quant_kwargs=self.act_quant_kwargs, + act_scale=self.act_scale, + act_pre_scale=self.act_pre_scale, ), ) @@ -322,6 +334,8 @@ def _(func, types, args, kwargs): old_int8_tensor.scale[index], old_int8_tensor.block_size[1:], old_int8_tensor.dtype, + old_int8_tensor.act_scale, + old_int8_tensor.act_pre_scale, old_int8_tensor.act_quant_kwargs, ) return return_and_correct_aliasing(func, args, kwargs, new_int8_tensor)