Skip to content

[MX] Remove mxfp8 kernel and rely on cublas #2130

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
May 2, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 4 additions & 6 deletions test/prototype/mx_formats/test_mx_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,6 @@ def test_linear_eager_vs_hp(
"recipe_name",
[
MXLinearRecipeName.MXFP8_CUBLAS,
MXLinearRecipeName.MXFP8_CUTLASS,
MXLinearRecipeName.MXFP4_CUTLASS,
],
)
Expand Down Expand Up @@ -206,7 +205,6 @@ def test_activation_checkpointing():
"mxfp8_emulated",
"mxfp4_emulated",
"mxfp8_cublas",
"mxfp8_cutlass",
"mxfp4_cutlass",
],
)
Expand All @@ -218,22 +216,22 @@ def test_linear_compile(hp_dtype, recipe_name, bias, use_fp8_dim1_cast_triton_ke
"""
Verify that compile does not change numerics of MX linear fw + bw
"""
if recipe_name in ["mxfp8_emulated", "mxfp8_cutlass"]:
if recipe_name in ["mxfp8_emulated"]:
if not is_sm_at_least_89():
pytest.skip("CUDA capability >= 8.9 required for float8 in triton")

if recipe_name in ["mxfp8_cublas", "mxfp8_cutlass", "mxfp4_cutlass"]:
if recipe_name in ["mxfp8_cublas", "mxfp4_cutlass"]:
if not TORCH_VERSION_AT_LEAST_2_8:
pytest.skip("torch.compile requires PyTorch 2.8+")
if not is_sm_at_least_100():
pytest.skip("CUDA capability >= 10.0 required for MX gemms")

if bias and recipe_name in ["mxfp8_cublas", "mxfp8_cutlass", "mxfp4_cutlass"]:
if bias and recipe_name in ["mxfp8_cublas", "mxfp4_cutlass"]:
# TODO(future PR): fix this, things are clearly broken with bias=True
pytest.skip("this test is broken for non-emulated recipes with bias=True")

if use_fp8_dim1_cast_triton_kernel:
if recipe_name not in ("mxfp8_emulated", "mxfp8_cublas", "mxfp8_cutlass"):
if recipe_name not in ("mxfp8_emulated", "mxfp8_cublas"):
pytest.skip("unsupported configuration")
if not is_sm_at_least_89():
pytest.skip("CUDA capability >= 8.9 required for float8 in triton")
Expand Down
10 changes: 8 additions & 2 deletions test/prototype/mx_formats/test_mx_mm.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,13 @@
#
# This source code is licensed under the BSD 3-Clause license found in the
# LICENSE file in the root directory of this source tree.
from functools import partial

import pytest
import torch

from torchao.float8.float8_utils import compute_error
from torchao.ops import mx_fp4_bf16, mx_fp8_bf16
from torchao.ops import mx_fp4_bf16
from torchao.prototype.mx_formats.mx_tensor import DTYPE_FP4, MXTensor
from torchao.prototype.mx_formats.utils import to_blocked
from torchao.utils import TORCH_VERSION_AT_LEAST_2_7, is_sm_at_least_100
Expand All @@ -24,7 +26,11 @@ def run_matrix_test(M: int, K: int, N: int, format) -> float:
b = torch.rand((N, K), dtype=dtype, device=device)

fmt = torch.float8_e4m3fn if format == "fp8" else DTYPE_FP4
mx_func = mx_fp8_bf16 if format == "fp8" else mx_fp4_bf16
mx_func = (
partial(torch._scaled_mm, out_dtype=torch.bfloat16)
if format == "fp8"
else mx_fp4_bf16
)

a_mx = MXTensor.to_mx(a, fmt, 32)
b_mx = MXTensor.to_mx(b, fmt, 32)
Expand Down
30 changes: 0 additions & 30 deletions torchao/csrc/cuda/mx_kernels/mx_fp_cutlass_kernels.cu
Original file line number Diff line number Diff line change
Expand Up @@ -221,33 +221,6 @@ void validate(at::Tensor a, at::Tensor b, at::Tensor a_scale, at::Tensor b_scale
"Input tensor 'b' must be contiguous in the K dimension (column-major)");
}


at::Tensor mx_fp8_bf16(at::Tensor a, at::Tensor b, at::Tensor a_scale,
at::Tensor b_scale) {
#if defined(BUILD_MX_KERNELS_CUTLASS)
validate(a, b, a_scale, b_scale);
auto M = a.size(0);
auto K = a.size(1);
auto N = b.size(1);

auto out =
at::empty({M, N}, a.options().dtype(at::kBFloat16));
using ElementA = cutlass::mx_float8_t<cutlass::float_e4m3_t>;
using ElementB = cutlass::mx_float8_t<cutlass::float_e4m3_t>;
using ElementD = cutlass::bfloat16_t;

using MmaTileShape = Shape<_128,_128,_128>;
using ClusterShape = Shape<_2,_1,_1>;
using PerSmTileShape_MNK = Shape<_128,_128,_128>;

run_gemm<ElementA, ElementB, ElementD, MmaTileShape, ClusterShape, PerSmTileShape_MNK>(a, b, a_scale, b_scale, out, M, K, N);
return out;
#else
TORCH_CHECK_NOT_IMPLEMENTED(false, __func__);
return at::Tensor{};
#endif
}

