@@ -196,55 +196,6 @@ def _fp4_packed_to_bf16(
196
196
output = output .to (tl .bfloat16 )
197
197
return output
198
198
199
- @triton .jit
200
- def triton_f4_to_bf16_kernel (
201
- x_ptr ,
202
- output_ptr ,
203
- n_elements_in ,
204
- sign_mask_f4 : tl .constexpr ,
205
- mantissa_mask_f4 : tl .constexpr ,
206
- mbits_f4_e2m1 : tl .constexpr ,
207
- ebits_f4_e2m1 : tl .constexpr ,
208
- f4_e2m1_exp_bias : tl .constexpr ,
209
- mbits_f32 : tl .constexpr ,
210
- ebits_f32 : tl .constexpr ,
211
- f32_exp_bias : tl .constexpr ,
212
- zero_bits_f32 : tl .constexpr ,
213
- zero_point_five_bits_f32 : tl .constexpr ,
214
- BLOCK_SIZE_IN : tl .constexpr ,
215
- ):
216
- pid = tl .program_id (axis = 0 )
217
- n_elements_out = n_elements_in * 2
218
- BLOCK_SIZE_OUT : tl .constexpr = BLOCK_SIZE_IN * 2
219
-
220
- block_start_in = pid * BLOCK_SIZE_IN
221
- offsets_in = block_start_in + tl .arange (0 , BLOCK_SIZE_IN )
222
-
223
- mask_in = offsets_in < n_elements_in
224
-
225
- # packed uint8
226
- x_packed = tl .load (x_ptr + offsets_in , mask = mask_in )
227
- output = _fp4_packed_to_bf16 (
228
- x_packed ,
229
- sign_mask_f4 ,
230
- mantissa_mask_f4 ,
231
- mbits_f4_e2m1 ,
232
- ebits_f4_e2m1 ,
233
- f4_e2m1_exp_bias ,
234
- mbits_f32 ,
235
- ebits_f32 ,
236
- f32_exp_bias ,
237
- zero_bits_f32 ,
238
- zero_point_five_bits_f32 ,
239
- )
240
-
241
- # set up output offsets
242
- block_start_out = pid * BLOCK_SIZE_OUT
243
- offsets_out = block_start_out + tl .arange (0 , BLOCK_SIZE_OUT )
244
- mask_out = offsets_out < n_elements_out
245
-
246
- tl .store (output_ptr + offsets_out , output , mask = mask_out )
247
-
248
199
@triton .autotune (
249
200
configs = [
250
201
triton .Config ({"BLOCK_SIZE_IN" : 128 }),
@@ -624,24 +575,6 @@ def triton_pack_uint6_kernel(
624
575
625
576
else :
626
577
627
- def triton_f4_to_bf16_kernel (
628
- x_ptr ,
629
- output_ptr ,
630
- n_elements_in ,
631
- sign_mask_f4 ,
632
- mantissa_mask_f4 ,
633
- mbits_f4_e2m1 ,
634
- ebits_f4_e2m1 ,
635
- f4_e2m1_exp_bias ,
636
- mbits_f32 ,
637
- ebits_f32 ,
638
- f32_exp_bias ,
639
- zero_bits_f32 ,
640
- zero_point_five_bits_f32 ,
641
- BLOCK_SIZE_IN ,
642
- ):
643
- raise AssertionError ("unsupported without triton" )
644
-
645
578
def triton_f4_to_scaled_bf16_kernel (
646
579
x_ptr ,
647
580
s_ptr ,
@@ -705,41 +638,6 @@ def triton_pack_uint6_kernel(
705
638
raise AssertionError ("unsupported without triton" )
706
639
707
640
708
- def triton_f4_to_bf16 (x : torch .Tensor ):
709
- """
710
- Input: a tensor of packed fp4 values
711
- Output: a tensor of bfloat16 values
712
-
713
- Note: this function is only used in testing, so we can test
714
- the numerical correctness of the cast without the scaling.
715
- """
716
- new_shape = (* x .shape [:- 1 ], x .shape [- 1 ] * 2 )
717
- output = torch .empty (* new_shape , device = x .device , dtype = torch .bfloat16 )
718
- assert x .is_contiguous ()
719
- assert x .is_cuda and output .is_cuda
720
- n_elements_in = x .numel ()
721
- grid = lambda meta : ( # noqa: E731
722
- triton .cdiv (n_elements_in , meta ["BLOCK_SIZE_IN" ]),
723
- ) # noqa: E731,E501
724
- triton_f4_to_bf16_kernel [grid ](
725
- x ,
726
- output ,
727
- n_elements_in ,
728
- sign_mask_f4 = SIGN_MASK_F4 ,
729
- mantissa_mask_f4 = MANTISSA_MASK_F4 ,
730
- mbits_f4_e2m1 = MBITS_F4_E2M1 ,
731
- ebits_f4_e2m1 = EBITS_F4_E2M1 ,
732
- f4_e2m1_exp_bias = F4_E2M1_EXP_BIAS ,
733
- mbits_f32 = MBITS_F32 ,
734
- ebits_f32 = EBITS_F32 ,
735
- f32_exp_bias = F32_EXP_BIAS ,
736
- zero_bits_f32 = ZERO_BITS_F32 ,
737
- zero_point_five_bits_f32 = ZERO_POINT_FIVE_BITS_F32 ,
738
- BLOCK_SIZE_IN = 512 ,
739
- )
740
- return output
741
-
742
-
743
641
def triton_f4_to_scaled_bf16 (
744
642
x : torch .Tensor ,
745
643
s_e8m0 : torch .Tensor ,
0 commit comments