@@ -221,33 +221,6 @@ void validate(at::Tensor a, at::Tensor b, at::Tensor a_scale, at::Tensor b_scale
221
221
" Input tensor 'b' must be contiguous in the K dimension (column-major)" );
222
222
}
223
223
224
-
225
- at::Tensor mx_fp8_bf16 (at::Tensor a, at::Tensor b, at::Tensor a_scale,
226
- at::Tensor b_scale) {
227
- #if defined(BUILD_MX_KERNELS_CUTLASS)
228
- validate (a, b, a_scale, b_scale);
229
- auto M = a.size (0 );
230
- auto K = a.size (1 );
231
- auto N = b.size (1 );
232
-
233
- auto out =
234
- at::empty ({M, N}, a.options ().dtype (at::kBFloat16 ));
235
- using ElementA = cutlass::mx_float8_t <cutlass::float_e4m3_t >;
236
- using ElementB = cutlass::mx_float8_t <cutlass::float_e4m3_t >;
237
- using ElementD = cutlass::bfloat16_t ;
238
-
239
- using MmaTileShape = Shape<_128,_128,_128>;
240
- using ClusterShape = Shape<_2,_1,_1>;
241
- using PerSmTileShape_MNK = Shape<_128,_128,_128>;
242
-
243
- run_gemm<ElementA, ElementB, ElementD, MmaTileShape, ClusterShape, PerSmTileShape_MNK>(a, b, a_scale, b_scale, out, M, K, N);
244
- return out;
245
- #else
246
- TORCH_CHECK_NOT_IMPLEMENTED (false , __func__);
247
- return at::Tensor{};
248
- #endif
249
- }
250
-
251
224
at::Tensor mx_fp4_bf16 (at::Tensor a, at::Tensor b, at::Tensor a_scale,
252
225
at::Tensor b_scale) {
253
226
#if defined(BUILD_MX_KERNELS_CUTLASS)
@@ -278,9 +251,6 @@ at::Tensor mx_fp4_bf16(at::Tensor a, at::Tensor b, at::Tensor a_scale,
278
251
#endif
279
252
}
280
253
281
- TORCH_LIBRARY_IMPL (torchao, CUDA, m) {
282
- m.impl (" torchao::mx_fp8_bf16" , &mx_fp8_bf16);
283
- }
284
254
TORCH_LIBRARY_IMPL (torchao, CUDA, m) {
285
255
m.impl (" torchao::mx_fp4_bf16" , &mx_fp4_bf16);
286
256
}
0 commit comments