at::Tensor mx_fp4_bf16(at::Tensor a, at::Tensor b, at::Tensor a_scale,
at::Tensor b_scale) {
#if defined(BUILD_MX_KERNELS_CUTLASS)
Expand Down Expand Up @@ -278,9 +251,6 @@ at::Tensor mx_fp4_bf16(at::Tensor a, at::Tensor b, at::Tensor a_scale,
#endif
}

TORCH_LIBRARY_IMPL(torchao, CUDA, m) {
m.impl("torchao::mx_fp8_bf16", &mx_fp8_bf16);
}
TORCH_LIBRARY_IMPL(torchao, CUDA, m) {
m.impl("torchao::mx_fp4_bf16", &mx_fp4_bf16);
}
Expand Down
23 changes: 0 additions & 23 deletions torchao/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -818,29 +818,6 @@ def _check_scale_dtypes(A_scale, B_scale):
)


def mx_fp8_bf16(A: Tensor, B: Tensor, A_scale: Tensor, B_scale: Tensor):
"""Defines a matmul between two fp8 tensors w/ MX scales in E8MO and returns a bf16 tensor.

This op is prototype subject to change.

Note: The mx scales are E8MO tensors store in uint8 tensors (for now).
The layout of the scales is very particular, see:
https://docs.nvidia.com/cuda/cublas/index.html#d-block-scaling-factors-layout

Args:
A: fp8 tensor w/ dtype = torch.float8_e4m3fn
B: fp8 tensor w/ dtype = torch.float8_e4m3fn
A_scale: E8M0 scale tensor for A with groupsize=32 in swizzled layout
B_scale: E8M0 scale tensor for B with groupsize=32 in swizzled layout

Returns:
MXN bf16 Tensor

"""
_check_scale_dtypes(A_scale, B_scale)
return torch.ops.torchao.mx_fp8_bf16.default(A, B, A_scale, B_scale)


@register_custom_op("torchao::mx_fp8_bf16")
def meta_mx_fp8_bf16(A: Tensor, B: Tensor, A_scale: Tensor, B_scale: Tensor):
"""Meta impl for mx_fp8_bf16"""
Expand Down
3 changes: 0 additions & 3 deletions torchao/prototype/mx_formats/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,6 @@ class MXGemmKernelChoice(Enum):
class MXLinearRecipeName(Enum):
MXFP8_EMULATED = "mxfp8_emulated"
MXFP8_CUBLAS = "mxfp8_cublas"
MXFP8_CUTLASS = "mxfp8_cutlass"
MXFP4_EMULATED = "mxfp4_emulated"
MXFP4_CUTLASS = "mxfp4_cutlass"

Expand Down Expand Up @@ -126,8 +125,6 @@ def from_recipe_name(
return MXLinearConfig()
elif recipe_name is MXLinearRecipeName.MXFP8_CUBLAS:
return MXLinearConfig(gemm_kernel_choice=MXGemmKernelChoice.CUBLAS)
elif recipe_name is MXLinearRecipeName.MXFP8_CUTLASS:
return MXLinearConfig(gemm_kernel_choice=MXGemmKernelChoice.CUTLASS)
elif recipe_name is MXLinearRecipeName.MXFP4_EMULATED:
return MXLinearConfig(elem_dtype=DTYPE_FP4)
elif recipe_name is MXLinearRecipeName.MXFP4_CUTLASS:
Expand Down
22 changes: 10 additions & 12 deletions torchao/prototype/mx_formats/mx_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,18 +88,16 @@ def mx_mm(aten_op, args, kwargs=None):
b_scale_block = to_blocked(b_scale)
if a._elem_dtype == torch.float8_e4m3fn:
assert b._elem_dtype == torch.float8_e4m3fn
if a._gemm_kernel_choice is MXGemmKernelChoice.CUBLAS:
res = torch._scaled_mm(
a._data,
b._data,
a_scale_block.view(torch.float8_e8m0fnu),
b_scale_block.view(torch.float8_e8m0fnu),
out_dtype=torch.bfloat16,
)
else:
res = torchao.ops.mx_fp8_bf16(
a._data, b._data, a_scale_block, b_scale_block
)
assert a._gemm_kernel_choice is MXGemmKernelChoice.CUBLAS, (
"CUBLAS is the only supported kernel choice for MX FP8 operations"
)
res = torch._scaled_mm(
a._data,
b._data,
a_scale_block.view(torch.float8_e8m0fnu),
b_scale_block.view(torch.float8_e8m0fnu),
out_dtype=torch.bfloat16,
)
else:
assert a._elem_dtype == DTYPE_FP4
assert b._elem_dtype == DTYPE_FP4
Expand Down
2 changes: 0 additions & 2 deletions torchao/testing/float8/roofline_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,7 +180,6 @@ def get_tensor_memory_traffic_ovhd_s(
else:
assert mx_recipe_name in (
"mxfp8_emulated",
"mxfp8_cutlass",
"mxfp8_cublas",
), "unsupported"
# For now, assume that we can't profitably fuse kernel 1 and kernel 2
Expand Down Expand Up @@ -227,7 +226,6 @@ def get_individual_gemm_time_sympy(
if mx_recipe_name is not None:
assert mx_recipe_name in (
"mxfp8_emulated",
"mxfp8_cutlass",
"mxfp8_cublas",
), "unsupported"
assert dtype in (torch.float8_e4m3fn, torch.float8_e5m2), "unsupported"
Expand Down
Loading