diff --git a/examples/awq_qdq_usage_example.py b/examples/awq_qdq_usage_example.py new file mode 100644 index 0000000000..86fd440a68 --- /dev/null +++ b/examples/awq_qdq_usage_example.py @@ -0,0 +1,115 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD 3-Clause license found in the +# LICENSE file in the root directory of this source tree. + +""" +Simple usage example for AWQ with QDQLayout and ExecuTorch support. + +This example demonstrates the complete workflow for using AWQ quantization +with QDQLayout support and 8-bit dynamic activation quantization. +""" + +import torch +import torch.nn as nn + +from torchao.prototype.awq import ( + insert_awq_observer_qdq_, + AWQQDQConfig, +) +from torchao.prototype.awq.executorch_awq import _is_awq_observed_linear_qdq +from torchao.quantization import quantize_ + + +def main(): + print("AWQ + QDQLayout + ExecuTorch Example") + print("=" * 40) + + # 1. Create a simple model + model = nn.Sequential( + nn.Linear(512, 1024), + nn.ReLU(), + nn.Linear(1024, 256), + ) + + print(f"Original model parameters: {sum(p.numel() for p in model.parameters()):,}") + + # 2. Insert AWQ observers with QDQLayout support + print("\n1. Inserting AWQ observers...") + insert_awq_observer_qdq_( + model, + n_validation_examples=5, + validation_sequence_len=64, + quant_dtype=torch.uint4, + group_size=128, + use_dynamic_activation_quant=True, # Enable 8-bit dynamic activation quantization + ) + + print(" Observers inserted successfully!") + + # 3. Calibrate the model + print("\n2. Calibrating model...") + model.eval() + with torch.no_grad(): + for i in range(5): + # Generate random calibration data + example_input = torch.randn(2, 64, 512) + model(example_input) + print(f" Calibration step {i + 1}/5 completed") + + print(" Calibration completed!") + + # 4. Apply AWQ quantization with QDQLayout + print("\n3. Applying AWQ quantization with QDQLayout...") + config = AWQQDQConfig( + quant_dtype=torch.uint4, + group_size=128, + use_dynamic_activation_quant=True, + ) + + # Use the custom filter to target AWQObservedLinearQDQ modules + quantize_(model, config, filter_fn=_is_awq_observed_linear_qdq) + + print(" Quantization applied successfully!") + + # 5. Test the quantized model + print("\n4. Testing quantized model...") + test_input = torch.randn(1, 64, 512) + + with torch.no_grad(): + output = model(test_input) + print(f" Input shape: {test_input.shape}") + print(f" Output shape: {output.shape}") + + # 6. Verify QDQLayout usage + print("\n5. Verifying QDQLayout tensors...") + for name, module in model.named_modules(): + if isinstance(module, nn.Linear): + weight = module.weight + if hasattr(weight, "__tensor_flatten__"): + print(f" ✓ {name}: Uses quantized tensor (QDQLayout)") + # Check for QDQLayout specific attributes + if hasattr(weight, "int_data"): + print(f" - int_data shape: {weight.int_data.shape}") + print(f" - scale shape: {weight.scale.shape}") + else: + print(f" ✗ {name}: Uses regular tensor") + + print("\n" + "=" * 40) + print("AWQ + QDQLayout quantization completed successfully!") + print("The model is now ready for ExecuTorch export.") + + return model + + +if __name__ == "__main__": + # Set random seed for reproducibility + torch.manual_seed(42) + + # Run the example + quantized_model = main() + + print(f"\nFinal model type: {type(quantized_model)}") + print("Example completed successfully!") diff --git a/test/prototype/test_awq_executorch.py b/test/prototype/test_awq_executorch.py new file mode 100644 index 0000000000..6b758e14dc --- /dev/null +++ b/test/prototype/test_awq_executorch.py @@ -0,0 +1,245 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD 3-Clause license found in the +# LICENSE file in the root directory of this source tree. + +import unittest +import torch +import torch.nn.functional as F +from torchao.prototype.awq import ( + insert_awq_observer_qdq_, + AWQQDQConfig, +) +from torchao.prototype.awq.executorch_awq import _is_awq_observed_linear_qdq +from torchao.quantization import quantize_ +from torchao.dtypes.uintx.q_dq_layout import QDQLayout + + +class TestAWQExecutorchIntegration(unittest.TestCase): + """Test suite for AWQ + QDQLayout + ExecuTorch integration.""" + + def setUp(self): + """Set up test fixtures.""" + torch.manual_seed(42) + + # Create a simple test model + self.model = torch.nn.Sequential( + torch.nn.Linear(64, 128), + torch.nn.ReLU(), + torch.nn.Linear(128, 32), + ) + + # Example input for testing + self.example_input = torch.randn(2, 16, 64) + self.batch_size, self.seq_len, self.hidden_size = self.example_input.shape + + def test_awq_observer_insertion(self): + """Test insertion of AWQ observers with QDQLayout support.""" + model = torch.nn.Sequential( + torch.nn.Linear(64, 128), + torch.nn.Linear(128, 32), + ) + + # Insert AWQ observers + insert_awq_observer_qdq_( + model, + n_validation_examples=2, + validation_sequence_len=16, + quant_dtype=torch.int4, + group_size=64, + ) + + # Check that Linear layers were replaced with AWQObservedLinearQDQ + from torchao.prototype.awq.executorch_awq import AWQObservedLinearQDQ + + for module in model.modules(): + if isinstance(module, torch.nn.Linear): + # Should be replaced with AWQObservedLinearQDQ + self.assertIsInstance(module, AWQObservedLinearQDQ) + # Check observer configuration + self.assertEqual(module.act_obs.n_validation_examples, 2) + self.assertEqual(module.act_obs.validation_sequence_len, 16) + + def test_awq_calibration_and_quantization(self): + """Test AWQ calibration and quantization with QDQLayout.""" + model = torch.nn.Sequential(torch.nn.Linear(64, 128)) + + # Insert AWQ observer + insert_awq_observer_qdq_( + model, + n_validation_examples=3, + validation_sequence_len=16, + quant_dtype=torch.int4, + group_size=32, + ) + + # Calibrate the model + model.eval() + with torch.no_grad(): + for _ in range(3): + example_input = torch.randn(2, 16, 64) + model(example_input) + + # Apply quantization + config = AWQQDQConfig( + quant_dtype=torch.int4, + group_size=32, + ) + quantize_(model, config, filter_fn=_is_awq_observed_linear_qdq) + + # Verify the model is quantized (model is modified in-place) + self.assertIsInstance(model, torch.nn.Sequential) + self.assertIsInstance(model[0], torch.nn.Linear) + + # Check that weight uses QDQLayout + weight_tensor = model[0].weight + self.assertTrue(hasattr(weight_tensor, "__tensor_flatten__")) # AQT tensor + + # Test forward pass + with torch.no_grad(): + output = model(self.example_input) + self.assertEqual(output.shape, (2, 16, 128)) + + def test_multiple_quantization_dtypes(self): + """Test AWQ with different quantization dtypes.""" + for quant_dtype in [torch.uint1, torch.uint2, torch.int4]: + with self.subTest(quant_dtype=quant_dtype): + model = torch.nn.Sequential(torch.nn.Linear(32, 64)) + + # Insert observer + insert_awq_observer_qdq_( + model, + n_validation_examples=2, + validation_sequence_len=4, + quant_dtype=quant_dtype, + group_size=16, + ) + + # Calibrate + model.eval() + with torch.no_grad(): + for _ in range(2): + model(torch.randn(1, 4, 32)) + + # Quantize + config = AWQQDQConfig(quant_dtype=quant_dtype, group_size=16) + quantize_(model, config, filter_fn=_is_awq_observed_linear_qdq) + + # Test forward pass + with torch.no_grad(): + output = model(torch.randn(1, 4, 32)) + self.assertEqual(output.shape, (1, 4, 64)) + + def test_different_group_sizes(self): + """Test AWQ with different group sizes.""" + for group_size in [16, 32, 64, 128]: + with self.subTest(group_size=group_size): + model = torch.nn.Sequential(torch.nn.Linear(128, 64)) + + # Insert observer + insert_awq_observer_qdq_( + model, + n_validation_examples=2, + validation_sequence_len=4, + quant_dtype=torch.int4, + group_size=group_size, + ) + + # Calibrate + model.eval() + with torch.no_grad(): + for _ in range(2): + model(torch.randn(1, 4, 128)) + + # Quantize + config = AWQQDQConfig(quant_dtype=torch.int4, group_size=group_size) + quantize_(model, config, filter_fn=_is_awq_observed_linear_qdq) + + # Test forward pass + with torch.no_grad(): + output = model(torch.randn(1, 4, 128)) + self.assertEqual(output.shape, (1, 4, 64)) + + def test_graph_pattern_for_executorch(self): + """Test that the graph pattern matches ExecuTorch expectations for XNNPACK lowering.""" + model = torch.nn.Sequential(torch.nn.Linear(128, 64)) + + # Insert AWQ observers with dynamic activation quantization + insert_awq_observer_qdq_( + model, + n_validation_examples=2, + validation_sequence_len=8, + quant_dtype=torch.int4, + group_size=32, + ) + + # Calibrate + model.eval() + with torch.no_grad(): + for _ in range(2): + model(torch.randn(1, 8, 128)) + + # Quantize + config = AWQQDQConfig( + quant_dtype=torch.int4, + group_size=32, + ) + quantize_(model, config, filter_fn=_is_awq_observed_linear_qdq) + + # Test the forward method applies the expected AWQ + dynamic activation quantization pattern + example_input = torch.randn(1, 8, 128) + + # Test that forward pass runs without error + with torch.no_grad(): + actual_output = model(example_input) + + # Verify output shape is correct + self.assertEqual(actual_output.shape, (1, 8, 64)) + + # Test graph pattern using torch.export (the proper way for ExecuTorch) + # Export with strict=True for ExecuTorch compatibility + exported_program = torch.export.export(model, (example_input,), strict=True) + + # Test that exported model produces same results + exported_results = exported_program.module()(example_input) + self.assertTrue( + torch.allclose(actual_output, exported_results, atol=1e-3), + "Exported model should produce same results as original", + ) + + # Use FileCheck to verify the graph contains required operations for AWQ + dynamic activation quantization + from torch.testing import FileCheck + + # Expected operations in the exported graph for AWQ + dynamic activation quantization + # This pattern is what ExecuTorch can recognize and lower to XNNPACK: + # 1. AWQ scaling (division operation) + # 2. Dynamic activation quantization (choose_qparams, quantize, dequantize) + # 3. Weight quantization/dequantization (from QDQLayout) + # 4. Linear operation on dequantized tensors + expected_operations = [ + # AWQ scaling - division operation to scale input by AWQ scale + "torch.ops.aten.div.Tensor", + # Dynamic activation quantization - choose quantization parameters + "torch.ops.torchao.choose_qparams_affine.default", + # Dynamic activation quantization - quantize activation + "torch.ops.torchao.quantize_affine.default", + # Dynamic activation dequantization - dequantize activation for linear op + "torch.ops.torchao.dequantize_affine.default", + # Linear operation on dequantized tensors + "torch.ops.aten.linear.default", + ] + + # Verify each required operation appears in the exported graph + for operation in expected_operations: + count = 1 + # We expect 2 dequantize operations: one for activation, one for weight + if operation == "torch.ops.torchao.dequantize_affine.default": + count = 2 + FileCheck().check_count(operation, count, exactly=True).run( + exported_program.graph_module.code + ) + + +if __name__ == "__main__": + unittest.main() diff --git a/torchao/prototype/awq/__init__.py b/torchao/prototype/awq/__init__.py index 570b0821d4..d43fc1ebf8 100644 --- a/torchao/prototype/awq/__init__.py +++ b/torchao/prototype/awq/__init__.py @@ -1,8 +1,21 @@ from .api import awq_uintx, insert_awq_observer_ from .core import AWQObservedLinear +from .executorch_awq import ( + insert_awq_observer_qdq_, + AWQQDQConfig, + AWQObserverQDQ, + AWQObservedLinearQDQ, + _is_awq_observed_linear_qdq, +) __all__ = [ "awq_uintx", "insert_awq_observer_", "AWQObservedLinear", + # ExecuTorch AWQ support + "insert_awq_observer_qdq_", + "AWQQDQConfig", + "AWQObserverQDQ", + "AWQObservedLinearQDQ", + "_is_awq_observed_linear_qdq", ] diff --git a/torchao/prototype/awq/executorch_awq.py b/torchao/prototype/awq/executorch_awq.py new file mode 100644 index 0000000000..c028c630b6 --- /dev/null +++ b/torchao/prototype/awq/executorch_awq.py @@ -0,0 +1,441 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD 3-Clause license found in the +# LICENSE file in the root directory of this source tree. + +""" +AWQ implementation with QDQLayout support and dynamic activation quantization for ExecuTorch. + +This module extends the existing AWQ implementation to support: +1. QDQLayout (Quantize-Dequantize Layout) for ExecuTorch compatibility +2. 8-bit dynamic activation quantization for scales +3. Improved quantization workflow for ExecuTorch deployment +""" + +import types +from dataclasses import dataclass +from typing import Optional, Callable, Dict, Any + +import torch +import torch.nn.functional as F + +import torchao +from torchao.core.config import AOBaseConfig +from torchao.dtypes import to_affine_quantized_intx +from torchao.dtypes.uintx.q_dq_layout import QDQLayout +from torchao.quantization.quant_primitives import _DTYPE_TO_BIT_WIDTH +from torchao.quantization import to_weight_tensor_with_linear_activation_scale_metadata +from torchao.quantization.granularity import PerGroup +from torchao.quantization.observer import AffineQuantizedObserverBase +from torchao.quantization.quant_api import ( + _linear_extra_repr, + _replace_with_custom_fn_if_matches_filter, +) +from torchao.quantization.quant_primitives import ( + _DTYPE_TO_QVALUE_BOUNDS, + MappingType, + ZeroPointDomain, +) +from torchao.quantization.transform_module import register_quantize_module_handler +from torchao.quantization.quant_api import _int8_asymm_per_token_quant +from torchao.quantization.quant_primitives import ( + choose_qparams_affine, + quantize_affine, + dequantize_affine, +) + +from torchao.quantization.linear_activation_quantized_tensor import ( + LinearActivationQuantizedTensor, + to_linear_activation_quantized, +) + + +class AWQObserverQDQ(AffineQuantizedObserverBase): + """ + AWQ Observer with QDQLayout support and dynamic activation quantization. + + This observer extends the base AWQ implementation to support: + - QDQLayout for ExecuTorch compatibility + - 8-bit dynamic activation quantization for scales + - Improved calibration workflow + """ + + def __init__( + self, + weight: torch.Tensor, + bias: torch.Tensor, + quantization_granularity: PerGroup, + mapping_type: MappingType, + target_dtype: torch.dtype, + n_validation_examples: int, + validation_sequence_len: int, + scale_search_space_size: int = 20, + quant_min: Optional[int] = None, + quant_max: Optional[int] = None, + eps: Optional[float] = None, + scale_dtype: Optional[torch.dtype] = None, + zero_point_dtype: Optional[torch.dtype] = None, + preserve_zero: Optional[bool] = True, + zero_point_domain=ZeroPointDomain.INT, + ): + """ + Args: + weight: The weight tensor to be observed + bias: The bias tensor to be observed + quantization_granularity: Granularity for weight quantization + mapping_type: Quantization mapping type + target_dtype: Target dtype for quantized weights + n_validation_examples: Number of calibration examples + validation_sequence_len: Sequence length for calibration + scale_search_space_size: Number of scale options to search + quant_min: Minimum quantization value + quant_max: Maximum quantization value + eps: Minimum scale value + scale_dtype: Scale tensor dtype + zero_point_dtype: Zero point tensor dtype + preserve_zero: Whether to preserve zero exactly + zero_point_domain: Domain for zero point values + """ + super().__init__( + mapping_type, + target_dtype, + quantization_granularity, + quant_min=quant_min, + quant_max=quant_max, + eps=eps, + scale_dtype=scale_dtype, + zero_point_dtype=zero_point_dtype, + preserve_zero=preserve_zero, + zero_point_domain=zero_point_domain, + ) + + self.quantization_granularity = quantization_granularity + self.weight = weight + self.bias = bias + self.n_validation_examples = n_validation_examples + self.validation_sequence_len = validation_sequence_len + self.scale_search_space_size = scale_search_space_size + + # Calibration state + self.calibration_token_count = 0 + self.inputs = [] + self.outputs = [] + self.device = self.weight.device + self.average = torch.zeros((1, weight.shape[1]), device=self.device) + + if self.bias is not None: + self.bias = self.bias.to(self.device) + + @torch.no_grad() + def forward(self, input: torch.Tensor, output: torch.Tensor): + """Collect calibration data during forward pass.""" + if len(self.inputs) < self.n_validation_examples: + self.inputs.append(input.to("cpu")) + self.outputs.append(output.to("cpu")) + + # Handle different input shapes + if len(input.shape) == 2: # [batch, hidden] + self.calibration_token_count += input.shape[0] + self.average += input.abs().sum(dim=0) + else: # [batch, seq_len, hidden] + self.calibration_token_count += ( + input.shape[-2] * input.shape[0] + ) # batch * seq_len + # Sum over batch and sequence dimensions to get [hidden_size] + self.average += input.abs().sum(dim=(0, -2)) + + def calculate_qparams(self): + """ + Calculate optimal quantization parameters using AWQ with optional dynamic activation quantization. + + Returns: + Optimal activation scales for AWQ quantization + """ + assert self.outputs, "Must calibrate observer first by running model on data" + + # Normalize average activation magnitudes + self.average /= self.calibration_token_count + + # Move calibration data to device + for i in range(self.n_validation_examples): + self.inputs[i] = self.inputs[i].to(self.device) + self.outputs[i] = self.outputs[i].to(self.device) + + best_loss = float("inf") + best_scales = None + + # Search over scale options + for i in range(self.scale_search_space_size): + ratio = i * 1.0 / self.scale_search_space_size + scales = self.average.pow(ratio).to(self.weight.dtype) + scales = scales / (scales.max() * scales.min()).sqrt() + + # Create quantized weight with QDQLayout + layout = QDQLayout() + tensor_dtype = torch.int8 + + quantized_weight = to_affine_quantized_intx( + self.weight * scales, + self.mapping_type, + (1, self.quantization_granularity.group_size), + tensor_dtype, + quant_min=self.quant_min, + quant_max=self.quant_max, + eps=self.eps, + scale_dtype=self.scale_dtype, + zero_point_dtype=self.zero_point_dtype, + preserve_zero=self.preserve_zero, + zero_point_domain=self.zero_point_domain, + _layout=layout, + ) + + # Evaluate quantization loss + total_loss = 0 + for j in range(self.n_validation_examples): + quantized_output = F.linear( + self.inputs[j] / scales, quantized_weight, self.bias + ) + loss = (self.outputs[j] - quantized_output).pow(2).mean().item() + total_loss += loss + + # Update best scales if this is better + if total_loss < best_loss: + best_scales = scales + best_loss = total_loss + + # Move calibration data back to CPU to save memory + for i in range(self.n_validation_examples): + self.inputs[i] = self.inputs[i].to("cpu") + self.outputs[i] = self.outputs[i].to("cpu") + + return best_scales.detach() + + +class AWQObservedLinearQDQ(torch.nn.Linear): + """ + AWQ Observed Linear layer with QDQLayout support. + + This layer captures activations during calibration and applies AWQ + quantization with QDQLayout for ExecuTorch compatibility. + """ + + def __init__( + self, + in_features: int, + out_features: int, + act_obs: AWQObserverQDQ, + bias: bool = True, + device=None, + dtype=None, + ): + super().__init__(in_features, out_features, bias, device, dtype) + self.act_obs = act_obs + + def forward(self, input: torch.Tensor): + """Forward pass with activation observation.""" + output = F.linear(input, self.weight, self.bias) + self.act_obs(input, output) + return output + + @classmethod + def from_float(cls, float_linear: torch.nn.Linear, act_obs: AWQObserverQDQ): + """Create AWQObservedLinearQDQ from a regular Linear layer.""" + observed_linear = cls( + float_linear.in_features, + float_linear.out_features, + act_obs, + float_linear.bias is not None, + device=float_linear.weight.device, + dtype=float_linear.weight.dtype, + ) + observed_linear.weight = float_linear.weight + observed_linear.bias = float_linear.bias + return observed_linear + + +def _awq_int8_dynamic_activation_intx_weight_quant(input_tensor, awq_scale): + """ + AWQ-aware activation quantization function that applies AWQ scaling before standard int8 quantization. + + Args: + input_tensor: Input tensor to quantize + awq_scale: AWQ scaling factor to apply before quantization + + Returns: + Quantized input tensor + """ + # Step 1: Apply AWQ scaling + scaled_input = input_tensor / awq_scale + # Step 2: Apply standard int8 dynamic activation quantization + return _int8_asymm_per_token_quant(scaled_input) + + +def insert_awq_observer_qdq_( + model: torch.nn.Module, + n_validation_examples: int, + validation_sequence_len: int, + quant_dtype: torch.dtype = torch.int4, + scale_search_space_size: int = 20, + group_size: int = 128, +): + """ + Insert AWQ observers with QDQLayout support into Linear layers. + + Args: + model: Model to modify in-place + n_validation_examples: Number of calibration examples + validation_sequence_len: Sequence length for calibration + quant_dtype: Target quantization dtype + scale_search_space_size: Number of scale options to search + group_size: Group size for quantization granularity + """ + _is_linear = lambda m, fqn: isinstance(m, torch.nn.Linear) + + assert quant_dtype != torch.uint8, ( + "Invalid quant_dtype. Please use torch.int1 .. torch.int8" + ) + + # Quantization configuration + mapping_type = MappingType.ASYMMETRIC + quantization_granularity = PerGroup(group_size) + quant_min = 0 + quant_max = ( + 255 if quant_dtype == torch.int8 else 2 ** _DTYPE_TO_BIT_WIDTH[quant_dtype] - 1 + ) + eps = torch.finfo(torch.float32).eps + preserve_zero = True + zero_point_dtype = torch.int64 + zero_point_domain = ZeroPointDomain.INT + + def replace_with_observer(layer): + """Replace Linear layer with AWQObservedLinearQDQ.""" + observer = AWQObserverQDQ( + layer.weight, + layer.bias, + quantization_granularity, + mapping_type, + quant_dtype, + n_validation_examples, + validation_sequence_len, + scale_search_space_size, + preserve_zero=preserve_zero, + zero_point_domain=zero_point_domain, + zero_point_dtype=zero_point_dtype, + quant_min=quant_min, + quant_max=quant_max, + eps=eps, + ) + return AWQObservedLinearQDQ.from_float(layer, observer) + + _replace_with_custom_fn_if_matches_filter(model, replace_with_observer, _is_linear) + + +def _is_awq_observed_linear_qdq(mod, *args): + """Filter function to identify AWQObservedLinearQDQ modules for quantization.""" + return isinstance(mod, AWQObservedLinearQDQ) + + +@dataclass +class AWQQDQConfig(AOBaseConfig): + """ + Configuration for AWQ quantization with QDQLayout support. + + Args: + quant_dtype: Target quantization dtype + group_size: Group size for quantization granularity + set_inductor_config: Whether to set recommended inductor settings + """ + + quant_dtype: torch.dtype = torch.int4 + group_size: int = 64 + set_inductor_config: bool = True + + +@register_quantize_module_handler(AWQQDQConfig) +def _awq_qdq_transform( + module: torch.nn.Module, + config: AWQQDQConfig, +) -> torch.nn.Module: + """ + Transform AWQObservedLinearQDQ to quantized Linear with QDQLayout. + + Args: + module: AWQObservedLinearQDQ module to transform + config: AWQ QDQ configuration + + Returns: + Quantized Linear module with QDQLayout weights + """ + # Only transform AWQObservedLinearQDQ modules + if not isinstance(module, AWQObservedLinearQDQ): + return module + + if config.set_inductor_config: + torchao.quantization.utils.recommended_inductor_config_setter() + + observed_linear = module + quant_dtype = config.quant_dtype + group_size = config.group_size + + assert quant_dtype != torch.uint8, ( + "Invalid quant_dtype. Please use torch.int1 .. torch.int8" + ) + + # Get optimal activation scales from AWQ calibration + equalization_scale = observed_linear.act_obs.calculate_qparams() + + # Configure quantization parameters to match working Int8DynamicActivationIntxWeightConfig + mapping_type = MappingType.SYMMETRIC + block_size = (1, group_size) + + # Use int8 target_dtype and bounds for weight quantization to match QDQLayout expectations + target_dtype = torch.int8 + quant_min = _DTYPE_TO_QVALUE_BOUNDS[quant_dtype][0] + quant_max = _DTYPE_TO_QVALUE_BOUNDS[quant_dtype][1] + eps = torch.finfo(torch.float32).eps + preserve_zero = True + zero_point_dtype = torch.int64 # Standard for int8 + zero_point_domain = ZeroPointDomain.INT + scale_dtype = torch.float32 # Match working implementation + layout = QDQLayout() + + # Apply AWQ scaling to weight BEFORE quantization (this is the key) + awq_scaled_weight = observed_linear.weight * equalization_scale + + # Create quantized weight with QDQLayout using AWQ-scaled weight + qw = to_affine_quantized_intx( + awq_scaled_weight, # Apply AWQ scaling before quantization + mapping_type, + block_size, + target_dtype, + quant_min, + quant_max, + eps, + scale_dtype=scale_dtype, + zero_point_dtype=zero_point_dtype, + preserve_zero=preserve_zero, + zero_point_domain=zero_point_domain, + _layout=layout, + ) + + activation_quant_func = _int8_asymm_per_token_quant + # First wrap up the AffineQuantized weight Tensor into LinearActivationQuantizedTensor + # This way before going to quantized weight tensor we go through activation + # dynamic quant + qw = to_linear_activation_quantized(qw, activation_quant_func) + # Second wrap it up in WeightTensorWithLinearActivationScaleMetadata to scale activation + # before they are fed to LinearActivationQuantizedTensor + qw = to_weight_tensor_with_linear_activation_scale_metadata(qw, equalization_scale) + + linear = torch.nn.Linear( + observed_linear.in_features, + observed_linear.out_features, + observed_linear.bias != None, + device=observed_linear.weight.device, + dtype=observed_linear.weight.dtype, + ) + linear.weight = torch.nn.Parameter(qw, requires_grad=False) + linear.extra_repr = types.MethodType(_linear_extra_repr, module) + linear.bias = observed_linear.bias + return linear