Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 7 additions & 3 deletions test/prototype/moe_training/test_kernels.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,15 +26,19 @@
torch_to_float8_per_group_colwise,
torch_to_float8_per_group_rowwise,
)
from torchao.testing.utils import skip_if_rocm
from torchao.testing.utils import(
skip_if_rocm,
)
from torchao.utils import auto_detect_device

_DEVICE = auto_detect_device()

@skip_if_rocm("ROCm enablement in progress")
@pytest.mark.parametrize("round_scales_to_power_of_2", [True, False])
def test_row_major_with_jagged_rowwise_scales(round_scales_to_power_of_2: bool):
# tests case where rowwise scales are computed for multiple distinct subtensors,
# with end boundary of each group is determine by their end column indexes (offsets).
device = "cuda"
device = _DEVICE
m, k, n_groups = 256, 256, 4
x = torch.randn(m, k * n_groups, device=device)
colwise_offs = torch.arange(k, k * n_groups + 1, k, device=device)
Expand Down Expand Up @@ -62,7 +66,7 @@ def test_row_major_with_jagged_rowwise_scales(round_scales_to_power_of_2: bool):
def test_column_major_with_jagged_colwise_scales(round_scales_to_power_of_2: bool):
# tests case where colwise scales are computed for multiple distinct subtensors,
# with end boundary of each group is determine by their end row indexes (offsets).
device = "cuda"
device = _DEVICE
m, k, n_groups = 256, 256, 4
x = torch.randn(m * n_groups, k, device=device).t().contiguous().t()
rowwise_offs = torch.arange(m, m * n_groups + 1, m, device=device)
Expand Down
30 changes: 17 additions & 13 deletions test/prototype/moe_training/test_scaled_grouped_mm.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,6 @@
# triton won't be available on CPU builds and torch < 2.5
if not (
TORCH_VERSION_AT_LEAST_2_7
and torch.cuda.is_available()
and torch.cuda.get_device_capability()[0] >= 9
):
pytest.skip("Unsupported PyTorch version", allow_module_level=True)

Expand All @@ -39,13 +37,16 @@
generate_jagged_offs,
)
from torchao.prototype.mx_formats.mx_tensor import to_mx
from torchao.testing.utils import skip_if_rocm
from torchao.testing.utils import skip_if_rocm, skip_if_xpu
from torchao.utils import auto_detect_device

_DEVICE = auto_detect_device()

@skip_if_rocm("ROCm not supported")
@skip_if_xpu("XPU not supported")
def test_valid_scaled_grouped_mm_2d_3d():
out_dtype = torch.bfloat16
device = "cuda"
device = _DEVICE
m, n, k, n_groups = 16, 32, 16, 4
a = torch.randn(
m * n_groups,
Expand All @@ -61,7 +62,7 @@ def test_valid_scaled_grouped_mm_2d_3d():
device=device,
dtype=torch.bfloat16,
)
offs = torch.arange(m, n_groups * m + 1, m, device="cuda", dtype=torch.int32)
offs = torch.arange(m, n_groups * m + 1, m, device=_DEVICE, dtype=torch.int32)

