Skip to content

Commit 0596713

Browse files
authored
Fix float8 + int4 QAT (#2851)
**Summary:** After #2779, `Float8DynamicActivationInt4Weight` no longer has the `group_size` field, but QAT continues to read from this field. This commit fixes it to just use the fixed 128 group size. **Test Plan:** ``` python test/quantization/test_qat.py -k test_infer_fp8_int4_config ```
1 parent 253d65a commit 0596713

File tree

2 files changed

+21
-1
lines changed

2 files changed

+21
-1
lines changed

test/quantization/test_qat.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1932,6 +1932,26 @@ def test_quantize_api_fp8_int4(self):
19321932
target_convert_sqnr=float("inf"),
19331933
)
19341934

1935+
@unittest.skipIf(not _CUDA_IS_AVAILABLE, "skipping when cuda is not available")
1936+
def test_infer_fp8_int4_config(self):
1937+
"""
1938+
Test that fake quantize configs are correctly inferred from
1939+
`Float8DynamicActivationInt4WeightConfig`.
1940+
"""
1941+
from torchao.quantization.qat.fake_quantize_config import (
1942+
_infer_fake_quantize_configs,
1943+
)
1944+
1945+
base_config = Float8DynamicActivationInt4WeightConfig()
1946+
(act_config, weight_config) = _infer_fake_quantize_configs(base_config)
1947+
self.assertIsInstance(act_config, Float8FakeQuantizeConfig)
1948+
self.assertEqual(act_config.dtype, torch.float8_e4m3fn)
1949+
self.assertIsInstance(act_config.granularity, PerRow)
1950+
self.assertIsInstance(weight_config, IntxFakeQuantizeConfig)
1951+
self.assertEqual(weight_config.dtype, torch.int4)
1952+
self.assertEqual(weight_config.group_size, 128)
1953+
self.assertTrue(weight_config.is_symmetric)
1954+
19351955

19361956
instantiate_parametrized_tests(TestQAT)
19371957

torchao/quantization/qat/fake_quantize_config.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -382,7 +382,7 @@ def _infer_fake_quantize_configs(
382382
)
383383
weight_config = IntxFakeQuantizeConfig(
384384
dtype=torch.int4,
385-
group_size=base_config.group_size,
385+
group_size=128,
386386
is_symmetric=True,
387387
)
388388
else:

0 commit comments

Comments
 (0)