Skip to content

Commit c13fa2b

Browse files
committed
Add support for float8 activation for Int4GroupwisePreshuffleTensor
Summary: Added basic op support like linear and bmm, we have both float8 and bf16 in the same Tensor because it's the same dtype, only difference is whether the activation is quantized or not. Although there is some differneces in implementation: bf16 activaton: * group_scale * group_zero fp8 activation * group_scale * row_scale Test Plan: python test/dtypes/test_float8_activation_int4_groupwise_preshuffle.py Reviewers: Subscribers: Tasks: Tags: stack-info: PR: #2437, branch: jerryzh168/stack/4
1 parent b828ffc commit c13fa2b

File tree

6 files changed

+836
-74
lines changed

6 files changed

+836
-74
lines changed
Lines changed: 166 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,166 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD 3-Clause license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
import unittest
8+
9+
import torch
10+
from torch.testing._internal.common_utils import (
11+
TestCase,
12+
run_tests,
13+
)
14+
15+
from torchao.quantization import (
16+
FbgemmConfig,
17+
quantize_,
18+
)
19+
from torchao.quantization.utils import compute_error
20+
from torchao.utils import (
21+
TORCH_VERSION_AT_LEAST_2_8,
22+
is_sm_at_least_90,
23+
)
24+
25+
26+
@unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_8, "Need pytorch 2.8+")
27+
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
28+
@unittest.skipIf(not is_sm_at_least_90(), "Nedd sm90+")
29+
class TestInt4GroupwisePreshuffleTensor(TestCase):
30+
def setUp(self):
31+
self.config = FbgemmConfig(
32+
input_dtype=torch.bfloat16,
33+
weight_dtype=torch.int4,
34+
output_dtype=torch.bfloat16,
35+
block_size=[1, 128],
36+
preshuffle=True,
37+
float8_activation=True,
38+
)
39+
self.bmm_config = FbgemmConfig(
40+
input_dtype=torch.bfloat16,
41+
weight_dtype=torch.int4,
42+
output_dtype=torch.bfloat16,
43+
block_size=[1, 1, 128],
44+
preshuffle=True,
45+
float8_activation=True,
46+
)
47+
self.GPU_DEVICES = ["cuda"] if torch.cuda.is_available() else []
48+
49+
def test_linear(self):
50+
dtype = torch.bfloat16
51+
device = "cuda"
52+
input = torch.randn(1, 128, dtype=dtype, device=device)
53+
linear = torch.nn.Linear(128, 256, dtype=dtype, device=device)
54+
original = linear(input)
55+
quantize_(linear, self.config)
56+
quantized = linear(input)
57+
self.assertTrue(compute_error(original, quantized) > 20)
58+
59+
# @unittest.skip("WIP: this doesn't work yet")
60+
def test_slice(self):
61+
dtype = torch.bfloat16
62+
device = "cuda"
63+
dummy = torch.nn.Linear(256, 256, bias=False, dtype=dtype, device=device)
64+
dummy1 = torch.nn.Linear(256, 64, bias=False, dtype=dtype, device=device)
65+
dummy1.weight = torch.nn.Parameter(
66+
dummy.weight.narrow(0, 0, 64), requires_grad=False
67+
)
68+
dummy2 = torch.nn.Linear(128, 256, dtype=dtype, device=device)
69+
dummy2.weight = torch.nn.Parameter(
70+
dummy.weight.narrow(1, 0, 128), requires_grad=False
71+
)
72+
73+
quantize_(dummy, self.config)
74+
weight1 = dummy.weight.narrow(0, 0, 64)
75+
weight2 = dummy.weight.narrow(1, 0, 128)
76+
# check the slicing operation is correctly performend of the constituents Tensors
77+
self.assertEqual(
78+
weight1.packed_weight, dummy.weight.packed_weight.narrow(0, 0, 64)
79+
)
80+
self.assertEqual(weight1.group_scale, dummy.weight.group_scale.narrow(2, 0, 64))
81+
self.assertEqual(
82+
weight2.packed_weight, dummy.weight.packed_weight.narrow(1, 0, 64)
83+
)
84+
self.assertEqual(weight2.group_scale, dummy.weight.group_scale.narrow(0, 0, 1))
85+
86+
# check for 1. sliced bf16 weight 2. sliced quantized weight
87+
# can produce similar results doing matmul on the same input Tensor
88+
89+
input = torch.randn(2, 256, dtype=dtype, device=device)
90+
res_ref = dummy1(input)
91+
dummy.weight = torch.nn.Parameter(weight1, requires_grad=False)
92+
res = dummy(input)
93+
sqnr = compute_error(res, res_ref)
94+
assert sqnr > 20, f"Got: {sqnr}"
95+
96+
input = torch.randn(2, 128, dtype=dtype, device=device)
97+
res_ref = dummy2(input)
98+
dummy.weight = torch.nn.Parameter(weight2, requires_grad=False)
99+
res = dummy(input)
100+
sqnr = compute_error(res, res_ref)
101+
assert sqnr > 15, f"Got: {sqnr}"
102+
103+
def test_slice_and_copy_(self):
104+
l = torch.nn.Linear(1024, 1024).to("cuda").to(torch.bfloat16)
105+
l.weight = torch.nn.Parameter(
106+
torch.zeros(1024, 1024, dtype=torch.bfloat16, device="cuda")
107+
)
108+
quantize_(l, self.config)
109+
param = l.weight
110+
param_data = param.data
111+
param_data = param_data.narrow(0, 0, 512)
112+
assert (
113+
param.data.packed_weight.data_ptr() == param_data.packed_weight.data_ptr()
114+
)
115+
assert param.data.group_scale.data_ptr() == param_data.group_scale.data_ptr()
116+
assert param.data.row_scale.data_ptr() == param_data.row_scale.data_ptr()
117+
orig_value = param.data.packed_weight[0][0].item()
118+
119+
# dummy_l has random input (shouldn't be 0)
120+
dummy_l = torch.nn.Linear(1024, 1024).to("cuda").to(torch.bfloat16)
121+
quantize_(dummy_l, self.config)
122+
quantized = dummy_l.weight
123+
quantized = quantized.narrow(0, 0, 512)
124+
125+
param_data.copy_(quantized)
126+
127+
# making sure param.data is updated
128+
assert param.data.packed_weight[0][0] != orig_value
129+
130+
def test_bmm(self):
131+
class M(torch.nn.Module):
132+
def __init__(self, weight):
133+
super().__init__()
134+
self.weight = weight
135+
136+
def forward(self, x):
137+
return torch.bmm(x, self.weight)
138+
139+
dtype = torch.bfloat16
140+
device = "cuda"
141+
input = torch.randn(10, 32, 128, dtype=dtype, device=device)
142+
weight = torch.randn(10, 128, 256, dtype=dtype, device=device)
143+
m = M(weight).eval()
144+
original = m(input)
145+
m.weight = torch.nn.Parameter(m.weight.transpose(1, 2).contiguous())
146+
quantize_(m, self.bmm_config, filter_fn=lambda x, fqn: True)
147+
quantized = m(input)
148+
self.assertTrue(compute_error(original, quantized) > 18)
149+
150+
def test_to_device(self):
151+
for device in self.GPU_DEVICES:
152+
linear = torch.nn.Linear(128, 256, dtype=torch.bfloat16)
153+
quantize_(linear, self.config)
154+
linear.to(device)
155+
156+
linear = torch.nn.Linear(128, 256, dtype=torch.bfloat16)
157+
quantize_(linear, self.config)
158+
linear.to(device=device)
159+
160+
linear = torch.nn.Linear(128, 256, dtype=torch.bfloat16)
161+
quantize_(linear, self.config)
162+
linear.to(device)
163+
164+
165+
if __name__ == "__main__":
166+
run_tests()

