Skip to content

Commit 7947ba1

Browse files
committed
various fixes
1 parent 34fa252 commit 7947ba1

File tree

6 files changed

+108
-32
lines changed

6 files changed

+108
-32
lines changed

test/prototype/mx_formats/test_mx_linear.py

+6-3
Original file line numberDiff line numberDiff line change
@@ -381,15 +381,16 @@ def test_inference_print_str():
381381
)
382382
@pytest.mark.skipif(not is_sm_at_least_100, reason="Reqs sm100")
383383
@pytest.mark.parametrize("elem_dtype", [torch.float8_e4m3fn])
384-
def test_inference_subclass(elem_dtype):
384+
@pytest.mark.parametrize("bias", [True, False])
385+
def test_inference_subclass(elem_dtype, bias: bool):
385386
"""
386387
Smoke test for inference compile
387388
"""
388389
if elem_dtype in (torch.float8_e4m3fn, torch.float8_e5m2):
389390
if not is_sm_at_least_89():
390391
pytest.skip("CUDA capability >= 8.9 required for float8 in triton")
391392

392-
m = nn.Sequential(nn.Linear(32, 128, bias=False, dtype=torch.bfloat16))
393+
m = nn.Sequential(nn.Linear(32, 128, bias=bias, dtype=torch.bfloat16))
393394
m = m.cuda()
394395
m_mx = copy.deepcopy(m)
395396
config = MXFPConfig()
@@ -401,6 +402,8 @@ def test_inference_subclass(elem_dtype):
401402
y_mx = m_mx(x)
402403
sqnr = compute_error(y_ref, y_mx)
403404
if elem_dtype is torch.float8_e4m3fn:
404-
assert sqnr >= 20.0
405+
assert (
406+
sqnr >= 20.0 if not bias else 14
407+
) # TODO see if this is working as expected
405408
else:
406409
assert sqnr >= 11.5

torchao/core/config.py

+1
Original file line numberDiff line numberDiff line change
@@ -175,6 +175,7 @@ def config_to_dict(config: AOBaseConfig) -> Dict[str, Any]:
175175
"torchao.quantization",
176176
"torchao.sparsity.sparse_api",
177177
"torchao.prototype.quantization",
178+
"torchao.prototype.mx_formats",
178179
}
179180

180181

torchao/prototype/mx_formats/__init__.py

+2
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
MXLinearConfig,
55
MXLinearRecipeName,
66
)
7+
from torchao.prototype.mx_formats.mx_subclass import MXFPConfig
78

89
# import mx_linear here to register the quantize_ transform logic
910
# ruff: noqa: I001
@@ -14,4 +15,5 @@
1415
"MXInferenceLinearConfig",
1516
"MXLinearConfig",
1617
"MXLinearRecipeName",
18+
"MXFPConfig",
1719
]

torchao/prototype/mx_formats/mx_ops.py

+74-26
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,9 @@
2020
from typing import Any, Dict, Optional
2121

2222
import torch
23+
from torch.utils._python_dispatch import (
24+
return_and_correct_aliasing,
25+
)
2326
from torch.utils._pytree import tree_map
2427

2528
import torchao.ops
@@ -52,21 +55,28 @@ def decorator(func):
5255
return decorator
5356

5457

