Skip to content

Commit 3277221

Browse files
committed
Enable range learning for QAT
**Summary:** This commit adds the option for QAT users to use range learning during training. Range learning means we train the scale and zero point instead of recomputing them based on the input at every iteration. Example usage: ``` import torch from torchao.quantization import quantize_ from torchao.quantization.qat import ( FakeQuantizeConfig, IntXQuantizationAwareTrainingConfig, initialize_fake_quantizers, ) config = FakeQuantizeConfig( torch.int8, "per_channel", is_dynamic=False, range_learning=True, scale_precision=torch.float32, zero_point_precision=torch.float32, ) m = M() example_inputs = (torch.randn(16, 32),) quantize_(m, IntXQuantizationAwareTrainingConfig(weight_config=config)) # New required step to turn scales and zero points into trainable # `nn.Parameters`, must be called before initializing the optimizer initialize_fake_quantizers(m, example_inputs) # initialize the optimizer # do training ``` **Test Plan:** python test/quantization/test_qat.py -k test_fake_quantize_config_dynamic_and_range_learning python test/quantization/test_qat.py -k test_fake_quantizer_range_learning python test/quantization/test_qat.py -k test_qat_range_learning
1 parent 31f119e commit 3277221

File tree

6 files changed

+212
-17
lines changed

6 files changed

+212
-17
lines changed

test/quantization/test_qat.py

+117
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99

1010
import copy
1111
import unittest
12+
from typing import List
1213

1314
import torch
1415
import torch.nn.functional as F
@@ -25,7 +26,9 @@
2526
from torchao.quantization.qat.api import (
2627
ComposableQATQuantizer,
2728
FakeQuantizeConfig,
29+
IntXQuantizationAwareTrainingConfig,
2830
from_intx_quantization_aware_training,
31+
initialize_fake_quantizers,
2932
intx_quantization_aware_training,
3033
)
3134
from torchao.quantization.qat.embedding import (
@@ -95,6 +98,16 @@ def __init__(self):
9598
def example_inputs(self):
9699
return (torch.randn(1, 512).to(torch.float),)
97100

101+
def _get_all_weight_qparams(self) -> List[torch.Tensor]:
102+
return [
103+
self.linear1.weight_fake_quantizer.scale,
104+
self.linear1.weight_fake_quantizer.zero_point,
105+
self.sub.linear.weight_fake_quantizer.scale,
106+
self.sub.linear.weight_fake_quantizer.zero_point,
107+
self.linear2.weight_fake_quantizer.scale,
108+
self.linear2.weight_fake_quantizer.zero_point,
109+
]
110+
98111
def forward(self, x):
99112
x = self.linear1(x)
100113
x = self.sub(x)
@@ -980,6 +993,21 @@ def test_fake_quantize_config_dtype(self):
980993
FakeQuantizeConfig(TorchAODType.INT7, "per_token")
981994
FakeQuantizeConfig(torch.int8, "per_token")
982995

996+
def test_fake_quantize_config_dynamic_and_range_learning(self):
997+
"""
998+
Test that `is_dynamic` and `range_learning` cannot both be set.
999+
"""
1000+
FakeQuantizeConfig(
1001+
torch.int8, "per_token", is_dynamic=True, range_learning=False
1002+
)
1003+
FakeQuantizeConfig(
1004+
torch.int8, "per_token", is_dynamic=False, range_learning=True
1005+
)
1006+
with self.assertRaisesRegex(ValueError, "not compatible"):
1007+
FakeQuantizeConfig(
1008+
torch.int8, "per_token", is_dynamic=True, range_learning=True
1009+
)
1010+
9831011
@unittest.skipIf(
9841012
not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower"
9851013
)
@@ -1486,6 +1514,95 @@ def test_qat_8da4w_prepare_vs_convert(self):
14861514
)
14871515
self.assertEqual(len(non_inf_sqnr), 0, fail_message)
14881516