test/quantization/quantize_/test_int4_groupwise_preshuffle.py

Lines changed: 67 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -4,11 +4,14 @@
44
# This source code is licensed under the BSD 3-Clause license found in the
55
# LICENSE file in the root directory of this source tree.
66

7+
import tempfile
78
import unittest
89

910
import torch
1011
from torch.testing._internal.common_utils import (
1112
TestCase,
13+
instantiate_parametrized_tests,
14+
parametrize,
1215
run_tests,
1316
)
1417

@@ -23,6 +26,42 @@
2326
is_sm_at_least_90,
2427
)
2528

29+
BF16_ACT_CONFIG = FbgemmConfig(
30+
input_dtype=torch.bfloat16,
31+
weight_dtype=torch.int4,
32+
output_dtype=torch.bfloat16,
33+
block_size=[1, 128],
34+
preshuffle=True,
35+
activation_dtype_for_int4="bf16",
36+
)
37+
38+
BF16_ACT_BMM_CONFIG = FbgemmConfig(
39+
input_dtype=torch.bfloat16,
40+
weight_dtype=torch.int4,
41+
output_dtype=torch.bfloat16,
42+
block_size=[1, 1, 128],
43+
preshuffle=True,
44+
activation_dtype_for_int4="bf16",
45+
)
46+
47+
FP8_ACT_CONFIG = FbgemmConfig(
48+
input_dtype=torch.bfloat16,
49+
weight_dtype=torch.int4,
50+
output_dtype=torch.bfloat16,
51+
block_size=[1, 128],
52+
preshuffle=True,
53+
activation_dtype_for_int4="fp8",
54+
)
55+
56+
FP8_ACT_BMM_CONFIG = FbgemmConfig(
57+
input_dtype=torch.bfloat16,
58+
weight_dtype=torch.int4,
59+
output_dtype=torch.bfloat16,
60+
block_size=[1, 1, 128],
61+
preshuffle=True,
62+
activation_dtype_for_int4="fp8",
63+
)
64+
2665

