Skip to content

Commit 25814b9

Browse files
committed
Update README, add Int4UnifTorchaoQuantizer
1 parent 667dfe7 commit 25814b9

File tree

5 files changed

+119
-38
lines changed

5 files changed

+119
-38
lines changed

test/prototype/test_parq.py

+10-25
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
QuantOptimizer,
1919
)
2020
from torchao.prototype.parq.quant import (
21+
Int4UnifTorchaoQuantizer,
2122
LSBQuantizer,
2223
UnifQuantizer,
2324
UnifTorchaoQuantizer,
@@ -29,7 +30,6 @@
2930
_is_linear,
3031
int4_weight_only,
3132
)
32-
from torchao.quantization.quant_primitives import ZeroPointDomain
3333
from torchao.utils import (
3434
TORCH_VERSION_AT_LEAST_2_4,
3535
TORCH_VERSION_AT_LEAST_2_6,
@@ -139,20 +139,6 @@ class TestUnifTorchaoQuantizer(common_utils.TestCase):
139139
def setUp(self):
140140
torch.manual_seed(123)
141141

142-
@staticmethod
143-
def int4_torchao_quantizer(config, b: int = 4):
144-
# based off torchao.quantization.quant_api._int4_weight_only_transform
145-
return UnifTorchaoQuantizer(
146-
symmetric=False,
147-
target_dtype=torch.int32,
148-
quant_min=0,
149-
quant_max=2**b - 1,
150-
eps=1e-6,
151-
preserve_zero=False,
152-
zero_point_domain=ZeroPointDomain.FLOAT,
153-
config=config,
154-
)
155-
156142
def compare_quantized_models(
157143
self,
158144
model: nn.Module,
@@ -190,8 +176,9 @@ def test_int4_weight_only(self, group_size: int = 32):
190176
quantize_(m_ref, config)
191177

192178
b = 4
193-
quantizer = self.int4_torchao_quantizer(config)
194-
self.compare_quantized_models(model, m_ref, quantizer, b, group_size)
179+
self.compare_quantized_models(
180+
model, m_ref, Int4UnifTorchaoQuantizer(), b, group_size
181+
)
195182

196183
@unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_6, "Test only enabled for 2.6+")
197184
@common_utils.parametrize("b", [2, 3, 4, 8])
@@ -208,11 +195,7 @@ def test_intx_weight_only(self, b: int = 2, group_size: int = 32):
208195
),
209196
)
210197

211-
quantizer = UnifTorchaoQuantizer(
212-
symmetric=True,
213-
preserve_zero=True,
214-
zero_point_domain=ZeroPointDomain.INT,
215-
)
198+
quantizer = UnifTorchaoQuantizer(symmetric=True)
216199
self.compare_quantized_models(model, m_ref, quantizer, b, group_size)
217200

218201
@unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_4, "Test only enabled for 2.4+")
@@ -235,9 +218,11 @@ def test_int4_weight_only_e2e(self, group_size: int = 32):
235218
]
236219
base_optimizer = torch.optim.AdamW(param_groups)
237220

238-
quantizer = self.int4_torchao_quantizer(config)
239221
optimizer = QuantOptimizer(
240-
base_optimizer, quantizer, ProxHardQuant(), quant_per_channel=True
222+
base_optimizer,
223+
Int4UnifTorchaoQuantizer(),
224+
ProxHardQuant(),
225+
quant_per_channel=True,
241226
)
242227

243228
# do not update model weights, just quantize
@@ -248,7 +233,7 @@ def test_int4_weight_only_e2e(self, group_size: int = 32):
248233

249234
# equivalent to torchao's convert step
250235
model.eval()
251-
optimizer.torchao_quantize_(model)
236+
optimizer.torchao_quantize_(model, config)
252237

253238
for n, module in model.named_modules():
254239
if not _is_linear(module):

torchao/prototype/parq/README.md

+85-4
Original file line numberDiff line numberDiff line change
@@ -7,23 +7,104 @@ This library applies QAT without modifying model-level code. It instead interfac
77
* quantization method: computing the best set of discrete, quantized values
88
* proximal mapping: projection of weights onto quantized values
99

