Skip to content

Commit 6c1b813

Browse files
committed
Add test_intx_weight_only_e2e, set UnifTorchaoQuantizer defaults
1 parent 490bdf6 commit 6c1b813

File tree

3 files changed

+106
-94
lines changed

3 files changed

+106
-94
lines changed

test/prototype/test_parq.py

+81-64
Original file line numberDiff line numberDiff line change
@@ -5,12 +5,13 @@
55
# LICENSE file in the root directory of this source tree.
66
import copy
77
import unittest
8+
from typing import Optional
89

910
import torch
1011
from torch import nn
1112
from torch.testing._internal import common_utils
1213

13-
from torchao import quantize_
14+
from torchao.core.config import AOBaseConfig
1415
from torchao.dtypes import Int4CPULayout
1516
from torchao.prototype.parq.optim import (
1617
ProxHardQuant,
@@ -29,6 +30,7 @@
2930
IntxWeightOnlyConfig,
3031
_is_linear,
3132
int4_weight_only,
33+
quantize_,
3234
)
3335
from torchao.utils import (
3436
TORCH_VERSION_AT_LEAST_2_4,
@@ -44,18 +46,26 @@ def split_param_groups(model):
4446

4547
def get_param_groups(model):
4648
for module in model.children():
47-
is_linear = isinstance(module, nn.Linear)
48-
for n, p in module.named_parameters(recurse=False):
49+
is_linear = _is_linear(module)
50+
for n, p in module.named_parameters():
4951
if is_linear and n == "weight":
5052
params_quant.append(p)
5153
else:
5254
params_no_quant.append(p)
53-
get_param_groups(module)
5455

5556
get_param_groups(model)
5657
return params_quant, params_no_quant
5758

5859

60+
def build_param_groups(model, b: int = 2, group_size: Optional[int] = None):
61+
params_quant, params_no_quant = split_param_groups(model)
62+
quant_kwargs = {"quant_block_size": group_size} if group_size else {}
63+
return [
64+
{"params": params_quant, "quant_bits": b, **quant_kwargs},
65+
{"params": params_no_quant},
66+
]
67+
68+
5969
class M(nn.Module):
6070
def __init__(self, m=256, n=128, k=16, bias=False):
6171
super().__init__()
@@ -87,46 +97,28 @@ class TestPARQuantization(common_utils.TestCase):
8797
def setUp(self):
8898
torch.manual_seed(123)
8999
self.model = M(bias=True).to(_DEVICE)
90-
self.params_quant, self.params_no_quant = split_param_groups(self.model)
91100

92-
def train_loop(self, optimizer, steps=1):
93-
for _ in range(steps):
94-
x = self.model.example_inputs(device=_DEVICE)
95-
out = self.model(x)
96-
out.sum().backward()
97-
optimizer.step()
101+
@common_utils.parametrize("b", [0, 1, 2, 4])
102+
@common_utils.parametrize("unif_quant", [True, False])
103+
@common_utils.parametrize("hard_prox", [True, False])
104+
def test_parq_train_loop(self, b: int = 2, unif_quant=True, hard_prox=True):
105+
if unif_quant and b == 0:
106+
self.skipTest("Ternary uniform quantization not yet supported")
98107

99-
def test_2bit_unif_quantizer_hard_prox(self):
100-
b = 2
101108
self.model.reset_parameters()
102-
param_groups = [
103-
{"params": self.params_quant, "quant_bits": b},
104-
{"params": self.params_no_quant},
105-
]
109+
param_groups = build_param_groups(self.model, b)
106110
base_optimizer = torch.optim.AdamW(param_groups)
107-
quantizer = UnifQuantizer()
108-
optimizer = QuantOptimizer(base_optimizer, quantizer, ProxHardQuant())
109-
self.train_loop(optimizer)
110111

111-
for child in self.model.children():
112-
if isinstance(child, nn.Linear):
113-
self.assertEqual(
114-
child.weight.unique().numel(), quantizer.get_quant_size(b)
115-
)
116-
117-
def test_ternarybit_lsbq_parq_prox(self):
118-
b = 0
119-
self.model.reset_parameters()
120-
param_groups = [
121-
{"params": self.params_quant, "quant_bits": b},
122-
{"params": self.params_no_quant},
123-
]
124-
base_optimizer = torch.optim.AdamW(param_groups)
125-
quantizer = LSBQuantizer()
126-
optimizer = QuantOptimizer(
127-
base_optimizer, quantizer, ProxPARQ(anneal_start=0, anneal_end=2)
112+
quantizer = UnifQuantizer() if unif_quant else LSBQuantizer()
113+
prox_map = (
114+
ProxHardQuant() if hard_prox else ProxPARQ(anneal_start=0, anneal_end=2)
128115
)
129-
self.train_loop(optimizer, steps=3)
116+
optimizer = QuantOptimizer(base_optimizer, quantizer, prox_map)
117+
for _ in range(3):
118+
x = self.model.example_inputs(device=_DEVICE)
119+
out = self.model(x)
120+
out.sum().backward()
121+
optimizer.step()
130122

131123
for child in self.model.children():
132124
if isinstance(child, nn.Linear):
@@ -163,6 +155,35 @@ def compare_quantized_models(
163155
ref = getattr(m_ref, n).weight.dequantize()
164156
self.assertTrue(q.equal(ref))
165157

158+
def compare_parq_convert(
159+
self,
160+
model: nn.Module,
161+
m_ref: nn.Module,
162+
optimizer: QuantOptimizer,
163+
config: AOBaseConfig,
164+
):
165+
# do not update model weights, just quantize
166+
optimizer.zero_grad()
167+
optimizer.step()
168+
169+
orig_model = copy.deepcopy(model) # save copy of PARQ quantized model
170+
171+
# equivalent to torchao's convert step
172+
model.eval()
173+
optimizer.restore_latent_params()
174+
quantize_(model, config, filter_fn=optimizer._get_filter_fn(model))
175+
176+
for n, module in model.named_modules():
177+
if not _is_linear(module):
178+
continue
179+
180+
p_orig = getattr(orig_model, n).weight # PARQ weight
181+
p = module.weight.dequantize() # PARQ weight after quantize_
182+
p_ref = getattr(m_ref, n).weight.dequantize() # native quantize_
183+
184+
self.assertTrue(p_orig.equal(p_ref))
185+
self.assertTrue(p.equal(p_ref))
186+
166187
@unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_4, "Test only enabled for 2.4+")
167188
@common_utils.parametrize("group_size", [32, 256])
168189
def test_int4_weight_only(self, group_size: int = 32):
@@ -195,7 +216,7 @@ def test_intx_weight_only(self, b: int = 2, group_size: int = 32):
195216
),
196217
)
197218

