Skip to content

Commit 2293732

Browse files
committed
[MX] Remove mxfp8 kernel and rely on cublas
stack-info: PR: #2130, branch: drisspg/stack/49
1 parent 868afa6 commit 2293732

File tree

7 files changed

+22
-78
lines changed

7 files changed

+22
-78
lines changed

test/prototype/mx_formats/test_mx_linear.py

+4-6
Original file line numberDiff line numberDiff line change
@@ -137,7 +137,6 @@ def test_linear_eager_vs_hp(
137137
"recipe_name",
138138
[
139139
MXLinearRecipeName.MXFP8_CUBLAS,
140-
MXLinearRecipeName.MXFP8_CUTLASS,
141140
MXLinearRecipeName.MXFP4_CUTLASS,
142141
],
143142
)
@@ -206,7 +205,6 @@ def test_activation_checkpointing():
206205
"mxfp8_emulated",
207206
"mxfp4_emulated",
208207
"mxfp8_cublas",
209-
"mxfp8_cutlass",
210208
"mxfp4_cutlass",
211209
],
212210
)
@@ -218,22 +216,22 @@ def test_linear_compile(hp_dtype, recipe_name, bias, use_fp8_dim1_cast_triton_ke
218216
"""
219217
Verify that compile does not change numerics of MX linear fw + bw
220218
"""
221-
if recipe_name in ["mxfp8_emulated", "mxfp8_cutlass"]:
219+
if recipe_name in ["mxfp8_emulated"]:
222220
if not is_sm_at_least_89():
223221
pytest.skip("CUDA capability >= 8.9 required for float8 in triton")
224222

225-
if recipe_name in ["mxfp8_cublas", "mxfp8_cutlass", "mxfp4_cutlass"]:
223+
if recipe_name in ["mxfp8_cublas", "mxfp4_cutlass"]:
226224
if not TORCH_VERSION_AT_LEAST_2_8:
227225
pytest.skip("torch.compile requires PyTorch 2.8+")
228226
if not is_sm_at_least_100():
229227
pytest.skip("CUDA capability >= 10.0 required for MX gemms")
230228

231-
if bias and recipe_name in ["mxfp8_cublas", "mxfp8_cutlass", "mxfp4_cutlass"]:
229+
if bias and recipe_name in ["mxfp8_cublas", "mxfp4_cutlass"]:
232230
# TODO(future PR): fix this, things are clearly broken with bias=True
233231
pytest.skip("this test is broken for non-emulated recipes with bias=True")
234232

235233
if use_fp8_dim1_cast_triton_kernel:
236-
if recipe_name not in ("mxfp8_emulated", "mxfp8_cublas", "mxfp8_cutlass"):
234+
if recipe_name not in ("mxfp8_emulated", "mxfp8_cublas"):
237235
pytest.skip("unsupported configuration")
238236
if not is_sm_at_least_89():
239237
pytest.skip("CUDA capability >= 8.9 required for float8 in triton")

test/prototype/mx_formats/test_mx_mm.py

+8-2
Original file line numberDiff line numberDiff line change
@@ -3,11 +3,13 @@
33
#
44
# This source code is licensed under the BSD 3-Clause license found in the
55
# LICENSE file in the root directory of this source tree.
6+
from functools import partial
7+
68
import pytest
79
import torch
810

911
from torchao.float8.float8_utils import compute_error
10-
from torchao.ops import mx_fp4_bf16, mx_fp8_bf16
12+
from torchao.ops import mx_fp4_bf16
1113
from torchao.prototype.mx_formats.mx_tensor import DTYPE_FP4, MXTensor
1214
from torchao.prototype.mx_formats.utils import to_blocked
1315
from torchao.utils import TORCH_VERSION_AT_LEAST_2_7, is_sm_at_least_100
@@ -24,7 +26,11 @@ def run_matrix_test(M: int, K: int, N: int, format) -> float:
2426
b = torch.rand((N, K), dtype=dtype, device=device)
2527

2628
fmt = torch.float8_e4m3fn if format == "fp8" else DTYPE_FP4
27-
mx_func = mx_fp8_bf16 if format == "fp8" else mx_fp4_bf16
29+
mx_func = (
30+
partial(torch._scaled_mm, out_dtype=torch.bfloat16)
31+
if format == "fp8"
32+
else mx_fp4_bf16
33+
)
2834

2935
a_mx = MXTensor.to_mx(a, fmt, 32)
3036
b_mx = MXTensor.to_mx(b, fmt, 32)

torchao/csrc/cuda/mx_kernels/mx_fp_cutlass_kernels.cu

-30
Original file line numberDiff line numberDiff line change
@@ -221,33 +221,6 @@ void validate(at::Tensor a, at::Tensor b, at::Tensor a_scale, at::Tensor b_scale
221221
"Input tensor 'b' must be contiguous in the K dimension (column-major)");
222222
}
223223

