Skip to content

Commit f5977ce

Browse files
committed
Add fbgemm preshuffle kernel into Int4WeightOnlyConfig and Float8DynamicActivationInt4WeightConfig
Summary: att, we will deprecate FbgemmConfig since it's a single kernel. we'd like to categorize things to derived dtype + packed format Test Plan: python test/quantization/quantize_/test_int4_groupwise_preshuffle.py Reviewers: Subscribers: Tasks: Tags: stack-info: PR: #2474, branch: jerryzh168/stack/10
1 parent 897ec7e commit f5977ce

File tree

3 files changed

+56
-35
lines changed

3 files changed

+56
-35
lines changed

test/quantization/quantize_/test_int4_groupwise_preshuffle.py

Lines changed: 9 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,8 @@
1616
)
1717

1818
from torchao.quantization import (
19-
FbgemmConfig,
19+
Float8ActivationInt4WeightConfig,
20+
Int4WeightOnlyConfig,
2021
quantize_,
2122
)
2223
from torchao.quantization.utils import compute_error
@@ -26,40 +27,14 @@
2627
is_sm_at_least_90,
2728
)
2829

29-
BF16_ACT_CONFIG = FbgemmConfig(
30-
input_dtype=torch.bfloat16,
31-
weight_dtype=torch.int4,
32-
output_dtype=torch.bfloat16,
33-
block_size=[1, 128],
34-
preshuffle=True,
35-
activation_dtype_for_int4="bf16",
30+
BF16_ACT_CONFIG = Int4WeightOnlyConfig(
31+
group_size=128,
32+
use_preshuffle=True,
3633
)
3734

38-
BF16_ACT_BMM_CONFIG = FbgemmConfig(
39-
input_dtype=torch.bfloat16,
40-
weight_dtype=torch.int4,
41-
output_dtype=torch.bfloat16,
42-
block_size=[1, 1, 128],
43-
preshuffle=True,
44-
activation_dtype_for_int4="bf16",
45-
)
46-
47-
FP8_ACT_CONFIG = FbgemmConfig(
48-
input_dtype=torch.bfloat16,
49-
weight_dtype=torch.int4,
50-
output_dtype=torch.bfloat16,
51-
block_size=[1, 128],
52-
preshuffle=True,
53-
activation_dtype_for_int4="fp8",
54-
)
55-
56-
FP8_ACT_BMM_CONFIG = FbgemmConfig(
57-
input_dtype=torch.bfloat16,
58-
weight_dtype=torch.int4,
59-
output_dtype=torch.bfloat16,
60-
block_size=[1, 1, 128],
61-
preshuffle=True,
62-
activation_dtype_for_int4="fp8",
35+
FP8_ACT_CONFIG = Float8ActivationInt4WeightConfig(
36+
group_size=128,
37+
use_preshuffle=True,
6338
)
6439

6540

@@ -86,7 +61,7 @@ def test_linear(self, config):
8661

8762
# Note: this order will error out: `Got bad cuda status: an illegal memory access was encountered at line: 449`
8863
# @parametrize("bmm_config", [BF16_ACT_BMM_CONFIG, FP8_ACT_BMM_CONFIG])
89-
@parametrize("bmm_config", [FP8_ACT_BMM_CONFIG, BF16_ACT_BMM_CONFIG])
64+
@parametrize("bmm_config", [FP8_ACT_CONFIG, BF16_ACT_CONFIG])
9065
def test_bmm(self, bmm_config):
9166
class M(torch.nn.Module):
9267
def __init__(self, weight):

torchao/quantization/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@
4444
from .quant_api import (
4545
CutlassInt4PackedLayout,
4646
FbgemmConfig,
47+
Float8ActivationInt4WeightConfig,
4748
Float8DynamicActivationFloat8SemiSparseWeightConfig,
4849
Float8DynamicActivationFloat8WeightConfig,
4950
Float8MMConfig,
@@ -141,6 +142,7 @@
141142
"Int8DynamicActivationInt8WeightConfig",
142143
"Int8DynamicActivationIntxWeightConfig",
143144
"Int4WeightOnlyConfig",
145+
"Float8ActivationInt4WeightConfig",
144146
"Int8WeightOnlyConfig",
145147
"Float8WeightOnlyConfig",
146148
"Float8DynamicActivationFloat8WeightConfig",

torchao/quantization/quant_api.py

Lines changed: 45 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1114,6 +1114,7 @@ class Int4WeightOnlyConfig(AOBaseConfig):
11141114
zero_point_domain: Optional[ZeroPointDomain] = ZeroPointDomain.NONE
11151115
set_inductor_config: bool = True
11161116
preserve_zero: Optional[bool] = None
1117+
use_preshuffle: bool = False
11171118

11181119

11191120
# for BC
@@ -1131,15 +1132,25 @@ def _int4_weight_only_quantize_tensor(weight, config):
11311132
layout = config.layout
11321133
use_hqq = config.use_hqq
11331134
zero_point_domain = config.zero_point_domain
1135+
use_preshuffle = config.use_preshuffle
11341136

11351137
if weight.shape[-1] % group_size != 0:
11361138
logger.info(
11371139
f"Skipping quantizing weight with int4 weight only quantization because the shape of weight {weight.shape} is not compatible with group_size {group_size}"
11381140
)
11391141
return weight
11401142

1143+
block_size = tuple([1 for _ in range(weight.ndim - 1)] + [group_size])
1144+
1145+
if use_preshuffle:
1146+
new_weight = Int4GroupwisePreshuffleTensor.from_float(
1147+
weight,
1148+
block_size,
1149+
activation_dtype="bf16",
1150+
)
1151+
return new_weight
1152+
11411153
mapping_type = MappingType.ASYMMETRIC
1142-
block_size = tuple([1 for _ in range(weight.dim() - 1)] + [group_size])
11431154
target_dtype = torch.int32
11441155
quant_min = 0
11451156
quant_max = 15
@@ -1211,6 +1222,39 @@ def _int4_weight_only_transform(
12111222
return module
12121223

12131224

1225+
@dataclass
1226+
class Float8ActivationInt4WeightConfig(AOBaseConfig):
1227+
group_size: int = 128
1228+
use_preshuffle: bool = False
1229+
kernel: str = "fbgemm"
1230+
1231+
1232+
@register_quantize_module_handler(Float8ActivationInt4WeightConfig)
1233+
def _(module: torch.nn.Module, config: Int4WeightOnlyConfig) -> torch.nn.Module:
1234+
assert hasattr(module, "weight"), (
1235+
"applying int8 weight only quant requires module to have weight attribute"
1236+
+ " but {module} does not have one"
1237+
)
1238+
group_size = config.group_size
1239+
use_preshuffle = config.use_preshuffle
1240+
kernel = config.kernel
1241+
1242+
assert use_preshuffle, (
1243+
f"only use_preshuffle == True is supported right now, got: {use_preshuffle}"
1244+
)
1245+
assert kernel == "fbgemm", f"only fbgemm kernel is supported, got: {kernel}"
1246+
weight = module.weight
1247+
block_size = tuple([1 for _ in range(weight.ndim - 1)] + [group_size])
1248+
new_weight = Int4GroupwisePreshuffleTensor.from_float(
1249+
module.weight,
1250+
block_size,
1251+
activation_dtype="fp8",
1252+
)
1253+
module.weight = torch.nn.Parameter(new_weight, requires_grad=False)
1254+
module.extra_repr = types.MethodType(_linear_extra_repr, module)
1255+
return module
1256+
1257+
12141258
@dataclass
12151259
class Int8WeightOnlyConfig(AOBaseConfig):
12161260
"""

0 commit comments

Comments
 (0)