Skip to content

Commit d37dcb7

Browse files
authored
mx: delete use_fp4_custom_triton_dequant_kernel option (#2831)
* Update [ghstack-poisoned] * Update [ghstack-poisoned]
1 parent 1aabda0 commit d37dcb7

File tree

8 files changed

+11
-370
lines changed

8 files changed

+11
-370
lines changed

test/prototype/mx_formats/test_kernels.py

Lines changed: 1 addition & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@
4141
triton_to_mxfp8_dim1_reference,
4242
unpack_uint4,
4343
)
44-
from torchao.prototype.mx_formats.mx_tensor import MXTensor, ScaleCalculationMode, to_mx
44+
from torchao.prototype.mx_formats.mx_tensor import ScaleCalculationMode, to_mx
4545
from torchao.prototype.mx_formats.utils import to_blocked
4646
from torchao.utils import (
4747
TORCH_VERSION_AT_LEAST_2_8,
@@ -326,28 +326,6 @@ def test_fp4_pack_unpack():
326326
assert torch.all(orig_vals_dq == orig_vals)
327327

328328

329-
# TODO(future PR): fix or delete this test
330-
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
331-
@pytest.mark.skipif(not has_triton(), reason="unsupported without triton")
332-
@pytest.mark.skipif(is_sm_at_least_89(), reason="broken on CUDA capability 8.9+")
333-
def test_fp4_triton_scaled_cast():
334-
size = (256,)
335-
orig_vals = torch.randn(size, dtype=torch.float, device="cuda") * 100
336-
mxtensor_ref = MXTensor.to_mx(
337-
orig_vals, block_size=32, elem_dtype=torch.float4_e2m1fn_x2
338-
)
339-
mxtensor_triton = MXTensor.to_mx(
340-
orig_vals,
341-
block_size=32,
342-
elem_dtype=torch.float4_e2m1fn_x2,
343-
use_fp4_custom_triton_dequant_kernel=True,
344-
)
345-
346-
f32_ref = mxtensor_ref.to_dtype(torch.float)
347-
f32_triton = mxtensor_triton.to_dtype(torch.float)
348-
assert torch.all(torch.eq(f32_ref, f32_triton))
349-
350-
351329
@pytest.mark.parametrize("dtype_name", (DTYPE_FP6_E2M3, DTYPE_FP6_E3M2))
352330
def test_fp6_values(dtype_name):
353331
"""

test/prototype/mx_formats/test_mx_tensor.py

Lines changed: 1 addition & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -380,14 +380,12 @@ def test_exponent_nan_out(elem_dtype, pack_fp6):
380380
else:
381381
raise AssertionError("unsupported")
382382
block_size = 4
383-
use_fp4_custom_triton_dequant_kernel = False
384383
tensor_mx = MXTensor(
385384
data_bits,
386385
scale_e8m0,
387386
elem_dtype,
388387
block_size,
389388
torch.float,
390-
use_fp4_custom_triton_dequant_kernel,
391389
MXGemmKernelChoice.EMULATED,
392390
pack_fp6,
393391
None,
@@ -427,22 +425,17 @@ def test_block_sizes(elem_dtype, B):
427425

428426
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
429427
@pytest.mark.parametrize("elem_dtype", SUPPORTED_ELEM_DTYPES)
430-
@pytest.mark.parametrize("fp4_triton", [False, True])
431-
def test_transpose(elem_dtype, fp4_triton):
428+
def test_transpose(elem_dtype):
432429
"""
433430
Verify that transposing an MX tensor works
434431
"""
435-
if elem_dtype != torch.float4_e2m1fn_x2 and fp4_triton:
436-
pytest.skip("unsupported configuration")
437-
438432
M, K = 128, 256
439433
block_size = 32
440434
tensor_hp = torch.randn(M, K, device="cuda", dtype=torch.bfloat16)
441435
tensor_mx = MXTensor.to_mx(
442436
tensor_hp,
443437
elem_dtype,
444438
block_size,
445-
use_fp4_custom_triton_dequant_kernel=fp4_triton,
446439
)
447440
tensor_mx_dq_t = tensor_mx.to_dtype(tensor_hp.dtype).t()
448441

@@ -510,15 +503,13 @@ def test_to_mx_from_mx_compile_numerics(elem_dtype, hp_dtype, all_zeros):
510503

511504
to_dtype_c = torch.compile(to_dtype, fullgraph=True)
512505

513-
use_fp4_custom_triton_dequant_kernel = False
514506
pack_fp6 = False
515507
x_mx_dq = to_dtype(
516508
x_mx.qdata,
517509
x_mx._scale_e8m0,
518510
x_mx._elem_dtype,
519511
x_mx._block_size,
520512
hp_dtype, # noqa: E501
521-
use_fp4_custom_triton_dequant_kernel,
522513
pack_fp6,
523514
)
524515
x_mx_c_dq = to_dtype_c(
@@ -527,7 +518,6 @@ def test_to_mx_from_mx_compile_numerics(elem_dtype, hp_dtype, all_zeros):
527518
x_mx_c._elem_dtype,
528519
x_mx_c._block_size,
529520
hp_dtype,
530-
use_fp4_custom_triton_dequant_kernel,
531521
pack_fp6,
532522
)
533523
torch.testing.assert_close(x_mx_dq, x_mx_c_dq, atol=0, rtol=0)

torchao/prototype/mx_formats/benchmarks/bench_qdq.py

Lines changed: 0 additions & 146 deletions
This file was deleted.

torchao/prototype/mx_formats/config.py

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -146,9 +146,6 @@ class MXLinearConfig(AOBaseConfig):
146146
MXFP8Dim1CastKernelChoice.TORCH
147147
)
148148

149-
# If True, uses a custom triton kernel for fp4 dequantize
150-
use_fp4_custom_triton_dequant_kernel: bool = False
151-
152149
scale_calculation_mode: ScaleCalculationMode = ScaleCalculationMode.FLOOR
153150

154151
def __post_init__(self):
@@ -217,8 +214,6 @@ def short_str(self) -> str:
217214
s += f", lp_go_override={DTYPE_TO_SHORT_STR[self.elem_dtype_grad_output_override]}"
218215
s += f", kernel={self.gemm_kernel_choice.value}"
219216
s += f", mxfp8_cast_kernel_choice={self.mxfp8_cast_kernel_choice.value}"
220-
if self.use_fp4_custom_triton_dequant_kernel:
221-
s += ", use_fp4_custom_triton_dequant_kernel=True"
222217
if self.scale_calculation_mode != ScaleCalculationMode.FLOOR:
223218
s += f", scale_calculation_mode={self.scale_calculation_mode}"
224219
return s

0 commit comments

Comments
 (0)