Skip to content

Commit cc359e6

Browse files
committed
Add support for float8 activation for Int4GroupwisePreshuffleTensor
Summary: Added basic op support like linear and bmm, we have both float8 and bf16 in the same Tensor because it's the same dtype, only difference is whether the activation is quantized or not. Although there is some differneces in implementation: bf16 activaton: * group_scale * group_zero fp8 activation * group_scale * row_scale Test Plan: python test/dtypes/test_float8_activation_int4_groupwise_preshuffle.py Reviewers: Subscribers: Tasks: Tags: stack-info: PR: #2437, branch: jerryzh168/stack/4
1 parent 5971b02 commit cc359e6

File tree

4 files changed

+225
-76
lines changed

4 files changed

+225
-76
lines changed

test/quantization/quantize_/test_int4_groupwise_preshuffle.py

Lines changed: 67 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -4,11 +4,14 @@
44
# This source code is licensed under the BSD 3-Clause license found in the
55
# LICENSE file in the root directory of this source tree.
66

7+
import tempfile
78
import unittest
89

910
import torch
1011
from torch.testing._internal.common_utils import (
1112
TestCase,
13+
instantiate_parametrized_tests,
14+
parametrize,
1215
run_tests,
1316
)
1417

@@ -23,6 +26,42 @@
2326
is_sm_at_least_90,
2427
)
2528

29+
BF16_ACT_CONFIG = FbgemmConfig(
30+
input_dtype=torch.bfloat16,
31+
weight_dtype=torch.int4,
32+
output_dtype=torch.bfloat16,
33+
block_size=[1, 128],
34+
preshuffle=True,
35+
activation_dtype_for_int4="bf16",
36+
)
37+
38+
BF16_ACT_BMM_CONFIG = FbgemmConfig(
39+
input_dtype=torch.bfloat16,
40+
weight_dtype=torch.int4,
41+
output_dtype=torch.bfloat16,
42+
block_size=[1, 1, 128],
43+
preshuffle=True,
44+
activation_dtype_for_int4="bf16",
45+
)
46+
47+
FP8_ACT_CONFIG = FbgemmConfig(
48+
input_dtype=torch.bfloat16,
49+
weight_dtype=torch.int4,
50+
output_dtype=torch.bfloat16,
51+
block_size=[1, 128],
52+
preshuffle=True,
53+
activation_dtype_for_int4="fp8",
54+
)
55+
56+
FP8_ACT_BMM_CONFIG = FbgemmConfig(
57+
input_dtype=torch.bfloat16,
58+
weight_dtype=torch.int4,
59+
output_dtype=torch.bfloat16,
60+
block_size=[1, 1, 128],
61+
preshuffle=True,
62+
activation_dtype_for_int4="fp8",
63+
)
64+
2665

2766
@unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_8, "Need pytorch 2.8+")
2867
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
@@ -32,33 +71,23 @@
3271
)
3372
class TestInt4GroupwisePreshuffleTensor(TestCase):
3473
def setUp(self):
35-
self.config = FbgemmConfig(
36-
input_dtype=torch.bfloat16,
37-
weight_dtype=torch.int4,
38-
output_dtype=torch.bfloat16,
39-
block_size=[1, 128],
40-
preshuffle=True,
41-
)
42-
self.bmm_config = FbgemmConfig(
43-
input_dtype=torch.bfloat16,
44-
weight_dtype=torch.int4,
45-
output_dtype=torch.bfloat16,
46-
block_size=[1, 1, 128],
47-
preshuffle=True,
48-
)
4974
self.GPU_DEVICES = ["cuda"] if torch.cuda.is_available() else []
5075

51-
def test_linear(self):
76+
@parametrize("config", [BF16_ACT_CONFIG, FP8_ACT_CONFIG])
77+
def test_linear(self, config):
5278
dtype = torch.bfloat16
5379
device = "cuda"
5480
input = torch.randn(1, 128, dtype=dtype, device=device)
5581
linear = torch.nn.Linear(128, 256, dtype=dtype, device=device)
5682
original = linear(input)
57-
quantize_(linear, self.config)
83+
quantize_(linear, config)
5884
quantized = linear(input)
5985
self.assertTrue(compute_error(original, quantized) > 20)
6086

