Skip to content

Commit 48746dc

Browse files
committed
add more tests
Signed-off-by: Xin Yao <[email protected]>
1 parent f212fc6 commit 48746dc

File tree

1 file changed

+29
-5
lines changed

1 file changed

+29
-5
lines changed

tests/pytorch/test_numerics.py

Lines changed: 29 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44

55
import math
66
import os
7+
from pickletools import genops
78
from typing import Dict, List, Tuple, Optional
89
import pytest
910
import random
@@ -94,6 +95,21 @@
9495
"swiglu",
9596
]
9697

98+
act_ops = {
99+
"gelu": te.ops.GELU,
100+
"geglu": te.ops.GEGLU,
101+
"qgelu": te.ops.QGELU,
102+
"qgeglu": te.ops.QGEGLU,
103+
"relu": te.ops.ReLU,
104+
"reglu": te.ops.ReGLU,
105+
"srelu": te.ops.SReLU,
106+
"sreglu": te.ops.SReGLU,
107+
"silu": te.ops.SiLU,
108+
"swiglu": te.ops.SwiGLU,
109+
}
110+
111+
gated_act_ops = (te.ops.GEGLU, te.ops.QGEGLU, te.ops.ReGLU, te.ops.SReGLU, te.ops.SwiGLU)
112+
97113
all_normalizations = ["LayerNorm", "RMSNorm"]
98114

99115
mask_types = ["causal", "no_mask"]
@@ -1713,14 +1729,20 @@ def _test_grouped_linear_accuracy(
17131729
fp8,
17141730
fuse_wgrad_accumulation,
17151731
delay_wgrad_compute=False,
1716-
activation_func=None, # assume gated activation function
1732+
activation_func=te.ops.Identity(),
17171733
):
17181734
reset_rng_states()
17191735
if fp8:
17201736
FP8GlobalStateManager.reset()
17211737

1722-
# assume gated activation function
1723-
hidden_size = config.hidden_size if activation_func is None else 2 * config.hidden_size
1738+
if isinstance(activation_func, gated_act_ops) or (
1739+
isinstance(activation_func, te.ops.Sequential)
1740+
and isinstance(activation_func[0], gated_act_ops)
1741+
):
1742+
hidden_size = 2 * config.hidden_size
1743+
else:
1744+
hidden_size = config.hidden_size
1745+
17241746
inp_hidden_states = torch.randn(
17251747
(config.max_seqlen_q, bs, hidden_size),
17261748
dtype=dtype,
@@ -1993,13 +2015,15 @@ def test_grouped_linear_accuracy_single_gemm(recipe):
19932015
@pytest.mark.parametrize("model", ["126m"])
19942016
@pytest.mark.parametrize("recipe", [recipe.MXFP8BlockScaling()])
19952017
@pytest.mark.parametrize("fp8_model_params", all_boolean)
2018+
@pytest.mark.parametrize("activation", all_activations)
19962019
def test_grouped_linear_fp8_input(
19972020
dtype,
19982021
num_gemms,
19992022
bs,
20002023
model,
20012024
recipe,
20022025
fp8_model_params,
2026+
activation,
20032027
):
20042028
config = model_configs[model]
20052029
if config.max_seqlen_q % 32 != 0:
@@ -2039,9 +2063,9 @@ def test_grouped_linear_fp8_input(
20392063
weight_i_copy = getattr(grouped_linear_fp8_input, f"weight{i}")
20402064
weight_i_copy.main_grad = weight_i.main_grad.clone()
20412065

2042-
bf16_activation = te.ops.SwiGLU()
2066+
bf16_activation = act_ops[activation]()
20432067
fp8_activation = te.ops.Sequential(
2044-
te.ops.SwiGLU(),
2068+
bf16_activation,
20452069
te.ops.Quantize(forward=True, backward=False), # Output QuantizedTensor in forward
20462070
)
20472071

0 commit comments

Comments
 (0)