Skip to content

Commit 93067f0

Browse files
committed
add subclass based method for inference
1 parent 2293732 commit 93067f0

File tree

4 files changed

+172
-6
lines changed

4 files changed

+172
-6
lines changed

test/prototype/mx_formats/test_mx_linear.py

+32
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
MXInferenceLinear,
2626
MXLinear,
2727
)
28+
from torchao.prototype.mx_formats.mx_subclass import MXFPConfig
2829
from torchao.quantization import quantize_
2930
from torchao.quantization.utils import compute_error
3031
from torchao.utils import (
@@ -372,3 +373,34 @@ def test_inference_print_str():
372373
s = str(m)
373374
assert "bl_sz=32" in s
374375
assert "kernel=emulated" in s
376+
377+
378+
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
379+
@pytest.mark.skipif(
380+
not TORCH_VERSION_AT_LEAST_2_8, reason="torch.compile requires PyTorch 2.8+"
381+
)
382+
@pytest.mark.skipif(not is_sm_at_least_100, reason="Reqs sm100")
383+
@pytest.mark.parametrize("elem_dtype", [torch.float8_e4m3fn])
384+
def test_inference_subclass(elem_dtype):
385+
"""
386+
Smoke test for inference compile
387+
"""
388+
if elem_dtype in (torch.float8_e4m3fn, torch.float8_e5m2):
389+
if not is_sm_at_least_89():
390+
pytest.skip("CUDA capability >= 8.9 required for float8 in triton")
391+
392+
m = nn.Sequential(nn.Linear(32, 128, bias=False, dtype=torch.bfloat16))
393+
m = m.cuda()
394+
m_mx = copy.deepcopy(m)
395+
config = MXFPConfig()
396+
quantize_(m_mx, config=config)
397+
m_mx = torch.compile(m_mx, fullgraph=True)
398+
399+
x = torch.randn(128, 32, device="cuda", dtype=torch.bfloat16)
400+
y_ref = m(x)
401+
y_mx = m_mx(x)
402+
sqnr = compute_error(y_ref, y_mx)
403+
if elem_dtype is torch.float8_e4m3fn:
404+
assert sqnr >= 20.0
405+
else:
406+
assert sqnr >= 11.5

torchao/__init__.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@
4343
quantize_,
4444
)
4545

46-
from . import dtypes, optim, swizzle, testing
46+
from . import dtypes, optim, quantization, swizzle, testing
4747

4848
__all__ = [
4949
"dtypes",
@@ -53,4 +53,5 @@
5353
"swizzle",
5454
"testing",
5555
"ops",
56+
"quantization",
5657
]

torchao/prototype/mx_formats/mx_ops.py

