Skip to content

Commit a587847

Browse files
committed
Test IntxWeightOnlyConfig
1 parent b26d8d2 commit a587847

File tree

2 files changed

+93
-45
lines changed

2 files changed

+93
-45
lines changed

test/prototype/test_parq.py

+82-36
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,8 @@
77
import unittest
88

99
import torch
10-
from torch.testing._internal.common_utils import TestCase
10+
from torch import nn
11+
from torch.testing._internal import common_utils
1112

1213
from torchao import quantize_
1314
from torchao.prototype.parq.optim import (
@@ -20,8 +21,17 @@
2021
UnifQuantizer,
2122
UnifTorchaoQuantizer,
2223
)
23-
from torchao.quantization.quant_api import int4_weight_only
24-
from torchao.utils import TORCH_VERSION_AT_LEAST_2_4
24+
from torchao.quantization.granularity import PerGroup
25+
from torchao.quantization.quant_api import (
26+
Int4WeightOnlyConfig,
27+
IntxWeightOnlyConfig,
28+
_is_linear,
29+
)
30+
from torchao.quantization.quant_primitives import (
31+
_DTYPE_TO_QVALUE_BOUNDS,
32+
ZeroPointDomain,
33+
)
34+
from torchao.utils import TORCH_VERSION_AT_LEAST_2_4, TORCH_VERSION_AT_LEAST_2_6
2535

2636
_DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
2737

@@ -36,20 +46,20 @@ def split_param_groups(model):
3646
return params_no_quant, params_quant
3747

3848

39-
class M(torch.nn.Module):
49+
class M(nn.Module):
4050
def __init__(self, m=256, n=128, k=16, bias=False):
4151
super().__init__()
42-
self.embedding = torch.nn.Embedding(10, m)
43-
self.linear1 = torch.nn.Linear(m, n, bias=bias)
44-
self.linear2 = torch.nn.Linear(n, k, bias=bias)
45-
self.relu = torch.nn.ReLU()
46-
self.sigmoid = torch.nn.Sigmoid()
52+
self.embedding = nn.Embedding(10, m)
53+
self.linear1 = nn.Linear(m, n, bias=bias)
54+
self.linear2 = nn.Linear(n, k, bias=bias)
55+
self.relu = nn.ReLU()
56+
self.sigmoid = nn.Sigmoid()
4757

4858
def reset_parameters(self):
4959
for module in (self.linear1, self.linear2):
50-
torch.nn.init.xavier_uniform_(module.weight)
60+
nn.init.xavier_uniform_(module.weight)
5161
if module.bias is not None:
52-
torch.nn.init.zeros_(module.bias)
62+
nn.init.zeros_(module.bias)
5363

5464
def example_inputs(self):
5565
return torch.randint(1, 10, (1, 256))
@@ -63,7 +73,7 @@ def forward(self, x):
6373
return x
6474

6575

66-
class TestPARQuantization(TestCase):
76+
class TestPARQuantization(common_utils.TestCase):
6777
def setUp(self):
6878
torch.manual_seed(123)
6979
self.model = M(bias=True).to(_DEVICE)
@@ -86,7 +96,7 @@ def test_2bit_unif_quantizer_hard_prox(self):
8696
optimizer.step()
8797

8898
for child in self.model.children():
89-
if isinstance(child, torch.nn.Linear):
99+
if isinstance(child, nn.Linear):
90100
self.assertEqual(child.weight.unique().numel(), 4)
91101

92102
def test_ternarybit_lsbq_parq_prox(self):
@@ -107,22 +117,47 @@ def test_ternarybit_lsbq_parq_prox(self):
107117
optimizer.step()
108118

109119
for child in self.model.children():
110-
if isinstance(child, torch.nn.Linear):
120+
if isinstance(child, nn.Linear):
111121
self.assertEqual(child.weight.unique().numel(), 3)
112122

113123

114-
class TestUnifTorchaoQuantizer(TestCase):
115-
def setUp(self, group_size=32):
124+
class TestUnifTorchaoQuantizer(common_utils.TestCase):
125+
def setUp(self):
116126
torch.manual_seed(123)
117-
self.model = M(n=1024, k=1024).to(torch.bfloat16).to(_DEVICE)
118-
self.group_size = group_size
127+
self.group_size = 32
128+
129+
def compare_quantized_models(
130+
self,
131+
model: nn.Module,
132+
m_ref: nn.Module,
133+
quantizer: UnifTorchaoQuantizer,
134+
b: int,
135+
):
136+
for n, module in model.named_children():
137+
if not _is_linear(module):
138+
continue
139+
140+
# simulate grouping from QuantOptimizer.step
141+
p = module.weight
142+
original_shape = p.shape
143+
p = p.view(-1, self.group_size)
144+
145+
q, Q = quantizer.quantize(p, b=b, dim=-1)
146+
q = q.view(original_shape)
147+
148+
# compare to AffineQuantizedTensor instance
149+
ref = getattr(m_ref, n).weight.dequantize()
150+
self.assertTrue(q.equal(ref))
119151

