12
12
import torch .nn as nn
13
13
from torch .nn import Parameter
14
14
15
+ import transformer_engine .pytorch as te
15
16
from transformer_engine .pytorch .fp8 import (
16
17
FP8GlobalStateManager ,
17
18
fp8_autocast ,
18
19
fp8_model_init ,
19
20
)
21
+ from transformer_engine .pytorch .tensor .mxfp8_tensor import MXFP8Quantizer
20
22
from transformer_engine .pytorch .utils import (
21
23
init_method_normal ,
22
24
scaled_init_method_normal ,
@@ -1697,13 +1699,16 @@ def _test_grouped_linear_accuracy(
1697
1699
fp8 ,
1698
1700
fuse_wgrad_accumulation ,
1699
1701
delay_wgrad_compute = False ,
1702
+ activation_func = None , # assume gated activation function
1700
1703
):
1701
1704
reset_rng_states ()
1702
1705
if fp8 :
1703
1706
FP8GlobalStateManager .reset ()
1704
1707
1708
+ # assume gated activation function
1709
+ hidden_size = config .hidden_size if activation_func is None else 2 * config .hidden_size
1705
1710
inp_hidden_states = torch .randn (
1706
- (config .max_seqlen_q , bs , config . hidden_size ),
1711
+ (config .max_seqlen_q , bs , hidden_size ),
1707
1712
dtype = dtype ,
1708
1713
device = "cuda" ,
1709
1714
requires_grad = True ,
@@ -1728,11 +1733,11 @@ def _test_grouped_linear_accuracy(
1728
1733
with fp8_autocast (enabled = fp8 , fp8_recipe = recipe ):
1729
1734
if isinstance (block , GroupedLinear ):
1730
1735
m_splits = m_splits * bs
1731
- out = block (inp_hidden_states , m_splits .tolist ())
1736
+ out = block (activation_func ( inp_hidden_states ) , m_splits .tolist ())
1732
1737
else :
1733
1738
out = torch .cat (
1734
1739
[
1735
- block [i ](inp )
1740
+ block [i ](activation_func ( inp ) )
1736
1741
for i , inp in enumerate (torch .split (inp_hidden_states , m_splits .tolist ()))
1737
1742
]
1738
1743
)
@@ -1967,6 +1972,92 @@ def test_grouped_linear_accuracy_single_gemm(recipe):
1967
1972
)
1968
1973
1969
1974
1975
+ @pytest .mark .skipif (not mxfp8_available , reason = "MXFP8 is not available" )
1976
+ @pytest .mark .parametrize ("dtype" , param_types , ids = str )
1977
+ @pytest .mark .parametrize ("num_gemms" , [3 , 6 ])
1978
+ @pytest .mark .parametrize ("bs" , batch_sizes )
1979
+ @pytest .mark .parametrize ("model" , ["126m" ])
1980
+ @pytest .mark .parametrize ("recipe" , [recipe .MXFP8BlockScaling ()])
1981
+ @pytest .mark .parametrize ("fp8_model_params" , all_boolean )
1982
+ def test_grouped_linear_fp8_input (
1983
+ dtype ,
1984
+ num_gemms ,
1985
+ bs ,
1986
+ model ,
1987
+ recipe ,
1988
+ fp8_model_params ,
1989
+ ):
1990
+ config = model_configs [model ]
1991
+ if config .max_seqlen_q % 32 != 0 :
1992
+ pytest .skip ("MXFP8 requires sequence length to be divisible by 32." )
1993
+
1994
+ with fp8_model_init (enabled = fp8_model_params , recipe = recipe ):
1995
+ grouped_linear_bf16_input = GroupedLinear (
1996
+ num_gemms ,
1997
+ config .hidden_size ,
1998
+ 4 * config .hidden_size ,
1999
+ bias = False ,
2000
+ params_dtype = dtype ,
2001
+ device = "cuda" ,
2002
+ fuse_wgrad_accumulation = True ,
2003
+ ).eval ()
2004
+
2005
+ grouped_linear_fp8_input = GroupedLinear (
2006
+ num_gemms ,
2007
+ config .hidden_size ,
2008
+ 4 * config .hidden_size ,
2009
+ bias = False ,
2010
+ params_dtype = dtype ,
2011
+ device = "cuda" ,
2012
+ fuse_wgrad_accumulation = True ,
2013
+ ).eval ()
2014
+
2015
+ # Share params
2016
+ with torch .no_grad ():
2017
+ for i in range (num_gemms ):
2018
+ setattr (
2019
+ grouped_linear_fp8_input ,
2020
+ f"weight{ i } " ,
2021
+ Parameter (getattr (grouped_linear_bf16_input , f"weight{ i } " ).clone ()),
2022
+ )
2023
+ weight_i = getattr (grouped_linear_bf16_input , f"weight{ i } " )
2024
+ weight_i .main_grad = torch .rand_like (weight_i , dtype = torch .float32 )
2025
+ weight_i_copy = getattr (grouped_linear_fp8_input , f"weight{ i } " )
2026
+ weight_i_copy .main_grad = weight_i .main_grad .clone ()
2027
+
2028
+ bf16_activation = te .ops .SwiGLU ()
2029
+ fp8_activation = te .ops .Sequential (
2030
+ te .ops .SwiGLU (),
2031
+ te .ops .Quantize (forward = True , backward = False ), # Output QuantizedTensor in forward
2032
+ )
2033
+
2034
+ outputs_ref = _test_grouped_linear_accuracy (
2035
+ grouped_linear_bf16_input ,
2036
+ num_gemms ,
2037
+ bs ,
2038
+ dtype ,
2039
+ config ,
2040
+ recipe ,
2041
+ fp8 = True ,
2042
+ fuse_wgrad_accumulation = True ,
2043
+ activation_func = bf16_activation ,
2044
+ )
2045
+ outputs = _test_grouped_linear_accuracy (
2046
+ grouped_linear_fp8_input ,
2047
+ num_gemms ,
2048
+ bs ,
2049
+ dtype ,
2050
+ config ,
2051
+ recipe ,
2052
+ fp8 = True ,
2053
+ fuse_wgrad_accumulation = True ,
2054
+ activation_func = fp8_activation ,
2055
+ )
2056
+ # Shoule be bit-wise match
2057
+ for i , (o , o_ref ) in enumerate (zip (outputs , outputs_ref )):
2058
+ torch .testing .assert_close (o , o_ref , rtol = 0 , atol = 0 )
2059
+
2060
+
1970
2061
def _test_padding_grouped_linear_accuracy (block , num_gemms , bs , dtype , config , recipe , fp8 = False ):
1971
2062
1972
2063
def _pad_tensor_for_fp8 (hidden_states , tokens_per_expert ):
@@ -2706,3 +2797,49 @@ def _run_module(m, inp):
2706
2797
out = _run_module (g2 , b )
2707
2798
2708
2799
assert_allclose (out , outT , 1e-7 )
2800
+
2801
+
2802
+ @pytest .mark .skipif (not mxfp8_available , reason = "MXFP8 is not available" )
2803
+ @pytest .mark .parametrize ("dtype" , param_types , ids = str )
2804
+ @pytest .mark .parametrize ("num_experts" , [8 ])
2805
+ @pytest .mark .parametrize ("m" , [64 , 128 , 256 ])
2806
+ @pytest .mark .parametrize ("k" , [64 , 128 , 256 ])
2807
+ def test_split_quantized_tensor (dtype , num_experts , m , k ):
2808
+
2809
+ tensor = torch .randn ((m * num_experts , k ), dtype = dtype , device = "cuda" )
2810
+ m_splits = [m ] * num_experts
2811
+
2812
+ quantizer = MXFP8Quantizer (
2813
+ fp8_dtype = tex .DType .kFloat8E4M3 ,
2814
+ rowwise = True ,
2815
+ columnwise = True ,
2816
+ )
2817
+
2818
+ # Split and quantize one by one
2819
+ ref_mxfp8 = tex .split_quantize (tensor , m_splits , [quantizer ] * num_experts )
2820
+
2821
+ # Quantize as a whole and then split
2822
+ out_mxfp8 = tex .split_quantized_tensor (quantizer .quantize (tensor ), m_splits )
2823
+
2824
+ for ref , out in zip (ref_mxfp8 , out_mxfp8 ):
2825
+ assert ref ._quantizer .rowwise_usage == out ._quantizer .rowwise_usage
2826
+ assert ref ._quantizer .columnwise_usage == out ._quantizer .columnwise_usage
2827
+ assert ref ._quantizer .dtype == out ._quantizer .dtype
2828
+ assert ref ._quantizer .internal == out ._quantizer .internal
2829
+
2830
+ torch .testing .assert_close (ref ._rowwise_data , out ._rowwise_data , rtol = 0 , atol = 0 )
2831
+ # Padded area are random filled.
2832
+ torch .testing .assert_close (
2833
+ ref ._rowwise_scale_inv [:m , : k // 32 ],
2834
+ out ._rowwise_scale_inv [:m , : k // 32 ],
2835
+ rtol = 0 ,
2836
+ atol = 0 ,
2837
+ )
2838
+ torch .testing .assert_close (ref ._columnwise_data , out ._columnwise_data , rtol = 0 , atol = 0 )
2839
+ # Padded area are random filled.
2840
+ torch .testing .assert_close (
2841
+ ref ._columnwise_scale_inv [: m // 32 , :k ],
2842
+ out ._columnwise_scale_inv [: m // 32 , :k ],
2843
+ rtol = 0 ,
2844
+ atol = 0 ,
2845
+ )
0 commit comments