+21-5
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
the underlying data fields to the MX matmul.
1818
"""
1919

20-
from typing import Any, Dict
20+
from typing import Any, Dict, Optional
2121

2222
import torch
2323
from torch.utils._pytree import tree_map
@@ -69,13 +69,29 @@ def mx_desugar_op(aten_op, args, kwargs=None):
6969
return new
7070

7171

72+
def _get_gemm_choice(
73+
choice_a: Optional[MXGemmKernelChoice], choice_b: Optional[MXGemmKernelChoice]
74+
) -> MXGemmKernelChoice:
75+
if choice_a is not None and choice_b is not None:
76+
assert choice_a == choice_b, (
77+
"Both MXTensor inputs must have the same gemm config if specified"
78+
)
79+
return choice_a
80+
81+
# Assert that at least one is set and return that one
82+
assert choice_a is not None or choice_b is not None, (
83+
"At least one gemm choice must be specified"
84+
)
85+
return choice_a if choice_a is not None else choice_b
86+
87+
7288
@implements([aten.mm.default, aten.matmul.default])
7389
def mx_mm(aten_op, args, kwargs=None):
7490
a = args[0]
7591
b = args[1]
7692
assert isinstance(a, MXTensor) and isinstance(b, MXTensor)
77-
assert a._gemm_kernel_choice == b._gemm_kernel_choice, "unsupported"
78-
if a._gemm_kernel_choice in (MXGemmKernelChoice.CUBLAS, MXGemmKernelChoice.CUTLASS):
93+
gemm_choice = _get_gemm_choice(a._gemm_kernel_choice, b._gemm_kernel_choice)
94+
if gemm_choice in (MXGemmKernelChoice.CUBLAS, MXGemmKernelChoice.CUTLASS):
7995
# real MX gemm backed by torchao's CUTLASS kernels
8096
M, K, N = a.shape[0], a.shape[1], b.shape[1]
8197
assert a._data.is_contiguous()
@@ -88,7 +104,7 @@ def mx_mm(aten_op, args, kwargs=None):
88104
b_scale_block = to_blocked(b_scale)
89105
if a._elem_dtype == torch.float8_e4m3fn:
90106
assert b._elem_dtype == torch.float8_e4m3fn
91-
assert a._gemm_kernel_choice is MXGemmKernelChoice.CUBLAS, (
107+
assert gemm_choice is MXGemmKernelChoice.CUBLAS, (
92108
"CUBLAS is the only supported kernel choice for MX FP8 operations"
93109
)
94110
res = torch._scaled_mm(
@@ -101,7 +117,7 @@ def mx_mm(aten_op, args, kwargs=None):
101117
else:
102118
assert a._elem_dtype == DTYPE_FP4
103119
assert b._elem_dtype == DTYPE_FP4
104-
assert a._gemm_kernel_choice is MXGemmKernelChoice.CUTLASS, "unsupported"
120+
assert gemm_choice is MXGemmKernelChoice.CUTLASS, "unsupported"
105121
res = torchao.ops.mx_fp4_bf16(
106122
a._data, b._data, a_scale_block, b_scale_block
107123
)
+117
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,117 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD 3-Clause license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
import types
8+
from dataclasses import dataclass
9+
from typing import Optional
10+
11+
import torch
12+
13+
import torchao
14+
from torchao.core.config import AOBaseConfig
15+
from torchao.prototype.mx_formats import (
16+
MXGemmKernelChoice,
17+
)
18+
from torchao.prototype.mx_formats.config import (
19+
_validate_elem_dtype,
20+
_validate_gemm_kernel_choice,
21+
)
22+
from torchao.prototype.mx_formats.mx_tensor import MXTensor
23+
from torchao.quantization.quant_api import to_linear_activation_quantized
24+
from torchao.quantization.transform_module import (
25+
register_quantize_module_handler,
26+
)
27+
from torchao.utils import is_sm_at_least_100
28+
29+
30+
@dataclass
31+
class MXFPConfig(AOBaseConfig):
32+
block_size: int = 32
33+
34+
# Dtypes for Input and Weights
35+
activation_dtype: torch.dtype = torch.float8_e4m3fn
36+
weight_dtype: torch.dtype = torch.float8_e4m3fn
37+
38+
# Which kernel to run for mm
39+
gemm_kernel_choice: MXGemmKernelChoice = MXGemmKernelChoice.CUBLAS
40+
41+
# Set some magic perf settings
42+
set_inductor_config: bool = True
43+
44+
def __post_init__(self):
45+
assert self.activation_dtype == self.weight_dtype, (
46+
"For now - we only support matching input/weight dtypes."
47+
)
48+
_validate_elem_dtype(self.activation_dtype)
49+
_validate_elem_dtype(self.weight_dtype)
50+
_validate_gemm_kernel_choice(
51+
self.gemm_kernel_choice, self.block_size, self.weight_dtype
52+
)
53+
54+
55+
def _linear_extra_repr(self):
56+
return f"in_features={self.weight.shape[1]}, out_features={self.weight.shape[0]}, weight={repr(self.weight)}"
57+
58+
59+
def _input_activation_quant_func_mxfp(
60+
x: torch.Tensor,
61+
activation_dtype: torch.dtype,
62+
block_size: int,
63+
scale: Optional[torch.Tensor] = None,
64+
):
65+
""" """
66+
67+
# TODO scale for static quant
68+
69+
activation = MXTensor.to_mx(
70+
x,
71+
activation_dtype,
72+
block_size=block_size,
73+
gemm_kernel_choice=None, # Get from weight
74+
pack_fp6=False, # TODO
75+
)
76+
return activation
77+
78+
79+
@register_quantize_module_handler(MXFPConfig)
80+
def _mx_inference_linear_transform(module: torch.nn.Module, config: MXFPConfig):
81+
# TODO Sm120 has slightly more restrictive reqs
82+
# TODO handle AMD
83+
assert is_sm_at_least_100(), "MXFP is only supported on sm100 machiens for now"
84+
if config.set_inductor_config:
85+
torchao.quantization.utils.recommended_inductor_config_setter()
86+
87+
activation_dtype = config.activation_dtype
88+
weight_dtype = config.weight_dtype
89+
weight = module.weight
90+
91+
assert weight.dtype == torch.bfloat16, (
92+
f"Only supporting bf16 out dtype for now, got {weight.dtype}"
93+
)
94+
95+
# Convert weight to MX Tensor
96+
quantized_weight = MXTensor.to_mx(
97+
weight,
98+
weight_dtype,
99+
block_size=config.block_size,
100+
gemm_kernel_choice=config.gemm_kernel_choice,
101+
pack_fp6=False, # TODO
102+
)
103+
104+
input_quant_func = _input_activation_quant_func_mxfp
105+
input_quant_kwargs = {
106+
"block_size": config.block_size,
107+
"activation_dtype": activation_dtype,
108+
"scale": None,
109+
}
110+
111+
quantized_weight = to_linear_activation_quantized(
112+
quantized_weight, input_quant_func, quant_kwargs=input_quant_kwargs
113+
)
114+
115+
module.weight = torch.nn.Parameter(quantized_weight, requires_grad=False)
116+
module.extra_repr = types.MethodType(_linear_extra_repr, module)
117+
return module

0 commit comments

Comments
 (0)