Skip to content

Commit 1fb0c25

Browse files
committed
mx: delete triton_f4_to_bf16 kernel
Summary: This is dead code from early days, test is broken, deleting. Test Plan: ```bash pytest test/prototype/mx_formats/ -s -x ``` Reviewers: Subscribers: Tasks: Tags: ghstack-source-id: 3460e23 ghstack-comment-id: 3210590444 Pull-Request: #2830
1 parent b6435f9 commit 1fb0c25

File tree

2 files changed

+0
-114
lines changed

2 files changed

+0
-114
lines changed

test/prototype/mx_formats/test_kernels.py

Lines changed: 0 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,6 @@
3535
get_bits,
3636
pack_uint4,
3737
pack_uint6,
38-
triton_f4_to_bf16,
3938
triton_f6_e2m3_to_bf16,
4039
triton_f6_e3m2_to_bf16,
4140
triton_to_mxfp8_dim1,
@@ -327,17 +326,6 @@ def test_fp4_pack_unpack():
327326
assert torch.all(orig_vals_dq == orig_vals)
328327

329328

330-
# TODO(future PR): fix or delete this test
331-
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
332-
@pytest.mark.skipif(not has_triton(), reason="unsupported without triton")
333-
@pytest.mark.skipif(is_sm_at_least_89(), reason="broken on CUDA capability 8.9+")
334-
def test_fp4_triton_unscaled_cast():
335-
packed_vals = torch.arange(0, 255, dtype=torch.uint8, device="cuda")
336-
f32_ref = f4_unpacked_to_f32(unpack_uint4(packed_vals))
337-
f32_triton = triton_f4_to_bf16(packed_vals).to(torch.float)
338-
assert torch.all(torch.eq(f32_ref, f32_triton))
339-
340-
341329
# TODO(future PR): fix or delete this test
342330
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
343331
@pytest.mark.skipif(not has_triton(), reason="unsupported without triton")

torchao/prototype/mx_formats/kernels.py

Lines changed: 0 additions & 102 deletions
Original file line numberDiff line numberDiff line change
@@ -196,55 +196,6 @@ def _fp4_packed_to_bf16(
196196
output = output.to(tl.bfloat16)
197197
return output
198198

199-
@triton.jit
200-
def triton_f4_to_bf16_kernel(
201-
x_ptr,
202-
output_ptr,
203-
n_elements_in,
204-
sign_mask_f4: tl.constexpr,
205-
mantissa_mask_f4: tl.constexpr,
206-
mbits_f4_e2m1: tl.constexpr,
207-
ebits_f4_e2m1: tl.constexpr,
208-
f4_e2m1_exp_bias: tl.constexpr,
209-
mbits_f32: tl.constexpr,
210-
ebits_f32: tl.constexpr,
211-
f32_exp_bias: tl.constexpr,
212-
zero_bits_f32: tl.constexpr,
213-
zero_point_five_bits_f32: tl.constexpr,
214-
BLOCK_SIZE_IN: tl.constexpr,
215-
):
216-
pid = tl.program_id(axis=0)
217-
n_elements_out = n_elements_in * 2
218-
BLOCK_SIZE_OUT: tl.constexpr = BLOCK_SIZE_IN * 2
219-
220-
block_start_in = pid * BLOCK_SIZE_IN
221-
offsets_in = block_start_in + tl.arange(0, BLOCK_SIZE_IN)
222-
223-
mask_in = offsets_in < n_elements_in
224-
225-
# packed uint8
226-
x_packed = tl.load(x_ptr + offsets_in, mask=mask_in)
227-
output = _fp4_packed_to_bf16(
228-
x_packed,
229-
sign_mask_f4,
230-
mantissa_mask_f4,
231-
mbits_f4_e2m1,
232-
ebits_f4_e2m1,
233-
f4_e2m1_exp_bias,
234-
mbits_f32,
235-
ebits_f32,
236-
f32_exp_bias,
237-
zero_bits_f32,
238-
zero_point_five_bits_f32,
239-
)
240-
241-
# set up output offsets
242-
block_start_out = pid * BLOCK_SIZE_OUT
243-
offsets_out = block_start_out + tl.arange(0, BLOCK_SIZE_OUT)
244-
mask_out = offsets_out < n_elements_out
245-
246-
tl.store(output_ptr + offsets_out, output, mask=mask_out)
247-
248199
@triton.autotune(
249200
configs=[
250201
triton.Config({"BLOCK_SIZE_IN": 128}),
@@ -624,24 +575,6 @@ def triton_pack_uint6_kernel(
624575

625576
else:
626577

627-
def triton_f4_to_bf16_kernel(
628-
x_ptr,
629-
output_ptr,
630-
n_elements_in,
631-
sign_mask_f4,
632-
mantissa_mask_f4,
633-
mbits_f4_e2m1,
634-
ebits_f4_e2m1,
635-
f4_e2m1_exp_bias,
636-
mbits_f32,
637-
ebits_f32,
638-
f32_exp_bias,
639-
zero_bits_f32,
640-
zero_point_five_bits_f32,
641-
BLOCK_SIZE_IN,
642-
):
643-
raise AssertionError("unsupported without triton")
644-
645578
def triton_f4_to_scaled_bf16_kernel(
646579
x_ptr,
647580
s_ptr,
@@ -705,41 +638,6 @@ def triton_pack_uint6_kernel(
705638
raise AssertionError("unsupported without triton")
706639

707640

708-
def triton_f4_to_bf16(x: torch.Tensor):
709-
"""
710-
Input: a tensor of packed fp4 values
711-
Output: a tensor of bfloat16 values
712-
713-
Note: this function is only used in testing, so we can test
714-
the numerical correctness of the cast without the scaling.
715-
"""
716-
new_shape = (*x.shape[:-1], x.shape[-1] * 2)
717-
output = torch.empty(*new_shape, device=x.device, dtype=torch.bfloat16)
718-
assert x.is_contiguous()
719-
assert x.is_cuda and output.is_cuda
720-
n_elements_in = x.numel()
721-
grid = lambda meta: ( # noqa: E731
722-
triton.cdiv(n_elements_in, meta["BLOCK_SIZE_IN"]),
723-
) # noqa: E731,E501
724-
triton_f4_to_bf16_kernel[grid](
725-
x,
726-
output,
727-
n_elements_in,
728-
sign_mask_f4=SIGN_MASK_F4,
729-
mantissa_mask_f4=MANTISSA_MASK_F4,
730-
mbits_f4_e2m1=MBITS_F4_E2M1,
731-
ebits_f4_e2m1=EBITS_F4_E2M1,
732-
f4_e2m1_exp_bias=F4_E2M1_EXP_BIAS,
733-
mbits_f32=MBITS_F32,
734-
ebits_f32=EBITS_F32,
735-
f32_exp_bias=F32_EXP_BIAS,
736-
zero_bits_f32=ZERO_BITS_F32,
737-
zero_point_five_bits_f32=ZERO_POINT_FIVE_BITS_F32,
738-
BLOCK_SIZE_IN=512,
739-
)
740-
return output
741-
742-
743641
def triton_f4_to_scaled_bf16(
744642
x: torch.Tensor,
745643
s_e8m0: torch.Tensor,

0 commit comments

Comments
 (0)