Skip to content

Commit ff4682e

Browse files
committed
Add fbgemm preshuffle kernel into Int4WeightOnlyConfig and Float8DynamicActivationInt4WeightConfig
Summary: att, we will deprecate FbgemmConfig since it's a single kernel. we'd like to categorize things to derived dtype + packed format Test Plan: python test/quantization/quantize_/test_int4_groupwise_preshuffle.py Reviewers: Subscribers: Tasks: Tags: stack-info: PR: #2474, branch: jerryzh168/stack/10
1 parent 7e9f224 commit ff4682e

File tree

3 files changed

+56
-35
lines changed

3 files changed

+56
-35
lines changed

test/quantization/quantize_/test_int4_groupwise_preshuffle.py

Lines changed: 9 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,8 @@
1616
)
1717

1818
from torchao.quantization import (
19-
FbgemmConfig,
19+
Float8ActivationInt4WeightConfig,
20+
Int4WeightOnlyConfig,
2021
quantize_,
2122
)
2223
from torchao.quantization.utils import compute_error
@@ -26,40 +27,14 @@
2627
is_sm_at_least_90,
2728
)
2829

29-
BF16_ACT_CONFIG = FbgemmConfig(
30-
input_dtype=torch.bfloat16,
31-
weight_dtype=torch.int4,
32-
output_dtype=torch.bfloat16,
33-
block_size=[1, 128],
34-
preshuffle=True,
35-
activation_dtype_for_int4="bf16",
30+
BF16_ACT_CONFIG = Int4WeightOnlyConfig(
31+
group_size=128,
32+
use_preshuffle=True,
3633
)
3734

38-
BF16_ACT_BMM_CONFIG = FbgemmConfig(
39-
input_dtype=torch.bfloat16,
40-
weight_dtype=torch.int4,
41-
output_dtype=torch.bfloat16,
42-
block_size=[1, 1, 128],
43-
preshuffle=True,
44-
activation_dtype_for_int4="bf16",
45-
)
46-
47-
FP8_ACT_CONFIG = FbgemmConfig(
48-
input_dtype=torch.bfloat16,
49-
weight_dtype=torch.int4,
50-
output_dtype=torch.bfloat16,
51-
block_size=[1, 128],
52-
preshuffle=True,
53-
activation_dtype_for_int4="fp8",
54-
)
55-
56-
FP8_ACT_BMM_CONFIG = FbgemmConfig(
57-
input_dtype=torch.bfloat16,
58-
weight_dtype=torch.int4,
59-
output_dtype=torch.bfloat16,
60-
block_size=[1, 1, 128],
61-
preshuffle=True,
62-
activation_dtype_for_int4="fp8",
35+
FP8_ACT_CONFIG = Float8ActivationInt4WeightConfig(
36+
group_size=128,
37+
use_preshuffle=True,
6338
)
6439

6540

@@ -86,7 +61,7 @@ def test_linear(self, config):
8661

8762
# Note: this order will error out: `Got bad cuda status: an illegal memory access was encountered at line: 449`
8863
# @parametrize("bmm_config", [BF16_ACT_BMM_CONFIG, FP8_ACT_BMM_CONFIG])
89-
@parametrize("bmm_config", [FP8_ACT_BMM_CONFIG, BF16_ACT_BMM_CONFIG])
64+
@parametrize("bmm_config", [FP8_ACT_CONFIG, BF16_ACT_CONFIG])
9065
def test_bmm(self, bmm_config):
9166
class M(torch.nn.Module):
9267
def __init__(self, weight):

torchao/quantization/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@
4444
from .quant_api import (
4545
CutlassInt4PackedLayout,
4646
FbgemmConfig,
47+
Float8ActivationInt4WeightConfig,
4748
Float8DynamicActivationFloat8SemiSparseWeightConfig,
4849
Float8DynamicActivationFloat8WeightConfig,
4950
Float8MMConfig,
@@ -141,6 +142,7 @@
141142
"Int8DynamicActivationInt8WeightConfig",
142143
"Int8DynamicActivationIntxWeightConfig",
143144
"Int4WeightOnlyConfig",
145+
"Float8ActivationInt4WeightConfig",
144146
"Int8WeightOnlyConfig",
145147
"Float8WeightOnlyConfig",
146148
"Float8DynamicActivationFloat8WeightConfig",

torchao/quantization/quant_api.py

Lines changed: 45 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1115,6 +1115,7 @@ class Int4WeightOnlyConfig(AOBaseConfig):
11151115
zero_point_domain: Optional[ZeroPointDomain] = ZeroPointDomain.NONE
11161116
set_inductor_config: bool = True
11171117
preserve_zero: Optional[bool] = None
1118+
use_preshuffle: bool = False
11181119

11191120

11201121
# for BC
@@ -1132,15 +1133,25 @@ def _int4_weight_only_quantize_tensor(weight, config):
11321133
layout = config.layout
11331134
use_hqq = config.use_hqq
11341135
zero_point_domain = config.zero_point_domain
1136+
use_preshuffle = config.use_preshuffle
11351137

11361138
if weight.shape[-1] % group_size != 0:
11371139
logger.info(
11381140
f"Skipping quantizing weight with int4 weight only quantization because the shape of weight {weight.shape} is not compatible with group_size {group_size}"
11391141
)
11401142
return weight
11411143

1144+
block_size = tuple([1 for _ in range(weight.ndim - 1)] + [group_size])
1145+
1146+
if use_preshuffle:
1147+
new_weight = Int4GroupwisePreshuffleTensor.from_float(
1148+
weight,
1149+
block_size,
1150+
activation_dtype="bf16",
1151+
)
1152+
return new_weight
1153+
11421154
mapping_type = MappingType.ASYMMETRIC
1143-
block_size = tuple([1 for _ in range(weight.dim() - 1)] + [group_size])
11441155
target_dtype = torch.int32
11451156
quant_min = 0
11461157
quant_max = 15
@@ -1212,6 +1223,39 @@ def _int4_weight_only_transform(
12121223
return module
12131224

12141225

1226+
@dataclass
1227+
class Float8ActivationInt4WeightConfig(AOBaseConfig):
1228+
group_size: int = 128
1229+
use_preshuffle: bool = False
1230+
kernel: str = "fbgemm"
1231+
1232+
1233+
@register_quantize_module_handler(Float8ActivationInt4WeightConfig)
1234+
def _(module: torch.nn.Module, config: Int4WeightOnlyConfig) -> torch.nn.Module:
1235+
assert hasattr(module, "weight"), (
1236+
"applying int8 weight only quant requires module to have weight attribute"
1237+
+ " but {module} does not have one"
1238+
)
1239+
group_size = config.group_size
1240+
use_preshuffle = config.use_preshuffle
1241+
kernel = config.kernel
1242+
1243+
assert use_preshuffle, (
1244+
f"only use_preshuffle == True is supported right now, got: {use_preshuffle}"
1245+
)
1246+
assert kernel == "fbgemm", f"only fbgemm kernel is supported, got: {kernel}"
1247+
weight = module.weight
1248+
block_size = tuple([1 for _ in range(weight.ndim - 1)] + [group_size])
1249+
new_weight = Int4GroupwisePreshuffleTensor.from_float(
1250+
module.weight,
1251+
block_size,
1252+
activation_dtype="fp8",
1253+
)
1254+
module.weight = torch.nn.Parameter(new_weight, requires_grad=False)
1255+
module.extra_repr = types.MethodType(_linear_extra_repr, module)
1256+
return module
1257+
1258+
12151259
@dataclass
12161260
class Int8WeightOnlyConfig(AOBaseConfig):
12171261
"""

0 commit comments

Comments
 (0)