Skip to content

Commit 9ff26dd

Browse files
committed
Add support for float8 activation for Int4GroupwisePreshuffleTensor
Summary: Added basic op support like linear and bmm, we have both float8 and bf16 in the same Tensor because it's the same dtype, only difference is whether the activation is quantized or not. Although there is some differneces in implementation: bf16 activaton: * group_scale * group_zero fp8 activation * group_scale * row_scale Test Plan: python test/dtypes/test_float8_activation_int4_groupwise_preshuffle.py Reviewers: Subscribers: Tasks: Tags: stack-info: PR: #2437, branch: jerryzh168/stack/4
1 parent ad1efd7 commit 9ff26dd

File tree

6 files changed

+289
-77
lines changed

6 files changed

+289
-77
lines changed
Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD 3-Clause license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
from torch.testing._internal import common_utils
8+
from torch.testing._internal.common_utils import (
9+
TestCase,
10+
run_tests,
11+
)
12+
from transformers import AutoModelForCausalLM, AutoTokenizer
13+
14+
_MODEL_NAMES = [
15+
"torchao-testing/opt-125m-int4wo-preshuffle",
16+
]
17+
18+
19+
class TestSerializationBC(TestCase):
20+
"""Test we can still load and run serialized model in previous AO versions
21+
we commit to have BC for 3 pytorch releases
22+
"""
23+
24+
@common_utils.parametrize("model_name", _MODEL_NAMES)
25+
def test_load_model_and_run(self, model_name):
26+
# Load and quantize model
27+
quantized_model = AutoModelForCausalLM.from_pretrained(
28+
model_name,
29+
torch_dtype="bfloat16",
30+
device_map="cuda",
31+
)
32+
tokenizer = AutoTokenizer.from_pretrained(model_name)
33+
34+
prompt = ("Hello, my name is",)
35+
36+
inputs = tokenizer(
37+
prompt,
38+
return_tensors="pt",
39+
).to("cuda")
40+
generated_ids = quantized_model.generate(**inputs, max_new_tokens=128)
41+
# make sure it runs
42+
_ = tokenizer.batch_decode(
43+
generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False
44+
)
45+
46+
47+
common_utils.instantiate_parametrized_tests(TestSerializationBC)
48+
49+
if __name__ == "__main__":
50+
run_tests()

test/quantization/quantize_/test_int4_groupwise_preshuffle.py

Lines changed: 64 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -4,11 +4,14 @@
44
# This source code is licensed under the BSD 3-Clause license found in the
55
# LICENSE file in the root directory of this source tree.
66

7+
import tempfile
78
import unittest
89

910
import torch
1011
from torch.testing._internal.common_utils import (
1112
TestCase,
13+
instantiate_parametrized_tests,
14+
parametrize,
1215
run_tests,
1316
)
1417

@@ -22,6 +25,39 @@
2225
_is_fbgemm_genai_gpu_available,
2326
is_sm_at_least_90,
2427
)
28+
from torchao.float8.config import e4m3_dtype
29+
30+
BF16_ACT_CONFIG = FbgemmConfig(
31+
input_dtype=torch.bfloat16,
32+
weight_dtype=torch.int4,
33+
output_dtype=torch.bfloat16,
34+
block_size=[1, 128],
35+
preshuffle=True,
36+
)
37+
38+
BF16_ACT_BMM_CONFIG = FbgemmConfig(
39+
input_dtype=torch.bfloat16,
40+
weight_dtype=torch.int4,
41+
output_dtype=torch.bfloat16,
42+
block_size=[1, 1, 128],
43+
preshuffle=True,
44+
)
45+
46+
FP8_ACT_CONFIG = FbgemmConfig(
47+
input_dtype=e4m3_dtype,
48+
weight_dtype=torch.int4,
49+
output_dtype=torch.bfloat16,
50+
block_size=[1, 128],
51+
preshuffle=True,
52+
)
53+
54+
FP8_ACT_BMM_CONFIG = FbgemmConfig(
55+
input_dtype=e4m3_dtype,
56+
weight_dtype=torch.int4,
57+
output_dtype=torch.bfloat16,
58+
block_size=[1, 1, 128],
59+
preshuffle=True,
60+
)
2561

2662

2763
@unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_8, "Need pytorch 2.8+")
@@ -32,33 +68,23 @@
3268
)
3369
class TestInt4GroupwisePreshuffleTensor(TestCase):
3470
def setUp(self):
35-
self.config = FbgemmConfig(
36-
input_dtype=torch.bfloat16,
37-
weight_dtype=torch.int4,
38-
output_dtype=torch.bfloat16,
39-
block_size=[1, 128],
40-
preshuffle=True,
41-
)
42-
self.bmm_config = FbgemmConfig(
43-
input_dtype=torch.bfloat16,
44-
weight_dtype=torch.int4,
45-
output_dtype=torch.bfloat16,
46-
block_size=[1, 1, 128],
47-
preshuffle=True,
48-
)
4971
self.GPU_DEVICES = ["cuda"] if torch.cuda.is_available() else []
5072

51-
def test_linear(self):
73+
@parametrize("config", [BF16_ACT_CONFIG, FP8_ACT_CONFIG])
74+
def test_linear(self, config):
5275
dtype = torch.bfloat16
5376
device = "cuda"
5477
input = torch.randn(1, 128, dtype=dtype, device=device)
5578
linear = torch.nn.Linear(128, 256, dtype=dtype, device=device)
5679
original = linear(input)
57-
quantize_(linear, self.config)
80+
quantize_(linear, config)
5881
quantized = linear(input)
5982
self.assertTrue(compute_error(original, quantized) > 20)
6083