224-
225-
at::Tensor mx_fp8_bf16(at::Tensor a, at::Tensor b, at::Tensor a_scale,
226-
at::Tensor b_scale) {
227-
#if defined(BUILD_MX_KERNELS_CUTLASS)
228-
validate(a, b, a_scale, b_scale);
229-
auto M = a.size(0);
230-
auto K = a.size(1);
231-
auto N = b.size(1);
232-
233-
auto out =
234-
at::empty({M, N}, a.options().dtype(at::kBFloat16));
235-
using ElementA = cutlass::mx_float8_t<cutlass::float_e4m3_t>;
236-
using ElementB = cutlass::mx_float8_t<cutlass::float_e4m3_t>;
237-
using ElementD = cutlass::bfloat16_t;
238-
239-
using MmaTileShape = Shape<_128,_128,_128>;
240-
using ClusterShape = Shape<_2,_1,_1>;
241-
using PerSmTileShape_MNK = Shape<_128,_128,_128>;
242-
243-
run_gemm<ElementA, ElementB, ElementD, MmaTileShape, ClusterShape, PerSmTileShape_MNK>(a, b, a_scale, b_scale, out, M, K, N);
244-
return out;
245-
#else
246-
TORCH_CHECK_NOT_IMPLEMENTED(false, __func__);
247-
return at::Tensor{};
248-
#endif
249-
}
250-
251224
at::Tensor mx_fp4_bf16(at::Tensor a, at::Tensor b, at::Tensor a_scale,
252225
at::Tensor b_scale) {
253226
#if defined(BUILD_MX_KERNELS_CUTLASS)
@@ -278,9 +251,6 @@ at::Tensor mx_fp4_bf16(at::Tensor a, at::Tensor b, at::Tensor a_scale,
278251
#endif
279252
}
280253

281-
TORCH_LIBRARY_IMPL(torchao, CUDA, m) {
282-
m.impl("torchao::mx_fp8_bf16", &mx_fp8_bf16);
283-
}
284254
TORCH_LIBRARY_IMPL(torchao, CUDA, m) {
285255
m.impl("torchao::mx_fp4_bf16", &mx_fp4_bf16);
286256
}

torchao/ops.py

-23
Original file line numberDiff line numberDiff line change
@@ -818,29 +818,6 @@ def _check_scale_dtypes(A_scale, B_scale):
818818
)
819819

820820

821-
def mx_fp8_bf16(A: Tensor, B: Tensor, A_scale: Tensor, B_scale: Tensor):
822-
"""Defines a matmul between two fp8 tensors w/ MX scales in E8MO and returns a bf16 tensor.
823-
824-
This op is prototype subject to change.
825-
826-
Note: The mx scales are E8MO tensors store in uint8 tensors (for now).
827-
The layout of the scales is very particular, see:
828-
https://docs.nvidia.com/cuda/cublas/index.html#d-block-scaling-factors-layout
829-
830-
Args:
831-
A: fp8 tensor w/ dtype = torch.float8_e4m3fn
832-
B: fp8 tensor w/ dtype = torch.float8_e4m3fn
833-
A_scale: E8M0 scale tensor for A with groupsize=32 in swizzled layout
834-
B_scale: E8M0 scale tensor for B with groupsize=32 in swizzled layout
835-
836-
Returns:
837-
MXN bf16 Tensor
838-
839-
"""
840-
_check_scale_dtypes(A_scale, B_scale)
841-
return torch.ops.torchao.mx_fp8_bf16.default(A, B, A_scale, B_scale)
842-
843-
844821
@register_custom_op("torchao::mx_fp8_bf16")
845822
def meta_mx_fp8_bf16(A: Tensor, B: Tensor, A_scale: Tensor, B_scale: Tensor):
846823
"""Meta impl for mx_fp8_bf16"""