10+
11+
## PARQ vs. torchao
12+
13+
There are two main QAT interfaces in torchao:
14+
15+
- Swap modules (e.g., `torch.nn.Linear`) with their quantized counterparts (e.g., `Int4WeightOnlyQATLinear`). See [Quantizer API (legacy)](torchao/quantization/qat#quantizer-api-legacy) for details.
16+
- Replace instances of `torch.Tensor` with a quantized tensor subclass such as `AffineQuantizedTensor`. The [`quantize_` API](quantization/qat#quantize_-api-recommended) uses this method by default.
17+
18+
PARQ is conceptually more similar to the tensor subclass interface. It quantizes tensors through the optimizer's parameter groups, leaving the model untouched.
19+
20+
An example PARQ flow and its torchao equivalent are shown below. The prepare stage occurs before training, while the convert stage runs after training to produce a quantized model.
21+
22+
<table>
23+
<tr>
24+
<td align="center"><b>stage</b><td align="center"><b>PARQ</b></td><td align="center"><b>torchao</b></td>
25+
</tr>
26+
<tr>
27+
<td>prepare</td>
28+
<td valign="top">
29+
30+
```python
31+
from torchao.prototype.parq.optim import QuantOptimizer
32+
from torchao.prototype.parq.quant import UnifTorchaoQuantizer
33+
34+
param_groups = [
35+
{"params": params_quant, "quant_bits": 4, "quant_block_size": 32},
36+
{"params": params_no_quant},
37+
]
38+
base_optimizer = torch.optim.AdamW(param_groups, ...)
39+
optimizer = QuantOptimizer(
40+
base_optimizer,
41+
UnifTorchaoQuantizer(symmetric=True),
42+
ProxHardQuant(),
43+
quant_per_channel=True,
44+
)
45+
```
46+
47+
</td>
48+
<td valign="top">
49+
50+
```python
51+
from torchao.quantization.qat import (
52+
FakeQuantizeConfig,
53+
intx_quantization_aware_training,
54+
)
55+
56+
weight_config = FakeQuantizeConfig(torch.int4, group_size=32)
57+
quantize_(
58+
model,
59+
intx_quantization_aware_training(weight_config=weight_config),
60+
)
61+
```
62+
63+
</td>
64+
</tr>
65+
<tr>
66+
<td>convert</td>
67+
<td valign="top">
68+
69+
```python
70+
config = IntXWeightOnlyConfig(
71+
weight_dtype=torch.int4, granularity=PerGroup(32)
72+
)
73+
optimizer.torchao_quantize_(model, config)
74+
```
75+
76+
</td>
77+
<td valign="top">
78+
79+
```python
80+
from torchao.quantization.qat import from_intx_quantization_aware_training
81+
82+
quantize_(model, from_intx_quantization_aware_training())
83+
```
84+
85+
</td>
86+
</tr>
87+
</table>
88+
89+
Note that `UnifTorchaoQuantizer` calls the same quantization primitives as torchao to match the numerics (see [Affine Quantization Details](torchao/quantization#affine-quantization-details)).
90+
1091
## QAT arguments
1192

1293
| | description | choices |
1394
| --- | --- | --- |
14-
| `quant-bits` | bit-width for quantized weights | 0 (ternary), 14 |
95+
| `quant-bits` | bit-width for quantized weights | 0 (ternary), 1-4 |
1596
| `quant-method` | method for determining quantized values | `lsbq`, `uniform` |
1697
| `quant-proxmap` | proximal mapping to project weights onto quantized values | `hard`, `parq`, `binaryrelax` |
1798
| `anneal-start` | start epoch for QAT annealing period | (0, `total_steps` - 1) |
1899
| `anneal-end` | end epoch for QAT annealing period | (`anneal_end`, `total_steps`) |
19-
| `anneal-steepness` | sigmoid steepness for PARQ inverse slope schedule | 25—100 |
100+
| `anneal-steepness` | sigmoid steepness for PARQ inverse slope schedule | 1-20 |
20101

21102
## Optimizer-only interface
22103

23104
The `QuantOptimizer` wrapper takes any `torch.optim.Optimizer` object. It is also initialized with a `Quantizer` and `ProxMap` object. Integration into new training pipelines is simple:
24105
```python
25-
from parq.optim import ProxPARQ, QuantOptimizer
26-
from parq.quant import LSBQuantizer
106+
from torchao.prototype.parq.optim import ProxPARQ, QuantOptimizer
107+
from torchao.prototype.parq.quant import LSBQuantizer
27108

28109

29110
# split params into quantizable and non-quantizable params

torchao/prototype/parq/optim/quantopt.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
from torch.optim import Optimizer
1515

1616
from torchao import quantize_
17+
from torchao.core.config import AOBaseConfig
1718

1819
from ..quant import Quantizer
1920
from ..utils import HAS_DTENSOR, is_dtensor
@@ -109,9 +110,8 @@ def quantize_(
109110
return q
110111

111112
@torch.no_grad()
112-
def torchao_quantize_(self, model):
113-
assert hasattr(self.quantizer, "config"), "Missing self.quantizer.config"
114-
113+
def torchao_quantize_(self, model: torch.nn.Module, config: AOBaseConfig):
114+
"""Recursively call torchao.quantize_ on model using given config."""
115115
self.restore_latent_params()
116116
param_set = {
117117
p.data_ptr()
@@ -123,7 +123,7 @@ def inner_quantize_(model):
123123
for module in model.children():
124124
for param in module.parameters(recurse=False):
125125
if param.data_ptr() in param_set:
126-
quantize_(module, self.quantizer.config)
126+
quantize_(module, config)
127127
break
128128

129129
inner_quantize_(module)

torchao/prototype/parq/quant/__init__.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -11,4 +11,7 @@
1111
TernaryUnifQuantizer,
1212
UnifQuantizer,
1313
)
14-
from .uniform_torchao import UnifTorchaoQuantizer # noqa: F401
14+
from .uniform_torchao import ( # noqa: F401
15+
Int4UnifTorchaoQuantizer,
16+
UnifTorchaoQuantizer,
17+
)

torchao/prototype/parq/quant/uniform_torchao.py

+16-4
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@
99
import torch
1010
from torch import Tensor
1111

12-
from torchao.core.config import AOBaseConfig
1312
from torchao.quantization.quant_primitives import (
1413
_DTYPE_TO_BIT_WIDTH,
1514
_DTYPE_TO_QVALUE_BOUNDS,
@@ -36,8 +35,7 @@ def __init__(
3635
quant_max: Optional[Union[int, float]] = None,
3736
eps: Optional[float] = None,
3837
preserve_zero: bool = True,
39-
zero_point_domain: ZeroPointDomain = ZeroPointDomain.NONE,
40-
config: Optional[AOBaseConfig] = None,
38+
zero_point_domain: ZeroPointDomain = ZeroPointDomain.INT,
4139
) -> None:
4240
super().__init__(center=False)
4341

@@ -50,7 +48,6 @@ def __init__(
5048
self.eps = eps
5149
self.preserve_zero = preserve_zero
5250
self.zero_point_domain = zero_point_domain
53-
self.config = config
5451

5552
def _init_quant_min_max(self, b: int) -> None:
5653
if self.quant_min is None or self.quant_max is None:
@@ -121,3 +118,18 @@ def quantize(
121118
zero_point_domain=self.zero_point_domain,
122119
)
123120
return q, Q
121+
122+
123+
class Int4UnifTorchaoQuantizer(UnifTorchaoQuantizer):
124+
"""Based on torchao.quantization.quant_api._int4_weight_only_transform"""
125+
126+
def __init__(self) -> None:
127+
super().__init__(
128+
symmetric=False,
129+
target_dtype=torch.int32,
130+
quant_min=0,
131+
quant_max=15,
132+
eps=1e-6,
133+
preserve_zero=False,
134+
zero_point_domain=ZeroPointDomain.FLOAT,
135+
)

0 commit comments

Comments
 (0)