61-
def test_bmm(self):
87+
# Note: this order will error out: `Got bad cuda status: an illegal memory access was encountered at line: 449`
88+
# @parametrize("bmm_config", [BF16_ACT_BMM_CONFIG, FP8_ACT_BMM_CONFIG])
89+
@parametrize("bmm_config", [FP8_ACT_BMM_CONFIG, BF16_ACT_BMM_CONFIG])
90+
def test_bmm(self, bmm_config):
6291
class M(torch.nn.Module):
6392
def __init__(self, weight):
6493
super().__init__()
@@ -74,32 +103,46 @@ def forward(self, x):
74103
m = M(weight).eval()
75104
original = m(input)
76105
m.weight = torch.nn.Parameter(m.weight.transpose(1, 2).contiguous())
77-
quantize_(m, self.bmm_config, filter_fn=lambda x, fqn: True)
106+
quantize_(m, bmm_config, filter_fn=lambda x, fqn: True)
78107
quantized = m(input)
79108
self.assertTrue(compute_error(original, quantized) > 18)
80109

81-
def test_to_device(self):
110+
@parametrize("config", [BF16_ACT_CONFIG, FP8_ACT_CONFIG])
111+
def test_to_device(self, config):
82112
for device in self.GPU_DEVICES:
83113
linear = torch.nn.Linear(128, 256, dtype=torch.bfloat16)
84-
quantize_(linear, self.config)
114+
quantize_(linear, config)
85115
linear.to(device)
86116

87117
linear = torch.nn.Linear(128, 256, dtype=torch.bfloat16)
88-
quantize_(linear, self.config)
118+
quantize_(linear, config)
89119
linear.to(device=device)
90120

91121
linear = torch.nn.Linear(128, 256, dtype=torch.bfloat16)
92-
quantize_(linear, self.config)
122+
quantize_(linear, config)
93123
linear.to(device)
94124

95-
def test_module_path(self):
125+
@parametrize("config", [BF16_ACT_CONFIG, FP8_ACT_CONFIG])
126+
def test_module_path(self, config):
96127
linear = torch.nn.Linear(128, 256, dtype=torch.bfloat16)
97-
quantize_(linear, self.config)
128+
quantize_(linear, config)
98129
self.assertEqual(
99130
str(type(linear.weight)),
100131
"<class 'torchao.quantization.Int4GroupwisePreshuffleTensor'>",
101132
)
102133

134+
with tempfile.NamedTemporaryFile() as f:
135+
torch.save(linear.state_dict(), f)
136+
f.seek(0)
137+
state_dict = torch.load(f)
138+
self.assertEqual(
139+
str(type(state_dict["weight"])),
140+
"<class 'torchao.quantization.Int4GroupwisePreshuffleTensor'>",
141+
)
142+
143+
144+
instantiate_parametrized_tests(TestInt4GroupwisePreshuffleTensor)
145+
103146

104147
if __name__ == "__main__":
105148
run_tests()

torchao/dtypes/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,9 @@
1010
)
1111
from .fbgemm_fp8_tensor import FbgemmFp8Tensor, to_fbgemm_fp8
1212
from .fbgemm_int4_tensor import FbgemmInt4Tensor, to_fbgemm_int4
13+
from .float8_activation_int4_groupwise_preshuffle_tensor import (
14+
Float8ActivationInt4GroupwisePreshuffleTensor,
15+
)
1316
from .floatx import (
1417
CutlassSemiSparseLayout,
1518
Float8Layout,
@@ -70,4 +73,5 @@
7073
"FbgemmFp8Tensor",
7174
"Int8DynamicActInt4WeightCPULayout",
7275
"Int4GroupwisePreshuffleTensor",
76+
"Float8ActivationInt4GroupwisePreshuffleTensor",
7377
]

torchao/quantization/quant_api.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2040,6 +2040,8 @@ class FbgemmConfig(AOBaseConfig):
20402040
weight_dtype (torch.dtype): weight dtype of the kernel
20412041
output_dtype (torch.dtype): output dtype of the kernel
20422042
group_size (int): The group size for weight
2043+
preshuffle (bool): whether preshuffle the weights or not
2044+
activation_dtype_for_int4 (str): the dtype for activation for int4 weight, either bf16 or fp8
20432045
"""
20442046

20452047
input_dtype: torch.dtype
@@ -2048,6 +2050,7 @@ class FbgemmConfig(AOBaseConfig):
20482050
block_size: Optional[List[int]] = None
20492051
activation_scale_ub: Optional[float] = None
20502052
preshuffle: bool = False
2053+
activation_dtype_for_int4: str = "bf16"
20512054

20522055

20532056
@register_quantize_module_handler(FbgemmConfig)
@@ -2067,7 +2070,9 @@ def _(module: torch.nn.Module, config: FbgemmConfig) -> torch.nn.Module:
20672070
):
20682071
if config.preshuffle:
20692072
weight = Int4GroupwisePreshuffleTensor.from_float(
2070-
module.weight, config.block_size
2073+
module.weight,
2074+
config.block_size,
2075+
activation_dtype=config.activation_dtype_for_int4,
20712076
)
20722077
else:
20732078
weight = to_fbgemm_int4(

0 commit comments

Comments
 (0)