1517+
@unittest.skipIf(
1518+
not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower"
1519+
)
1520+
def test_fake_quantizer_range_learning(self):
1521+
"""
1522+
Test that range learning requires `FakeQuantizer`s to be initialized correctly.
1523+
"""
1524+
config = FakeQuantizeConfig(
1525+
torch.int8,
1526+
"per_channel",
1527+
is_dynamic=False,
1528+
range_learning=True,
1529+
scale_precision=torch.float32,
1530+
zero_point_precision=torch.float32,
1531+
)
1532+
fake_quantizer = FakeQuantizer(config)
1533+
example_inputs = (torch.randn(2, 3),)
1534+
1535+
# Not initialized, should fail
1536+
self.assertFalse(fake_quantizer._initialized)
1537+
self.assertIsNone(fake_quantizer.scale)
1538+
self.assertIsNone(fake_quantizer.zero_point)
1539+
with self.assertRaisesRegex(
1540+
ValueError,
1541+
"Please call `torchao.quantization.qat.initialize_fake_quantizers` "
1542+
"before initializing the optimizer and beginning training.",
1543+
):
1544+
fake_quantizer(*example_inputs)
1545+
1546+
# Should pass after initializing
1547+
initialize_fake_quantizers(fake_quantizer, example_inputs)
1548+
self.assertTrue(fake_quantizer._initialized)
1549+
self.assertIsInstance(fake_quantizer.scale, torch.nn.Parameter)
1550+
self.assertIsInstance(fake_quantizer.zero_point, torch.nn.Parameter)
1551+
self.assertTrue(fake_quantizer.scale.requires_grad)
1552+
self.assertTrue(fake_quantizer.zero_point.requires_grad)
1553+
fake_quantizer(*example_inputs)
1554+
1555+
@unittest.skipIf(
1556+
not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower"
1557+
)
1558+
def test_qat_range_learning(self):
1559+
"""
1560+
Test that range learning requires `FakeQuantizer`s to be initialized correctly.
1561+
"""
1562+
config = FakeQuantizeConfig(
1563+
torch.int8,
1564+
"per_channel",
1565+
is_dynamic=False,
1566+
range_learning=True,
1567+
scale_precision=torch.float32,
1568+
zero_point_precision=torch.float32,
1569+
)
1570+
m = M()
1571+
example_inputs = m.example_inputs()
1572+
quantize_(m, IntXQuantizationAwareTrainingConfig(weight_config=config))
1573+
1574+
# Not initialized, should fail
1575+
for t in m._get_all_weight_qparams():
1576+
self.assertIsNone(t)
1577+
with self.assertRaisesRegex(
1578+
ValueError,
1579+
"Please call `torchao.quantization.qat.initialize_fake_quantizers` "
1580+
"before initializing the optimizer and beginning training.",
1581+
):
1582+
m(*example_inputs)
1583+
1584+
# Should pass after initializing
1585+
# All scales and zero points should be in `m.parameters()`
1586+
initialize_fake_quantizers(m, example_inputs)
1587+
params = set(m.parameters())
1588+
for t in m._get_all_weight_qparams():
1589+
self.assertIsInstance(t, torch.nn.Parameter)
1590+
self.assertTrue(t.requires_grad)
1591+
self.assertTrue(t in params)
1592+
m(*example_inputs)
1593+
1594+
# Simulate training
1595+
optimizer = torch.optim.SGD(
1596+
m.parameters(), lr=0.001, momentum=0.9, weight_decay=1e-5
1597+
)
1598+
loss_fn = torch.nn.CrossEntropyLoss()
1599+
target = torch.randn(1, 512).float()
1600+
out = m(*example_inputs)
1601+
loss = loss_fn(out, target)
1602+
optimizer.zero_grad()
1603+
loss.backward()
1604+
optimizer.step()
1605+
14891606

14901607
if __name__ == "__main__":
14911608
unittest.main()

torchao/quantization/qat/__init__.py

