Skip to content

Commit bea111c

Browse files
committed
Per-row IntxWeightOnlyConfig test
1 parent 21a0f04 commit bea111c

File tree

3 files changed

+31
-21
lines changed

3 files changed

+31
-21
lines changed

test/prototype/test_parq.py

+22-15
Original file line numberDiff line numberDiff line change
@@ -122,9 +122,11 @@ def test_ternarybit_lsbq_parq_prox(self):
122122

123123

124124
class TestUnifTorchaoQuantizer(common_utils.TestCase):
125-
def setUp(self):
125+
def __init__(self, methodName, group_size: int = 32):
126+
super(TestUnifTorchaoQuantizer, self).__init__(methodName)
126127
torch.manual_seed(123)
127-
self.group_size = 32
128+
self.group_size = group_size
129+
self.out_dim = 512
128130

129131
def compare_quantized_models(
130132
self,
@@ -152,14 +154,13 @@ def compare_quantized_models(
152154
@unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_4, "Test only enabled for 2.4+")
153155
@unittest.skipIf(_DEVICE == "cpu", "Need GPU available")
154156
def test_int4_weight_only(self):
155-
model = M(n=1024, k=1024).to(torch.bfloat16).to(_DEVICE)
157+
model = M(m=self.out_dim, n=self.out_dim).to(torch.bfloat16).to(_DEVICE)
156158
model.reset_parameters()
157159
m_ref = copy.deepcopy(model)
158160

159-
config = Int4WeightOnlyConfig(group_size=self.group_size)
160-
quantize_(m_ref, config, device=_DEVICE)
161+
quantize_(m_ref, Int4WeightOnlyConfig(group_size=self.group_size))
161162

162-
# copied from torchao.quantization.quant_api._int4_weight_only_transform
163+
# based off torchao.quantization.quant_api._int4_weight_only_transform
163164
b = 4
164165
quantizer = UnifTorchaoQuantizer(
165166
symmetric=False,
@@ -169,20 +170,19 @@ def test_int4_weight_only(self):
169170
eps=1e-6,
170171
preserve_zero=False,
171172
)
172-
self.assertTrue(
173-
quantizer.get_quant_size(b) == quantizer.quant_max - quantizer.quant_min + 1
174-
)
175173
self.compare_quantized_models(model, m_ref, quantizer, b)
176174

177-
@unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_6, "Test only enabled for 2.4+")
175+
@unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_6, "Test only enabled for 2.6+")
178176
@unittest.skipIf(_DEVICE == "cpu", "Need GPU available")
179177
def test_intx_weight_only(self):
180-
model = M(n=512, k=512).to(_DEVICE)
178+
model = M(m=self.out_dim, n=self.out_dim).to(_DEVICE)
181179
model.reset_parameters()
182180
m_ref = copy.deepcopy(model)
183181

184182
config = IntxWeightOnlyConfig(granularity=PerGroup(self.group_size))
185-
quantize_(m_ref, config, device=_DEVICE)
183+
quantize_(m_ref, config)
184+
185+
# based off torchao.quantization.quant_api._intx_weight_only_transform
186186
b = 8
187187
q_dtype = torch.int8
188188
quant_min, quant_max = _DTYPE_TO_QVALUE_BOUNDS[q_dtype]
@@ -195,11 +195,18 @@ def test_intx_weight_only(self):
195195
preserve_zero=True,
196196
zero_point_domain=ZeroPointDomain.INT,
197197
)
198-
self.assertTrue(
199-
quantizer.get_quant_size(b) == max(abs(quant_min), quant_max) + 1
200-
)
201198
self.compare_quantized_models(model, m_ref, quantizer, b)
202199

203200

201+
def load_tests(loader, tests, pattern):
202+
suite = unittest.TestSuite()
203+
suite.addTests(loader.loadTestsFromTestCase(TestPARQuantization))
204+
suite.addTests(loader.loadTestsFromTestCase(TestUnifTorchaoQuantizer))
205+
206+
group_size = suite._tests[-1].out_dim # row-wise grouping
207+
suite.addTest(TestUnifTorchaoQuantizer("test_intx_weight_only", group_size))
208+
return suite
209+
210+
204211
if __name__ == "__main__":
205212
unittest.main()

torchao/prototype/parq/quant/uniform_torchao.py

+3-5
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@ def q_kwargs(self) -> dict[str, Union[int, float]]:
5555
}
5656

5757
def get_quant_size(self, b: int) -> int:
58-
return 2 ** (b - 1) + 1 if self.mapping_type == MappingType.SYMMETRIC else 2**b
58+
return self.quant_max - self.quant_min + 1
5959

6060
def quantize(
6161
self, p: Tensor, b: int, dim: Optional[int] = None
@@ -73,15 +73,13 @@ def quantize(
7373
self.mapping_type,
7474
block_size,
7575
self.target_dtype,
76-
quant_min=self.quant_min,
77-
quant_max=self.quant_max,
7876
eps=self.eps,
7977
preserve_zero=self.preserve_zero,
80-
zero_point_domain=self.zero_point_domain,
78+
**self.q_kwargs,
8179
)
8280
q_args = (block_size, s, zero_point, self.target_dtype)
8381
q = quantize_affine(p, *q_args, **self.q_kwargs)
84-
q = dequantize_affine(q, *q_args, **self.q_kwargs, output_dtype=p.dtype)
82+
q = dequantize_affine(q, *q_args, output_dtype=p.dtype, **self.q_kwargs)
8583

8684
Q = torch.arange(
8785
self.quant_min, self.quant_max + 1, dtype=self.target_dtype, device=p.device

torchao/swizzle/swizzle_ops.py

+6-1
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,12 @@ def swizzle_mm(aten_op, args, kwargs=None):
3030
a = args[0]
3131
b = args[1]
3232

33-
if torch.is_floating_point(a) and torch.is_floating_point(b) and a.ndim == 2 and b.ndim == 2:
33+
if (
34+
torch.is_floating_point(a)
35+
and torch.is_floating_point(b)
36+
and a.ndim == 2
37+
and b.ndim == 2
38+
):
3439
a_is_swizzled = False
3540
b_is_swizzled = False
3641
if isinstance(a, SwizzleTensor):

0 commit comments

Comments
 (0)