Skip to content

Commit 0154432

Browse files
committed
[wip] mx: expose a fast path for casting to fp4x2
Summary: not ready for review yet Test Plan: Reviewers: Subscribers: Tasks: Tags: ghstack-source-id: 2d88961 ghstack-comment-id: 3210931181 Pull-Request: #2832
1 parent ad9155a commit 0154432

File tree

2 files changed

+85
-0
lines changed

2 files changed

+85
-0
lines changed

test/prototype/mx_formats/test_kernels.py

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -561,3 +561,45 @@ def test_cuda_mx_dim1_invalid_block_size():
561561
scale_dim_x=1,
562562
scale_dim_y=invalid_block_size,
563563
)
564+
565+
566+
def _fp32_to_fp4_reference(
567+
data_hp: torch.Tensor,
568+
) -> torch.Tensor:
569+
# works
570+
data_hp = data_hp.float()
571+
data_lp = f32_to_f4_unpacked(data_hp)
572+
573+
# does not work
574+
# data_lp = f32_to_f4_unpacked(data_hp.float())
575+
576+
data_lp = pack_uint4(data_lp)
577+
return data_lp
578+
579+
580+
# TODO add skips
581+
def test_fp32_cast_to_fp4x2():
582+
from torchao.prototype.mx_formats.kernels import triton_fp32_cast_to_fp4x2
583+
584+
M, K = 16, 16
585+
x = torch.randn(M, K, dtype=torch.bfloat16, device="cuda")
586+
# make x's range be the representable range of fp4
587+
x = x * 6.0
588+
589+
# this leads to values in `x` being overridden inplace
590+
# TODO fix it
591+
print(0, x)
592+
data = triton_fp32_cast_to_fp4x2(x)
593+
print(1, x)
594+
return
595+
596+
data_ref = _fp32_to_fp4_reference(x)
597+
# print(2, x[0])
598+
data = triton_fp32_cast_to_fp4x2(x)
599+
# print(3, x[0])
600+
# print(0, x)
601+
# print(1, data_ref, data_ref.shape)
602+
# print(2, data, data.shape)
603+
torch.testing.assert_close(data_ref, data)
604+
assert data.shape == (M, K // 2)
605+
print("done")

torchao/prototype/mx_formats/kernels.py

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1454,6 +1454,49 @@ def _(scale_tensor):
14541454
padded_cols = n_col_blocks * 4
14551455

14561456
return scale_tensor.new_empty((padded_rows, padded_cols))
1457+
1458+
@triton.jit
1459+
def fp32_cast_to_fp4x2_triton_kernel(
1460+
x_ptr,
1461+
q_ptr,
1462+
stride_xm,
1463+
stride_xn,
1464+
M,
1465+
N,
1466+
):
1467+
pid_m = tl.program_id(1)
1468+
pid_n = tl.program_id(0)
1469+
1470+
offs_m = pid_m * 128 + tl.arange(0, 128)[:, None]
1471+
offs_n = pid_n * 64 + tl.arange(0, 64)[None, :]
1472+
mask = None
1473+
other = None
1474+
x = tl.load(
1475+
x_ptr + offs_m * stride_xm + offs_n * stride_xn, mask=mask, other=other
1476+
) # [128, 64]
1477+
x_blocks = x.to(tl.float32).reshape(128, 4, 16) # [128, 4, 16]
1478+
1479+
# Convert to FP4
1480+
x_fp4x2 = convert_fp32_to_fp4_packed(x_blocks.reshape(128, 32, 2).split())
1481+
offs_m = pid_m * 128 + tl.arange(0, 128)[:, None]
1482+
offs_n = pid_n * 32 + tl.arange(0, 32)[None, :]
1483+
tl.store(q_ptr + offs_m * (N // 2) + offs_n, x_fp4x2, mask=None)
1484+
1485+
def triton_fp32_cast_to_fp4x2(x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
1486+
M, N = x.shape
1487+
assert N % 16 == 0, "N must be divisible by 16 for NVFP4 quantization"
1488+
xq = x.new_empty(M, N // 2, dtype=torch.uint8)
1489+
grid = (triton.cdiv(N, 64), triton.cdiv(M, 128))
1490+
fp32_cast_to_fp4x2_triton_kernel[grid](
1491+
x,
1492+
xq,
1493+
x.stride(0),
1494+
x.stride(1),
1495+
M,
1496+
N,
1497+
)
1498+
1499+
return xq.view(torch.uint8)
14571500
else:
14581501

14591502
def triton_to_mxfp8_dim1(

0 commit comments

Comments
 (0)