From 59f2bf78bf696d16bd3338176dd49e6b8628aa95 Mon Sep 17 00:00:00 2001 From: vasiliy Date: Thu, 21 Aug 2025 06:22:50 -0700 Subject: [PATCH 1/8] Update [ghstack-poisoned] --- test/prototype/mx_formats/test_kernels.py | 12 --- torchao/prototype/mx_formats/kernels.py | 102 ---------------------- 2 files changed, 114 deletions(-) diff --git a/test/prototype/mx_formats/test_kernels.py b/test/prototype/mx_formats/test_kernels.py index d04a67771d..e553946413 100644 --- a/test/prototype/mx_formats/test_kernels.py +++ b/test/prototype/mx_formats/test_kernels.py @@ -35,7 +35,6 @@ get_bits, pack_uint4, pack_uint6, - triton_f4_to_bf16, triton_f6_e2m3_to_bf16, triton_f6_e3m2_to_bf16, triton_to_mxfp8_dim1, @@ -327,17 +326,6 @@ def test_fp4_pack_unpack(): assert torch.all(orig_vals_dq == orig_vals) -# TODO(future PR): fix or delete this test -@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") -@pytest.mark.skipif(not has_triton(), reason="unsupported without triton") -@pytest.mark.skipif(is_sm_at_least_89(), reason="broken on CUDA capability 8.9+") -def test_fp4_triton_unscaled_cast(): - packed_vals = torch.arange(0, 255, dtype=torch.uint8, device="cuda") - f32_ref = f4_unpacked_to_f32(unpack_uint4(packed_vals)) - f32_triton = triton_f4_to_bf16(packed_vals).to(torch.float) - assert torch.all(torch.eq(f32_ref, f32_triton)) - - # TODO(future PR): fix or delete this test @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") @pytest.mark.skipif(not has_triton(), reason="unsupported without triton") diff --git a/torchao/prototype/mx_formats/kernels.py b/torchao/prototype/mx_formats/kernels.py index cabb61276a..732af4df2a 100644 --- a/torchao/prototype/mx_formats/kernels.py +++ b/torchao/prototype/mx_formats/kernels.py @@ -196,55 +196,6 @@ def _fp4_packed_to_bf16( output = output.to(tl.bfloat16) return output - @triton.jit - def triton_f4_to_bf16_kernel( - x_ptr, - output_ptr, - n_elements_in, - sign_mask_f4: tl.constexpr, - mantissa_mask_f4: tl.constexpr, - mbits_f4_e2m1: tl.constexpr, - ebits_f4_e2m1: tl.constexpr, - f4_e2m1_exp_bias: tl.constexpr, - mbits_f32: tl.constexpr, - ebits_f32: tl.constexpr, - f32_exp_bias: tl.constexpr, - zero_bits_f32: tl.constexpr, - zero_point_five_bits_f32: tl.constexpr, - BLOCK_SIZE_IN: tl.constexpr, - ): - pid = tl.program_id(axis=0) - n_elements_out = n_elements_in * 2 - BLOCK_SIZE_OUT: tl.constexpr = BLOCK_SIZE_IN * 2 - - block_start_in = pid * BLOCK_SIZE_IN - offsets_in = block_start_in + tl.arange(0, BLOCK_SIZE_IN) - - mask_in = offsets_in < n_elements_in - - # packed uint8 - x_packed = tl.load(x_ptr + offsets_in, mask=mask_in) - output = _fp4_packed_to_bf16( - x_packed, - sign_mask_f4, - mantissa_mask_f4, - mbits_f4_e2m1, - ebits_f4_e2m1, - f4_e2m1_exp_bias, - mbits_f32, - ebits_f32, - f32_exp_bias, - zero_bits_f32, - zero_point_five_bits_f32, - ) - - # set up output offsets - block_start_out = pid * BLOCK_SIZE_OUT - offsets_out = block_start_out + tl.arange(0, BLOCK_SIZE_OUT) - mask_out = offsets_out < n_elements_out - - tl.store(output_ptr + offsets_out, output, mask=mask_out) - @triton.autotune( configs=[ triton.Config({"BLOCK_SIZE_IN": 128}), @@ -624,24 +575,6 @@ def triton_pack_uint6_kernel( else: - def triton_f4_to_bf16_kernel( - x_ptr, - output_ptr, - n_elements_in, - sign_mask_f4, - mantissa_mask_f4, - mbits_f4_e2m1, - ebits_f4_e2m1, - f4_e2m1_exp_bias, - mbits_f32, - ebits_f32, - f32_exp_bias, - zero_bits_f32, - zero_point_five_bits_f32, - BLOCK_SIZE_IN, - ): - raise AssertionError("unsupported without triton") - def triton_f4_to_scaled_bf16_kernel( x_ptr, s_ptr, @@ -705,41 +638,6 @@ def triton_pack_uint6_kernel( raise AssertionError("unsupported without triton") -def triton_f4_to_bf16(x: torch.Tensor): - """ - Input: a tensor of packed fp4 values - Output: a tensor of bfloat16 values - - Note: this function is only used in testing, so we can test - the numerical correctness of the cast without the scaling. - """ - new_shape = (*x.shape[:-1], x.shape[-1] * 2) - output = torch.empty(*new_shape, device=x.device, dtype=torch.bfloat16) - assert x.is_contiguous() - assert x.is_cuda and output.is_cuda - n_elements_in = x.numel() - grid = lambda meta: ( # noqa: E731 - triton.cdiv(n_elements_in, meta["BLOCK_SIZE_IN"]), - ) # noqa: E731,E501 - triton_f4_to_bf16_kernel[grid]( - x, - output, - n_elements_in, - sign_mask_f4=SIGN_MASK_F4, - mantissa_mask_f4=MANTISSA_MASK_F4, - mbits_f4_e2m1=MBITS_F4_E2M1, - ebits_f4_e2m1=EBITS_F4_E2M1, - f4_e2m1_exp_bias=F4_E2M1_EXP_BIAS, - mbits_f32=MBITS_F32, - ebits_f32=EBITS_F32, - f32_exp_bias=F32_EXP_BIAS, - zero_bits_f32=ZERO_BITS_F32, - zero_point_five_bits_f32=ZERO_POINT_FIVE_BITS_F32, - BLOCK_SIZE_IN=512, - ) - return output - - def triton_f4_to_scaled_bf16( x: torch.Tensor, s_e8m0: torch.Tensor, From 256eed6c0f152381324d9bad695566e3f69108b7 Mon Sep 17 00:00:00 2001 From: vasiliy Date: Thu, 21 Aug 2025 06:22:53 -0700 Subject: [PATCH 2/8] Update [ghstack-poisoned] --- test/prototype/mx_formats/test_kernels.py | 24 +-- test/prototype/mx_formats/test_mx_tensor.py | 12 +- .../mx_formats/benchmarks/bench_qdq.py | 146 ----------------- torchao/prototype/mx_formats/config.py | 5 - torchao/prototype/mx_formats/kernels.py | 147 ------------------ torchao/prototype/mx_formats/mx_linear.py | 2 - torchao/prototype/mx_formats/mx_ops.py | 5 - torchao/prototype/mx_formats/mx_tensor.py | 40 ++--- 8 files changed, 11 insertions(+), 370 deletions(-) delete mode 100644 torchao/prototype/mx_formats/benchmarks/bench_qdq.py diff --git a/test/prototype/mx_formats/test_kernels.py b/test/prototype/mx_formats/test_kernels.py index e553946413..0957bf0fb9 100644 --- a/test/prototype/mx_formats/test_kernels.py +++ b/test/prototype/mx_formats/test_kernels.py @@ -41,7 +41,7 @@ triton_to_mxfp8_dim1_reference, unpack_uint4, ) -from torchao.prototype.mx_formats.mx_tensor import MXTensor, ScaleCalculationMode, to_mx +from torchao.prototype.mx_formats.mx_tensor import ScaleCalculationMode, to_mx from torchao.prototype.mx_formats.utils import to_blocked from torchao.utils import ( TORCH_VERSION_AT_LEAST_2_8, @@ -326,28 +326,6 @@ def test_fp4_pack_unpack(): assert torch.all(orig_vals_dq == orig_vals) -# TODO(future PR): fix or delete this test -@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") -@pytest.mark.skipif(not has_triton(), reason="unsupported without triton") -@pytest.mark.skipif(is_sm_at_least_89(), reason="broken on CUDA capability 8.9+") -def test_fp4_triton_scaled_cast(): - size = (256,) - orig_vals = torch.randn(size, dtype=torch.float, device="cuda") * 100 - mxtensor_ref = MXTensor.to_mx( - orig_vals, block_size=32, elem_dtype=torch.float4_e2m1fn_x2 - ) - mxtensor_triton = MXTensor.to_mx( - orig_vals, - block_size=32, - elem_dtype=torch.float4_e2m1fn_x2, - use_fp4_custom_triton_dequant_kernel=True, - ) - - f32_ref = mxtensor_ref.to_dtype(torch.float) - f32_triton = mxtensor_triton.to_dtype(torch.float) - assert torch.all(torch.eq(f32_ref, f32_triton)) - - @pytest.mark.parametrize("dtype_name", (DTYPE_FP6_E2M3, DTYPE_FP6_E3M2)) def test_fp6_values(dtype_name): """ diff --git a/test/prototype/mx_formats/test_mx_tensor.py b/test/prototype/mx_formats/test_mx_tensor.py index f4af52bafa..ea1b7c6459 100644 --- a/test/prototype/mx_formats/test_mx_tensor.py +++ b/test/prototype/mx_formats/test_mx_tensor.py @@ -380,14 +380,12 @@ def test_exponent_nan_out(elem_dtype, pack_fp6): else: raise AssertionError("unsupported") block_size = 4 - use_fp4_custom_triton_dequant_kernel = False tensor_mx = MXTensor( data_bits, scale_e8m0, elem_dtype, block_size, torch.float, - use_fp4_custom_triton_dequant_kernel, MXGemmKernelChoice.EMULATED, pack_fp6, None, @@ -427,14 +425,10 @@ def test_block_sizes(elem_dtype, B): @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") @pytest.mark.parametrize("elem_dtype", SUPPORTED_ELEM_DTYPES) -@pytest.mark.parametrize("fp4_triton", [False, True]) -def test_transpose(elem_dtype, fp4_triton): +def test_transpose(elem_dtype): """ Verify that transposing an MX tensor works """ - if elem_dtype != torch.float4_e2m1fn_x2 and fp4_triton: - pytest.skip("unsupported configuration") - M, K = 128, 256 block_size = 32 tensor_hp = torch.randn(M, K, device="cuda", dtype=torch.bfloat16) @@ -442,7 +436,6 @@ def test_transpose(elem_dtype, fp4_triton): tensor_hp, elem_dtype, block_size, - use_fp4_custom_triton_dequant_kernel=fp4_triton, ) tensor_mx_dq_t = tensor_mx.to_dtype(tensor_hp.dtype).t() @@ -510,7 +503,6 @@ def test_to_mx_from_mx_compile_numerics(elem_dtype, hp_dtype, all_zeros): to_dtype_c = torch.compile(to_dtype, fullgraph=True) - use_fp4_custom_triton_dequant_kernel = False pack_fp6 = False x_mx_dq = to_dtype( x_mx.qdata, @@ -518,7 +510,6 @@ def test_to_mx_from_mx_compile_numerics(elem_dtype, hp_dtype, all_zeros): x_mx._elem_dtype, x_mx._block_size, hp_dtype, # noqa: E501 - use_fp4_custom_triton_dequant_kernel, pack_fp6, ) x_mx_c_dq = to_dtype_c( @@ -527,7 +518,6 @@ def test_to_mx_from_mx_compile_numerics(elem_dtype, hp_dtype, all_zeros): x_mx_c._elem_dtype, x_mx_c._block_size, hp_dtype, - use_fp4_custom_triton_dequant_kernel, pack_fp6, ) torch.testing.assert_close(x_mx_dq, x_mx_c_dq, atol=0, rtol=0) diff --git a/torchao/prototype/mx_formats/benchmarks/bench_qdq.py b/torchao/prototype/mx_formats/benchmarks/bench_qdq.py deleted file mode 100644 index ca0b926ce5..0000000000 --- a/torchao/prototype/mx_formats/benchmarks/bench_qdq.py +++ /dev/null @@ -1,146 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. - -# This source code is licensed under the license found in the -# LICENSE file in the root directory of this source tree. - -""" -Benchmarking mx quantize/dequantize -""" - -from typing import Optional - -import fire -import tabulate -import torch -from torch.profiler import ProfilerActivity, profile - -from torchao.prototype.mx_formats import config -from torchao.prototype.mx_formats.constants import ( # noqa: E501 - SUPPORTED_ELEM_DTYPES, -) -from torchao.prototype.mx_formats.mx_tensor import MXTensor -from torchao.utils import benchmark_torch_function_in_microseconds - - -def run(profile_folder: Optional[str] = None): - headers = [ - "elem_dtype", - "use_fp4_custom_triton_dequant_kernel", - "q_time_us", - "q_mem_bw_tb_s", - "dq_time_us", - "dq_mem_bw_tb_s", - ] - results = [] - - data_hp = torch.randn(1, 4096, 11008, dtype=torch.bfloat16, device="cuda") - - for elem_dtype in SUPPORTED_ELEM_DTYPES: - for use_fp4_custom_triton_dequant_kernel in (False, True): - config.use_fp4_custom_triton_dequant_kernel = ( - use_fp4_custom_triton_dequant_kernel - ) - - if ( - elem_dtype != torch.float4_e2m1fn_x2 - and use_fp4_custom_triton_dequant_kernel # noqa: E501 - ): - # custom_triton_kernels only works for fp4 - continue - - print( - "elem_dtype", - elem_dtype, - "use_fp4_custom_triton_dequant_kernel", - use_fp4_custom_triton_dequant_kernel, - ) - - data_lp = MXTensor.to_mx(data_hp, elem_dtype, block_size=32) - - if not use_fp4_custom_triton_dequant_kernel: - quant = torch.compile(MXTensor.to_mx, fullgraph=True) - dequant = torch.compile(data_lp.to_dtype, fullgraph=True) - else: - # As of 2024-04, torch.compile didn't work with the - # handwritten triton kernel, - # crashed on tl.interleave: - # https://github.com/pytorch/pytorch/issues/123967 - # As of 2024-05-24, now there is message asking to convert to - # an opaque custom op: - # https://gist.github.com/vkuzo/0b0b90dca03bdb8e0446e4135644238a # noqa: E501 - # TODO(future): make this better - quant = MXTensor.to_mx - dequant = data_lp.to_dtype - - # warm up - quant(data_hp, elem_dtype, block_size=32) - res = dequant(torch.bfloat16) - - if profile_folder is not None: - with profile( - activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA], - record_shapes=True, - ) as prof: - for _ in range(5): - quant(data_hp, elem_dtype, block_size=32) - dequant(torch.bfloat16) - prof.export_chrome_trace( - profile_folder - + f"/mx_qdq_{elem_dtype}_{use_fp4_custom_triton_dequant_kernel}.json" # noqa: E501 - ) - - q_execution_time_us = benchmark_torch_function_in_microseconds( - quant, data_hp, elem_dtype, block_size=32 - ) - dq_execution_time_us = benchmark_torch_function_in_microseconds( - dequant, torch.bfloat16 - ) - print(f"q time: {q_execution_time_us} us") - print(f"dq time: {dq_execution_time_us} us") - - # memory reads per element: - byte_per_stored_element = 1.0 # fp8 or 2xfp4 - byte_per_stored_exp_element = 1.0 # e8m0 - byte_per_dequantized_element = 2.0 # bfloat16 - mem_reads_writes_bytes = ( - # read raw data - (data_lp._data.numel() * byte_per_stored_element) - + - # read exponent - (data_lp._scale_e8m0.numel() * byte_per_stored_exp_element) - + - # write dequant - (res.numel() * byte_per_dequantized_element) - ) - # note: the above also works for quant, with reads/writes in - # reverse - - q_mem_bw_tb_s = (mem_reads_writes_bytes / 1e12) / ( - q_execution_time_us / 1e6 - ) - dq_mem_bw_tb_s = (mem_reads_writes_bytes / 1e12) / ( - dq_execution_time_us / 1e6 - ) - print(f"q mem bw: {q_mem_bw_tb_s} TB/s") - print(f"dq mem bw: {dq_mem_bw_tb_s} TB/s") - - results.append( - ( - elem_dtype, - use_fp4_custom_triton_dequant_kernel, - q_execution_time_us, - q_mem_bw_tb_s, - dq_execution_time_us, - dq_mem_bw_tb_s, - ) - ) - config.use_fp4_custom_triton_dequant_kernel = False - - torch._dynamo.reset() - - print(tabulate.tabulate(results, headers=headers, floatfmt=".2f")) - - -if __name__ == "__main__": - fire.Fire(run) diff --git a/torchao/prototype/mx_formats/config.py b/torchao/prototype/mx_formats/config.py index 7de90daa1c..388af07874 100644 --- a/torchao/prototype/mx_formats/config.py +++ b/torchao/prototype/mx_formats/config.py @@ -146,9 +146,6 @@ class MXLinearConfig(AOBaseConfig): MXFP8Dim1CastKernelChoice.TORCH ) - # If True, uses a custom triton kernel for fp4 dequantize - use_fp4_custom_triton_dequant_kernel: bool = False - scale_calculation_mode: ScaleCalculationMode = ScaleCalculationMode.FLOOR def __post_init__(self): @@ -217,8 +214,6 @@ def short_str(self) -> str: s += f", lp_go_override={DTYPE_TO_SHORT_STR[self.elem_dtype_grad_output_override]}" s += f", kernel={self.gemm_kernel_choice.value}" s += f", mxfp8_cast_kernel_choice={self.mxfp8_cast_kernel_choice.value}" - if self.use_fp4_custom_triton_dequant_kernel: - s += ", use_fp4_custom_triton_dequant_kernel=True" if self.scale_calculation_mode != ScaleCalculationMode.FLOOR: s += f", scale_calculation_mode={self.scale_calculation_mode}" return s diff --git a/torchao/prototype/mx_formats/kernels.py b/torchao/prototype/mx_formats/kernels.py index 732af4df2a..cd605917af 100644 --- a/torchao/prototype/mx_formats/kernels.py +++ b/torchao/prototype/mx_formats/kernels.py @@ -30,7 +30,6 @@ from torchao.prototype.mx_formats.constants import ( E8M0_EXPONENT_BIAS, E8M0_EXPONENT_NAN_VAL, - F4_E2M1_EXP_BIAS, F6_E2M3_EXP_BIAS, F6_E3M2_EXP_BIAS, F32_EXP_BIAS, @@ -196,89 +195,6 @@ def _fp4_packed_to_bf16( output = output.to(tl.bfloat16) return output - @triton.autotune( - configs=[ - triton.Config({"BLOCK_SIZE_IN": 128}), - triton.Config({"BLOCK_SIZE_IN": 256}), - triton.Config({"BLOCK_SIZE_IN": 512}), - triton.Config({"BLOCK_SIZE_IN": 1024}), - triton.Config({"BLOCK_SIZE_IN": 2048}), - ], - key=["n_elements_in"], - ) - @triton.jit - def triton_f4_to_scaled_bf16_kernel( - x_ptr, - s_ptr, - output_ptr, - n_elements_in, - mx_block_size: tl.constexpr, - sign_mask_f4: tl.constexpr, - mantissa_mask_f4: tl.constexpr, - mbits_f4_e2m1: tl.constexpr, - ebits_f4_e2m1: tl.constexpr, - f4_e2m1_exp_bias: tl.constexpr, - mbits_f32: tl.constexpr, - ebits_f32: tl.constexpr, - f32_exp_bias: tl.constexpr, - zero_bits_f32: tl.constexpr, - zero_point_five_bits_f32: tl.constexpr, - e8m0_exponent_bias: tl.constexpr, - e8m0_exponent_nan_val: tl.constexpr, - BLOCK_SIZE_IN: tl.constexpr, - ): - pid = tl.program_id(axis=0) - n_elements_out = n_elements_in * 2 - n_elements_s = n_elements_out // 32 - - BLOCK_SIZE_S: tl.constexpr = BLOCK_SIZE_IN // 16 - BLOCK_SIZE_OUT: tl.constexpr = BLOCK_SIZE_IN * 2 - - block_start_in = pid * BLOCK_SIZE_IN - offsets_in = block_start_in + tl.arange(0, BLOCK_SIZE_IN) - mask_in = offsets_in < n_elements_in - # packed uint8 - x_packed = tl.load(x_ptr + offsets_in, mask=mask_in) - output = _fp4_packed_to_bf16( - x_packed, - sign_mask_f4, - mantissa_mask_f4, - mbits_f4_e2m1, - ebits_f4_e2m1, - f4_e2m1_exp_bias, - mbits_f32, - ebits_f32, - f32_exp_bias, - zero_bits_f32, - zero_point_five_bits_f32, - ) - - # load scale - block_start_s = pid * BLOCK_SIZE_S - offsets_s = block_start_s + tl.arange(0, BLOCK_SIZE_S) - mask_s = offsets_s < n_elements_s - s = tl.load(s_ptr + offsets_s, mask=mask_s) - - # create the scale in bf16 - s_offset = s.to(tl.int16) - e8m0_exponent_bias - s_fp = libdevice.pow(2.0, s_offset).to(tl.bfloat16) - s_fp = tl.where(s != e8m0_exponent_nan_val, s_fp, float("nan")) - - # multiply output by scale - # TODO(later): see if manipulating the exponent instead of fp - # multiplication is going to give a significant speedup - output = tl.reshape(output, (BLOCK_SIZE_OUT // mx_block_size, mx_block_size)) # noqa: E501 - s_fp = tl.reshape(s_fp, (BLOCK_SIZE_S // 1, 1)) - output = output * s_fp - output = tl.reshape(output, (BLOCK_SIZE_OUT,)) - - # set up output offsets - block_start_out = pid * BLOCK_SIZE_OUT - offsets_out = block_start_out + tl.arange(0, BLOCK_SIZE_OUT) - mask_out = offsets_out < n_elements_out - - tl.store(output_ptr + offsets_out, output, mask=mask_out) - @triton.jit def _fp6_packed_to_bf16( packed_4bits_a, @@ -575,28 +491,6 @@ def triton_pack_uint6_kernel( else: - def triton_f4_to_scaled_bf16_kernel( - x_ptr, - s_ptr, - output_ptr, - n_elements_in, - mx_block_size, - sign_mask_f4, - mantissa_mask_f4, - mbits_f4_e2m1, - ebits_f4_e2m1, - f4_e2m1_exp_bias, - mbits_f32, - ebits_f32, - f32_exp_bias, - zero_bits_f32, - zero_point_five_bits_f32, - e8m0_exponent_bias, - e8m0_exponent_nan_val, - BLOCK_SIZE_IN, - ): - raise AssertionError("unsupported without triton") - def triton_f6_to_bf16_kernel( x_ptr, output_ptr, @@ -638,47 +532,6 @@ def triton_pack_uint6_kernel( raise AssertionError("unsupported without triton") -def triton_f4_to_scaled_bf16( - x: torch.Tensor, - s_e8m0: torch.Tensor, - mx_block_size: int, -): - """ - Input: a tensor of packed fp4 values, and a scale in e8m0 format. The block - size is currently assumed to be 32. - Output: a tensor of bfloat16 values, multiplied by the encoded scale - """ - s_e8m0 = s_e8m0.view(torch.uint8) - new_shape = (*x.shape[:-1], x.shape[-1] * 2) - output = torch.empty(*new_shape, device=x.device, dtype=torch.bfloat16) - assert x.is_contiguous() - assert x.is_cuda and output.is_cuda - n_elements_in = x.numel() - grid = lambda meta: ( # noqa: E731 - triton.cdiv(n_elements_in, meta["BLOCK_SIZE_IN"]), - ) - triton_f4_to_scaled_bf16_kernel[grid]( - x, - s_e8m0, - output, - n_elements_in, - mx_block_size, - sign_mask_f4=SIGN_MASK_F4, - mantissa_mask_f4=MANTISSA_MASK_F4, - mbits_f4_e2m1=MBITS_F4_E2M1, - ebits_f4_e2m1=EBITS_F4_E2M1, - f4_e2m1_exp_bias=F4_E2M1_EXP_BIAS, - mbits_f32=MBITS_F32, - ebits_f32=EBITS_F32, - f32_exp_bias=F32_EXP_BIAS, - zero_bits_f32=ZERO_BITS_F32, - zero_point_five_bits_f32=ZERO_POINT_FIVE_BITS_F32, - e8m0_exponent_bias=E8M0_EXPONENT_BIAS, - e8m0_exponent_nan_val=E8M0_EXPONENT_NAN_VAL, - ) - return output - - def triton_f6_e2m3_to_bf16(x: torch.Tensor) -> torch.Tensor: """ Input: a tensor of packed fp6 values diff --git a/torchao/prototype/mx_formats/mx_linear.py b/torchao/prototype/mx_formats/mx_linear.py index 1a033a1096..161fcd6064 100644 --- a/torchao/prototype/mx_formats/mx_linear.py +++ b/torchao/prototype/mx_formats/mx_linear.py @@ -65,7 +65,6 @@ def _to_mxfp8_dim1_kernel_wrapper( elem_dtype, block_size, hp_dtype, - False, gemm_kernel_choice, False, None, @@ -85,7 +84,6 @@ def _to_mxfp8_dim1_kernel_wrapper( elem_dtype, block_size, hp_dtype, - False, gemm_kernel_choice, False, None, diff --git a/torchao/prototype/mx_formats/mx_ops.py b/torchao/prototype/mx_formats/mx_ops.py index bd4efd379b..07e47eed66 100644 --- a/torchao/prototype/mx_formats/mx_ops.py +++ b/torchao/prototype/mx_formats/mx_ops.py @@ -95,7 +95,6 @@ def _addmm_mx_dispatch( k.elem_dtype, k.block_size, k.scaling_mode, - k.use_fp4_custom_triton_dequant_kernel, k.gemm_kernel_choice, k.pack_fp6, ) @@ -186,7 +185,6 @@ def mx_t(func, types, args, kwargs): old._elem_dtype, old._block_size, old._orig_dtype, - old._use_fp4_custom_triton_dequant_kernel, old._gemm_kernel_choice, old._pack_fp6, old.act_quant_kwargs, @@ -231,7 +229,6 @@ def mx_view_op(func, types, args, kwargs): args[0]._elem_dtype, args[0]._block_size, args[0]._orig_dtype, - args[0]._use_fp4_custom_triton_dequant_kernel, args[0]._gemm_kernel_choice, args[0]._pack_fp6, args[0].act_quant_kwargs, @@ -293,7 +290,6 @@ def mx_slice(func, types, args, kwargs): x._elem_dtype, x._block_size, x._orig_dtype, - x._use_fp4_custom_triton_dequant_kernel, x._gemm_kernel_choice, x._pack_fp6, x.act_quant_kwargs, @@ -348,7 +344,6 @@ def autocast_to_copy(func, types, args, kwargs): tensor._elem_dtype, tensor._block_size, dtype, - tensor._use_fp4_custom_triton_dequant_kernel, tensor._gemm_kernel_choice, tensor._pack_fp6, tensor.act_quant_kwargs, diff --git a/torchao/prototype/mx_formats/mx_tensor.py b/torchao/prototype/mx_formats/mx_tensor.py index 533e186acd..273f1b2b56 100644 --- a/torchao/prototype/mx_formats/mx_tensor.py +++ b/torchao/prototype/mx_formats/mx_tensor.py @@ -53,7 +53,6 @@ f32_to_f6_e3m2_unpacked, pack_uint4, pack_uint6, - triton_f4_to_scaled_bf16, triton_f6_e2m3_to_scaled_bf16, triton_f6_e3m2_to_scaled_bf16, unpack_uint4, @@ -77,7 +76,6 @@ class QuantizeTensorToMXKwargs(QuantizeTensorKwargs): elem_dtype: Union[torch.dtype, str] = torch.float8_e4m3fn block_size: int = 32 scaling_mode: ScaleCalculationMode = ScaleCalculationMode.FLOOR - use_fp4_custom_triton_dequant_kernel: bool = False gemm_kernel_choice: MXGemmKernelChoice = MXGemmKernelChoice.EMULATED pack_fp6: bool = False @@ -349,7 +347,6 @@ def to_dtype( elem_dtype, block_size, target_dtype, - use_fp4_custom_triton_dequant_kernel, pack_fp6, ): orig_shape = data_lp.shape @@ -392,25 +389,15 @@ def to_dtype( data_hp = f6_e3m2_unpacked_to_f32(data_lp) data_hp = data_hp.to(target_dtype).reshape(orig_shape) elif elem_dtype == torch.float4_e2m1fn_x2: - if use_fp4_custom_triton_dequant_kernel: - data_hp_rescaled = triton_f4_to_scaled_bf16( - data_lp, - scale_e8m0, - block_size, - ) - if is_transposed: - data_hp_rescaled = data_hp_rescaled.t() - return data_hp_rescaled.to(target_dtype) - else: - # fp4 - f4_unpacked = unpack_uint4(data_lp) - # for now we only have a cast to f32 - # TODO(future PR): add cast directly to bf16 - f32 = f4_unpacked_to_f32(f4_unpacked) - data_hp = f32.to(target_dtype) - # manually adjust shape to account for the unpacking - # TODO(future PR): clean up the shape code and remove the hack - # below + # fp4 + f4_unpacked = unpack_uint4(data_lp) + # for now we only have a cast to f32 + # TODO(future PR): add cast directly to bf16 + f32 = f4_unpacked_to_f32(f4_unpacked) + data_hp = f32.to(target_dtype) + # manually adjust shape to account for the unpacking + # TODO(future PR): clean up the shape code and remove the hack + # below orig_shape = (*orig_shape[:-1], orig_shape[-1] * 2) else: raise AssertionError("unsupported") @@ -469,7 +456,6 @@ class MXTensor(TorchAOBaseTensor): "_elem_dtype", "_block_size", "_orig_dtype", - "_use_fp4_custom_triton_dequant_kernel", "_gemm_kernel_choice", "_pack_fp6", "act_quant_kwargs", @@ -482,7 +468,6 @@ def __new__( elem_dtype, block_size, orig_dtype, - use_fp4_custom_triton_dequant_kernel, gemm_kernel_choice, pack_fp6, act_quant_kwargs, @@ -551,9 +536,6 @@ def __new__( self._elem_dtype = elem_dtype self._block_size = block_size self._orig_dtype = orig_dtype - self._use_fp4_custom_triton_dequant_kernel = ( - use_fp4_custom_triton_dequant_kernel - ) self._gemm_kernel_choice = gemm_kernel_choice self._pack_fp6 = pack_fp6 self.act_quant_kwargs = act_quant_kwargs @@ -587,7 +569,6 @@ def to_dtype(self, target_dtype): self._elem_dtype, self._block_size, target_dtype, - self._use_fp4_custom_triton_dequant_kernel, self._pack_fp6, ) @@ -598,7 +579,6 @@ def to_mx( elem_dtype: Union[torch.dtype, str], block_size: int = BLOCK_SIZE_DEFAULT, scaling_mode: ScaleCalculationMode = ScaleCalculationMode.FLOOR, - use_fp4_custom_triton_dequant_kernel: bool = False, # TODO(future PR): switch default gemm to cublas gemm_kernel_choice: MXGemmKernelChoice = MXGemmKernelChoice.EMULATED, pack_fp6: bool = False, @@ -617,7 +597,6 @@ def to_mx( elem_dtype, block_size, data_hp.dtype, - use_fp4_custom_triton_dequant_kernel, gemm_kernel_choice, pack_fp6, act_quant_kwargs, @@ -636,7 +615,6 @@ def to_mx( elem_dtype, block_size, data_hp.dtype, - use_fp4_custom_triton_dequant_kernel, gemm_kernel_choice, pack_fp6, act_quant_kwargs, From 8b69d20368e04894d50ac8626c36a858dd5a8eda Mon Sep 17 00:00:00 2001 From: vasiliy Date: Thu, 21 Aug 2025 07:50:27 -0700 Subject: [PATCH 3/8] Update [ghstack-poisoned] --- test/prototype/mx_formats/test_kernels.py | 35 ++++++++++++++++++ torchao/prototype/mx_formats/kernels.py | 43 +++++++++++++++++++++++ 2 files changed, 78 insertions(+) diff --git a/test/prototype/mx_formats/test_kernels.py b/test/prototype/mx_formats/test_kernels.py index 0957bf0fb9..3d32973f40 100644 --- a/test/prototype/mx_formats/test_kernels.py +++ b/test/prototype/mx_formats/test_kernels.py @@ -561,3 +561,38 @@ def test_cuda_mx_dim1_invalid_block_size(): scale_dim_x=1, scale_dim_y=invalid_block_size, ) + + +def _fp32_to_fp4_reference( + data_hp: torch.Tensor, +) -> torch.Tensor: + # works + data_hp = data_hp.float() + data_lp = f32_to_f4_unpacked(data_hp) + + # does not work + # data_lp = f32_to_f4_unpacked(data_hp.float()) + + data_lp = pack_uint4(data_lp) + return data_lp + + +# TODO add skips +def test_fp32_cast_to_fp4x2(): + from torchao.prototype.mx_formats.kernels import triton_fp32_cast_to_fp4x2 + + M, K = 16, 16 + x = torch.randn(M, K, dtype=torch.bfloat16, device="cuda") + # make x's range be the representable range of fp4 + x = x * 6.0 + + data_ref = _fp32_to_fp4_reference(x) + # print(2, x[0]) + data = triton_fp32_cast_to_fp4x2(x) + # print(3, x[0]) + # print(0, x) + # print(1, data_ref, data_ref.shape) + # print(2, data, data.shape) + torch.testing.assert_close(data_ref, data) + assert data.shape == (M, K // 2) + print("done") diff --git a/torchao/prototype/mx_formats/kernels.py b/torchao/prototype/mx_formats/kernels.py index cd605917af..4271763e58 100644 --- a/torchao/prototype/mx_formats/kernels.py +++ b/torchao/prototype/mx_formats/kernels.py @@ -1454,6 +1454,49 @@ def _(scale_tensor): padded_cols = n_col_blocks * 4 return scale_tensor.new_empty((padded_rows, padded_cols)) + + @triton.jit + def fp32_cast_to_fp4x2_triton_kernel( + x_ptr, + q_ptr, + stride_xm, + stride_xn, + M, + N, + ): + pid_m = tl.program_id(1) + pid_n = tl.program_id(0) + + offs_m = pid_m * 128 + tl.arange(0, 128)[:, None] + offs_n = pid_n * 64 + tl.arange(0, 64)[None, :] + mask = None + other = None + x = tl.load( + x_ptr + offs_m * stride_xm + offs_n * stride_xn, mask=mask, other=other + ) # [128, 64] + x_blocks = x.to(tl.float32).reshape(128, 4, 16) # [128, 4, 16] + + # Convert to FP4 + x_fp4x2 = convert_fp32_to_fp4_packed(x_blocks.reshape(128, 32, 2).split()) + offs_m = pid_m * 128 + tl.arange(0, 128)[:, None] + offs_n = pid_n * 32 + tl.arange(0, 32)[None, :] + tl.store(q_ptr + offs_m * (N // 2) + offs_n, x_fp4x2, mask=None) + + def triton_fp32_cast_to_fp4x2(x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + M, N = x.shape + assert N % 16 == 0, "N must be divisible by 16 for NVFP4 quantization" + xq = x.new_empty(M, N // 2, dtype=torch.uint8) + grid = (triton.cdiv(N, 64), triton.cdiv(M, 128)) + fp32_cast_to_fp4x2_triton_kernel[grid]( + x, + xq, + x.stride(0), + x.stride(1), + M, + N, + ) + + return xq.view(torch.uint8) else: def triton_to_mxfp8_dim1( From 12657713703e860d567c3a08eb026a21b8afd97a Mon Sep 17 00:00:00 2001 From: vasiliy Date: Thu, 21 Aug 2025 07:52:36 -0700 Subject: [PATCH 4/8] Update [ghstack-poisoned] --- test/prototype/mx_formats/test_kernels.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/test/prototype/mx_formats/test_kernels.py b/test/prototype/mx_formats/test_kernels.py index 3d32973f40..b1d2d5291b 100644 --- a/test/prototype/mx_formats/test_kernels.py +++ b/test/prototype/mx_formats/test_kernels.py @@ -586,6 +586,13 @@ def test_fp32_cast_to_fp4x2(): # make x's range be the representable range of fp4 x = x * 6.0 + # this leads to values in `x` being overridden inplace + # TODO fix it + print(0, x) + data = triton_fp32_cast_to_fp4x2(x) + print(1, x) + return + data_ref = _fp32_to_fp4_reference(x) # print(2, x[0]) data = triton_fp32_cast_to_fp4x2(x) From ce7a34555c31cf20b195b1be25d26b454820c199 Mon Sep 17 00:00:00 2001 From: vasiliy Date: Thu, 21 Aug 2025 07:59:28 -0700 Subject: [PATCH 5/8] Update [ghstack-poisoned] --- test/prototype/mx_formats/test_kernels.py | 26 +++++------------------ torchao/prototype/mx_formats/kernels.py | 3 ++- 2 files changed, 7 insertions(+), 22 deletions(-) diff --git a/test/prototype/mx_formats/test_kernels.py b/test/prototype/mx_formats/test_kernels.py index b1d2d5291b..42c1c4d1de 100644 --- a/test/prototype/mx_formats/test_kernels.py +++ b/test/prototype/mx_formats/test_kernels.py @@ -566,18 +566,15 @@ def test_cuda_mx_dim1_invalid_block_size(): def _fp32_to_fp4_reference( data_hp: torch.Tensor, ) -> torch.Tensor: - # works - data_hp = data_hp.float() - data_lp = f32_to_f4_unpacked(data_hp) - - # does not work - # data_lp = f32_to_f4_unpacked(data_hp.float()) - + data_lp = f32_to_f4_unpacked(data_hp.float()) data_lp = pack_uint4(data_lp) return data_lp -# TODO add skips +@pytest.mark.skipif( + not is_sm_at_least_100(), + reason="requires CUDA capability 10.0 or greater", +) def test_fp32_cast_to_fp4x2(): from torchao.prototype.mx_formats.kernels import triton_fp32_cast_to_fp4x2 @@ -586,20 +583,7 @@ def test_fp32_cast_to_fp4x2(): # make x's range be the representable range of fp4 x = x * 6.0 - # this leads to values in `x` being overridden inplace - # TODO fix it - print(0, x) - data = triton_fp32_cast_to_fp4x2(x) - print(1, x) - return - data_ref = _fp32_to_fp4_reference(x) - # print(2, x[0]) data = triton_fp32_cast_to_fp4x2(x) - # print(3, x[0]) - # print(0, x) - # print(1, data_ref, data_ref.shape) - # print(2, data, data.shape) torch.testing.assert_close(data_ref, data) assert data.shape == (M, K // 2) - print("done") diff --git a/torchao/prototype/mx_formats/kernels.py b/torchao/prototype/mx_formats/kernels.py index 4271763e58..091f0e29d9 100644 --- a/torchao/prototype/mx_formats/kernels.py +++ b/torchao/prototype/mx_formats/kernels.py @@ -1480,7 +1480,8 @@ def fp32_cast_to_fp4x2_triton_kernel( x_fp4x2 = convert_fp32_to_fp4_packed(x_blocks.reshape(128, 32, 2).split()) offs_m = pid_m * 128 + tl.arange(0, 128)[:, None] offs_n = pid_n * 32 + tl.arange(0, 32)[None, :] - tl.store(q_ptr + offs_m * (N // 2) + offs_n, x_fp4x2, mask=None) + mask = (offs_m < M) & (offs_n < N // 2) + tl.store(q_ptr + offs_m * (N // 2) + offs_n, x_fp4x2, mask=mask) def triton_fp32_cast_to_fp4x2(x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: M, N = x.shape From ec9f618954f333082260b7116fcbc9f5e468984b Mon Sep 17 00:00:00 2001 From: vasiliy Date: Thu, 21 Aug 2025 08:04:38 -0700 Subject: [PATCH 6/8] Update [ghstack-poisoned] --- torchao/prototype/mx_formats/kernels.py | 16 +++++++++++----- 1 file changed, 11 insertions(+), 5 deletions(-) diff --git a/torchao/prototype/mx_formats/kernels.py b/torchao/prototype/mx_formats/kernels.py index 091f0e29d9..c91ee40f27 100644 --- a/torchao/prototype/mx_formats/kernels.py +++ b/torchao/prototype/mx_formats/kernels.py @@ -1466,7 +1466,6 @@ def fp32_cast_to_fp4x2_triton_kernel( ): pid_m = tl.program_id(1) pid_n = tl.program_id(0) - offs_m = pid_m * 128 + tl.arange(0, 128)[:, None] offs_n = pid_n * 64 + tl.arange(0, 64)[None, :] mask = None @@ -1475,7 +1474,6 @@ def fp32_cast_to_fp4x2_triton_kernel( x_ptr + offs_m * stride_xm + offs_n * stride_xn, mask=mask, other=other ) # [128, 64] x_blocks = x.to(tl.float32).reshape(128, 4, 16) # [128, 4, 16] - # Convert to FP4 x_fp4x2 = convert_fp32_to_fp4_packed(x_blocks.reshape(128, 32, 2).split()) offs_m = pid_m * 128 + tl.arange(0, 128)[:, None] @@ -1483,9 +1481,18 @@ def fp32_cast_to_fp4x2_triton_kernel( mask = (offs_m < M) & (offs_n < N // 2) tl.store(q_ptr + offs_m * (N // 2) + offs_n, x_fp4x2, mask=mask) - def triton_fp32_cast_to_fp4x2(x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + def triton_fp32_cast_to_fp4x2(x: torch.Tensor) -> torch.Tensor: + """ + Input: a float32 tensor with shape (M, N) + Output: a uint8 tensor with shape (M, N // 2), with the values being the result + of casting each original value to fp4_e2m1, and then packing fp4x2 + + TODO(future PR): optimize performance + TODO(future PR): better checks for shapes, etc + TODO(future PR): integrate into training/inference + TODO(future PR): integrate with compile, ideally allowing fusion + """ M, N = x.shape - assert N % 16 == 0, "N must be divisible by 16 for NVFP4 quantization" xq = x.new_empty(M, N // 2, dtype=torch.uint8) grid = (triton.cdiv(N, 64), triton.cdiv(M, 128)) fp32_cast_to_fp4x2_triton_kernel[grid]( @@ -1496,7 +1503,6 @@ def triton_fp32_cast_to_fp4x2(x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tens M, N, ) - return xq.view(torch.uint8) else: From 3c5fe9ec28ecc58568d99c83f8e018aad80027dd Mon Sep 17 00:00:00 2001 From: vasiliy Date: Thu, 21 Aug 2025 08:05:40 -0700 Subject: [PATCH 7/8] Update [ghstack-poisoned] --- test/prototype/mx_formats/test_kernels.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/prototype/mx_formats/test_kernels.py b/test/prototype/mx_formats/test_kernels.py index 42c1c4d1de..a2a8318a81 100644 --- a/test/prototype/mx_formats/test_kernels.py +++ b/test/prototype/mx_formats/test_kernels.py @@ -585,5 +585,5 @@ def test_fp32_cast_to_fp4x2(): data_ref = _fp32_to_fp4_reference(x) data = triton_fp32_cast_to_fp4x2(x) - torch.testing.assert_close(data_ref, data) + torch.testing.assert_close(data_ref, data, atol=0, rtol=0) assert data.shape == (M, K // 2) From a0f764c08b7a120793ecea5e571e7503d3752409 Mon Sep 17 00:00:00 2001 From: vasiliy Date: Fri, 22 Aug 2025 07:07:18 -0700 Subject: [PATCH 8/8] Update [ghstack-poisoned] --- benchmarks/mx_formats/cast_bench.py | 44 ++++++++++++++++++++- test/prototype/mx_formats/test_mx_tensor.py | 16 ++++++++ torchao/prototype/mx_formats/kernels.py | 6 ++- torchao/prototype/mx_formats/mx_tensor.py | 28 +++++++++---- 4 files changed, 84 insertions(+), 10 deletions(-) diff --git a/benchmarks/mx_formats/cast_bench.py b/benchmarks/mx_formats/cast_bench.py index a9d8b18ae7..7f6277831a 100644 --- a/benchmarks/mx_formats/cast_bench.py +++ b/benchmarks/mx_formats/cast_bench.py @@ -59,8 +59,15 @@ def to_mx_dim0_reference( block_size, scaling_mode=ScaleCalculationMode.FLOOR, target_dtype=torch.float8_e4m3fn, + use_fp32_to_fp4_triton_kernel=False, ): - scale_d0, data_d0 = to_mx(x_hp, target_dtype, block_size, scaling_mode=scaling_mode) + scale_d0, data_d0 = to_mx( + x_hp, + target_dtype, + block_size, + scaling_mode=scaling_mode, + use_fp32_to_fp4_triton_kernel=use_fp32_to_fp4_triton_kernel, + ) return data_d0, scale_d0 @@ -96,6 +103,7 @@ def run( "dim0_dim1", "dim0_mxfp8_floor", "dim0_mxfp4_floor", + "dim0_mxfp4_triton_floor", "dim0_mxfp8_rceil", "dim1_mxfp8_floor", "dim1_mxfp8_rceil", @@ -204,6 +212,40 @@ def run( bytes_w = (y_d0.numel() + s_d0.numel()) * bytes_per_el_fp8 bps = (bytes_r + bytes_w) / (time_us / 1e6) + elif mode == "dim0_mxfp4_triton_floor": + to_mx_dim0_reference_c = torch.compile(to_mx_dim0_reference) + y_d0, s_d0 = to_mx_dim0_reference_c( + x, + BLOCK_SIZE, + target_dtype=torch.float4_e2m1fn_x2, + use_fp32_to_fp4_triton_kernel=True, + ) + + for _ in range(2): + __ = to_mx_dim0_reference_c( + x, + BLOCK_SIZE, + target_dtype=torch.float4_e2m1fn_x2, + use_fp32_to_fp4_triton_kernel=True, + ) + time_us = benchmark_cuda_function_in_microseconds( + lambda x, b: to_mx_dim0_reference_c( + x, + BLOCK_SIZE, + target_dtype=torch.float4_e2m1fn_x2, + use_fp32_to_fp4_triton_kernel=True, + ), + x, + BLOCK_SIZE, + ) + + # TODO(future PR): make to_mx return float4 directly + assert y_d0.dtype == torch.uint8 + assert s_d0.dtype == torch.float8_e8m0fnu + bytes_r = x.numel() * bytes_per_el_bf16 + bytes_w = (y_d0.numel() + s_d0.numel()) * bytes_per_el_fp8 + bps = (bytes_r + bytes_w) / (time_us / 1e6) + elif mode == "dim0_mxfp8_rceil": to_mx_dim0_reference_c = torch.compile(to_mx_dim0_reference) y_d0, s_d0 = to_mx_dim0_reference_c(x, BLOCK_SIZE, ScaleCalculationMode.RCEIL) diff --git a/test/prototype/mx_formats/test_mx_tensor.py b/test/prototype/mx_formats/test_mx_tensor.py index ea1b7c6459..e235872ebc 100644 --- a/test/prototype/mx_formats/test_mx_tensor.py +++ b/test/prototype/mx_formats/test_mx_tensor.py @@ -96,6 +96,22 @@ def test_realistic_numerics(elem_dtype, scale_calculation_mode): _test_mx(data, elem_dtype, block_size, scale_calculation_mode) +def test_fp4_triton_cast_does_not_change_numerics(): + # TODO(before land): proper skips + # TODO(before land): test rank 3 + data = torch.randn(128, 128, device="cuda", dtype=torch.bfloat16) + data_mx_ref = MXTensor.to_mx( + data, torch.float4_e2m1fn_x2, 32, use_fp32_to_fp4_triton_kernel=False + ) + data_mx = MXTensor.to_mx( + data, torch.float4_e2m1fn_x2, 32, use_fp32_to_fp4_triton_kernel=True + ) + torch.testing.assert_close(data_mx_ref.qdata, data_mx.qdata, atol=0, rtol=0) + torch.testing.assert_close( + data_mx_ref._scale_e8m0, data_mx._scale_e8m0, atol=0, rtol=0 + ) + + @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") @pytest.mark.parametrize("elem_dtype", SUPPORTED_ELEM_DTYPES) def test_all_zeros(elem_dtype): diff --git a/torchao/prototype/mx_formats/kernels.py b/torchao/prototype/mx_formats/kernels.py index c91ee40f27..b3ddbe1596 100644 --- a/torchao/prototype/mx_formats/kernels.py +++ b/torchao/prototype/mx_formats/kernels.py @@ -1487,7 +1487,8 @@ def triton_fp32_cast_to_fp4x2(x: torch.Tensor) -> torch.Tensor: Output: a uint8 tensor with shape (M, N // 2), with the values being the result of casting each original value to fp4_e2m1, and then packing fp4x2 - TODO(future PR): optimize performance + TODO(future PR): optimize performance, lowest hanging fruit is we want + to add an e8m0 scale and scale the incoming tensor inside of this kernel TODO(future PR): better checks for shapes, etc TODO(future PR): integrate into training/inference TODO(future PR): integrate with compile, ideally allowing fusion @@ -1525,6 +1526,9 @@ def triton_quantize_nvfp4( ) -> Tuple[torch.Tensor, torch.Tensor]: raise AssertionError("needs torch version 2.8+ and triton") + def triton_fp32_cast_to_fp4x2(x: torch.Tensor) -> torch.Tensor: + raise AssertionError("needs torch version 2.8+ and triton") + # MXFP8 CUDA kernel is only built on SM100+ if is_sm_at_least_100(): diff --git a/torchao/prototype/mx_formats/mx_tensor.py b/torchao/prototype/mx_formats/mx_tensor.py index 273f1b2b56..de3d947a6e 100644 --- a/torchao/prototype/mx_formats/mx_tensor.py +++ b/torchao/prototype/mx_formats/mx_tensor.py @@ -55,6 +55,7 @@ pack_uint6, triton_f6_e2m3_to_scaled_bf16, triton_f6_e3m2_to_scaled_bf16, + triton_fp32_cast_to_fp4x2, unpack_uint4, ) from torchao.quantization.quantize_.common import ( @@ -134,6 +135,7 @@ def to_mx( block_size: int, scaling_mode: ScaleCalculationMode = ScaleCalculationMode.FLOOR, pack_fp6: bool = False, + use_fp32_to_fp4_triton_kernel: bool = False, ): """ Takes a high precision tensor and converts to MX scale and raw data, in @@ -309,13 +311,17 @@ def to_mx( # need to reshape at the end to help inductor fuse things data_lp = data_lp.reshape(orig_shape) elif elem_dtype == torch.float4_e2m1fn_x2: - # can't reshape at the end without handling it in the packing code, - # punt until later since we'll need to rethink the torch.compile - # approach for fp4x2 in any case - data_lp = data_lp.reshape(orig_shape) - data_lp = f32_to_f4_unpacked(data_lp) - orig_shape = [*orig_shape[:-1], orig_shape[-1] // 2] - data_lp = pack_uint4(data_lp) + if use_fp32_to_fp4_triton_kernel: + data_lp = data_lp.reshape(orig_shape) + data_lp = triton_fp32_cast_to_fp4x2(data_lp) + else: + # can't reshape at the end without handling it in the packing code, + # punt until later since we'll need to rethink the torch.compile + # approach for fp4x2 in any case + data_lp = data_lp.reshape(orig_shape) + data_lp = f32_to_f4_unpacked(data_lp) + orig_shape = [*orig_shape[:-1], orig_shape[-1] // 2] + data_lp = pack_uint4(data_lp) else: raise AssertionError("unsupported") @@ -583,9 +589,15 @@ def to_mx( gemm_kernel_choice: MXGemmKernelChoice = MXGemmKernelChoice.EMULATED, pack_fp6: bool = False, act_quant_kwargs: Optional[QuantizeTensorToMXKwargs] = None, + use_fp32_to_fp4_triton_kernel: bool = False, ): scale_e8m0_biased, data_lp = to_mx( - data_hp, elem_dtype, block_size, scaling_mode, pack_fp6 + data_hp, + elem_dtype, + block_size, + scaling_mode, + pack_fp6, + use_fp32_to_fp4_triton_kernel, ) if isinstance(scale_e8m0_biased, DTensor): assert isinstance(data_lp, DTensor), "unsupported"