2766
@unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_8, "Need pytorch 2.8+")
2867
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
@@ -32,33 +71,23 @@
3271
)
3372
class TestInt4GroupwisePreshuffleTensor(TestCase):
3473
def setUp(self):
35-
self.config = FbgemmConfig(
36-
input_dtype=torch.bfloat16,
37-
weight_dtype=torch.int4,
38-
output_dtype=torch.bfloat16,
39-
block_size=[1, 128],
40-
preshuffle=True,
41-
)
42-
self.bmm_config = FbgemmConfig(
43-
input_dtype=torch.bfloat16,
44-
weight_dtype=torch.int4,
45-
output_dtype=torch.bfloat16,
46-
block_size=[1, 1, 128],
47-
preshuffle=True,
48-
)
4974
self.GPU_DEVICES = ["cuda"] if torch.cuda.is_available() else []
5075

51-
def test_linear(self):
76+
@parametrize("config", [BF16_ACT_CONFIG, FP8_ACT_CONFIG])
77+
def test_linear(self, config):
5278
dtype = torch.bfloat16
5379
device = "cuda"
5480
input = torch.randn(1, 128, dtype=dtype, device=device)
5581
linear = torch.nn.Linear(128, 256, dtype=dtype, device=device)
5682
original = linear(input)
57-
quantize_(linear, self.config)
83+
quantize_(linear, config)
5884
quantized = linear(input)
5985
self.assertTrue(compute_error(original, quantized) > 20)
6086

61-
def test_bmm(self):
87+
# Note: this order will error out: `Got bad cuda status: an illegal memory access was encountered at line: 449`
88+
# @parametrize("bmm_config", [BF16_ACT_BMM_CONFIG, FP8_ACT_BMM_CONFIG])
89+
@parametrize("bmm_config", [FP8_ACT_BMM_CONFIG, BF16_ACT_BMM_CONFIG])
90+
def test_bmm(self, bmm_config):
6291
class M(torch.nn.Module):
6392
def __init__(self, weight):
6493
super().__init__()
@@ -74,32 +103,46 @@ def forward(self, x):
74103
m = M(weight).eval()
75104
original = m(input)
76105
m.weight = torch.nn.Parameter(m.weight.transpose(1, 2).contiguous())
77-
quantize_(m, self.bmm_config, filter_fn=lambda x, fqn: True)
106+
quantize_(m, bmm_config, filter_fn=lambda x, fqn: True)
78107
quantized = m(input)
79108
self.assertTrue(compute_error(original, quantized) > 18)
80109

81-
def test_to_device(self):
110+
@parametrize("config", [BF16_ACT_CONFIG, FP8_ACT_CONFIG])
111+
def test_to_device(self, config):
82112
for device in self.GPU_DEVICES:
83113
linear = torch.nn.Linear(128, 256, dtype=torch.bfloat16)
84-
quantize_(linear, self.config)
114+
quantize_(linear, config)
85115
linear.to(device)
86116

87117
linear = torch.nn.Linear(128, 256, dtype=torch.bfloat16)
88-
quantize_(linear, self.config)
118+
quantize_(linear, config)
89119
linear.to(device=device)
90120

91121
linear = torch.nn.Linear(128, 256, dtype=torch.bfloat16)
92-
quantize_(linear, self.config)
122+
quantize_(linear, config)
93123
linear.to(device)
94124

95-
def test_module_path(self):
125+
@parametrize("config", [BF16_ACT_CONFIG, FP8_ACT_CONFIG])
126+
def test_module_path(self, config):
96127
linear = torch.nn.Linear(128, 256, dtype=torch.bfloat16)
97-
quantize_(linear, self.config)
128+
quantize_(linear, config)
98129
self.assertEqual(
99130
str(type(linear.weight)),
100131
"<class 'torchao.quantization.Int4GroupwisePreshuffleTensor'>",
101132
)
102133

134+
with tempfile.NamedTemporaryFile() as f:
135+
torch.save(linear.state_dict(), f)
136+
f.seek(0)
137+
state_dict = torch.load(f)
138+
self.assertEqual(
139+
str(type(state_dict["weight"])),
140+
"<class 'torchao.quantization.Int4GroupwisePreshuffleTensor'>",
141+
)
142+
143+
144+
instantiate_parametrized_tests(TestInt4GroupwisePreshuffleTensor)
145+
103146

104147
if __name__ == "__main__":
105148
run_tests()

torchao/dtypes/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,9 @@
1010
)
1111
from .fbgemm_fp8_tensor import FbgemmFp8Tensor, to_fbgemm_fp8
1212
from .fbgemm_int4_tensor import FbgemmInt4Tensor, to_fbgemm_int4
13+
from .float8_activation_int4_groupwise_preshuffle_tensor import (
14+
Float8ActivationInt4GroupwisePreshuffleTensor,
15+
)
1316
from .floatx import (
1417
CutlassSemiSparseLayout,
1518
Float8Layout,
@@ -70,4 +73,5 @@
7073
"FbgemmFp8Tensor",
7174
"Int8DynamicActInt4WeightCPULayout",
7275
"Int4GroupwisePreshuffleTensor",
76+
"Float8ActivationInt4GroupwisePreshuffleTensor",
7377
]

0 commit comments

Comments
 (0)