diff --git a/test/quantization/test_quant_api.py b/test/quantization/test_quant_api.py index 86631e58b2..f35b6bdfcc 100644 --- a/test/quantization/test_quant_api.py +++ b/test/quantization/test_quant_api.py @@ -875,6 +875,31 @@ def test_int4wo_cpu(self, dtype, x_dim, use_hqq): assert "_weight_int4pack_mm_for_cpu" in code[0] assert "aten.mm.default" not in code[0] + @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_6, "Test only enabled for 2.6+") + @common_utils.parametrize("dtype", [torch.float, torch.bfloat16, torch.half]) + @common_utils.parametrize("x_dim", [2, 3]) + def test_8da4w_cpu(self, dtype, x_dim): + device = "cpu" + m = ToyLinearModel().eval().to(dtype).to(device) + example_inputs = m.example_inputs(dtype=dtype, device=device) + if x_dim == 3: + example_inputs = (example_inputs[0].unsqueeze(0),) + + with torch.no_grad(): + quantize_( + m, + int8_dynamic_activation_int4_weight( + group_size=32, layout=Int4CPULayout() + ), + ) + # ensure the expected op is in the code + _, code = torch._inductor.utils.run_and_get_code( + torch.compile(m, fullgraph=True, dynamic=True), + *example_inputs, + ) + assert "_weight_int4pack_mm_for_cpu" in code[0] + assert "aten.mm.default" not in code[0] + # TODO(#1690): move to new config names @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_4, "Test only enabled for 2.4+") @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") diff --git a/torchao/dtypes/uintx/int4_cpu_layout.py b/torchao/dtypes/uintx/int4_cpu_layout.py index 9c368fd17a..605e71ed6a 100644 --- a/torchao/dtypes/uintx/int4_cpu_layout.py +++ b/torchao/dtypes/uintx/int4_cpu_layout.py @@ -112,9 +112,10 @@ def from_plain( assert isinstance(_layout, Int4CPULayout) if TORCH_VERSION_AT_LEAST_2_6: - assert int_data.dtype == torch.int32, ( - "torch.ops.aten._convert_weight_to_int4pack_for_cpu expects `int32` dtype" + assert int_data.dtype in [torch.int32, torch.int8], ( + "torch.ops.aten._convert_weight_to_int4pack_for_cpu expects `int32` or `int8` dtype" ) + int_data = int_data.to(torch.int32) packed_weight = torch.ops.aten._convert_weight_to_int4pack_for_cpu( int_data, 1, # TODO:remove