Skip to content

Commit cb98579

Browse files
committed
Enable range learning for QAT
**Summary:** This commit adds the option for QAT users to use range learning during training. Range learning means we train the scale and zero point instead of recomputing them based on the input at every iteration. Example usage: ``` import torch from torchao.quantization import quantize_ from torchao.quantization.qat import ( FakeQuantizeConfig, IntXQuantizationAwareTrainingConfig, initialize_fake_quantizers, ) config = FakeQuantizeConfig( torch.int8, "per_channel", is_dynamic=False, range_learning=True, scale_precision=torch.float32, zero_point_precision=torch.float32, ) m = M() example_inputs = (torch.randn(16, 32),) quantize_(m, IntXQuantizationAwareTrainingConfig(weight_config=config)) # New required step to turn scales and zero points into trainable # `nn.Parameters`, must be called before initializing the optimizer initialize_fake_quantizers(m, example_inputs) # initialize the optimizer # do training ``` **Test Plan:** python test/quantization/test_qat.py -k test_fake_quantize_config_dynamic_and_range_learning python test/quantization/test_qat.py -k test_fake_quantizer_range_learning python test/quantization/test_qat.py -k test_qat_range_learning
1 parent 12467d2 commit cb98579

File tree

8 files changed

+217
-19
lines changed

8 files changed

+217
-19
lines changed

test/quantization/test_qat.py

Lines changed: 117 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99

1010
import copy
1111
import unittest
12+
from typing import List
1213

1314
import torch
1415
import torch.nn.functional as F
@@ -26,7 +27,9 @@
2627
from torchao.quantization.qat.api import (
2728
ComposableQATQuantizer,
2829
FakeQuantizeConfig,
30+
IntXQuantizationAwareTrainingConfig,
2931
from_intx_quantization_aware_training,
32+
initialize_fake_quantizers,
3033
intx_quantization_aware_training,
3134
)
3235
from torchao.quantization.qat.embedding import (
@@ -99,6 +102,16 @@ def __init__(self):
99102
def example_inputs(self):
100103
return (torch.randn(1, 512).to(torch.float),)
101104

105+
def _get_all_weight_qparams(self) -> List[torch.Tensor]:
106+
return [
107+
self.linear1.weight_fake_quantizer.scale,
108+
self.linear1.weight_fake_quantizer.zero_point,
109+
self.sub.linear.weight_fake_quantizer.scale,
110+
self.sub.linear.weight_fake_quantizer.zero_point,
111+
self.linear2.weight_fake_quantizer.scale,
112+
self.linear2.weight_fake_quantizer.zero_point,
113+
]
114+
102115
def forward(self, x):
103116
x = self.linear1(x)
104117
x = self.sub(x)
@@ -996,6 +1009,21 @@ def test_fake_quantize_config_dtype(self):
9961009
FakeQuantizeConfig(TorchAODType.INT7, "per_token")
9971010
FakeQuantizeConfig(torch.int8, "per_token")
9981011

1012+
def test_fake_quantize_config_dynamic_and_range_learning(self):
1013+
"""
1014+
Test that `is_dynamic` and `range_learning` cannot both be set.
1015+
"""
1016+
FakeQuantizeConfig(
1017+
torch.int8, "per_channel", is_dynamic=True, range_learning=False
1018+
)
1019+
FakeQuantizeConfig(
1020+
torch.int8, "per_channel", is_dynamic=False, range_learning=True
1021+
)
1022+
with self.assertRaisesRegex(ValueError, "not compatible"):
1023+
FakeQuantizeConfig(
1024+
torch.int8, "per_channel", is_dynamic=True, range_learning=True
1025+
)
1026+
9991027
@unittest.skipIf(
10001028
not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower"
10011029
)
@@ -1513,6 +1541,95 @@ def test_qat_8da4w_prepare_vs_convert(self, dtype: torch.dtype):
15131541
)
15141542
self.assertEqual(len(non_inf_sqnr), 0, fail_message)
15151543

