Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 43 additions & 1 deletion benchmarks/mx_formats/cast_bench.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,8 +59,15 @@ def to_mx_dim0_reference(
block_size,
scaling_mode=ScaleCalculationMode.FLOOR,
target_dtype=torch.float8_e4m3fn,
use_fp32_to_fp4_triton_kernel=False,
):
scale_d0, data_d0 = to_mx(x_hp, target_dtype, block_size, scaling_mode=scaling_mode)
scale_d0, data_d0 = to_mx(
x_hp,
target_dtype,
block_size,
scaling_mode=scaling_mode,
use_fp32_to_fp4_triton_kernel=use_fp32_to_fp4_triton_kernel,
)
return data_d0, scale_d0


Expand Down Expand Up @@ -96,6 +103,7 @@ def run(
"dim0_dim1",
"dim0_mxfp8_floor",
"dim0_mxfp4_floor",
"dim0_mxfp4_triton_floor",
"dim0_mxfp8_rceil",
"dim1_mxfp8_floor",
"dim1_mxfp8_rceil",
Expand Down Expand Up @@ -204,6 +212,40 @@ def run(
bytes_w = (y_d0.numel() + s_d0.numel()) * bytes_per_el_fp8
bps = (bytes_r + bytes_w) / (time_us / 1e6)

elif mode == "dim0_mxfp4_triton_floor":
to_mx_dim0_reference_c = torch.compile(to_mx_dim0_reference)
y_d0, s_d0 = to_mx_dim0_reference_c(
x,
BLOCK_SIZE,
target_dtype=torch.float4_e2m1fn_x2,
use_fp32_to_fp4_triton_kernel=True,
)

for _ in range(2):
__ = to_mx_dim0_reference_c(
x,
BLOCK_SIZE,
target_dtype=torch.float4_e2m1fn_x2,
use_fp32_to_fp4_triton_kernel=True,
)
time_us = benchmark_cuda_function_in_microseconds(
lambda x, b: to_mx_dim0_reference_c(
x,
BLOCK_SIZE,
target_dtype=torch.float4_e2m1fn_x2,
use_fp32_to_fp4_triton_kernel=True,
),
x,
BLOCK_SIZE,
)

# TODO(future PR): make to_mx return float4 directly
assert y_d0.dtype == torch.uint8
assert s_d0.dtype == torch.float8_e8m0fnu
bytes_r = x.numel() * bytes_per_el_bf16
bytes_w = (y_d0.numel() + s_d0.numel()) * bytes_per_el_fp8
bps = (bytes_r + bytes_w) / (time_us / 1e6)

elif mode == "dim0_mxfp8_rceil":
to_mx_dim0_reference_c = torch.compile(to_mx_dim0_reference)
y_d0, s_d0 = to_mx_dim0_reference_c(x, BLOCK_SIZE, ScaleCalculationMode.RCEIL)
Expand Down
26 changes: 26 additions & 0 deletions test/prototype/mx_formats/test_kernels.py
Original file line number Diff line number Diff line change
Expand Up @@ -561,3 +561,29 @@ def test_cuda_mx_dim1_invalid_block_size():
scale_dim_x=1,
scale_dim_y=invalid_block_size,
)


def _fp32_to_fp4_reference(
data_hp: torch.Tensor,
) -> torch.Tensor:
data_lp = f32_to_f4_unpacked(data_hp.float())
data_lp = pack_uint4(data_lp)
return data_lp


@pytest.mark.skipif(
not is_sm_at_least_100(),
reason="requires CUDA capability 10.0 or greater",
)
def test_fp32_cast_to_fp4x2():
from torchao.prototype.mx_formats.kernels import triton_fp32_cast_to_fp4x2

M, K = 16, 16
x = torch.randn(M, K, dtype=torch.bfloat16, device="cuda")
# make x's range be the representable range of fp4
x = x * 6.0

data_ref = _fp32_to_fp4_reference(x)
data = triton_fp32_cast_to_fp4x2(x)
torch.testing.assert_close(data_ref, data, atol=0, rtol=0)
assert data.shape == (M, K // 2)
16 changes: 16 additions & 0 deletions test/prototype/mx_formats/test_mx_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,22 @@ def test_realistic_numerics(elem_dtype, scale_calculation_mode):
_test_mx(data, elem_dtype, block_size, scale_calculation_mode)


def test_fp4_triton_cast_does_not_change_numerics():
# TODO(before land): proper skips
# TODO(before land): test rank 3
data = torch.randn(128, 128, device="cuda", dtype=torch.bfloat16)
data_mx_ref = MXTensor.to_mx(
data, torch.float4_e2m1fn_x2, 32, use_fp32_to_fp4_triton_kernel=False
)
data_mx = MXTensor.to_mx(
data, torch.float4_e2m1fn_x2, 32, use_fp32_to_fp4_triton_kernel=True
)
torch.testing.assert_close(data_mx_ref.qdata, data_mx.qdata, atol=0, rtol=0)
torch.testing.assert_close(
data_mx_ref._scale_e8m0, data_mx._scale_e8m0, atol=0, rtol=0
)


@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
@pytest.mark.parametrize("elem_dtype", SUPPORTED_ELEM_DTYPES)
def test_all_zeros(elem_dtype):
Expand Down
54 changes: 54 additions & 0 deletions torchao/prototype/mx_formats/kernels.py
Original file line number Diff line number Diff line change
Expand Up @@ -1454,6 +1454,57 @@ def _(scale_tensor):
padded_cols = n_col_blocks * 4

return scale_tensor.new_empty((padded_rows, padded_cols))

@triton.jit
def fp32_cast_to_fp4x2_triton_kernel(
x_ptr,
q_ptr,
stride_xm,
stride_xn,
M,
N,
):
pid_m = tl.program_id(1)
pid_n = tl.program_id(0)
offs_m = pid_m * 128 + tl.arange(0, 128)[:, None]
offs_n = pid_n * 64 + tl.arange(0, 64)[None, :]
mask = None
other = None
x = tl.load(
x_ptr + offs_m * stride_xm + offs_n * stride_xn, mask=mask, other=other
) # [128, 64]
x_blocks = x.to(tl.float32).reshape(128, 4, 16) # [128, 4, 16]
# Convert to FP4
x_fp4x2 = convert_fp32_to_fp4_packed(x_blocks.reshape(128, 32, 2).split())
offs_m = pid_m * 128 + tl.arange(0, 128)[:, None]
offs_n = pid_n * 32 + tl.arange(0, 32)[None, :]
mask = (offs_m < M) & (offs_n < N // 2)
tl.store(q_ptr + offs_m * (N // 2) + offs_n, x_fp4x2, mask=mask)

def triton_fp32_cast_to_fp4x2(x: torch.Tensor) -> torch.Tensor:
"""
Input: a float32 tensor with shape (M, N)
Output: a uint8 tensor with shape (M, N // 2), with the values being the result
of casting each original value to fp4_e2m1, and then packing fp4x2

TODO(future PR): optimize performance, lowest hanging fruit is we want
to add an e8m0 scale and scale the incoming tensor inside of this kernel
TODO(future PR): better checks for shapes, etc
TODO(future PR): integrate into training/inference
TODO(future PR): integrate with compile, ideally allowing fusion
"""
M, N = x.shape
xq = x.new_empty(M, N // 2, dtype=torch.uint8)
grid = (triton.cdiv(N, 64), triton.cdiv(M, 128))
fp32_cast_to_fp4x2_triton_kernel[grid](
x,
xq,
x.stride(0),
x.stride(1),
M,
N,
)
return xq.view(torch.uint8)
else:

def triton_to_mxfp8_dim1(
Expand All @@ -1475,6 +1526,9 @@ def triton_quantize_nvfp4(
) -> Tuple[torch.Tensor, torch.Tensor]:
raise AssertionError("needs torch version 2.8+ and triton")

def triton_fp32_cast_to_fp4x2(x: torch.Tensor) -> torch.Tensor:
raise AssertionError("needs torch version 2.8+ and triton")


# MXFP8 CUDA kernel is only built on SM100+
if is_sm_at_least_100():
Expand Down
28 changes: 20 additions & 8 deletions torchao/prototype/mx_formats/mx_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@
pack_uint6,
triton_f6_e2m3_to_scaled_bf16,
triton_f6_e3m2_to_scaled_bf16,
triton_fp32_cast_to_fp4x2,
unpack_uint4,
)
from torchao.quantization.quantize_.common import (
Expand Down Expand Up @@ -134,6 +135,7 @@ def to_mx(
block_size: int,
scaling_mode: ScaleCalculationMode = ScaleCalculationMode.FLOOR,
pack_fp6: bool = False,
use_fp32_to_fp4_triton_kernel: bool = False,
):
"""
Takes a high precision tensor and converts to MX scale and raw data, in
Expand Down Expand Up @@ -309,13 +311,17 @@ def to_mx(
# need to reshape at the end to help inductor fuse things
data_lp = data_lp.reshape(orig_shape)
elif elem_dtype == torch.float4_e2m1fn_x2:
# can't reshape at the end without handling it in the packing code,
# punt until later since we'll need to rethink the torch.compile
# approach for fp4x2 in any case
data_lp = data_lp.reshape(orig_shape)
data_lp = f32_to_f4_unpacked(data_lp)
orig_shape = [*orig_shape[:-1], orig_shape[-1] // 2]
data_lp = pack_uint4(data_lp)
if use_fp32_to_fp4_triton_kernel:
data_lp = data_lp.reshape(orig_shape)
data_lp = triton_fp32_cast_to_fp4x2(data_lp)
else:
# can't reshape at the end without handling it in the packing code,
# punt until later since we'll need to rethink the torch.compile
# approach for fp4x2 in any case
data_lp = data_lp.reshape(orig_shape)
data_lp = f32_to_f4_unpacked(data_lp)
orig_shape = [*orig_shape[:-1], orig_shape[-1] // 2]
data_lp = pack_uint4(data_lp)
else:
raise AssertionError("unsupported")

Expand Down Expand Up @@ -583,9 +589,15 @@ def to_mx(
gemm_kernel_choice: MXGemmKernelChoice = MXGemmKernelChoice.EMULATED,
pack_fp6: bool = False,
act_quant_kwargs: Optional[QuantizeTensorToMXKwargs] = None,
use_fp32_to_fp4_triton_kernel: bool = False,
):
scale_e8m0_biased, data_lp = to_mx(
data_hp, elem_dtype, block_size, scaling_mode, pack_fp6
data_hp,
elem_dtype,
block_size,
scaling_mode,
pack_fp6,
use_fp32_to_fp4_triton_kernel,
)
if isinstance(scale_e8m0_biased, DTensor):
assert isinstance(data_lp, DTensor), "unsupported"
Expand Down
Loading