Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
42 commits
Select commit Hold shift + click to select a range
48cdb61
Int8Tensor migration
jcaip Dec 1, 2025
0b73aed
ruff fixes
jcaip Dec 1, 2025
1e49945
add init
jcaip Dec 1, 2025
669b6ee
fix ruff again
jcaip Dec 1, 2025
9071526
update
jcaip Dec 1, 2025
1539e0f
wip
jcaip Dec 2, 2025
d9a2b1b
Merge branch 'main' into jcaip/int8-tensor
jcaip Dec 3, 2025
673f228
undo update tests
jcaip Dec 3, 2025
739fd64
fix ruff
jcaip Dec 3, 2025
750db1a
fix varname
jcaip Dec 3, 2025
9410488
fix typing
jcaip Dec 3, 2025
45a3a76
add tests
jcaip Dec 3, 2025
4e2f09c
fix dtype
jcaip Dec 3, 2025
dd80cca
fix ci
jcaip Dec 3, 2025
7f73062
address granularity cr
jcaip Dec 4, 2025
ac6a2b6
update _choose_quant_func_and_quantize_tensor
jcaip Dec 4, 2025
f28df4a
make block size required attribute
jcaip Dec 4, 2025
328585e
made dtype required as well
jcaip Dec 4, 2025
ce4d568
address nits
jcaip Dec 4, 2025
a665d45
skip per tensor weight only test for now
jcaip Dec 4, 2025
0338016
add static quant
jcaip Dec 3, 2025
ee39691
add static quant
jcaip Dec 4, 2025
9eb0aa9
update
jcaip Dec 5, 2025
d4a1514
static quant working eager + compile
jcaip Dec 6, 2025
3cdea56
remove file
jcaip Dec 6, 2025
fa9022d
added asserts
jcaip Dec 6, 2025
8ce5cde
undo smoothquant change
jcaip Dec 6, 2025
6f64121
fix return
jcaip Dec 6, 2025
8ae921d
Merge branch 'main' into jcaip/static-quant-rebased
jcaip Dec 7, 2025
5b9e243
got smoothquant + int8 static working
jcaip Dec 8, 2025
7a0e38f
generalized smoothquat code
jcaip Dec 8, 2025
3d18edf
free tests
jcaip Dec 8, 2025
9e07f8b
fix static scale check
jcaip Dec 8, 2025
4274e02
update
jcaip Dec 8, 2025
b5309eb
address cr feedback
jcaip Dec 9, 2025
a732fee
Merge branch 'jcaip/static-quant-rebased' into jcaip/enable-smoothquant
jcaip Dec 9, 2025
0c23589
Merge branch 'main' into jcaip/enable-smoothquant
jcaip Dec 9, 2025
0872986
update
jcaip Dec 17, 2025
049830f
fix ruff
jcaip Dec 17, 2025
2586ab6
fix varname
jcaip Dec 18, 2025
8cd656e
cr feedback
jcaip Dec 19, 2025
ea9b8e2
fix compare
jcaip Dec 19, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 21 additions & 13 deletions test/prototype/test_smoothquant.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,12 @@
)
from torchao.prototype.smoothquant.core import SmoothQuantStep
from torchao.quantization import quantize_
from torchao.quantization.linear_activation_scale import (
WeightTensorWithLinearActivationScaleMetadata,
)
from torchao.quantization.granularity import PerRow, PerTensor
from torchao.quantization.quant_api import (
Int8DynamicActivationInt8WeightConfig,
Int8StaticActivationInt8WeightConfig,
)
from torchao.quantization.quantize_.common import SupportsActivationPreScaling
from torchao.quantization.utils import (
compute_error as SQNR,
)
Expand Down Expand Up @@ -83,7 +83,9 @@ def setUpClass(cls):
@common_utils.parametrize(
"base_config",
[
Int8DynamicActivationInt8WeightConfig(),
Int8DynamicActivationInt8WeightConfig(version=2),
Int8StaticActivationInt8WeightConfig(granularity=PerRow()),
Int8StaticActivationInt8WeightConfig(granularity=PerTensor()),
# Note: float8_static_activation_float8_weight is broken after recent PyTorch update.
# TODO(#1639): Fix for supporting more API in torchao/quantization/quant_api.py
],
Expand All @@ -101,7 +103,15 @@ def test_smoothquant_accuracy(self, alpha, base_config, device, input_dtype):

# Step 1. Basic quantization
basic_model = deepcopy(m)
quantize_(basic_model, base_config)
if isinstance(base_config, Int8StaticActivationInt8WeightConfig):
quantize_(
basic_model,
Int8DynamicActivationInt8WeightConfig(
version=2, granularity=base_config.granularity
),
)
else:
quantize_(basic_model, base_config)
out_basic = basic_model(*x)
loss_base = torch.nn.functional.mse_loss(out_basic, out_ref).item()

Expand All @@ -119,12 +129,10 @@ def test_smoothquant_accuracy(self, alpha, base_config, device, input_dtype):

config.step = SmoothQuantStep.CONVERT
quantize_(model, config)
assert isinstance(
model.linear1.weight, WeightTensorWithLinearActivationScaleMetadata
)
assert isinstance(
model.linear2.weight, WeightTensorWithLinearActivationScaleMetadata
)
assert isinstance(model.linear1.weight, SupportsActivationPreScaling)
assert isinstance(model.linear2.weight, SupportsActivationPreScaling)
assert model.linear1.weight.act_pre_scale is not None
assert model.linear2.weight.act_pre_scale is not None

out_smoothquant = model(*x)
loss_smoothquant = torch.nn.functional.mse_loss(out_smoothquant, out_ref).item()
Expand All @@ -138,7 +146,7 @@ def test_smoothquant_accuracy(self, alpha, base_config, device, input_dtype):
@common_utils.parametrize(
"base_config",
[
Int8DynamicActivationInt8WeightConfig(),
Int8DynamicActivationInt8WeightConfig(version=2),
# TODO: Check more quantization APIs
],
)
Expand Down Expand Up @@ -177,7 +185,7 @@ def test_observer_insertion(self, base_config):
@common_utils.parametrize(
"base_config",
[
Int8DynamicActivationInt8WeightConfig(),
Int8DynamicActivationInt8WeightConfig(version=2),
# TODO: Check more quantization APIs
],
)
Expand Down
45 changes: 39 additions & 6 deletions test/quantization/quantize_/workflows/int8/test_int8_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -274,17 +274,50 @@ def test_static_activation_per_row_int8_weight(self, granularity, dtype):
static_out_compile = model_dynamic_quant(input_tensor)
sqnr_static_compile = compute_error(model_out_baseline, static_out_compile)

assert (
sqnr_static_compile
== sqnr_static_eager
== sqnr_dynamic_compile
== sqnr_dynamic_eager
), "SQNR should be the same for all quantization methods and eager/compile"
assert sqnr_static_compile == sqnr_static_eager, (
f"Static SQNR mismatch: compile={sqnr_static_compile} vs eager={sqnr_static_eager}"
)
assert sqnr_static_eager == sqnr_dynamic_compile, (
f"Static eager vs dynamic compile SQNR mismatch: {sqnr_static_eager} vs {sqnr_dynamic_compile}"
)
assert sqnr_dynamic_compile == sqnr_dynamic_eager, (
f"Dynamic SQNR mismatch: compile={sqnr_dynamic_compile} vs eager={sqnr_dynamic_eager}"
)

# eager numerics should match exactly
# for compile, we can't compare dynamic vs static because we may get slightly different qparams when fused
torch.testing.assert_close(dynamic_out_eager, static_out_eager)

def test_static_per_feature_act_quant_not_supported(self):
"""Test that PerRow(dim != -1) activation quantization raises an error.

Per-feature activation quantization (PerRow(dim=0)) would require slicing
act_scale when weight is sliced, which is not currently supported.
We explicitly disallow this configuration.
"""
from torchao.quantization.granularity import PerRow as PerRowGranularity

# Attempting to create a config with PerRow(dim=0) should raise an error
with self.assertRaises(ValueError) as cm:
static_config = Int8StaticActivationInt8WeightConfig(
static_scale=torch.ones(1, 1, device="cuda"),
granularity=PerRowGranularity(dim=0), # This should fail
act_mapping_type=MappingType.SYMMETRIC,
)

self.assertIn("PerRow(dim=-1)", str(cm.exception))
self.assertIn("Per-feature", str(cm.exception).lower())

# Verify that PerRow() (default dim=-1) and PerTensor() still work
for granularity in [PerRow(), PerTensor()]:
static_config = Int8StaticActivationInt8WeightConfig(
static_scale=torch.ones(1, 1, device="cuda"),
granularity=granularity,
act_mapping_type=MappingType.SYMMETRIC,
)
# Should not raise an error
self.assertIsNotNone(static_config)


if __name__ == "__main__":
common_utils.run_tests()
32 changes: 25 additions & 7 deletions torchao/prototype/smoothquant/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,15 @@
import torch

from torchao.core.config import AOBaseConfig
from torchao.quantization.linear_activation_scale import (
to_weight_tensor_with_linear_activation_scale_metadata,
)
from torchao.quantization.quant_api import (
_QUANTIZE_CONFIG_HANDLER,
Int8StaticActivationInt8WeightConfig,
_linear_extra_repr,
)
from torchao.quantization.quantize_.common import SupportsActivationPreScaling
from torchao.quantization.quantize_.workflows.int8.int8_tensor import (
QuantizeTensorToInt8Kwargs,
)
from torchao.quantization.transform_module import (
register_quantize_module_handler,
)
Expand Down Expand Up @@ -95,8 +97,18 @@ def _smooth_quant_transform(
else:
raise ValueError(f"Unexpected step: {step}")

if isinstance(base_config, Int8StaticActivationInt8WeightConfig):
Copy link
Contributor

@jerryzh168 jerryzh168 Dec 18, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we shouldn't have specific config here, maybe change this to a similar protocol like SupportsActivationPreScaling for config?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think figuring out how to do this generally will need a bit more design, we'd need to figure out how to map to the appropriate QuantizeTensorToInt/FloatXKwargs object. Agree we should be able to do this though, but can I address in a later PR?

quant_kwargs = QuantizeTensorToInt8Kwargs(
granularity=base_config.granularity,
mapping_type=base_config.act_mapping_type,
)
else:
quant_kwargs = None

# Compute smoothed weight parameters
smoothing_factor = observed_linear.obs.calculate_qparams()
smoothing_factor, activation_scale = observed_linear.obs.calculate_qparams(
weight_quant_kwargs=quant_kwargs
)
weight = observed_linear.weight * smoothing_factor

# Create new linear layer
Expand All @@ -111,15 +123,21 @@ def _smooth_quant_transform(
linear.bias = observed_linear.bias

# Quantize weights
if isinstance(base_config, Int8StaticActivationInt8WeightConfig):
base_config.static_scale = activation_scale

base_config_handler = _QUANTIZE_CONFIG_HANDLER[type(base_config)]
dummy_mod = DummyModule(weight)
quant_mod = base_config_handler(dummy_mod, base_config)
qw = quant_mod.weight

# Add smoothing factor metadata
qw = to_weight_tensor_with_linear_activation_scale_metadata(
qw, smoothing_factor.to(qw.dtype)
# Add smoothing factor as activation pre-scale
assert isinstance(qw, SupportsActivationPreScaling), (
"weight must support activation scaling through implementing `SupportsActivationPreScaling`"
)
# Store reciprocal for runtime efficiency: act * act_pre_scale
qw.act_pre_scale = 1.0 / smoothing_factor

linear.weight = torch.nn.Parameter(qw, requires_grad=False)
linear.extra_repr = types.MethodType(_linear_extra_repr, linear)

Expand Down
25 changes: 19 additions & 6 deletions torchao/prototype/smoothquant/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,10 @@
import torch
import torch.nn.functional as F

from torchao.quantization.quantize_.common import (
_choose_quant_func_and_quantize_tensor,
)


class SmoothQuantStep(str, Enum):
PREPARE = "prepare"
Expand Down Expand Up @@ -41,13 +45,14 @@ def forward(self, input: torch.Tensor):
self.inputs.append(input.to("cpu"))
return input

def calculate_qparams(self):
def calculate_qparams(self, weight_quant_kwargs=None):
assert self.inputs and len(self.inputs) > 0, (
"calibrate observer first by running model on exemplar data"
)
inputs = [inp.to(self.device) for inp in self.inputs]
acc = torch.cat(inputs, dim=0)
# Reshape if needed: [batch, seq, features] -> [batch*seq, features]
example_input_for_quantization = acc
if acc.ndim > 2:
acc = acc.view(-1, acc.shape[-1])

Expand All @@ -57,12 +62,20 @@ def calculate_qparams(self):

# Calculate smoothing factor
if self.alpha is None:
return torch.ones_like(x_abs_max)
smoothing_factor = torch.ones_like(x_abs_max)
else:
eps = torch.finfo(torch.float32).eps
smoothing_factor = torch.pow(x_abs_max + eps, self.alpha) / torch.pow(
w_abs_max + eps, 1 - self.alpha
)

eps = torch.finfo(torch.float32).eps
return torch.pow(x_abs_max + eps, self.alpha) / torch.pow(
w_abs_max + eps, 1 - self.alpha
)
if weight_quant_kwargs is not None:
quant_smooth_activation = _choose_quant_func_and_quantize_tensor(
example_input_for_quantization / smoothing_factor, weight_quant_kwargs
)
return smoothing_factor, quant_smooth_activation.scale
else:
return smoothing_factor, None


class SmoothQuantObservedLinear(torch.nn.Linear):
Expand Down
14 changes: 11 additions & 3 deletions torchao/quantization/quant_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -1651,14 +1651,14 @@ class Int8StaticActivationInt8WeightConfig(AOBaseConfig):
Configuration for applying int8 static symmetric quantization to both activation and weight

Args:
scale (torch.Tensor): The scale tensor for activation quantization.
static_scale (torch.Tensor): The scale tensor for activation quantization.
granularity (Granularity): The granularity of quantization. PerRow() and PerTensor() are supported currently
act_mapping_type (MappingType): The mapping type for activation quantization. only SYMMETRIC is supported currently
set_inductor_config (bool): if True, adjusts `torchinductor` settings to recommended values.
version (int): the version of the config
"""

scale: torch.Tensor
static_scale: Optional[torch.Tensor] = None
granularity: Granularity = PerRow()
act_mapping_type: Optional[MappingType] = MappingType.SYMMETRIC
set_inductor_config: bool = True
Expand All @@ -1669,6 +1669,14 @@ def __post_init__(self):
"torchao.quantization.Int8StaticActivationInt8WeightConfig"
)

# Validate activation granularity for static quantization
if isinstance(self.granularity, PerRow) and self.granularity.dim != -1:
raise ValueError(
f"Int8StaticActivationInt8WeightConfig only supports PerRow(dim=-1) "
f"for activation quantization, got PerRow(dim={self.granularity.dim}). "
f"Per-feature activation quantization is not supported due to slicing limitations."
)


@register_quantize_module_handler(Int8StaticActivationInt8WeightConfig)
def _int8_static_activation_int8_weight_transform(
Expand Down Expand Up @@ -1700,7 +1708,7 @@ def _int8_static_activation_int8_weight_transform(
granularity=activation_granularity,
mapping_type=config.act_mapping_type,
),
act_scale=config.scale.detach(),
act_scale=config.static_scale.detach(),
)

setattr(
Expand Down
20 changes: 17 additions & 3 deletions torchao/quantization/quantize_/workflows/int8/int8_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
#
# This source code is licensed under the BSD 3-Clause license found in the
# LICENSE file in the root directory of this source tree.

from dataclasses import dataclass
from typing import List, Optional

Expand Down Expand Up @@ -60,7 +59,7 @@ class Int8Tensor(TorchAOBaseTensor):
"""

tensor_data_names = ["qdata", "scale"]
optional_tensor_data_names = ["act_scale"]
optional_tensor_data_names = ["act_scale", "act_pre_scale"]
tensor_attribute_names = ["block_size", "dtype"]
optional_tensor_attribute_names = [
"act_quant_kwargs",
Expand All @@ -73,6 +72,7 @@ def __new__(
block_size: List[int],
dtype: torch.dtype,
act_scale=None,
act_pre_scale: Optional[torch.Tensor] = None,
act_quant_kwargs: Optional[QuantizeTensorToInt8Kwargs] = None,
):
kwargs = {
Expand All @@ -89,6 +89,7 @@ def __init__(
block_size: List[int],
dtype: torch.dtype,
act_scale=None,
act_pre_scale: Optional[torch.Tensor] = None,
act_quant_kwargs: Optional[QuantizeTensorToInt8Kwargs] = None,
):
super().__init__()
Expand All @@ -98,6 +99,7 @@ def __init__(
# don't set dtype because this gets done in __new__
self.act_quant_kwargs = act_quant_kwargs
self.act_scale = act_scale
self.act_pre_scale = act_pre_scale

def __repr__(self):
return (
Expand All @@ -106,6 +108,7 @@ def __repr__(self):
f"qdata={self.qdata}, "
f"scale={self.scale}, "
f"act_scale={self.act_scale}, "
f"act_pre_scale={self.act_scale}, "
f"block_size={self.block_size}, "
f"shape={self.shape}, "
f"device={self.device}, "
Expand All @@ -121,6 +124,7 @@ def from_hp(
scale: Optional[torch.Tensor] = None,
act_quant_kwargs: Optional[QuantizeTensorToInt8Kwargs] = None,
act_scale: Optional[torch.Tensor] = None,
act_pre_scale: Optional[torch.Tensor] = None,
):
"""Create Int8Tensor from high-precision tensor"""
block_size = get_block_size(hp_tensor.shape, granularity)
Expand Down Expand Up @@ -161,6 +165,7 @@ def from_hp(
block_size,
hp_tensor.dtype,
act_scale=act_scale,
act_pre_scale=act_pre_scale,
act_quant_kwargs=act_quant_kwargs,
)

Expand Down Expand Up @@ -198,13 +203,18 @@ def _(func, types, args, kwargs):

output_dtype = activation_tensor.dtype

# Apply activation pre-scaling if present (for AWQ, SmoothQuant, etc.)
if weight_tensor.act_pre_scale is not None:
activation_tensor = activation_tensor * weight_tensor.act_pre_scale

if weight_tensor.act_quant_kwargs is not None:
# for int8 dynamic + static quantization path

activation_tensor = _choose_quant_func_and_quantize_tensor(
activation_tensor,
weight_tensor.act_quant_kwargs,
scale=weight_tensor.act_scale,
)
# Dynamic activation quantization path

# 1. do the matrix form of dot(X_i, W_j)
#
Expand Down Expand Up @@ -292,6 +302,8 @@ def _(func, types, args, kwargs):
block_size,
self.dtype,
act_quant_kwargs=self.act_quant_kwargs,
act_scale=self.act_scale,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess slice doesn't work for static quant int8 before, can you add a test for that?

act_pre_scale=self.act_pre_scale,
),
)

Expand Down Expand Up @@ -322,6 +334,8 @@ def _(func, types, args, kwargs):
old_int8_tensor.scale[index],
old_int8_tensor.block_size[1:],
old_int8_tensor.dtype,
old_int8_tensor.act_scale,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same for this one, seems like select op breaks before with static quant

old_int8_tensor.act_pre_scale,
old_int8_tensor.act_quant_kwargs,
)
return return_and_correct_aliasing(func, args, kwargs, new_int8_tensor)
Expand Down
Loading