+5-3
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
FromIntXQuantizationAwareTrainingConfig,
55
IntXQuantizationAwareTrainingConfig,
66
from_intx_quantization_aware_training,
7+
initialize_fake_quantizers,
78
intx_quantization_aware_training,
89
)
910
from .embedding import (
@@ -17,11 +18,12 @@
1718
__all__ = [
1819
"ComposableQATQuantizer",
1920
"FakeQuantizeConfig",
20-
"Int4WeightOnlyQATQuantizer",
21+
"FromIntXQuantizationAwareTrainingConfig",
2122
"Int4WeightOnlyEmbeddingQATQuantizer",
23+
"Int4WeightOnlyQATQuantizer",
2224
"Int8DynActInt4WeightQATQuantizer",
25+
"IntXQuantizationAwareTrainingConfig",
26+
"initialize_fake_quantizers",
2327
"intx_quantization_aware_training",
2428
"from_intx_quantization_aware_training",
25-
"FromIntXQuantizationAwareTrainingConfig",
26-
"IntXQuantizationAwareTrainingConfig",
2729
]

torchao/quantization/qat/api.py

+27-2
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
# LICENSE file in the root directory of this source tree.
66

77
from dataclasses import dataclass
8-
from typing import Any, List, Optional, Union
8+
from typing import Any, List, Optional, Tuple, Union
99

1010
import torch
1111

@@ -51,7 +51,8 @@ class FakeQuantizeConfig:
5151
zero_point_precision: zero point dtype (default torch.int32)
5252
zero_point_domain: whether zero point is in integer (default) or float domain
5353
is_dynamic: whether to use dynamic (default) or static scale and zero points
54-
range_learning: whether to learn scale and zero points during training (coming soon)
54+
range_learning: whether to learn scale and zero points during training,
55+
not compatible with `is_dynamic`.
5556
5657
kwargs (optional):
5758
group_size: size of each group in per group fake quantization,
@@ -120,6 +121,10 @@ def __init__(
120121
"Unsupported dtype '%s', choose from %s" % (dtype, all_dtypes)
121122
)
122123

124+
# Dynamic is not compatible with range learning
125+
if is_dynamic and range_learning:
126+
raise ValueError("`is_dynamic` is not compatible with `range_learning`")
127+
123128
def _get_granularity(
124129
self,
125130
granularity: Union[Granularity, str, None],
@@ -391,3 +396,23 @@ def convert(
391396
for quantizer in self.quantizers:
392397
model = quantizer.convert(model)
393398
return model
399+
400+
401+
def initialize_fake_quantizers(
402+
model: torch.nn.Module,
403+
example_inputs: Tuple[Any, ...],
404+
) -> None:
405+
"""
406+
Initialize the scales and zero points on all
407+
:class:`~`torchao.quantization.qat.fake_quantizer.FakeQuantizer`
408+
in the model based on the provided example inputs.
409+
"""
410+
# avoid circular dependencies
411+
from torchao.quantization.qat.fake_quantizer import FakeQuantizer
412+
413+
def _set_initialized(m: torch.nn.Module):
414+
if isinstance(m, FakeQuantizer):
415+
m._initialized = True
416+
417+
model.apply(_set_initialized)
418+
model(*example_inputs)

torchao/quantization/qat/fake_quantizer.py

+41-6
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
_choose_qparams_per_token_asymmetric,
3030
_fake_quantize_per_channel_group,
3131
_fake_quantize_per_token,
32+
_Round,
3233
)
3334

3435

@@ -44,26 +45,38 @@ def __init__(self, config: FakeQuantizeConfig):
4445
self.scale: Optional[torch.Tensor] = None
4546
self.zero_point: Optional[torch.Tensor] = None
4647

47-
# TODO: support range learinng
48-
if self.config.range_learning:
49-
raise NotImplementedError("Range learning is not supported yet")
48+
# For range learning only
49+
# TODO: make this configurable?
50+
self._scale_eps = 1e-9
51+
self._initialized = False
5052

51-
def forward(self, x: torch.Tensor):
53+
def forward(self, x: torch.Tensor) -> torch.Tensor:
5254
"""
5355
Apply fake quantization to the tensor based on the bit-width,
5456
granularity, symmetry, and other properties specified in the config.
5557
"""
5658
if not self.enabled:
5759
return x
5860

61+
if (
62+
self.config.range_learning
63+
and not self._initialized
64+
and (self.scale is None or self.zero_point is None)
65+
):
66+
raise ValueError(
67+
"Scales and zero points must be initialized for range learning. "
68+
"Please call `torchao.quantization.qat.initialize_fake_quantizers` "
69+
"before initializing the optimizer and beginning training."
70+
)
71+
5972
if isinstance(self.config.granularity, PerToken):
6073
return self._per_token_forward(x)
6174
elif isinstance(self.config.granularity, (PerAxis, PerGroup)):
6275
return self._per_channel_or_group_forward(x)
6376
else:
6477
raise ValueError("Unknown granularity '%s'" % self.config.granularity)
6578

66-
def _per_token_forward(self, x: torch.Tensor):
79+
def _per_token_forward(self, x: torch.Tensor) -> torch.Tensor:
6780
"""
6881
Perform per token fake quantization on the tensor.
6982
"""
@@ -75,10 +88,11 @@ def _per_token_forward(self, x: torch.Tensor):
7588
self.config.scale_precision,
7689
self.config.zero_point_precision,
7790
)
91+
self._maybe_update_qparams_for_range_learning()
7892
qmin, qmax = _DTYPE_TO_QVALUE_BOUNDS[self.config.dtype]
7993
return _fake_quantize_per_token(x, self.scale, self.zero_point, qmin, qmax)
8094

81-
def _per_channel_or_group_forward(self, x: torch.Tensor):
95+
def _per_channel_or_group_forward(self, x: torch.Tensor) -> torch.Tensor:
8296
"""
8397
Perform per channel or per group fake quantization on the tensor.
8498
We express per channel using per group where the group size is the size
@@ -117,6 +131,7 @@ def _per_channel_or_group_forward(self, x: torch.Tensor):
117131
scale_precision,
118132
)
119133
self.zero_point = self.zero_point.to(zero_point_precision)
134+
self._maybe_update_qparams_for_range_learning()
120135

121136
qmin, qmax = _DTYPE_TO_QVALUE_BOUNDS[self.config.dtype]
122137
return _fake_quantize_per_channel_group(
@@ -135,6 +150,26 @@ def _should_compute_qparams(self) -> bool:
135150
"""
136151
return self.config.is_dynamic or self.scale is None or self.zero_point is None
137152

153+
def _maybe_update_qparams_for_range_learning(self) -> None:
154+
"""
155+
If range learning is enabled, turn scales and zero points into trainable parameters.
156+
This function is idempotent and should only be called once.
157+
"""
158+
if (
159+
not self.config.range_learning
160+
or isinstance(self.scale, torch.nn.Parameter)
161+
or isinstance(self.zero_point, torch.nn.Parameter)
162+
):
163+
return
164+
scale, zero_point = self.scale, self.zero_point
165+
qmin, qmax = _DTYPE_TO_QVALUE_BOUNDS[self.config.dtype]
166+
# Stabilize range learning
167+
scale = torch.clamp(scale, min=self._scale_eps)
168+
zero_point = _Round.apply(zero_point)
169+
zero_point = torch.clamp(zero_point, qmin, qmax)
170+
self.scale = torch.nn.Parameter(scale, requires_grad=True)
171+
self.zero_point = torch.nn.Parameter(zero_point, requires_grad=True)
172+
138173
def __repr__(self) -> str:
139174
"""
140175
Return a human readable representation of this `FakeQuantizer` with config details.

torchao/quantization/qat/linear.py

+8-6
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
_replace_linear_int4,
1919
groupwise_affine_quantize_tensor,
2020
)
21+
from torchao.quantization.granularity import PerGroup
2122
from torchao.quantization.quant_primitives import (
2223
TorchAODType,
2324
ZeroPointDomain,
@@ -83,12 +84,13 @@ def __init__(
8384

8485
# initialize weight fake quantizer
8586
if weight_config is not None:
86-
group_size = weight_config.group_size
87-
if group_size is not None and in_features % group_size != 0:
88-
raise ValueError(
89-
"in_features (%s) %% group_size (%s) must be == 0"
90-
% (in_features, group_size)
91-
)
87+
if isinstance(weight_config.granularity, PerGroup):
88+
group_size = weight_config.group_size
89+
if group_size is not None and in_features % group_size != 0:
90+
raise ValueError(
91+
"in_features (%s) %% group_size (%s) must be == 0"
92+
% (in_features, group_size)
93+
)
9294
self.weight_fake_quantizer = FakeQuantizer(weight_config)
9395
else:
9496
self.weight_fake_quantizer = None

torchao/quantization/qat/utils.py

+14
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,20 @@ def backward(ctx, gy):
9191
return (gy,)
9292

9393

94+
class _Round(torch.autograd.Function):
95+
"""
96+
Implementation of generic round operation with backward STE.
97+
"""
98+
99+
@staticmethod
100+
def forward(ctx, x: torch.Tensor) -> torch.Tensor:
101+
return torch.round(x)
102+
103+
@staticmethod
104+
def backward(ctx, gy: torch.Tensor) -> torch.Tensor:
105+
return gy
106+
107+
94108
def _fake_quantize_per_channel_group(
95109
input: torch.Tensor,
96110
scales: torch.Tensor,

0 commit comments

Comments
 (0)