198-
quantizer = UnifTorchaoQuantizer(symmetric=True)
219+
quantizer = UnifTorchaoQuantizer()
199220
self.compare_quantized_models(model, m_ref, quantizer, b, group_size)
200221

201222
@unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_4, "Test only enabled for 2.4+")
@@ -211,40 +232,36 @@ def test_int4_weight_only_e2e(self, group_size: int = 32):
211232
quantize_(m_ref, config)
212233

213234
b = 4
214-
params_quant, params_no_quant = split_param_groups(model)
215-
param_groups = [
216-
{"params": params_quant, "quant_bits": b, "quant_block_size": group_size},
217-
{"params": params_no_quant},
218-
]
219-
base_optimizer = torch.optim.AdamW(param_groups)
220-
235+
base_optimizer = torch.optim.AdamW(build_param_groups(model, b, group_size))
221236
optimizer = QuantOptimizer(
222237
base_optimizer,
223238
Int4UnifTorchaoQuantizer(),
224239
ProxHardQuant(),
225240
quant_per_channel=True,
226241
)
242+
self.compare_parq_convert(model, m_ref, optimizer, config)
227243

228-
# do not update model weights, just quantize
229-
optimizer.zero_grad()
230-
optimizer.step()
231-
232-
orig_model = copy.deepcopy(model) # save copy of PARQ quantized model
233-
234-
# equivalent to torchao's convert step
235-
model.eval()
236-
optimizer.torchao_quantize_(model, config)
237-
238-
for n, module in model.named_modules():
239-
if not _is_linear(module):
240-
continue
244+
@unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_6, "Test only enabled for 2.6+")
245+
@unittest.skipIf(_DEVICE == "cpu", "Need GPU available")
246+
@common_utils.parametrize("b", [2, 3, 4, 8])
247+
def test_intx_weight_only_e2e(self, b: int = 2, group_size: int = 32):
248+
model = M(m=512, n=512).to(_DEVICE)
249+
model.reset_parameters()
241250

242-
p_orig = getattr(orig_model, n).weight # PARQ weight
243-
p = module.weight.dequantize() # PARQ weight after quantize_
244-
p_ref = getattr(m_ref, n).weight.dequantize() # natively quantize_
251+
m_ref = copy.deepcopy(model).eval().to(_DEVICE)
252+
config = IntxWeightOnlyConfig(
253+
weight_dtype=_BIT_WIDTH_TO_DTYPE[b], granularity=PerGroup(group_size)
254+
)
255+
quantize_(m_ref, config)
245256

246-
self.assertTrue(p_orig.equal(p_ref))
247-
self.assertTrue(p.equal(p_ref))
257+
base_optimizer = torch.optim.AdamW(build_param_groups(model, b, group_size))
258+
optimizer = QuantOptimizer(
259+
base_optimizer,
260+
UnifTorchaoQuantizer(),
261+
ProxHardQuant(),
262+
quant_per_channel=True,
263+
)
264+
self.compare_parq_convert(model, m_ref, optimizer, config)
248265

249266

250267
common_utils.instantiate_parametrized_tests(TestPARQuantization)

torchao/prototype/parq/optim/quantopt.py

