20
20
from typing import Any , Dict , Optional
21
21
22
22
import torch
23
+ from torch .utils ._python_dispatch import (
24
+ return_and_correct_aliasing ,
25
+ )
23
26
from torch .utils ._pytree import tree_map
24
27
25
28
import torchao .ops
@@ -52,21 +55,28 @@ def decorator(func):
52
55
return decorator
53
56
54
57
55
- @implements ([aten .detach .default ])
56
- def mx_desugar_op (aten_op , args , kwargs = None ):
57
- old = args [0 ]
58
- new_data = aten_op (old ._data , * args [1 :], ** kwargs )
59
- new = MXTensor (
60
- old ._scale_e8m0 ,
61
- new_data ,
62
- old ._elem_dtype ,
63
- old ._block_size ,
64
- old ._orig_dtype ,
65
- old ._use_fp4_custom_triton_dequant_kernel ,
66
- old ._gemm_kernel_choice ,
67
- old ._pack_fp6 ,
58
+ # @implements([aten.detach.default])
59
+ # def mx_desugar_op(aten_op, args, kwargs=None):
60
+ # old = args[0]
61
+ # new_data = aten_op(old._data, *args[1:], **kwargs)
62
+ # new = MXTensor(
63
+ # old._scale_e8m0,
64
+ # new_data,
65
+ # old._elem_dtype,
66
+ # old._block_size,
67
+ # old._orig_dtype,
68
+ # old._use_fp4_custom_triton_dequant_kernel,
69
+ # old._gemm_kernel_choice,
70
+ # old._pack_fp6,
71
+ # )
72
+ # return new
73
+
74
+
75
+ @implements ([aten .detach .default , aten .alias .default ])
76
+ def _ (func , types , args , kwargs ):
77
+ return return_and_correct_aliasing (
78
+ func , args , kwargs , args [0 ]._apply_fn_to_data (func )
68
79
)
69
- return new
70
80
71
81
72
82
def _get_gemm_choice (
@@ -85,12 +95,15 @@ def _get_gemm_choice(
85
95
return choice_a if choice_a is not None else choice_b
86
96
87
97
88
- @implements ([aten .mm .default , aten .matmul .default ])
89
- def mx_mm (aten_op , args , kwargs = None ):
90
- a = args [0 ]
91
- b = args [1 ]
92
- assert isinstance (a , MXTensor ) and isinstance (b , MXTensor )
98
+ def _addmm_mx_dispatch (
99
+ a : MXTensor , b : MXTensor , aten_op , bias : Optional [torch .Tensor ] = None
100
+ ) -> torch .Tensor :
101
+ """
102
+ Core implementation shared between mx_mm and mx_addmm.
103
+ The only difference is whether bias is None or not.
104
+ """
93
105
gemm_choice = _get_gemm_choice (a ._gemm_kernel_choice , b ._gemm_kernel_choice )
106
+
94
107
if gemm_choice in (MXGemmKernelChoice .CUBLAS , MXGemmKernelChoice .CUTLASS ):
95
108
# real MX gemm backed by torchao's CUTLASS kernels
96
109
M , K , N = a .shape [0 ], a .shape [1 ], b .shape [1 ]
@@ -112,28 +125,63 @@ def mx_mm(aten_op, args, kwargs=None):
112
125
b ._data ,
113
126
a_scale_block .view (torch .float8_e8m0fnu ),
114
127
b_scale_block .view (torch .float8_e8m0fnu ),
128
+ bias = bias ,
115
129
out_dtype = torch .bfloat16 ,
116
130
)
117
131
else :
118
132
assert a ._elem_dtype == DTYPE_FP4
119
133
assert b ._elem_dtype == DTYPE_FP4
120
134
assert gemm_choice is MXGemmKernelChoice .CUTLASS , "unsupported"
135
+ # FP4 operations
121
136
res = torchao .ops .mx_fp4_bf16 (
122
137
a ._data , b ._data , a_scale_block , b_scale_block
123
138
)
139
+
140
+ # TODO update kernel to accept optional
141
+ if bias is not None :
142
+ res = res + bias
124
143
else :
125
144
# emulated MX gemm
126
145
a_hp = a .to_dtype (a ._orig_dtype )
127
146
b_hp = b .to_dtype (b ._orig_dtype )
128
147
# assert memory layout we expect to be required in hardware
129
148
assert a_hp .is_contiguous ()
130
149
assert b_hp .t ().is_contiguous ()
131
- res = aten_op (a_hp , b_hp )
150
+
151
+ # Call appropriate aten_op based on whether bias is provided
152
+ if bias is not None :
153
+ res = aten_op (bias , a_hp , b_hp ) # addmm
154
+ else :
155
+ res = aten_op (a_hp , b_hp ) # mm
156
+
132
157
return res
133
158
134
159
160
+ @implements ([aten .mm .default , aten .matmul .default ])
161
+ def mx_mm (func , types , args , kwargs ):
162
+ a = args [0 ]
163
+ b = args [1 ]
164
+ assert isinstance (a , MXTensor ) and isinstance (b , MXTensor )
165
+
166
+ return _addmm_mx_dispatch (a , b , func )
167
+
168
+
169
+ @implements ([aten .addmm .default ])
170
+ def mx_addmm (func , types , args , kwargs ):
171
+ assert (
172
+ isinstance (args [0 ], torch .Tensor )
173
+ and isinstance (args [1 ], MXTensor )
174
+ and isinstance (args [2 ], MXTensor )
175
+ )
176
+ bias = args [0 ]
177
+ a = args [1 ]
178
+ b = args [2 ]
179
+
180
+ return _addmm_mx_dispatch (a , b , func , bias = bias )
181
+
182
+
135
183
@implements ([aten .t .default ])
136
- def mx_t (aten_op , args , kwargs = None ):
184
+ def mx_t (func , types , args , kwargs ):
137
185
# For now, only transpose(input, 0, 1) is supported.
138
186
old = args [0 ]
139
187
new = MXTensor (
@@ -150,7 +198,7 @@ def mx_t(aten_op, args, kwargs=None):
150
198
151
199
152
200
@implements ([aten .sum .dim_IntList ])
153
- def mx_cast_up_op (aten_op , args , kwargs = None ):
201
+ def mx_cast_up_op (func , types , args , kwargs ):
154
202
"""Be careful with this function, this is a "fallback" op that
155
203
casts the output of the op to the original precision. And performs the op.
156
204
@@ -166,11 +214,11 @@ def unwrap(x):
166
214
167
215
new_args = tree_map (unwrap , args )
168
216
new_kwargs = tree_map (unwrap , kwargs )
169
- return aten_op (* new_args , ** new_kwargs )
217
+ return func (* new_args , ** new_kwargs )
170
218
171
219
172
220
@implements ([aten .view .default ])
173
- def mx_view_op (aten_op , args , kwargs = None ):
221
+ def mx_view_op (func , types , args , kwargs ):
174
222
data = args [0 ]._data
175
223
new_size = args [1 ]
176
224
if args [0 ]._elem_dtype == DTYPE_FP4 :
@@ -179,7 +227,7 @@ def mx_view_op(aten_op, args, kwargs=None):
179
227
elif args [0 ]._elem_dtype in [DTYPE_FP6_E3M2 , DTYPE_FP6_E2M3 ] and args [0 ]._pack_fp6 :
180
228
# special case fp6 as we pack 4 elements in 3 bytes
181
229
new_size = tensor_size_hpx3_to_fp6x4 (new_size , data .is_contiguous ())
182
- new_data = aten_op (data , new_size , * args [2 :], ** kwargs )
230
+ new_data = func (data , new_size , * args [2 :], ** kwargs )
183
231
return MXTensor (
184
232
args [0 ]._scale_e8m0 ,
185
233
new_data ,
@@ -193,7 +241,7 @@ def mx_view_op(aten_op, args, kwargs=None):
193
241
194
242
195
243
@implements ([aten ._to_copy .default ])
196
- def autocast_to_copy (aten_op , args , kwargs = None ):
244
+ def autocast_to_copy (func , types , args , kwargs ):
197
245
"""This gets called when running matmul under autocast
198
246
when the input is a MXTensor, presenting as a fp32
199
247
tensor.
0 commit comments