1544+
@unittest.skipIf(
1545+
not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower"
1546+
)
1547+
def test_fake_quantizer_range_learning(self):
1548+
"""
1549+
Test that range learning requires `FakeQuantizer`s to be initialized correctly.
1550+
"""
1551+
config = FakeQuantizeConfig(
1552+
torch.int8,
1553+
"per_channel",
1554+
is_dynamic=False,
1555+
range_learning=True,
1556+
scale_precision=torch.float32,
1557+
zero_point_precision=torch.float32,
1558+
)
1559+
fake_quantizer = FakeQuantizer(config)
1560+
example_inputs = (torch.randn(2, 3),)
1561+
1562+
# Not initialized, should fail
1563+
self.assertFalse(fake_quantizer._initialized)
1564+
self.assertIsNone(fake_quantizer.scale)
1565+
self.assertIsNone(fake_quantizer.zero_point)
1566+
with self.assertRaisesRegex(
1567+
ValueError,
1568+
"Please call `torchao.quantization.qat.initialize_fake_quantizers` "
1569+
"before initializing the optimizer and beginning training.",
1570+
):
1571+
fake_quantizer(*example_inputs)
1572+
1573+
# Should pass after initializing
1574+
initialize_fake_quantizers(fake_quantizer, example_inputs)
1575+
self.assertTrue(fake_quantizer._initialized)
1576+
self.assertIsInstance(fake_quantizer.scale, torch.nn.Parameter)
1577+
self.assertIsInstance(fake_quantizer.zero_point, torch.nn.Parameter)
1578+
self.assertTrue(fake_quantizer.scale.requires_grad)
1579+
self.assertTrue(fake_quantizer.zero_point.requires_grad)
1580+
fake_quantizer(*example_inputs)
1581+
1582+
@unittest.skipIf(
1583+
not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower"
1584+
)
1585+
def test_qat_range_learning(self):
1586+
"""
1587+
Test end-to-end QAT flow with range learning.
1588+
"""
1589+
config = FakeQuantizeConfig(
1590+
torch.int8,
1591+
"per_channel",
1592+
is_dynamic=False,
1593+
range_learning=True,
1594+
scale_precision=torch.float32,
1595+
zero_point_precision=torch.float32,
1596+
)
1597+
m = M()
1598+
example_inputs = m.example_inputs()
1599+
quantize_(m, IntXQuantizationAwareTrainingConfig(weight_config=config))
1600+
1601+
# Not initialized, should fail
1602+
for t in m._get_all_weight_qparams():
1603+
self.assertIsNone(t)
1604+
with self.assertRaisesRegex(
1605+
ValueError,
1606+
"Please call `torchao.quantization.qat.initialize_fake_quantizers` "
1607+
"before initializing the optimizer and beginning training.",
1608+
):
1609+
m(*example_inputs)
1610+
1611+
# Should pass after initializing
1612+
# All scales and zero points should be in `m.parameters()`
1613+
initialize_fake_quantizers(m, example_inputs)
1614+
params = set(m.parameters())
1615+
for t in m._get_all_weight_qparams():
1616+
self.assertIsInstance(t, torch.nn.Parameter)
1617+
self.assertTrue(t.requires_grad)
1618+
self.assertTrue(t in params)
1619+
m(*example_inputs)
1620+
1621+
# Simulate training
1622+
optimizer = torch.optim.SGD(
1623+
m.parameters(), lr=0.001, momentum=0.9, weight_decay=1e-5
1624+
)
1625+
loss_fn = torch.nn.CrossEntropyLoss()
1626+
target = torch.randn(1, 512).float()
1627+
out = m(*example_inputs)
1628+
loss = loss_fn(out, target)
1629+
optimizer.zero_grad()
1630+
loss.backward()
1631+
optimizer.step()
1632+
15161633

15171634
if __name__ == "__main__":
15181635
unittest.main()

third_party/cutlass

Submodule cutlass updated 507 files