+21-22
Original file line numberDiff line numberDiff line change
@@ -13,9 +13,6 @@
1313
from torch import Tensor
1414
from torch.optim import Optimizer
1515

16-
from torchao import quantize_
17-
from torchao.core.config import AOBaseConfig
18-
1916
from ..quant import Quantizer
2017
from ..utils import HAS_DTENSOR, is_dtensor
2118
from .proxmap import ProxMap
@@ -109,32 +106,32 @@ def quantize_(
109106
quants.copy_(Q)
110107
return q
111108

112-
@torch.no_grad()
113-
def torchao_quantize_(self, model: torch.nn.Module, config: AOBaseConfig):
114-
"""Recursively call torchao.quantize_ on model using given config."""
115-
self.restore_latent_params()
116-
param_set = {
109+
def regularized_param_groups(self): # pyre-ignore[3]
110+
"""Yield parameter groups that need to be quantized."""
111+
for group in self.param_groups:
112+
if group.get("quant_bits", 16) < 16:
113+
yield group
114+
115+
@property
116+
def _param_set(self) -> set[int]:
117+
return {
117118
p.data_ptr()
118119
for group in self.regularized_param_groups()
119120
for p in group["params"]
120121
}
121122

122-
def inner_quantize_(model):
123-
for module in model.children():
124-
for param in module.parameters(recurse=False):
125-
if param.data_ptr() in param_set:
126-
quantize_(module, config)
127-
break
123+
def _get_filter_fn(
124+
self, module: torch.nn.Module
125+
) -> Callable[[torch.nn.Module], bool]:
126+
param_set = self._param_set
128127

129-
inner_quantize_(module)
128+
def _filter_fn(module: torch.nn.Module, *args) -> bool:
129+
for p in module.parameters(recurse=False):
130+
if p.data_ptr() in param_set:
131+
return True
132+
return False
130133

131-
inner_quantize_(model)
132-
133-
def regularized_param_groups(self): # pyre-ignore[3]
134-
"""Yield parameter groups that need to be quantized."""
135-
for group in self.param_groups:
136-
if group.get("quant_bits", 16) < 16:
137-
yield group
134+
return _filter_fn
138135

139136
@torch._disable_dynamo
140137
def state_dict(self) -> dict[str, Any]:
@@ -285,6 +282,7 @@ def step(self, closure: Optional[Callable[[], float]] = None) -> Optional[float]
285282
return loss
286283

287284
@torch._disable_dynamo
285+
@torch.no_grad()
288286
def restore_latent_params(self) -> None:
289287
"""Restore latent parameters as optimizer parameters"""
290288
for group in self.regularized_param_groups():
@@ -293,6 +291,7 @@ def restore_latent_params(self) -> None:
293291
p.copy_(self.state[p]["latent"])
294292

295293
@torch._disable_dynamo
294+
@torch.no_grad()
296295
def save_latent_params(self) -> None:
297296
"""Save updated latent parameters before applying prox-map"""
298297
if self.warmup_steps == 0:

torchao/prototype/parq/quant/uniform_torchao.py

+4-8
Original file line numberDiff line numberDiff line change
@@ -29,8 +29,8 @@ class UnifTorchaoQuantizer(Quantizer):
2929

3030
def __init__(
3131
self,
32-
symmetric: bool,
33-
target_dtype: Optional[torch.dtype] = None,
32+
mapping_type: MappingType = MappingType.SYMMETRIC,
33+
target_dtype: torch.dtype = torch.int8,
3434
quant_min: Optional[Union[int, float]] = None,
3535
quant_max: Optional[Union[int, float]] = None,
3636
eps: Optional[float] = None,
@@ -39,9 +39,7 @@ def __init__(
3939
) -> None:
4040
super().__init__(center=False)
4141

42-
self.mapping_type = (
43-
MappingType.SYMMETRIC if symmetric else MappingType.ASYMMETRIC
44-
)
42+
self.mapping_type = mapping_type
4543
self.target_dtype = target_dtype
4644
self.quant_min = quant_min
4745
self.quant_max = quant_max
@@ -55,8 +53,6 @@ def _init_quant_min_max(self, b: int) -> None:
5553
self.quant_min, self.quant_max = _DTYPE_TO_QVALUE_BOUNDS[
5654
_BIT_WIDTH_TO_DTYPE[b]
5755
]
58-
if self.target_dtype is None:
59-
self.target_dtype = torch.int8
6056

6157
def get_quant_size(self, b: int) -> int:
6258
self._init_quant_min_max(b)
@@ -125,7 +121,7 @@ class Int4UnifTorchaoQuantizer(UnifTorchaoQuantizer):
125121

126122
def __init__(self) -> None:
127123
super().__init__(
128-
symmetric=False,
124+
mapping_type=MappingType.ASYMMETRIC,
129125
target_dtype=torch.int32,
130126
quant_min=0,
131127
quant_max=15,

0 commit comments

Comments
 (0)