Skip to content

Commit 7d76707

Browse files
committed
Unskip test_qat_8da4w_prepare_vs_convert
**Test Plan:** python test/quantization/test_qat.py -k test_qat_8da4w_prepare_vs_convert
1 parent cdced21 commit 7d76707

File tree

3 files changed

+11
-4
lines changed

3 files changed

+11
-4
lines changed

test/quantization/test_qat.py

+5-2
Original file line numberDiff line numberDiff line change
@@ -1474,7 +1474,6 @@ def test_fake_quantize_per_token_vs_convert(self, dtype: torch.dtype):
14741474
@unittest.skipIf(
14751475
not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower"
14761476
)
1477-
@unittest.skip("Currently failing on sqnr")
14781477
def test_qat_8da4w_prepare_vs_convert(self, dtype: torch.dtype):
14791478
"""
14801479
Test that the prepare and convert steps of Int8DynActInt4QATQuantizer produces
@@ -1493,7 +1492,11 @@ def test_qat_8da4w_prepare_vs_convert(self, dtype: torch.dtype):
14931492
torch.manual_seed(seed)
14941493
x = m.example_inputs()
14951494

1496-
quantizer = Int8DynActInt4WeightQATQuantizer(groupsize=group_size)
1495+
quantizer = Int8DynActInt4WeightQATQuantizer(
1496+
groupsize=group_size,
1497+
precision=dtype,
1498+
scales_precision=dtype,
1499+
)
14971500
prepared = quantizer.prepare(m)
14981501
prepared_out = prepared(*x)
14991502
converted = quantizer.convert(prepared)

torchao/quantization/GPTQ.py

+5-1
Original file line numberDiff line numberDiff line change
@@ -933,7 +933,11 @@ def linear_forward_8da4w(
933933
groupsize,
934934
precision,
935935
):
936-
x = per_token_dynamic_quant(x, scale_dtype=precision, zero_point_dtype=precision)
936+
x = per_token_dynamic_quant(
937+
x,
938+
scale_dtype=torch.float32,
939+
zero_point_dtype=torch.int8,
940+
)
937941
# TODO: verify and remove following reshape code
938942
# origin_x_size = x.size()
939943
# x = x.reshape(-1, origin_x_size[-1])

torchao/quantization/qat/linear.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -270,7 +270,7 @@ def __init__(
270270
precision: torch.dtype = torch.float32,
271271
scales_precision: torch.dtype = torch.float32,
272272
) -> None:
273-
activation_config = _get_8da4w_activation_config(scales_precision)
273+
activation_config = _get_8da4w_activation_config(torch.float32)
274274
weight_config = _get_8da4w_weight_config(groupsize, scales_precision)
275275
super().__init__(
276276
in_features,

0 commit comments

Comments
 (0)