torchao/quantization/qat/__init__.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
FromIntXQuantizationAwareTrainingConfig,
55
IntXQuantizationAwareTrainingConfig,
66
from_intx_quantization_aware_training,
7+
initialize_fake_quantizers,
78
intx_quantization_aware_training,
89
)
910
from .embedding import (
@@ -17,11 +18,12 @@
1718
__all__ = [
1819
"ComposableQATQuantizer",
1920
"FakeQuantizeConfig",
20-
"Int4WeightOnlyQATQuantizer",
21+
"FromIntXQuantizationAwareTrainingConfig",
2122
"Int4WeightOnlyEmbeddingQATQuantizer",
23+
"Int4WeightOnlyQATQuantizer",
2224
"Int8DynActInt4WeightQATQuantizer",
25+
"IntXQuantizationAwareTrainingConfig",
26+
"initialize_fake_quantizers",
2327
"intx_quantization_aware_training",
2428
"from_intx_quantization_aware_training",
25-
"FromIntXQuantizationAwareTrainingConfig",
26-
"IntXQuantizationAwareTrainingConfig",
2729
]

torchao/quantization/qat/api.py

Lines changed: 27 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
# LICENSE file in the root directory of this source tree.
66

77
from dataclasses import dataclass
8-
from typing import Any, List, Optional, Union
8+
from typing import Any, List, Optional, Tuple, Union
99

1010
import torch
1111

@@ -51,7 +51,8 @@ class FakeQuantizeConfig:
5151
zero_point_precision: zero point dtype (default torch.int32)
5252
zero_point_domain: whether zero point is in integer (default) or float domain
5353
is_dynamic: whether to use dynamic (default) or static scale and zero points
54-
range_learning: whether to learn scale and zero points during training (coming soon)
54+
range_learning: whether to learn scale and zero points during training
55+
(default false), not compatible with `is_dynamic`.
5556
5657
kwargs (optional):
5758
group_size: size of each group in per group fake quantization,
@@ -120,6 +121,10 @@ def __init__(
120121
"Unsupported dtype '%s', choose from %s" % (dtype, all_dtypes)
121122
)
122123

124+
# Dynamic is not compatible with range learning
125+
if is_dynamic and range_learning:
126+
raise ValueError("`is_dynamic` is not compatible with `range_learning`")
127+
123128
def _get_granularity(
124129
self,
125130
granularity: Union[Granularity, str, None],
@@ -391,3 +396,23 @@ def convert(
391396
for quantizer in self.quantizers:
392397
model = quantizer.convert(model)
393398
return model
399+
400+
401+
def initialize_fake_quantizers(
402+
model: torch.nn.Module,
403+
example_inputs: Tuple[Any, ...],
404+
) -> None:
405+
"""
406+
Initialize the scales and zero points on all
407+
:class:`~`torchao.quantization.qat.fake_quantizer.FakeQuantizer`
408+
in the model based on the provided example inputs.
409+
"""
410+
# avoid circular dependencies
411+
from torchao.quantization.qat.fake_quantizer import FakeQuantizer
412+
413+
def _set_initialized(m: torch.nn.Module):
414+
if isinstance(m, FakeQuantizer):
415+
m._initialized = True
416+
417+
model.apply(_set_initialized)
418+
model(*example_inputs)

torchao/quantization/qat/embedding.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,7 @@ def to_embedding(self) -> torch.nn.Embedding:
9292
self.scale_grad_by_freq,
9393
self.sparse,
9494
device=self.weight.device,
95+
dtype=self.weight.dtype,
9596
)
9697
# In distributed training, the model may be instantiated
9798
# on the meta device, in which case there is no need to
@@ -116,6 +117,7 @@ def from_embedding(
116117
mod.sparse,
117118
weight_config=weight_config,
118119
device=mod.weight.device,
120+
dtype=mod.weight.dtype,
119121
)
120122
# In distributed training, the model may be instantiated
121123
# on the meta device, in which case there is no need to

torchao/quantization/qat/fake_quantizer.py

Lines changed: 41 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
from .utils import (
3232
_fake_quantize_per_channel_group,
3333
_fake_quantize_per_token,
34+
_Round,
3435
)
3536

3637

@@ -46,32 +47,43 @@ def __init__(self, config: FakeQuantizeConfig):
4647
self.scale: Optional[torch.Tensor] = None
4748
self.zero_point: Optional[torch.Tensor] = None
4849

49-
# TODO: support range learinng
50-
if self.config.range_learning:
51-
raise NotImplementedError("Range learning is not supported yet")
50+
# For range learning only
51+
# TODO: make this configurable?
52+
self._scale_eps = 1e-9
53+
self._initialized = False
5254

53-
def forward(self, x: torch.Tensor):
55+
def forward(self, x: torch.Tensor) -> torch.Tensor:
5456
"""
5557
Apply fake quantization to the tensor based on the bit-width,
5658
granularity, symmetry, and other properties specified in the config.
5759
"""
5860
if not self.enabled:
5961
return x
6062

63+
if (
64+
self.config.range_learning
65+
and not self._initialized
66+
and (self.scale is None or self.zero_point is None)
67+
):
68+
raise ValueError(
69+
"Scales and zero points must be initialized for range learning. "
70+
"Please call `torchao.quantization.qat.initialize_fake_quantizers` "
71+
"before initializing the optimizer and beginning training."
72+
)
73+
6174
if isinstance(self.config.granularity, PerToken):
6275
return self._per_token_forward(x)
6376
elif isinstance(self.config.granularity, (PerAxis, PerGroup)):
6477
return self._per_channel_or_group_forward(x)
6578
else:
6679
raise ValueError("Unknown granularity '%s'" % self.config.granularity)
6780

68-
def _per_token_forward(self, x: torch.Tensor):
81+
def _per_token_forward(self, x: torch.Tensor) -> torch.Tensor:
6982
"""
7083
Perform per token fake quantization on the tensor.
7184
"""
7285
if self.config.is_symmetric:
7386
raise NotImplementedError("Symmetric per token is not supported yet")
74-
7587
qmin, qmax = _DTYPE_TO_QVALUE_BOUNDS[self.config.dtype]
7688
if self._should_compute_qparams():
7789
self.scale, self.zero_point = choose_qparams_affine(
@@ -84,9 +96,10 @@ def _per_token_forward(self, x: torch.Tensor):
8496
scale_dtype=self.config.scale_precision,
8597
zero_point_dtype=self.config.zero_point_precision,
8698
)
99+
self._maybe_update_qparams_for_range_learning()
87100
return _fake_quantize_per_token(x, self.scale, self.zero_point, qmin, qmax)
88101

89-
def _per_channel_or_group_forward(self, x: torch.Tensor):
102+
def _per_channel_or_group_forward(self, x: torch.Tensor) -> torch.Tensor:
90103
"""
91104
Perform per channel or per group fake quantization on the tensor.
92105
We express per channel using per group where the group size is the size
@@ -126,6 +139,7 @@ def _per_channel_or_group_forward(self, x: torch.Tensor):
126139
scale_precision,
127140
)
128141
self.zero_point = self.zero_point.to(zero_point_precision)
142+
self._maybe_update_qparams_for_range_learning()
129143

130144
qmin, qmax = _DTYPE_TO_QVALUE_BOUNDS[self.config.dtype]
131145
return _fake_quantize_per_channel_group(
@@ -144,6 +158,26 @@ def _should_compute_qparams(self) -> bool:
144158
"""
145159
return self.config.is_dynamic or self.scale is None or self.zero_point is None
146160

161+
def _maybe_update_qparams_for_range_learning(self) -> None:
162+
"""
163+
If range learning is enabled, turn scales and zero points into trainable parameters.
164+
This function is idempotent and should only be called once.
165+
"""
166+
if (
167+
not self.config.range_learning
168+
or isinstance(self.scale, torch.nn.Parameter)
169+
or isinstance(self.zero_point, torch.nn.Parameter)
170+
):
171+
return
172+
scale, zero_point = self.scale, self.zero_point
173+
qmin, qmax = _DTYPE_TO_QVALUE_BOUNDS[self.config.dtype]
174+
# Stabilize range learning
175+
scale = torch.clamp(scale, min=self._scale_eps)
176+
zero_point = _Round.apply(zero_point)
177+
zero_point = torch.clamp(zero_point, qmin, qmax)
178+
self.scale = torch.nn.Parameter(scale, requires_grad=True)
179+
self.zero_point = torch.nn.Parameter(zero_point, requires_grad=True)
180+
147181
def __repr__(self) -> str:
148182
"""
149183
Return a human readable representation of this `FakeQuantizer` with config details.

torchao/quantization/qat/linear.py

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
_replace_linear_int4,
1919
groupwise_affine_quantize_tensor,
2020
)
21+
from torchao.quantization.granularity import PerGroup
2122
from torchao.quantization.quant_primitives import (
2223
TorchAODType,
2324
ZeroPointDomain,
@@ -83,12 +84,13 @@ def __init__(
8384

8485
# initialize weight fake quantizer
8586
if weight_config is not None:
86-
group_size = weight_config.group_size
87-
if group_size is not None and in_features % group_size != 0:
88-
raise ValueError(
89-
"in_features (%s) %% group_size (%s) must be == 0"
90-
% (in_features, group_size)
91-
)
87+
if isinstance(weight_config.granularity, PerGroup):
88+
group_size = weight_config.group_size
89+
if group_size is not None and in_features % group_size != 0:
90+
raise ValueError(
91+
"in_features (%s) %% group_size (%s) must be == 0"
92+
% (in_features, group_size)
93+
)
9294
self.weight_fake_quantizer = FakeQuantizer(weight_config)
9395
else:
9496
self.weight_fake_quantizer = None
@@ -108,6 +110,7 @@ def to_linear(self) -> torch.nn.Linear:
108110
self.out_features,
109111
self.bias is not None,
110112
device=self.weight.device,
113+
dtype=self.weight.dtype,
111114
)
112115
# In distributed training, the model may be instantiated
113116
# on the meta device, in which case there is no need to
@@ -131,6 +134,7 @@ def from_linear(
131134
activation_config=activation_config,
132135
weight_config=weight_config,
133136
device=mod.weight.device,
137+
dtype=mod.weight.dtype,
134138
)
135139
# In distributed training, the model may be instantiated
136140
# on the meta device, in which case there is no need to

0 commit comments

Comments
 (0)