120152
@unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_4, "Test only enabled for 2.4+")
121153
@unittest.skipIf(_DEVICE == "cpu", "Need GPU available")
122154
def test_int4_weight_only(self):
123-
self.model.reset_parameters()
124-
m_copy = copy.deepcopy(self.model)
125-
quantize_(m_copy, int4_weight_only(group_size=self.group_size))
155+
model = M(n=1024, k=1024).to(torch.bfloat16).to(_DEVICE)
156+
model.reset_parameters()
157+
m_ref = copy.deepcopy(model)
158+
159+
config = Int4WeightOnlyConfig(group_size=self.group_size)
160+
quantize_(m_ref, config, device=_DEVICE)
126161

127162
# copied from torchao.quantization.quant_api._int4_weight_only_transform
128163
b = 4
@@ -137,22 +172,33 @@ def test_int4_weight_only(self):
137172
self.assertTrue(
138173
quantizer.get_quant_size(b) == quantizer.quant_max - quantizer.quant_min + 1
139174
)
175+
self.compare_quantized_models(model, m_ref, quantizer, b)
140176

141-
for n, module in self.model.named_children():
142-
if not isinstance(module, torch.nn.Linear):
143-
continue
144-
145-
# simulate grouping from QuantOptimizer.step
146-
p = module.weight
147-
original_shape = p.shape
148-
p = p.view(-1, self.group_size)
149-
150-
q, Q = quantizer.quantize(p, b=b, dim=-1)
151-
q = q.view(original_shape)
152-
153-
# compare to AffineQuantizedTensor instance
154-
ref = getattr(m_copy, n).weight.dequantize()
155-
self.assertTrue(q.equal(ref))
177+
@unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_6, "Test only enabled for 2.4+")
178+
@unittest.skipIf(_DEVICE == "cpu", "Need GPU available")
179+
def test_intx_weight_only(self):
180+
model = M(n=512, k=512).to(_DEVICE)
181+
model.reset_parameters()
182+
m_ref = copy.deepcopy(model)
183+
184+
config = IntxWeightOnlyConfig(granularity=PerGroup(self.group_size))
185+
quantize_(m_ref, config, device=_DEVICE)
186+
b = 8
187+
q_dtype = torch.int8
188+
quant_min, quant_max = _DTYPE_TO_QVALUE_BOUNDS[q_dtype]
189+
quantizer = UnifTorchaoQuantizer(
190+
symmetric=True,
191+
target_dtype=q_dtype,
192+
quant_min=quant_min,
193+
quant_max=quant_max,
194+
eps=torch.finfo(torch.float32).eps,
195+
preserve_zero=True,
196+
zero_point_domain=ZeroPointDomain.INT,
197+
)
198+
self.assertTrue(
199+
quantizer.get_quant_size(b) == max(abs(quant_min), quant_max) + 1
200+
)
201+
self.compare_quantized_models(model, m_ref, quantizer, b)
156202

157203

158204
if __name__ == "__main__":

torchao/prototype/parq/quant/uniform_torchao.py

+11-9
Original file line numberDiff line numberDiff line change
@@ -22,8 +22,7 @@
2222

2323

2424
class UnifTorchaoQuantizer(Quantizer):
25-
"""Uniform quantizer that uses torchao's quantization primitives to determine scale and zero point."""
26-
25+
"""Uniform quantizer that uses torchao's quantization primitives"""
2726
def __init__(
2827
self,
2928
symmetric: bool,
@@ -46,6 +45,14 @@ def __init__(
4645
self.preserve_zero = preserve_zero
4746
self.zero_point_domain = zero_point_domain
4847

48+
@property
49+
def q_kwargs(self) -> dict[str, Union[int, float]]:
50+
return {
51+
"quant_min": self.quant_min,
52+
"quant_max": self.quant_max,
53+
"zero_point_domain": self.zero_point_domain,
54+
}
55+
4956
def get_quant_size(self, b: int) -> int:
5057
return 2 ** (b - 1) + 1 if self.mapping_type == MappingType.SYMMETRIC else 2**b
5158

@@ -72,13 +79,8 @@ def quantize(
7279
zero_point_domain=self.zero_point_domain,
7380
)
7481
q_args = (block_size, s, zero_point, self.target_dtype)
75-
q_kwargs = {
76-
"quant_min": self.quant_min,
77-
"quant_max": self.quant_max,
78-
"zero_point_domain": self.zero_point_domain,
79-
}
80-
q = quantize_affine(p,*q_args,**q_kwargs)
81-
q = dequantize_affine(q, *q_args, **q_kwargs, output_dtype=p.dtype)
82+
q = quantize_affine(p, *q_args, **self.q_kwargs)
83+
q = dequantize_affine(q, *q_args, **self.q_kwargs, output_dtype=p.dtype)
8284

8385
Q = torch.arange(
8486
self.quant_min, self.quant_max + 1, dtype=self.target_dtype, device=p.device

0 commit comments

Comments
 (0)