torchao/prototype/mx_formats/config.py

-3
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,6 @@ class MXGemmKernelChoice(Enum):
3838
class MXLinearRecipeName(Enum):
3939
MXFP8_EMULATED = "mxfp8_emulated"
4040
MXFP8_CUBLAS = "mxfp8_cublas"
41-
MXFP8_CUTLASS = "mxfp8_cutlass"
4241
MXFP4_EMULATED = "mxfp4_emulated"
4342
MXFP4_CUTLASS = "mxfp4_cutlass"
4443

@@ -126,8 +125,6 @@ def from_recipe_name(
126125
return MXLinearConfig()
127126
elif recipe_name is MXLinearRecipeName.MXFP8_CUBLAS:
128127
return MXLinearConfig(gemm_kernel_choice=MXGemmKernelChoice.CUBLAS)
129-
elif recipe_name is MXLinearRecipeName.MXFP8_CUTLASS:
130-
return MXLinearConfig(gemm_kernel_choice=MXGemmKernelChoice.CUTLASS)
131128
elif recipe_name is MXLinearRecipeName.MXFP4_EMULATED:
132129
return MXLinearConfig(elem_dtype=DTYPE_FP4)
133130
elif recipe_name is MXLinearRecipeName.MXFP4_CUTLASS:

torchao/prototype/mx_formats/mx_ops.py

+10-12
Original file line numberDiff line numberDiff line change
@@ -88,18 +88,16 @@ def mx_mm(aten_op, args, kwargs=None):
8888
b_scale_block = to_blocked(b_scale)
8989
if a._elem_dtype == torch.float8_e4m3fn:
9090
assert b._elem_dtype == torch.float8_e4m3fn
91-
if a._gemm_kernel_choice is MXGemmKernelChoice.CUBLAS:
92-
res = torch._scaled_mm(
93-
a._data,
94-
b._data,
95-
a_scale_block.view(torch.float8_e8m0fnu),
96-
b_scale_block.view(torch.float8_e8m0fnu),
97-
out_dtype=torch.bfloat16,
98-
)
99-
else:
100-
res = torchao.ops.mx_fp8_bf16(
101-
a._data, b._data, a_scale_block, b_scale_block
102-
)
91+
assert a._gemm_kernel_choice is MXGemmKernelChoice.CUBLAS, (
92+
"CUBLAS is the only supported kernel choice for MX FP8 operations"
93+
)
94+
res = torch._scaled_mm(
95+
a._data,
96+
b._data,
97+
a_scale_block.view(torch.float8_e8m0fnu),
98+
b_scale_block.view(torch.float8_e8m0fnu),
99+
out_dtype=torch.bfloat16,
100+
)
103101
else:
104102
assert a._elem_dtype == DTYPE_FP4
105103
assert b._elem_dtype == DTYPE_FP4

torchao/testing/float8/roofline_utils.py

-2
Original file line numberDiff line numberDiff line change
@@ -180,7 +180,6 @@ def get_tensor_memory_traffic_ovhd_s(
180180
else:
181181
assert mx_recipe_name in (
182182
"mxfp8_emulated",
183-
"mxfp8_cutlass",
184183
"mxfp8_cublas",
185184
), "unsupported"
186185
# For now, assume that we can't profitably fuse kernel 1 and kernel 2
@@ -227,7 +226,6 @@ def get_individual_gemm_time_sympy(
227226
if mx_recipe_name is not None:
228227
assert mx_recipe_name in (
229228
"mxfp8_emulated",
230-
"mxfp8_cutlass",
231229
"mxfp8_cublas",
232230
), "unsupported"
233231
assert dtype in (torch.float8_e4m3fn, torch.float8_e5m2), "unsupported"

0 commit comments

Comments
 (0)