61-
def test_bmm(self):
84+
# Note: this order will error out: `Got bad cuda status: an illegal memory access was encountered at line: 449`
85+
# @parametrize("bmm_config", [BF16_ACT_BMM_CONFIG, FP8_ACT_BMM_CONFIG])
86+
@parametrize("bmm_config", [FP8_ACT_BMM_CONFIG, BF16_ACT_BMM_CONFIG])
87+
def test_bmm(self, bmm_config):
6288
class M(torch.nn.Module):
6389
def __init__(self, weight):
6490
super().__init__()
@@ -74,32 +100,46 @@ def forward(self, x):
74100
m = M(weight).eval()
75101
original = m(input)
76102
m.weight = torch.nn.Parameter(m.weight.transpose(1, 2).contiguous())
77-
quantize_(m, self.bmm_config, filter_fn=lambda x, fqn: True)
103+
quantize_(m, bmm_config, filter_fn=lambda x, fqn: True)
78104
quantized = m(input)
79105
self.assertTrue(compute_error(original, quantized) > 18)
80106

81-
def test_to_device(self):
107+
@parametrize("config", [BF16_ACT_CONFIG, FP8_ACT_CONFIG])
108+
def test_to_device(self, config):
82109
for device in self.GPU_DEVICES:
83110
linear = torch.nn.Linear(128, 256, dtype=torch.bfloat16)
84-
quantize_(linear, self.config)
111+
quantize_(linear, config)
85112
linear.to(device)
86113

87114
linear = torch.nn.Linear(128, 256, dtype=torch.bfloat16)
88-
quantize_(linear, self.config)
115+
quantize_(linear, config)
89116
linear.to(device=device)
90117

91118
linear = torch.nn.Linear(128, 256, dtype=torch.bfloat16)
92-
quantize_(linear, self.config)
119+
quantize_(linear, config)
93120
linear.to(device)
94121

95-
def test_module_path(self):
122+
@parametrize("config", [BF16_ACT_CONFIG, FP8_ACT_CONFIG])
123+
def test_module_path(self, config):
96124
linear = torch.nn.Linear(128, 256, dtype=torch.bfloat16)
97-
quantize_(linear, self.config)
125+
quantize_(linear, config)
98126
self.assertEqual(
99127
str(type(linear.weight)),
100128
"<class 'torchao.quantization.Int4GroupwisePreshuffleTensor'>",
101129
)
102130

131+
with tempfile.NamedTemporaryFile() as f:
132+
torch.save(linear.state_dict(), f)
133+
f.seek(0)
134+
state_dict = torch.load(f)
135+
self.assertEqual(
136+
str(type(state_dict["weight"])),
137+
"<class 'torchao.quantization.Int4GroupwisePreshuffleTensor'>",
138+
)
139+
140+
141+
instantiate_parametrized_tests(TestInt4GroupwisePreshuffleTensor)
142+
103143

104144
if __name__ == "__main__":
105145
run_tests()

torchao/quantization/quant_api.py

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2040,6 +2040,8 @@ class FbgemmConfig(AOBaseConfig):
20402040
weight_dtype (torch.dtype): weight dtype of the kernel
20412041
output_dtype (torch.dtype): output dtype of the kernel
20422042
group_size (int): The group size for weight
2043+
preshuffle (bool): whether preshuffle the weights or not
2044+
activation_dtype_for_int4 (str): the dtype for activation for int4 weight, either bf16 or fp8
20432045
"""
20442046

20452047
input_dtype: torch.dtype
@@ -2067,7 +2069,9 @@ def _(module: torch.nn.Module, config: FbgemmConfig) -> torch.nn.Module:
20672069
):
20682070
if config.preshuffle:
20692071
weight = Int4GroupwisePreshuffleTensor.from_float(
2070-
module.weight, config.block_size
2072+
module.weight,
2073+
config.block_size,
2074+
activation_dtype="bf16",
20712075
)
20722076
else:
20732077
weight = to_fbgemm_int4(
@@ -2077,6 +2081,20 @@ def _(module: torch.nn.Module, config: FbgemmConfig) -> torch.nn.Module:
20772081
module.weight = torch.nn.Parameter(weight, requires_grad=False)
20782082
module.extra_repr = types.MethodType(_linear_extra_repr, module)
20792083
return module
2084+
if (
2085+
(config.input_dtype == e4m3_dtype)
2086+
and (config.weight_dtype == torch.int4)
2087+
and (config.output_dtype == torch.bfloat16)
2088+
):
2089+
if config.preshuffle:
2090+
weight = Int4GroupwisePreshuffleTensor.from_float(
2091+
module.weight,
2092+
config.block_size,
2093+
activation_dtype="fp8",
2094+
)
2095+
module.weight = torch.nn.Parameter(weight, requires_grad=False)
2096+
module.extra_repr = types.MethodType(_linear_extra_repr, module)
2097+
return module
20802098
elif (
20812099
(config.input_dtype == e4m3_dtype)
20822100
and (config.weight_dtype == e4m3_dtype)

torchao/quantization/quantize_/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from .int4_groupwise_preshuffle_tensor import (
1+
from .int4 import (
22
Int4GroupwisePreshuffleTensor,
33
)
44

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
from .int4_groupwise_preshuffle_tensor import (
2+
Int4GroupwisePreshuffleTensor,
3+
)
4+
5+
__all__ = [
6+
"Int4GroupwisePreshuffleTensor",
7+
]

0 commit comments

Comments
 (0)