4
4
5
5
import math
6
6
import os
7
- from pickletools import genops
8
7
from typing import Dict , List , Tuple , Optional
9
8
import pytest
10
9
import random
108
107
"swiglu" : te .ops .SwiGLU ,
109
108
}
110
109
111
- gated_act_ops = (te .ops .GEGLU , te .ops .QGEGLU , te .ops .ReGLU , te .ops .SReGLU , te .ops .SwiGLU )
112
-
113
110
all_normalizations = ["LayerNorm" , "RMSNorm" ]
114
111
115
112
mask_types = ["causal" , "no_mask" ]
@@ -1729,20 +1726,34 @@ def _test_grouped_linear_accuracy(
1729
1726
fp8 ,
1730
1727
fuse_wgrad_accumulation ,
1731
1728
delay_wgrad_compute = False ,
1732
- activation_func = te .ops .Identity (),
1729
+ activation_func = None ,
1730
+ fp8_activation = False ,
1733
1731
):
1734
1732
reset_rng_states ()
1735
1733
if fp8 :
1736
1734
FP8GlobalStateManager .reset ()
1737
1735
1738
- if isinstance (activation_func , gated_act_ops ) or (
1739
- isinstance (activation_func , te .ops .Sequential )
1740
- and isinstance (activation_func [0 ], gated_act_ops )
1741
- ):
1736
+ if activation_func in ("geglu" , "qgeglu" , "reglu" , "sreglu" , "swiglu" ):
1742
1737
hidden_size = 2 * config .hidden_size
1743
1738
else :
1744
1739
hidden_size = config .hidden_size
1745
1740
1741
+ if activation_func is None :
1742
+ input_act_func = te .ops .Identity ()
1743
+ output_act_func = te .ops .Identity ()
1744
+ elif fp8_activation :
1745
+ input_act_func = te .ops .Sequential (
1746
+ act_ops [activation_func ](),
1747
+ te .ops .Quantize (forward = True , backward = False ), # Output QuantizedTensor in forward
1748
+ )
1749
+ output_act_func = te .ops .Sequential (
1750
+ te .ops .Quantize (forward = False , backward = True ), # Output QuantizedTensor in backward
1751
+ act_ops [activation_func ](),
1752
+ )
1753
+ else :
1754
+ input_act_func = act_ops [activation_func ]()
1755
+ output_act_func = act_ops [activation_func ]()
1756
+
1746
1757
inp_hidden_states = torch .randn (
1747
1758
(config .max_seqlen_q , bs , hidden_size ),
1748
1759
dtype = dtype ,
@@ -1769,11 +1780,11 @@ def _test_grouped_linear_accuracy(
1769
1780
with fp8_autocast (enabled = fp8 , fp8_recipe = recipe ):
1770
1781
if isinstance (block , GroupedLinear ):
1771
1782
m_splits = m_splits * bs
1772
- out = block (activation_func (inp_hidden_states ), m_splits .tolist ())
1783
+ out = output_act_func ( block (input_act_func (inp_hidden_states ), m_splits .tolist () ))
1773
1784
else :
1774
1785
out = torch .cat (
1775
1786
[
1776
- block [i ](activation_func (inp ))
1787
+ output_act_func ( block [i ](input_act_func (inp ) ))
1777
1788
for i , inp in enumerate (torch .split (inp_hidden_states , m_splits .tolist ()))
1778
1789
]
1779
1790
)
@@ -2063,12 +2074,6 @@ def test_grouped_linear_fp8_input(
2063
2074
weight_i_copy = getattr (grouped_linear_fp8_input , f"weight{ i } " )
2064
2075
weight_i_copy .main_grad = weight_i .main_grad .clone ()
2065
2076
2066
- bf16_activation = act_ops [activation ]()
2067
- fp8_activation = te .ops .Sequential (
2068
- bf16_activation ,
2069
- te .ops .Quantize (forward = True , backward = False ), # Output QuantizedTensor in forward
2070
- )
2071
-
2072
2077
outputs_ref = _test_grouped_linear_accuracy (
2073
2078
grouped_linear_bf16_input ,
2074
2079
num_gemms ,
@@ -2078,7 +2083,8 @@ def test_grouped_linear_fp8_input(
2078
2083
recipe ,
2079
2084
fp8 = True ,
2080
2085
fuse_wgrad_accumulation = True ,
2081
- activation_func = bf16_activation ,
2086
+ activation_func = activation ,
2087
+ fp8_activation = False ,
2082
2088
)
2083
2089
outputs = _test_grouped_linear_accuracy (
2084
2090
grouped_linear_fp8_input ,
@@ -2089,7 +2095,8 @@ def test_grouped_linear_fp8_input(
2089
2095
recipe ,
2090
2096
fp8 = True ,
2091
2097
fuse_wgrad_accumulation = True ,
2092
- activation_func = fp8_activation ,
2098
+ activation_func = activation ,
2099
+ fp8_activation = True ,
2093
2100
)
2094
2101
# Shoule be bit-wise match
2095
2102
for i , (o , o_ref ) in enumerate (zip (outputs , outputs_ref )):
0 commit comments