# b must be transposed and in column major format.
b_t = b.contiguous().transpose(-2, -1).requires_grad_(True)
Expand Down Expand Up @@ -109,7 +110,7 @@ def test_K_or_N_dim_not_multiple_of_16(m, n, k):
if n % 16 == 0 and k % 16 == 0:
return
out_dtype = torch.bfloat16
device = "cuda"
device = _DEVICE
n_groups = 4
a = torch.randn(
m * n_groups,
Expand All @@ -131,7 +132,7 @@ def test_K_or_N_dim_not_multiple_of_16(m, n, k):
b_t = b.transpose(-2, -1)
b_t = b_t.transpose(-2, -1).contiguous().transpose(-2, -1)

offs = torch.arange(m, n_groups * m + 1, m, device="cuda", dtype=torch.int32)
offs = torch.arange(m, n_groups * m + 1, m, device=_DEVICE, dtype=torch.int32)

# Compute output.
with pytest.raises(AssertionError):
Expand Down Expand Up @@ -226,11 +227,12 @@ def compute_reference_forward(


@skip_if_rocm("ROCm not supported")
@skip_if_xpu("XPU not supported")
@pytest.mark.parametrize("M,K,N", [(1024, 1024, 1024), (1024, 2048, 4096)])
@pytest.mark.parametrize("num_experts", (1, 8, 16))
def test_emulate_mxfp8_grouped_gemm_2d_3d(M, K, N, num_experts):
x = torch.randn(M, K, dtype=torch.bfloat16, device="cuda")
w_t = torch.randn(num_experts, K, N, dtype=torch.bfloat16, device="cuda")
x = torch.randn(M, K, dtype=torch.bfloat16, device=_DEVICE)
w_t = torch.randn(num_experts, K, N, dtype=torch.bfloat16, device=_DEVICE)
offs = generate_jagged_offs(num_experts, M)
x_ref, w_t_ref, offs_ref = x.clone(), w_t.clone(), offs.clone()

Expand All @@ -257,15 +259,16 @@ def test_emulate_mxfp8_grouped_gemm_2d_3d(M, K, N, num_experts):


@skip_if_rocm("ROCm not supported")
@skip_if_xpu("XPU not supported")
@pytest.mark.parametrize("M", (1024, 4096))
@pytest.mark.parametrize("N", (1024, 4096))
@pytest.mark.parametrize("num_experts", (8, 16))
def test_emulate_mxfp8_grouped_gemm_2d_2d(M, N, num_experts):
# Simluate 2d-2d grouped gemm grad_weight = grad_output_t @ x
block_size = 32
grad_out = torch.randn(M, N, dtype=torch.bfloat16, device="cuda")
grad_out = torch.randn(M, N, dtype=torch.bfloat16, device=_DEVICE)
grad_out_t = grad_out.t().contiguous()
x = torch.randn(M, N, dtype=torch.bfloat16, device="cuda")
x = torch.randn(M, N, dtype=torch.bfloat16, device=_DEVICE)
offs = generate_jagged_offs(num_experts, M, multiple_of=block_size)
x_ref, grad_out_t_ref, offs_ref = x.clone(), grad_out_t.clone(), offs.clone()

Expand Down Expand Up @@ -305,6 +308,7 @@ def test_emulate_mxfp8_grouped_gemm_2d_2d(M, N, num_experts):


@skip_if_rocm("ROCm not supported")
@skip_if_xpu("XPU not supported")
@pytest.mark.parametrize("M,K,N", [(1024, 1024, 1024), (1024, 2048, 4096)])
@pytest.mark.parametrize("num_experts", (1, 8, 16))
def test_mxfp8_grouped_gemm_with_dq_fwd_bwd(M, K, N, num_experts):
Expand All @@ -313,9 +317,9 @@ def test_mxfp8_grouped_gemm_with_dq_fwd_bwd(M, K, N, num_experts):
)

block_size = 32
x = torch.randn(M, K, dtype=torch.bfloat16, device="cuda", requires_grad=True)
x = torch.randn(M, K, dtype=torch.bfloat16, device=_DEVICE, requires_grad=True)
w_t = torch.randn(
num_experts, K, N, dtype=torch.bfloat16, device="cuda", requires_grad=True
num_experts, K, N, dtype=torch.bfloat16, device=_DEVICE, requires_grad=True
)
offs = generate_jagged_offs(num_experts, M, multiple_of=block_size)
x_ref, w_t_ref, offs_ref = (
Expand Down
60 changes: 28 additions & 32 deletions test/prototype/mx_formats/test_kernels.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,11 +45,18 @@
from torchao.prototype.mx_formats.mx_tensor import MXTensor, ScaleCalculationMode, to_mx
from torchao.prototype.mx_formats.utils import to_blocked
from torchao.utils import (
get_available_devices,
TORCH_VERSION_AT_LEAST_2_8,
is_sm_at_least_89,
is_sm_at_least_100,
)

from torchao.testing.utils import skip_if_xpu

from torchao.utils import get_available_devices


_DEVICES = get_available_devices()
torch.manual_seed(0)

if not TORCH_VERSION_AT_LEAST_2_8:
Expand Down Expand Up @@ -327,19 +334,19 @@ def test_fp4_pack_unpack():
assert torch.all(orig_vals_dq == orig_vals)


@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
@pytest.mark.skipif(not has_triton(), reason="unsupported without triton")
@pytest.mark.skipif(is_sm_at_least_100(), reason="broken on CUDA capability 10.0")
@skip_if_xpu("XPU not Support")
def test_fp4_triton_unscaled_cast():
packed_vals = torch.arange(0, 255, dtype=torch.uint8, device="cuda")
f32_ref = f4_unpacked_to_f32(unpack_uint4(packed_vals))
f32_triton = triton_f4_to_bf16(packed_vals).to(torch.float)
assert torch.all(torch.eq(f32_ref, f32_triton))


@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
@pytest.mark.skipif(not has_triton(), reason="unsupported without triton")
@pytest.mark.skipif(is_sm_at_least_100(), reason="broken on CUDA capability 10.0")
@skip_if_xpu("XPU not Support")
def test_fp4_triton_scaled_cast():
size = (256,)
orig_vals = torch.randn(size, dtype=torch.float, device="cuda") * 100
Expand All @@ -357,7 +364,7 @@ def test_fp4_triton_scaled_cast():
f32_triton = mxtensor_triton.to_dtype(torch.float)
assert torch.all(torch.eq(f32_ref, f32_triton))


@skip_if_xpu("XPU not Support")
@pytest.mark.parametrize("dtype_name", (DTYPE_FP6_E2M3, DTYPE_FP6_E3M2))
def test_fp6_values(dtype_name):
"""
Expand Down Expand Up @@ -403,18 +410,8 @@ def test_fp6_values(dtype_name):
torch.testing.assert_close(f32, f32_ref, rtol=0, atol=0)


@pytest.mark.parametrize(
"device",
[
"cpu",
pytest.param(
"cuda",
marks=pytest.mark.skipif(
not torch.cuda.is_available(), reason="CUDA not available"
),
),
],
)
@skip_if_xpu("XPU not Support")
@pytest.mark.parametrize("device", _DEVICES)
@pytest.mark.parametrize(
"f32_val,f6_e3m2_enc",
[
Expand All @@ -433,12 +430,11 @@ def test_fp6_e3m2_rounding(f32_val, f6_e3m2_enc, device):
assert f6_e3m2_unpacked.item() == (f6_e3m2_enc | 0b100000)


@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
@skip_if_xpu("XPU not Support")
@pytest.mark.parametrize("device", _DEVICES)
@pytest.mark.skipif(not has_triton(), reason="unsupported without triton")
def test_fp6_e2m3_pack_unpack():
orig_vals = torch.Tensor([[0.0, 0.5, 7.5, -0.0], [-0.875, 1.0, -6.0, 0.125]]).to(
"cuda"
)
def test_fp6_e2m3_pack_unpack(device):
orig_vals = torch.Tensor([[0.0, 0.5, 7.5, -0.0], [-0.875, 1.0, -6.0, 0.125]]).to(device)
orig_vals_f6_unpacked = f32_to_f6_e2m3_unpacked(orig_vals)
orig_vals_f6_packed = pack_uint6(orig_vals_f6_unpacked)
assert orig_vals_f6_packed.numel() == (3 * orig_vals.numel() // 4)
Expand All @@ -448,12 +444,11 @@ def test_fp6_e2m3_pack_unpack():
assert torch.all(orig_vals_f6_packed_unpacked == orig_vals)


@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
@skip_if_xpu("XPU not Support")
@pytest.mark.parametrize("device", _DEVICES)
@pytest.mark.skipif(not has_triton(), reason="unsupported without triton")
def test_fp6_e3m2_pack_unpack():
orig_vals = torch.Tensor([[0.0, 5.0, 28.0, -0.0], [-0.25, 0.1875, 0.0625, 8.0]]).to(
"cuda"
)
def test_fp6_e3m2_pack_unpack(device):
orig_vals = torch.Tensor([[0.0, 5.0, 28.0, -0.0], [-0.25, 0.1875, 0.0625, 8.0]]).to(device)
orig_vals_f6_unpacked = f32_to_f6_e3m2_unpacked(orig_vals)
orig_vals_f6_packed = pack_uint6(orig_vals_f6_unpacked)
assert orig_vals_f6_packed.numel() == (3 * orig_vals.numel() // 4)
Expand All @@ -471,14 +466,15 @@ def test_fp6_e3m2_pack_unpack():
@pytest.mark.parametrize("M", (256, 2048))
@pytest.mark.parametrize("K", (256, 2048))
def test_triton_mxfp8_dim1_randn(M, K):
x = torch.randn(M, K, dtype=torch.bfloat16, device="cuda")
x = torch.randn(M, K, dtype=torch.bfloat16, device=device)
x_mx_ref, x_s_ref = triton_to_mxfp8_dim1_reference(x, block_size=32)
x_mx_t, x_s_t = triton_to_mxfp8_dim1(x, inner_block_size=32)
torch.testing.assert_close(x_mx_t, x_mx_ref, rtol=0, atol=0)
torch.testing.assert_close(x_s_t, x_s_ref, rtol=0, atol=0)


@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
@skip_if_xpu("XPU not Support")
@pytest.mark.parametrize("device", _DEVICES)
@pytest.mark.parametrize(
"shape",
[
Expand All @@ -492,8 +488,8 @@ def test_triton_mxfp8_dim1_randn(M, K):
(128, 1),
],
)
def test_rearrange(shape):
scales = torch.randint(256, size=shape, device="cuda", dtype=torch.uint8)
def test_rearrange(device, shape):
scales = torch.randint(256, size=shape, device=device, dtype=torch.uint8)
eager = to_blocked(scales, False)
triton = to_blocked(scales, True)
torch.testing.assert_close(eager, triton, atol=0, rtol=0)
Expand All @@ -519,7 +515,7 @@ def test_cuda_mx_dim1_numerics(M, K, input_dtype, scaling_mode):

# Use disinct incrementing values from 0 to M*K-1 to make debugging easier.
x = (
torch.arange(0, M * K, dtype=input_dtype, device="cuda")
torch.arange(0, M * K, dtype=input_dtype, device=device)
.reshape(M, K)
.contiguous()
)
Expand Down Expand Up @@ -557,7 +553,7 @@ def test_cuda_mx_dim0_not_supported():
M, K = 64, 64
block_size = 32
x = (
torch.arange(0, M * K, dtype=torch.bfloat16, device="cuda")
torch.arange(0, M * K, dtype=torch.bfloat16, device=device)
.reshape(M, K)
.contiguous()
)
Expand All @@ -580,7 +576,7 @@ def test_cuda_mx_dim1_invalid_block_size():

M, K = 64, 64
x = (
torch.arange(0, M * K, dtype=torch.bfloat16, device="cuda")
torch.arange(0, M * K, dtype=torch.bfloat16, device=device)
.reshape(M, K)
.contiguous()
)
Expand Down
Loading