Skip to content

[CPU] enable int8_dynamic_activation_int4_weight with Int4CPULayout #2128

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 3 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 25 additions & 0 deletions test/quantization/test_quant_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -875,6 +875,31 @@ def test_int4wo_cpu(self, dtype, x_dim, use_hqq):
assert "_weight_int4pack_mm_for_cpu" in code[0]
assert "aten.mm.default" not in code[0]

@unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_6, "Test only enabled for 2.6+")
@common_utils.parametrize("dtype", [torch.float, torch.bfloat16, torch.half])
@common_utils.parametrize("x_dim", [2, 3])
def test_8da4w_cpu(self, dtype, x_dim):
device = "cpu"
m = ToyLinearModel().eval().to(dtype).to(device)
example_inputs = m.example_inputs(dtype=dtype, device=device)
if x_dim == 3:
example_inputs = (example_inputs[0].unsqueeze(0),)

with torch.no_grad():
quantize_(
m,
int8_dynamic_activation_int4_weight(
group_size=32, layout=Int4CPULayout()
),
)
# ensure the expected op is in the code
_, code = torch._inductor.utils.run_and_get_code(
torch.compile(m, fullgraph=True, dynamic=True),
*example_inputs,
)
assert "_weight_int4pack_mm_for_cpu" in code[0]
assert "aten.mm.default" not in code[0]

# TODO(#1690): move to new config names
@unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_4, "Test only enabled for 2.4+")
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
Expand Down
5 changes: 3 additions & 2 deletions torchao/dtypes/uintx/int4_cpu_layout.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,9 +112,10 @@ def from_plain(
assert isinstance(_layout, Int4CPULayout)

if TORCH_VERSION_AT_LEAST_2_6:
assert int_data.dtype == torch.int32, (
"torch.ops.aten._convert_weight_to_int4pack_for_cpu expects `int32` dtype"
assert int_data.dtype in [torch.int32, torch.int8], (
"torch.ops.aten._convert_weight_to_int4pack_for_cpu expects `int32` or `int8` dtype"
)
int_data = int_data.to(torch.int32)
packed_weight = torch.ops.aten._convert_weight_to_int4pack_for_cpu(
int_data,
1, # TODO:remove
Expand Down
Loading