Skip to content

Commit e506cf3

Browse files
committed
Add all fbgemm kernel Tensors into Int4WeightOnlyConfig and Float8DynamicActivationInt4WeightConfig
Summary: att, we will deprecate FbgemmConfig since it's a single kernel. we'd like to categorize things to derived dtype + packed format Test Plan: python test/quantization/quantize_/test_int4_groupwise_preshuffle.py Reviewers: Subscribers: Tasks: Tags: stack-info: PR: #2474, branch: jerryzh168/stack/10
1 parent d87d7db commit e506cf3

File tree

8 files changed

+97
-62
lines changed

8 files changed

+97
-62
lines changed

test/integration/test_serialization_bc.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313

1414
_MODEL_NAMES = [
1515
"torchao-testing/opt-125m-float8dq-row-fbgemm",
16+
"torchao-testing/opt-125m-int4wo-preshuffle",
1617
]
1718

1819

test/quantization/quantize_/int4/test_int4_groupwise_preshuffle.py

Lines changed: 9 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,8 @@
1717

1818
from torchao.float8.config import e4m3_dtype
1919
from torchao.quantization import (
20-
FbgemmConfig,
20+
Float8ActivationInt4WeightConfig,
21+
Int4WeightOnlyConfig,
2122
quantize_,
2223
)
2324
from torchao.quantization.utils import compute_error
@@ -27,36 +28,14 @@
2728
is_sm_at_least_90,
2829
)
2930

