diff --git a/test/prototype/mx_formats/test_mx_linear.py b/test/prototype/mx_formats/test_mx_linear.py index 934e40eb74..b0cee1e918 100644 --- a/test/prototype/mx_formats/test_mx_linear.py +++ b/test/prototype/mx_formats/test_mx_linear.py @@ -137,7 +137,6 @@ def test_linear_eager_vs_hp( "recipe_name", [ MXLinearRecipeName.MXFP8_CUBLAS, - MXLinearRecipeName.MXFP8_CUTLASS, MXLinearRecipeName.MXFP4_CUTLASS, ], ) @@ -206,7 +205,6 @@ def test_activation_checkpointing(): "mxfp8_emulated", "mxfp4_emulated", "mxfp8_cublas", - "mxfp8_cutlass", "mxfp4_cutlass", ], ) @@ -218,22 +216,22 @@ def test_linear_compile(hp_dtype, recipe_name, bias, use_fp8_dim1_cast_triton_ke """ Verify that compile does not change numerics of MX linear fw + bw """ - if recipe_name in ["mxfp8_emulated", "mxfp8_cutlass"]: + if recipe_name in ["mxfp8_emulated"]: if not is_sm_at_least_89(): pytest.skip("CUDA capability >= 8.9 required for float8 in triton") - if recipe_name in ["mxfp8_cublas", "mxfp8_cutlass", "mxfp4_cutlass"]: + if recipe_name in ["mxfp8_cublas", "mxfp4_cutlass"]: if not TORCH_VERSION_AT_LEAST_2_8: pytest.skip("torch.compile requires PyTorch 2.8+") if not is_sm_at_least_100(): pytest.skip("CUDA capability >= 10.0 required for MX gemms") - if bias and recipe_name in ["mxfp8_cublas", "mxfp8_cutlass", "mxfp4_cutlass"]: + if bias and recipe_name in ["mxfp8_cublas", "mxfp4_cutlass"]: # TODO(future PR): fix this, things are clearly broken with bias=True pytest.skip("this test is broken for non-emulated recipes with bias=True") if use_fp8_dim1_cast_triton_kernel: - if recipe_name not in ("mxfp8_emulated", "mxfp8_cublas", "mxfp8_cutlass"): + if recipe_name not in ("mxfp8_emulated", "mxfp8_cublas"): pytest.skip("unsupported configuration") if not is_sm_at_least_89(): pytest.skip("CUDA capability >= 8.9 required for float8 in triton") diff --git a/test/prototype/mx_formats/test_mx_mm.py b/test/prototype/mx_formats/test_mx_mm.py index 2f9695aa38..1b16fa24ab 100644 --- a/test/prototype/mx_formats/test_mx_mm.py +++ b/test/prototype/mx_formats/test_mx_mm.py @@ -3,11 +3,13 @@ # # This source code is licensed under the BSD 3-Clause license found in the # LICENSE file in the root directory of this source tree. +from functools import partial + import pytest import torch from torchao.float8.float8_utils import compute_error -from torchao.ops import mx_fp4_bf16, mx_fp8_bf16 +from torchao.ops import mx_fp4_bf16 from torchao.prototype.mx_formats.mx_tensor import DTYPE_FP4, MXTensor from torchao.prototype.mx_formats.utils import to_blocked from torchao.utils import TORCH_VERSION_AT_LEAST_2_7, is_sm_at_least_100 @@ -24,7 +26,11 @@ def run_matrix_test(M: int, K: int, N: int, format) -> float: b = torch.rand((N, K), dtype=dtype, device=device) fmt = torch.float8_e4m3fn if format == "fp8" else DTYPE_FP4 - mx_func = mx_fp8_bf16 if format == "fp8" else mx_fp4_bf16 + mx_func = ( + partial(torch._scaled_mm, out_dtype=torch.bfloat16) + if format == "fp8" + else mx_fp4_bf16 + ) a_mx = MXTensor.to_mx(a, fmt, 32) b_mx = MXTensor.to_mx(b, fmt, 32) diff --git a/torchao/csrc/cuda/mx_kernels/mx_fp_cutlass_kernels.cu b/torchao/csrc/cuda/mx_kernels/mx_fp_cutlass_kernels.cu index 87aa0b0501..9f928d2b89 100644 --- a/torchao/csrc/cuda/mx_kernels/mx_fp_cutlass_kernels.cu +++ b/torchao/csrc/cuda/mx_kernels/mx_fp_cutlass_kernels.cu @@ -221,33 +221,6 @@ void validate(at::Tensor a, at::Tensor b, at::Tensor a_scale, at::Tensor b_scale "Input tensor 'b' must be contiguous in the K dimension (column-major)"); } - -at::Tensor mx_fp8_bf16(at::Tensor a, at::Tensor b, at::Tensor a_scale, - at::Tensor b_scale) { -#if defined(BUILD_MX_KERNELS_CUTLASS) - validate(a, b, a_scale, b_scale); - auto M = a.size(0); - auto K = a.size(1); - auto N = b.size(1); - - auto out = - at::empty({M, N}, a.options().dtype(at::kBFloat16)); - using ElementA = cutlass::mx_float8_t; - using ElementB = cutlass::mx_float8_t; - using ElementD = cutlass::bfloat16_t; - - using MmaTileShape = Shape<_128,_128,_128>; - using ClusterShape = Shape<_2,_1,_1>; - using PerSmTileShape_MNK = Shape<_128,_128,_128>; - - run_gemm(a, b, a_scale, b_scale, out, M, K, N); - return out; - #else - TORCH_CHECK_NOT_IMPLEMENTED(false, __func__); - return at::Tensor{}; -#endif -} - at::Tensor mx_fp4_bf16(at::Tensor a, at::Tensor b, at::Tensor a_scale, at::Tensor b_scale) { #if defined(BUILD_MX_KERNELS_CUTLASS) @@ -278,9 +251,6 @@ at::Tensor mx_fp4_bf16(at::Tensor a, at::Tensor b, at::Tensor a_scale, #endif } -TORCH_LIBRARY_IMPL(torchao, CUDA, m) { - m.impl("torchao::mx_fp8_bf16", &mx_fp8_bf16); -} TORCH_LIBRARY_IMPL(torchao, CUDA, m) { m.impl("torchao::mx_fp4_bf16", &mx_fp4_bf16); } diff --git a/torchao/ops.py b/torchao/ops.py index fb84bbf878..0cb9b7e706 100644 --- a/torchao/ops.py +++ b/torchao/ops.py @@ -818,29 +818,6 @@ def _check_scale_dtypes(A_scale, B_scale): ) -def mx_fp8_bf16(A: Tensor, B: Tensor, A_scale: Tensor, B_scale: Tensor): - """Defines a matmul between two fp8 tensors w/ MX scales in E8MO and returns a bf16 tensor. - - This op is prototype subject to change. - - Note: The mx scales are E8MO tensors store in uint8 tensors (for now). - The layout of the scales is very particular, see: - https://docs.nvidia.com/cuda/cublas/index.html#d-block-scaling-factors-layout - - Args: - A: fp8 tensor w/ dtype = torch.float8_e4m3fn - B: fp8 tensor w/ dtype = torch.float8_e4m3fn - A_scale: E8M0 scale tensor for A with groupsize=32 in swizzled layout - B_scale: E8M0 scale tensor for B with groupsize=32 in swizzled layout - - Returns: - MXN bf16 Tensor - - """ - _check_scale_dtypes(A_scale, B_scale) - return torch.ops.torchao.mx_fp8_bf16.default(A, B, A_scale, B_scale) - - @register_custom_op("torchao::mx_fp8_bf16") def meta_mx_fp8_bf16(A: Tensor, B: Tensor, A_scale: Tensor, B_scale: Tensor): """Meta impl for mx_fp8_bf16""" diff --git a/torchao/prototype/mx_formats/config.py b/torchao/prototype/mx_formats/config.py index ad1462c880..c49e1595a8 100644 --- a/torchao/prototype/mx_formats/config.py +++ b/torchao/prototype/mx_formats/config.py @@ -38,7 +38,6 @@ class MXGemmKernelChoice(Enum): class MXLinearRecipeName(Enum): MXFP8_EMULATED = "mxfp8_emulated" MXFP8_CUBLAS = "mxfp8_cublas" - MXFP8_CUTLASS = "mxfp8_cutlass" MXFP4_EMULATED = "mxfp4_emulated" MXFP4_CUTLASS = "mxfp4_cutlass" @@ -126,8 +125,6 @@ def from_recipe_name( return MXLinearConfig() elif recipe_name is MXLinearRecipeName.MXFP8_CUBLAS: return MXLinearConfig(gemm_kernel_choice=MXGemmKernelChoice.CUBLAS) - elif recipe_name is MXLinearRecipeName.MXFP8_CUTLASS: - return MXLinearConfig(gemm_kernel_choice=MXGemmKernelChoice.CUTLASS) elif recipe_name is MXLinearRecipeName.MXFP4_EMULATED: return MXLinearConfig(elem_dtype=DTYPE_FP4) elif recipe_name is MXLinearRecipeName.MXFP4_CUTLASS: diff --git a/torchao/prototype/mx_formats/mx_ops.py b/torchao/prototype/mx_formats/mx_ops.py index eb9b28ba04..af2d89c112 100644 --- a/torchao/prototype/mx_formats/mx_ops.py +++ b/torchao/prototype/mx_formats/mx_ops.py @@ -88,18 +88,16 @@ def mx_mm(aten_op, args, kwargs=None): b_scale_block = to_blocked(b_scale) if a._elem_dtype == torch.float8_e4m3fn: assert b._elem_dtype == torch.float8_e4m3fn - if a._gemm_kernel_choice is MXGemmKernelChoice.CUBLAS: - res = torch._scaled_mm( - a._data, - b._data, - a_scale_block.view(torch.float8_e8m0fnu), - b_scale_block.view(torch.float8_e8m0fnu), - out_dtype=torch.bfloat16, - ) - else: - res = torchao.ops.mx_fp8_bf16( - a._data, b._data, a_scale_block, b_scale_block - ) + assert a._gemm_kernel_choice is MXGemmKernelChoice.CUBLAS, ( + "CUBLAS is the only supported kernel choice for MX FP8 operations" + ) + res = torch._scaled_mm( + a._data, + b._data, + a_scale_block.view(torch.float8_e8m0fnu), + b_scale_block.view(torch.float8_e8m0fnu), + out_dtype=torch.bfloat16, + ) else: assert a._elem_dtype == DTYPE_FP4 assert b._elem_dtype == DTYPE_FP4 diff --git a/torchao/testing/float8/roofline_utils.py b/torchao/testing/float8/roofline_utils.py index 92becb9b94..7bfb9887df 100644 --- a/torchao/testing/float8/roofline_utils.py +++ b/torchao/testing/float8/roofline_utils.py @@ -180,7 +180,6 @@ def get_tensor_memory_traffic_ovhd_s( else: assert mx_recipe_name in ( "mxfp8_emulated", - "mxfp8_cutlass", "mxfp8_cublas", ), "unsupported" # For now, assume that we can't profitably fuse kernel 1 and kernel 2 @@ -227,7 +226,6 @@ def get_individual_gemm_time_sympy( if mx_recipe_name is not None: assert mx_recipe_name in ( "mxfp8_emulated", - "mxfp8_cutlass", "mxfp8_cublas", ), "unsupported" assert dtype in (torch.float8_e4m3fn, torch.float8_e5m2), "unsupported"