Skip to content

ROCm mx-fp8 Gemm #2066

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion torchao/prototype/mx_formats/README.md
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# MX training and inference with native PyTorch

This is a workflow for e2e training and inference with MX dtypes from the [MX OCP spec](https://www.opencompute.org/documents/ocp-microscaling-formats-mx-v1-0-spec-final-pdf)
in native PyTorch. We are currently in prototype and are actively working on optimizing these workflows on the NVIDIA B200 hardware.
in native PyTorch. We are currently in prototype and are actively working on optimizing these workflows on the NVIDIA B200 and AMD MI355x hardware.

## Overall status

Expand Down Expand Up @@ -29,6 +29,9 @@ from torchao.prototype.mx_formats import MXLinearConfig, MXGemmKernelChoice
gemm_kernel_choice = MXGemmKernelChoice.CUBLAS
# gemm_kernel_choice = MXGemmKernelChoice.CUTLASS

# on AMD MI355x GPUs with ROCm 6.5+ and gfx950, you can use HIPBLASLT mxfp8 kernels
# gemm_kernel_choice = MXGemmKernelChoice.HIPBLASLT

# on older NVIDIA gpus, you can run training with emulated MX gemm
# gemm_kernel_choice = MXGemmKernelChoice.EMULATED

Expand Down Expand Up @@ -97,6 +100,8 @@ on supported hardware, you can run the following command:
// example output: https://gist.github.com/vkuzo/a1ddb782e6e1c2aef0c726b3df99efbc
```

On AMD MI355x GPUs with ROCm 6.5+ and gfx950, we use HIPBLASLT for mxfp8 gemm. We are actively working on optimizing the end-to-end performance for AMD hardware.

## to_mx cast across dim0 and dim1

On NVIDIA B200 machines, our to_mx kernels for mxfp8 achieve **up to 5.5 TB/s** for the dim0 cast (with torch.compile),
Expand Down
21 changes: 18 additions & 3 deletions torchao/prototype/mx_formats/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,12 +33,16 @@ class MXGemmKernelChoice(Enum):
# note: torch.compile does not work yet, see https://github.com/pytorch/pytorch/issues/147873
CUBLAS = "cublas"

# available only on ROCm with HIPBLASLT support
HIPBLASLT = "hipblaslt"


# Pre-made recipes for common configurations
class MXLinearRecipeName(Enum):
MXFP8_EMULATED = "mxfp8_emulated"
MXFP8_CUBLAS = "mxfp8_cublas"
MXFP8_CUTLASS = "mxfp8_cutlass"
MXFP8_HIPBLASLT = "mxfp8_hipblaslt"
MXFP4_EMULATED = "mxfp4_emulated"
MXFP4_CUTLASS = "mxfp4_cutlass"

Expand All @@ -63,9 +67,18 @@ def _validate_gemm_kernel_choice(gemm_kernel_choice, block_size, elem_dtype):
f"block_size must be 32 to use the cuBLAS MX gemm kernels, got {block_size}"
)
valid_dtypes = [torch.float8_e4m3fn]
assert elem_dtype in valid_dtypes, (
f"elem_dtype must be one of {valid_dtypes} to use the CUTLASS MX gemm kernels, got {elem_dtype}"
)
assert (
elem_dtype in valid_dtypes
), f"elem_dtype must be one of {valid_dtypes} to use the CUTLASS MX gemm kernels, got {elem_dtype}"
elif gemm_kernel_choice == MXGemmKernelChoice.HIPBLASLT:
assert (
block_size == 32
), f"block_size must be 32 to use the HIPBLASLT MX gemm kernels, got {block_size}"
valid_dtypes = [torch.float8_e4m3fn]
assert (
elem_dtype in valid_dtypes
), f"elem_dtype must be one of {valid_dtypes} to use the HIPBLASLT MX gemm kernels, got {elem_dtype}"
assert torch.version.hip is not None, "HIPBLASLT requires ROCm"


@dataclass
Expand Down Expand Up @@ -128,6 +141,8 @@ def from_recipe_name(
return MXLinearConfig(gemm_kernel_choice=MXGemmKernelChoice.CUBLAS)
elif recipe_name is MXLinearRecipeName.MXFP8_CUTLASS:
return MXLinearConfig(gemm_kernel_choice=MXGemmKernelChoice.CUTLASS)
elif recipe_name is MXLinearRecipeName.MXFP8_HIPBLASLT:
return MXLinearConfig(gemm_kernel_choice=MXGemmKernelChoice.HIPBLASLT)
elif recipe_name is MXLinearRecipeName.MXFP4_EMULATED:
return MXLinearConfig(elem_dtype=DTYPE_FP4)
elif recipe_name is MXLinearRecipeName.MXFP4_CUTLASS:
Expand Down
20 changes: 16 additions & 4 deletions torchao/prototype/mx_formats/mx_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,8 +75,14 @@ def mx_mm(aten_op, args, kwargs=None):
b = args[1]
assert isinstance(a, MXTensor) and isinstance(b, MXTensor)
assert a._gemm_kernel_choice == b._gemm_kernel_choice, "unsupported"
if a._gemm_kernel_choice in (MXGemmKernelChoice.CUBLAS, MXGemmKernelChoice.CUTLASS):
# real MX gemm backed by torchao's CUTLASS kernels
kernel_choice = a._gemm_kernel_choice
valid_kernels = (
MXGemmKernelChoice.CUBLAS,
MXGemmKernelChoice.CUTLASS,
MXGemmKernelChoice.HIPBLASLT,
)
if kernel_choice in valid_kernels:
# real MX gemm backed by torchao's CUTLASS/CUBLAS/HIPBLASLT kernels
M, K, N = a.shape[0], a.shape[1], b.shape[1]
assert a._data.is_contiguous()
assert b._data.t().is_contiguous()
Expand All @@ -88,7 +94,12 @@ def mx_mm(aten_op, args, kwargs=None):
b_scale_block = to_blocked(b_scale)
if a._elem_dtype == torch.float8_e4m3fn:
assert b._elem_dtype == torch.float8_e4m3fn
if a._gemm_kernel_choice is MXGemmKernelChoice.CUBLAS:
scaled_mm_kernels = (
MXGemmKernelChoice.CUBLAS,
MXGemmKernelChoice.HIPBLASLT,
)
if kernel_choice in scaled_mm_kernels:
# Use native scaled_mm for both CUBLAS and HIPBLASLT
res = torch._scaled_mm(
a._data,
b._data,
Expand All @@ -103,7 +114,8 @@ def mx_mm(aten_op, args, kwargs=None):
else:
assert a._elem_dtype == DTYPE_FP4
assert b._elem_dtype == DTYPE_FP4
assert a._gemm_kernel_choice is MXGemmKernelChoice.CUTLASS, "unsupported"
msg = "FP4 is only supported with CUTLASS kernel at this moment"
assert kernel_choice is MXGemmKernelChoice.CUTLASS, msg
res = torchao.ops.mx_fp4_bf16(
a._data, b._data, a_scale_block, b_scale_block
)
Expand Down
Loading