30-
BF16_ACT_CONFIG = FbgemmConfig(
31-
input_dtype=torch.bfloat16,
32-
weight_dtype=torch.int4,
33-
output_dtype=torch.bfloat16,
34-
block_size=[1, 128],
35-
preshuffle=True,
31+
BF16_ACT_CONFIG = Int4WeightOnlyConfig(
32+
group_size=128,
33+
use_preshuffle=True,
3634
)
3735

38-
BF16_ACT_BMM_CONFIG = FbgemmConfig(
39-
input_dtype=torch.bfloat16,
40-
weight_dtype=torch.int4,
41-
output_dtype=torch.bfloat16,
42-
block_size=[1, 1, 128],
43-
preshuffle=True,
44-
)
45-
46-
FP8_ACT_CONFIG = FbgemmConfig(
47-
input_dtype=e4m3_dtype,
48-
weight_dtype=torch.int4,
49-
output_dtype=torch.bfloat16,
50-
block_size=[1, 128],
51-
preshuffle=True,
52-
)
53-
54-
FP8_ACT_BMM_CONFIG = FbgemmConfig(
55-
input_dtype=e4m3_dtype,
56-
weight_dtype=torch.int4,
57-
output_dtype=torch.bfloat16,
58-
block_size=[1, 1, 128],
59-
preshuffle=True,
36+
FP8_ACT_CONFIG = Float8ActivationInt4WeightConfig(
37+
group_size=128,
38+
use_preshuffle=True,
6039
)
6140

6241

@@ -83,7 +62,7 @@ def test_linear(self, config):
8362

8463
# Note: this order will error out: `Got bad cuda status: an illegal memory access was encountered at line: 449`
8564
# @parametrize("bmm_config", [BF16_ACT_BMM_CONFIG, FP8_ACT_BMM_CONFIG])
86-
@parametrize("bmm_config", [FP8_ACT_BMM_CONFIG, BF16_ACT_BMM_CONFIG])
65+
@parametrize("bmm_config", [FP8_ACT_CONFIG, BF16_ACT_CONFIG])
8766
def test_bmm(self, bmm_config):
8867
class M(torch.nn.Module):
8968
def __init__(self, weight):

test/dtypes/test_fbgemm_int4.py renamed to test/quantization/quantize_/int4/test_int4_groupwise_tensor.py

Lines changed: 7 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
)
1414

1515
from torchao.quantization import (
16-
FbgemmConfig,
16+
Int4WeightOnlyConfig,
1717
quantize_,
1818
)
1919
from torchao.quantization.utils import compute_error
@@ -26,19 +26,12 @@
2626
@unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_8, "Need pytorch 2.8+")
2727
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
2828
@unittest.skipIf(not is_sm_at_least_90(), "Nedd sm90+")
29-
class TestFbgemmInt4Tensor(TestCase):
29+
class TestInt4GroupwiseTensor(TestCase):
3030
def setUp(self):
31-
self.config = FbgemmConfig(
32-
input_dtype=torch.bfloat16,
33-
weight_dtype=torch.int4,
34-
output_dtype=torch.bfloat16,
35-
block_size=[1, 128],
36-
)
37-
self.bmm_config = FbgemmConfig(
38-
input_dtype=torch.bfloat16,
39-
weight_dtype=torch.int4,
40-
output_dtype=torch.bfloat16,
41-
block_size=[1, 1, 128],
31+
self.config = Int4WeightOnlyConfig(
32+
group_size=128,
33+
use_preshuffle=False,
34+
gemm_kernel_choice="fbgemm",
4235
)
4336
self.GPU_DEVICES = ["cuda"] if torch.cuda.is_available() else []
4437

@@ -135,7 +128,7 @@ def forward(self, x):
135128
original = m(input)
136129
# we need to transpose the weight first for bmm
137130
m.weight = torch.nn.Parameter(m.weight.transpose(1, 2).contiguous())
138-
quantize_(m, self.bmm_config, filter_fn=lambda x, fqn: True)
131+
quantize_(m, self.config, filter_fn=lambda x, fqn: True)
139132
quantized = m(input)
140133
self.assertTrue(compute_error(original, quantized) > 18)
141134

torchao/quantization/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@
4444
from .quant_api import (
4545
CutlassInt4PackedLayout,
4646
FbgemmConfig,
47+
Float8ActivationInt4WeightConfig,
4748
Float8DynamicActivationFloat8SemiSparseWeightConfig,
4849
Float8DynamicActivationFloat8WeightConfig,
4950
Float8MMConfig,
@@ -141,6 +142,7 @@
141142
"Int8DynamicActivationInt8WeightConfig",
142143
"Int8DynamicActivationIntxWeightConfig",
143144
"Int4WeightOnlyConfig",
145+
"Float8ActivationInt4WeightConfig",
144146
"Int8WeightOnlyConfig",
145147
"Float8WeightOnlyConfig",
146148
"Float8DynamicActivationFloat8WeightConfig",

torchao/quantization/quant_api.py

Lines changed: 61 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,6 @@
5050
to_affine_quantized_floatx_static,
5151
to_affine_quantized_intx,
5252
to_fbgemm_fp8,
53-
to_fbgemm_int4,
5453
to_marlinqqq_quantized_intx,
5554
)
5655
from torchao.dtypes.uintx.packed_linear_int8_dynamic_activation_intx_weight_layout import (
@@ -73,6 +72,7 @@
7372
from torchao.quantization.quantize_ import (
7473
Float8Tensor,
7574
Int4GroupwisePreshuffleTensor,
75+
Int4GroupwiseTensor,
7676
)
7777
from torchao.quantization.transform_module import (
7878
_QUANTIZE_CONFIG_HANDLER,
@@ -1117,6 +1117,8 @@ class Int4WeightOnlyConfig(AOBaseConfig):
11171117
zero_point_domain: Optional[ZeroPointDomain] = ZeroPointDomain.NONE
11181118
set_inductor_config: bool = True
11191119
preserve_zero: Optional[bool] = None
1120+
use_preshuffle: bool = False
1121+
gemm_kernel_choice: GemmKernelChoice = GemmKernelChoice.ATEN
11201122

11211123

11221124
# for BC
@@ -1134,15 +1136,38 @@ def _int4_weight_only_quantize_tensor(weight, config):
11341136
layout = config.layout
11351137
use_hqq = config.use_hqq
11361138
zero_point_domain = config.zero_point_domain
1139+
use_preshuffle = config.use_preshuffle
1140+
gemm_kernel_choice = config.gemm_kernel_choice
11371141

11381142
if weight.shape[-1] % group_size != 0:
11391143
logger.info(
11401144
f"Skipping quantizing weight with int4 weight only quantization because the shape of weight {weight.shape} is not compatible with group_size {group_size}"
11411145
)
11421146
return weight
11431147

1148+
if use_preshuffle and gemm_kernel_choice != GemmKernelChoice.FBGEMM:
1149+
raise NotImplementedError(
1150+
f"use_preshuffle is only supported for fbgemm kernel, got: {gemm_kernel_choice}"
1151+
)
1152+
1153+
block_size = tuple([1 for _ in range(weight.ndim - 1)] + [group_size])
1154+
1155+
if gemm_kernel_choice == GemmKernelChoice.FBGEMM:
1156+
if use_preshuffle:
1157+
new_weight = Int4GroupwisePreshuffleTensor.from_float(
1158+
weight,
1159+
block_size,
1160+
activation_dtype="bf16",
1161+
)
1162+
return new_weight
1163+
else:
1164+
new_weight = Int4GroupwiseTensor.from_float(
1165+
weight,
1166+
block_size,
1167+
)
1168+
return new_weight
1169+
11441170
mapping_type = MappingType.ASYMMETRIC
1145-
block_size = tuple([1 for _ in range(weight.dim() - 1)] + [group_size])
11461171
target_dtype = torch.int32
11471172
quant_min = 0
11481173
quant_max = 15
@@ -1214,6 +1239,39 @@ def _int4_weight_only_transform(
12141239
return module
12151240

12161241

1242+
@dataclass
1243+
class Float8ActivationInt4WeightConfig(AOBaseConfig):
1244+
group_size: int = 128
1245+
use_preshuffle: bool = False
1246+
kernel: str = "fbgemm"
1247+
1248+
1249+
@register_quantize_module_handler(Float8ActivationInt4WeightConfig)
1250+
def _(module: torch.nn.Module, config: Int4WeightOnlyConfig) -> torch.nn.Module:
1251+
assert hasattr(module, "weight"), (
1252+
"applying int8 weight only quant requires module to have weight attribute"
1253+
+ " but {module} does not have one"
1254+
)
1255+
group_size = config.group_size
1256+
use_preshuffle = config.use_preshuffle
1257+
kernel = config.kernel
1258+
1259+
assert use_preshuffle, (
1260+
f"only use_preshuffle == True is supported right now, got: {use_preshuffle}"
1261+
)
1262+
assert kernel == "fbgemm", f"only fbgemm kernel is supported, got: {kernel}"
1263+
weight = module.weight
1264+
block_size = tuple([1 for _ in range(weight.ndim - 1)] + [group_size])
1265+
new_weight = Int4GroupwisePreshuffleTensor.from_float(
1266+
module.weight,
1267+
block_size,
1268+
activation_dtype="fp8",
1269+
)
1270+
module.weight = torch.nn.Parameter(new_weight, requires_grad=False)
1271+
module.extra_repr = types.MethodType(_linear_extra_repr, module)
1272+
return module
1273+
1274+
12171275
@dataclass
12181276
class Int8WeightOnlyConfig(AOBaseConfig):
12191277
"""
@@ -2078,7 +2136,7 @@ def _(module: torch.nn.Module, config: FbgemmConfig) -> torch.nn.Module:
20782136
activation_dtype="bf16",
20792137
)
20802138
else:
2081-
weight = to_fbgemm_int4(
2139+
weight = Int4GroupwiseTensor.from_float(
20822140
module.weight,
20832141
config.block_size,
20842142
)

torchao/quantization/quantize_/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,11 @@
33
)
44
from .int4 import (
55
Int4GroupwisePreshuffleTensor,
6+
Int4GroupwiseTensor,
67
)
78

89
__all__ = [
910
"Int4GroupwisePreshuffleTensor",
11+
"Int4GroupwiseTensor",
1012
"Float8Tensor",
1113
]
Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,11 @@
11
from .int4_groupwise_preshuffle_tensor import (
22
Int4GroupwisePreshuffleTensor,
33
)
4+
from .int4_groupwise_tensor import (
5+
Int4GroupwiseTensor,
6+
)
47

58
__all__ = [
69
"Int4GroupwisePreshuffleTensor",
10+
"Int4GroupwiseTensor",
711
]

torchao/dtypes/fbgemm_int4_tensor.py renamed to torchao/quantization/quantize_/int4/int4_groupwise_tensor.py

Lines changed: 11 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,7 @@
1717
)
1818

1919
__all__ = [
20-
"to_fbgemm_int4",
21-
"FbgemmInt4Tensor",
20+
"Int4GroupwiseTensor",
2221
]
2322

2423
aten = torch.ops.aten
@@ -31,7 +30,7 @@
3130
pack_int4 = None
3231

3332

34-
class FbgemmInt4Tensor(TorchAOBaseTensor):
33+
class Int4GroupwiseTensor(TorchAOBaseTensor):
3534
tensor_data_attrs = ["packed_weight", "scale", "zero_point"]
3635
tensor_attributes = ["group_size", "shape"]
3736

@@ -118,7 +117,7 @@ def from_float(
118117
zero_point = zero_point.to(w.dtype)
119118

120119
del w
121-
return FbgemmInt4Tensor(
120+
return Int4GroupwiseTensor(
122121
packed_weight=wq,
123122
scale=scale,
124123
zero_point=zero_point,
@@ -127,7 +126,7 @@ def from_float(
127126
)
128127

129128

130-
implements = FbgemmInt4Tensor.implements
129+
implements = Int4GroupwiseTensor.implements
131130

132131

133132
@implements([torch.nn.functional.linear, aten.linear.default])
@@ -143,8 +142,8 @@ def _(func, types, args, kwargs):
143142
res = torch.ops.fbgemm.bf16i4bf16_rowwise(
144143
input_tensor,
145144
weight_tensor.packed_weight.contiguous(),
146-
weight_tensor.scale,
147-
weight_tensor.zero_point,
145+
weight_tensor.scale.contiguous(),
146+
weight_tensor.zero_point.contiguous(),
148147
)
149148
res = res.reshape(*orig_act_size[:-1], orig_out_features)
150149
if bias is not None:
@@ -185,10 +184,10 @@ def _(func, types, args, kwargs):
185184
)
186185

187186

188-
def _same_metadata(self: "FbgemmInt4Tensor", src: "FbgemmInt4Tensor") -> bool:
187+
def _same_metadata(self: "Int4GroupwiseTensor", src: "Int4GroupwiseTensor") -> bool:
189188
return (
190-
isinstance(self, FbgemmInt4Tensor)
191-
and isinstance(src, FbgemmInt4Tensor)
189+
isinstance(self, Int4GroupwiseTensor)
190+
and isinstance(src, Int4GroupwiseTensor)
192191
and self.shape == src.shape
193192
and self.packed_weight.shape == src.packed_weight.shape
194193
and self.scale.shape == src.scale.shape
@@ -287,9 +286,6 @@ def _(func, types, args, kwargs):
287286
return return_and_correct_aliasing(func, args, kwargs, new)
288287

289288

290-
to_fbgemm_int4 = FbgemmInt4Tensor.from_float
291-
292-
293289
if TORCH_VERSION_AT_LEAST_2_5:
294-
# Allow a model with FbgemmInt4Tensor weights to be loaded with `weights_only=True`
295-
torch.serialization.add_safe_globals([FbgemmInt4Tensor])
290+
# Allow a model with Int4GroupwiseTensor weights to be loaded with `weights_only=True`
291+
torch.serialization.add_safe_globals([Int4GroupwiseTensor])

0 commit comments

Comments
 (0)