Skip to content

Commit 81ba287

Browse files
committed
Add QuantOptimizer.torchao_quantize_
1 parent d7710cf commit 81ba287

File tree

3 files changed

+53
-18
lines changed

3 files changed

+53
-18
lines changed

test/prototype/test_parq.py

+1-3
Original file line numberDiff line numberDiff line change
@@ -248,9 +248,7 @@ def test_int4_weight_only_e2e(self, group_size: int = 32):
248248

249249
# equivalent to torchao's convert step
250250
model.eval()
251-
with torch.no_grad():
252-
optimizer.restore_latent_params()
253-
quantize_(model, quantizer.config)
251+
optimizer.torchao_quantize_(model)
254252

255253
for n, module in model.named_modules():
256254
if not _is_linear(module):

torchao/prototype/parq/optim/quantopt.py

+27-3
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,8 @@
1313
from torch import Tensor
1414
from torch.optim import Optimizer
1515

16+
from torchao import quantize_
17+
1618
from ..quant import Quantizer
1719
from ..utils import HAS_DTENSOR, is_dtensor
1820
from .proxmap import ProxMap
@@ -106,6 +108,28 @@ def quantize_(
106108
quants.copy_(Q)
107109
return q
108110

111+
@torch.no_grad()
112+
def torchao_quantize_(self, model):
113+
assert hasattr(self.quantizer, "config"), "Missing self.quantizer.config"
114+
115+
self.restore_latent_params()
116+
param_set = {
117+
p.data_ptr()
118+
for group in self.regularized_param_groups()
119+
for p in group["params"]
120+
}
121+
122+
def inner_quantize_(model):
123+
for module in model.children():
124+
for param in module.parameters(recurse=False):
125+
if param.data_ptr() in param_set:
126+
quantize_(module, self.quantizer.config)
127+
break
128+
129+
inner_quantize_(module)
130+
131+
inner_quantize_(model)
132+
109133
def regularized_param_groups(self): # pyre-ignore[3]
110134
"""Yield parameter groups that need to be quantized."""
111135
for group in self.param_groups:
@@ -189,9 +213,9 @@ def step(self, closure: Optional[Callable[[], float]] = None) -> Optional[float]
189213

190214
# reshape p according to block size if specified
191215
if block_size is not None:
192-
assert p.size(-1) % block_size == 0, (
193-
f"{p.size(-1)=} is not divisible by {block_size=}"
194-
)
216+
assert (
217+
p.size(-1) % block_size == 0
218+
), f"{p.size(-1)=} is not divisible by {block_size=}"
195219
assert p.dim() <= 2, f"Invalid {p.dim()=} for {block_size=}"
196220
if p.dim() == 1:
197221
p = p.unsqueeze(0)

torchao/prototype/parq/quant/uniform_torchao.py

+25-12
Original file line numberDiff line numberDiff line change
@@ -52,14 +52,6 @@ def __init__(
5252
self.zero_point_domain = zero_point_domain
5353
self.config = config
5454

55-
@property
56-
def q_kwargs(self) -> dict[str, Union[int, float]]:
57-
return {
58-
"quant_min": self.quant_min,
59-
"quant_max": self.quant_max,
60-
"zero_point_domain": self.zero_point_domain,
61-
}
62-
6355
def _init_quant_min_max(self, b: int) -> None:
6456
if self.quant_min is None or self.quant_max is None:
6557
assert b in _BIT_WIDTH_TO_DTYPE, f"Unsupported bitwidth {b}"
@@ -89,11 +81,26 @@ def quantize(
8981
self.target_dtype,
9082
eps=self.eps,
9183
preserve_zero=self.preserve_zero,
92-
**self.q_kwargs,
84+
quant_min=self.quant_min,
85+
quant_max=self.quant_max,
86+
zero_point_domain=self.zero_point_domain,
9387
)
9488
q_args = (block_size, s, zero_point, self.target_dtype)
95-
q = quantize_affine(p, *q_args, **self.q_kwargs)
96-
q = dequantize_affine(q, *q_args, output_dtype=p.dtype, **self.q_kwargs)
89+
q = quantize_affine(
90+
p,
91+
*q_args,
92+
quant_min=self.quant_min,
93+
quant_max=self.quant_max,
94+
zero_point_domain=self.zero_point_domain,
95+
)
96+
q = dequantize_affine(
97+
q,
98+
*q_args,
99+
output_dtype=p.dtype,
100+
quant_min=self.quant_min,
101+
quant_max=self.quant_max,
102+
zero_point_domain=self.zero_point_domain,
103+
)
97104

98105
Q = torch.arange(
99106
self.quant_min, self.quant_max + 1, dtype=self.target_dtype, device=p.device
@@ -105,6 +112,12 @@ def quantize(
105112
block_size = Q.shape
106113

107114
Q = dequantize_affine(
108-
Q, block_size, *q_args[1:], output_dtype=p.dtype, **self.q_kwargs
115+
Q,
116+
block_size,
117+
*q_args[1:],
118+
output_dtype=p.dtype,
119+
quant_min=self.quant_min,
120+
quant_max=self.quant_max,
121+
zero_point_domain=self.zero_point_domain,
109122
)
110123
return q, Q

0 commit comments

Comments
 (0)