@@ -1454,6 +1454,57 @@ def _(scale_tensor):
14541454 padded_cols = n_col_blocks * 4
14551455
14561456 return scale_tensor .new_empty ((padded_rows , padded_cols ))
1457+
1458+ @triton .jit
1459+ def fp32_cast_to_fp4x2_triton_kernel (
1460+ x_ptr ,
1461+ q_ptr ,
1462+ stride_xm ,
1463+ stride_xn ,
1464+ M ,
1465+ N ,
1466+ ):
1467+ pid_m = tl .program_id (1 )
1468+ pid_n = tl .program_id (0 )
1469+ offs_m = pid_m * 128 + tl .arange (0 , 128 )[:, None ]
1470+ offs_n = pid_n * 64 + tl .arange (0 , 64 )[None , :]
1471+ mask = None
1472+ other = None
1473+ x = tl .load (
1474+ x_ptr + offs_m * stride_xm + offs_n * stride_xn , mask = mask , other = other
1475+ ) # [128, 64]
1476+ x_blocks = x .to (tl .float32 ).reshape (128 , 4 , 16 ) # [128, 4, 16]
1477+ # Convert to FP4
1478+ x_fp4x2 = convert_fp32_to_fp4_packed (x_blocks .reshape (128 , 32 , 2 ).split ())
1479+ offs_m = pid_m * 128 + tl .arange (0 , 128 )[:, None ]
1480+ offs_n = pid_n * 32 + tl .arange (0 , 32 )[None , :]
1481+ mask = (offs_m < M ) & (offs_n < N // 2 )
1482+ tl .store (q_ptr + offs_m * (N // 2 ) + offs_n , x_fp4x2 , mask = mask )
1483+
1484+ def triton_fp32_cast_to_fp4x2 (x : torch .Tensor ) -> torch .Tensor :
1485+ """
1486+ Input: a float32 tensor with shape (M, N)
1487+ Output: a uint8 tensor with shape (M, N // 2), with the values being the result
1488+ of casting each original value to fp4_e2m1, and then packing fp4x2
1489+
1490+ TODO(future PR): optimize performance, lowest hanging fruit is we want
1491+ to add an e8m0 scale and scale the incoming tensor inside of this kernel
1492+ TODO(future PR): better checks for shapes, etc
1493+ TODO(future PR): integrate into training/inference
1494+ TODO(future PR): integrate with compile, ideally allowing fusion
1495+ """
1496+ M , N = x .shape
1497+ xq = x .new_empty (M , N // 2 , dtype = torch .uint8 )
1498+ grid = (triton .cdiv (N , 64 ), triton .cdiv (M , 128 ))
1499+ fp32_cast_to_fp4x2_triton_kernel [grid ](
1500+ x ,
1501+ xq ,
1502+ x .stride (0 ),
1503+ x .stride (1 ),
1504+ M ,
1505+ N ,
1506+ )
1507+ return xq .view (torch .uint8 )
14571508else :
14581509
14591510 def triton_to_mxfp8_dim1 (
@@ -1475,6 +1526,9 @@ def triton_quantize_nvfp4(
14751526 ) -> Tuple [torch .Tensor , torch .Tensor ]:
14761527 raise AssertionError ("needs torch version 2.8+ and triton" )
14771528
1529+ def triton_fp32_cast_to_fp4x2 (x : torch .Tensor ) -> torch .Tensor :
1530+ raise AssertionError ("needs torch version 2.8+ and triton" )
1531+
14781532
14791533# MXFP8 CUDA kernel is only built on SM100+
14801534if is_sm_at_least_100 ():
0 commit comments