55-
@implements([aten.detach.default])
56-
def mx_desugar_op(aten_op, args, kwargs=None):
57-
old = args[0]
58-
new_data = aten_op(old._data, *args[1:], **kwargs)
59-
new = MXTensor(
60-
old._scale_e8m0,
61-
new_data,
62-
old._elem_dtype,
63-
old._block_size,
64-
old._orig_dtype,
65-
old._use_fp4_custom_triton_dequant_kernel,
66-
old._gemm_kernel_choice,
67-
old._pack_fp6,
58+
# @implements([aten.detach.default])
59+
# def mx_desugar_op(aten_op, args, kwargs=None):
60+
# old = args[0]
61+
# new_data = aten_op(old._data, *args[1:], **kwargs)
62+
# new = MXTensor(
63+
# old._scale_e8m0,
64+
# new_data,
65+
# old._elem_dtype,
66+
# old._block_size,
67+
# old._orig_dtype,
68+
# old._use_fp4_custom_triton_dequant_kernel,
69+
# old._gemm_kernel_choice,
70+
# old._pack_fp6,
71+
# )
72+
# return new
73+
74+
75+
@implements([aten.detach.default, aten.alias.default])
76+
def _(func, types, args, kwargs):
77+
return return_and_correct_aliasing(
78+
func, args, kwargs, args[0]._apply_fn_to_data(func)
6879
)
69-
return new
7080

7181

7282
def _get_gemm_choice(
@@ -85,12 +95,15 @@ def _get_gemm_choice(
8595
return choice_a if choice_a is not None else choice_b
8696

8797

88-
@implements([aten.mm.default, aten.matmul.default])
89-
def mx_mm(aten_op, args, kwargs=None):
90-
a = args[0]
91-
b = args[1]
92-
assert isinstance(a, MXTensor) and isinstance(b, MXTensor)
98+
def _addmm_mx_dispatch(
99+
a: MXTensor, b: MXTensor, aten_op, bias: Optional[torch.Tensor] = None
100+
) -> torch.Tensor:
101+
"""
102+
Core implementation shared between mx_mm and mx_addmm.
103+
The only difference is whether bias is None or not.
104+
"""
93105
gemm_choice = _get_gemm_choice(a._gemm_kernel_choice, b._gemm_kernel_choice)
106+
94107
if gemm_choice in (MXGemmKernelChoice.CUBLAS, MXGemmKernelChoice.CUTLASS):
95108
# real MX gemm backed by torchao's CUTLASS kernels
96109
M, K, N = a.shape[0], a.shape[1], b.shape[1]
@@ -112,28 +125,63 @@ def mx_mm(aten_op, args, kwargs=None):
112125
b._data,
113126
a_scale_block.view(torch.float8_e8m0fnu),
114127
b_scale_block.view(torch.float8_e8m0fnu),
128+
bias=bias,
115129
out_dtype=torch.bfloat16,
116130
)
117131
else:
118132
assert a._elem_dtype == DTYPE_FP4
119133
assert b._elem_dtype == DTYPE_FP4
120134
assert gemm_choice is MXGemmKernelChoice.CUTLASS, "unsupported"
135+
# FP4 operations
121136
res = torchao.ops.mx_fp4_bf16(
122137
a._data, b._data, a_scale_block, b_scale_block
123138
)
139+
140+
# TODO update kernel to accept optional
141+
if bias is not None:
142+
res = res + bias
124143
else:
125144
# emulated MX gemm
126145
a_hp = a.to_dtype(a._orig_dtype)
127146
b_hp = b.to_dtype(b._orig_dtype)
128147
# assert memory layout we expect to be required in hardware
129148
assert a_hp.is_contiguous()
130149
assert b_hp.t().is_contiguous()
131-
res = aten_op(a_hp, b_hp)
150+
151+
# Call appropriate aten_op based on whether bias is provided
152+
if bias is not None:
153+
res = aten_op(bias, a_hp, b_hp) # addmm
154+
else:
155+
res = aten_op(a_hp, b_hp) # mm
156+
132157
return res
133158

134159

160+
@implements([aten.mm.default, aten.matmul.default])
161+
def mx_mm(func, types, args, kwargs):
162+
a = args[0]
163+
b = args[1]
164+
assert isinstance(a, MXTensor) and isinstance(b, MXTensor)
165+
166+
return _addmm_mx_dispatch(a, b, func)
167+
168+
169+
@implements([aten.addmm.default])
170+
def mx_addmm(func, types, args, kwargs):
171+
assert (
172+
isinstance(args[0], torch.Tensor)
173+
and isinstance(args[1], MXTensor)
174+
and isinstance(args[2], MXTensor)
175+
)
176+
bias = args[0]
177+
a = args[1]
178+
b = args[2]
179+
180+
return _addmm_mx_dispatch(a, b, func, bias=bias)
181+
182+
135183
@implements([aten.t.default])
136-
def mx_t(aten_op, args, kwargs=None):
184+
def mx_t(func, types, args, kwargs):
137185
# For now, only transpose(input, 0, 1) is supported.
138186
old = args[0]
139187
new = MXTensor(
@@ -150,7 +198,7 @@ def mx_t(aten_op, args, kwargs=None):
150198

151199

152200
@implements([aten.sum.dim_IntList])
153-
def mx_cast_up_op(aten_op, args, kwargs=None):
201+
def mx_cast_up_op(func, types, args, kwargs):
154202
"""Be careful with this function, this is a "fallback" op that
155203
casts the output of the op to the original precision. And performs the op.
156204
@@ -166,11 +214,11 @@ def unwrap(x):
166214

167215
new_args = tree_map(unwrap, args)
168216
new_kwargs = tree_map(unwrap, kwargs)
169-
return aten_op(*new_args, **new_kwargs)
217+
return func(*new_args, **new_kwargs)
170218

171219

172220
@implements([aten.view.default])
173-
def mx_view_op(aten_op, args, kwargs=None):
221+
def mx_view_op(func, types, args, kwargs):
174222
data = args[0]._data
175223
new_size = args[1]
176224
if args[0]._elem_dtype == DTYPE_FP4:
@@ -179,7 +227,7 @@ def mx_view_op(aten_op, args, kwargs=None):
179227
elif args[0]._elem_dtype in [DTYPE_FP6_E3M2, DTYPE_FP6_E2M3] and args[0]._pack_fp6:
180228
# special case fp6 as we pack 4 elements in 3 bytes
181229
new_size = tensor_size_hpx3_to_fp6x4(new_size, data.is_contiguous())
182-
new_data = aten_op(data, new_size, *args[2:], **kwargs)
230+
new_data = func(data, new_size, *args[2:], **kwargs)
183231
return MXTensor(
184232
args[0]._scale_e8m0,
185233
new_data,
@@ -193,7 +241,7 @@ def mx_view_op(aten_op, args, kwargs=None):
193241

194242

195243
@implements([aten._to_copy.default])
196-
def autocast_to_copy(aten_op, args, kwargs=None):
244+
def autocast_to_copy(func, types, args, kwargs):
197245
"""This gets called when running matmul under autocast
198246
when the input is a MXTensor, presenting as a fp32
199247
tensor.

torchao/prototype/mx_formats/mx_subclass.py

+7-1
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424
from torchao.quantization.transform_module import (
2525
register_quantize_module_handler,
2626
)
27-
from torchao.utils import is_sm_at_least_100
27+
from torchao.utils import TORCH_VERSION_AT_LEAST_2_5, is_sm_at_least_100
2828

2929

3030
@dataclass
@@ -115,3 +115,9 @@ def _mx_inference_linear_transform(module: torch.nn.Module, config: MXFPConfig):
115115
module.weight = torch.nn.Parameter(quantized_weight, requires_grad=False)
116116
module.extra_repr = types.MethodType(_linear_extra_repr, module)
117117
return module
118+
119+
120+
if TORCH_VERSION_AT_LEAST_2_5:
121+
torch.serialization.add_safe_globals(
122+
[MXTensor, MXGemmKernelChoice, _input_activation_quant_func_mxfp]
123+
)

torchao/prototype/mx_formats/mx_tensor.py

+18-2
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
"""
1919

2020
from enum import Enum, auto
21-
from typing import Dict, Union
21+
from typing import Callable, Dict, Union
2222

2323
import torch
2424

@@ -562,7 +562,7 @@ def __torch_dispatch__(cls, func, types, args, kwargs=None):
562562
from torchao.prototype.mx_formats.mx_ops import MX_OPS_TABLE
563563

564564
if func in MX_OPS_TABLE:
565-
return MX_OPS_TABLE[func](func, args, kwargs)
565+
return MX_OPS_TABLE[func](func, types, args, kwargs)
566566

567567
raise NotImplementedError(f"{func} not implemented")
568568

@@ -631,5 +631,21 @@ def __tensor_unflatten__(
631631
metadata["_pack_fp6"],
632632
)
633633

634+
def _apply_fn_to_data(self, fn: Callable):
635+
"""Applies a fn to all tensor components stored on this class"""
636+
tensor_names, ctx = self.__tensor_flatten__()
637+
638+
# Apply the function to each tensor component
639+
new_tensors = {}
640+
for name in tensor_names:
641+
new_tensors[name] = fn(getattr(self, name))
642+
643+
return self.__class__.__tensor_unflatten__(
644+
new_tensors,
645+
ctx,
646+
None, # outer_size parameter
647+
None, # outer_stride parameter
648+
)
649+
634650
# Do not force the MXTensor type on the returned tensor
635651
__torch_function__ = torch._C._disabled_torch_function_impl

0 commit comments

Comments
 (0)