|
4 | 4 |
|
5 | 5 | import math
|
6 | 6 | import os
|
| 7 | +from pickletools import genops |
7 | 8 | from typing import Dict, List, Tuple, Optional
|
8 | 9 | import pytest
|
9 | 10 | import random
|
|
94 | 95 | "swiglu",
|
95 | 96 | ]
|
96 | 97 |
|
| 98 | +act_ops = { |
| 99 | + "gelu": te.ops.GELU, |
| 100 | + "geglu": te.ops.GEGLU, |
| 101 | + "qgelu": te.ops.QGELU, |
| 102 | + "qgeglu": te.ops.QGEGLU, |
| 103 | + "relu": te.ops.ReLU, |
| 104 | + "reglu": te.ops.ReGLU, |
| 105 | + "srelu": te.ops.SReLU, |
| 106 | + "sreglu": te.ops.SReGLU, |
| 107 | + "silu": te.ops.SiLU, |
| 108 | + "swiglu": te.ops.SwiGLU, |
| 109 | +} |
| 110 | + |
| 111 | +gated_act_ops = (te.ops.GEGLU, te.ops.QGEGLU, te.ops.ReGLU, te.ops.SReGLU, te.ops.SwiGLU) |
| 112 | + |
97 | 113 | all_normalizations = ["LayerNorm", "RMSNorm"]
|
98 | 114 |
|
99 | 115 | mask_types = ["causal", "no_mask"]
|
@@ -1713,14 +1729,20 @@ def _test_grouped_linear_accuracy(
|
1713 | 1729 | fp8,
|
1714 | 1730 | fuse_wgrad_accumulation,
|
1715 | 1731 | delay_wgrad_compute=False,
|
1716 |
| - activation_func=None, # assume gated activation function |
| 1732 | + activation_func=te.ops.Identity(), |
1717 | 1733 | ):
|
1718 | 1734 | reset_rng_states()
|
1719 | 1735 | if fp8:
|
1720 | 1736 | FP8GlobalStateManager.reset()
|
1721 | 1737 |
|
1722 |
| - # assume gated activation function |
1723 |
| - hidden_size = config.hidden_size if activation_func is None else 2 * config.hidden_size |
| 1738 | + if isinstance(activation_func, gated_act_ops) or ( |
| 1739 | + isinstance(activation_func, te.ops.Sequential) |
| 1740 | + and isinstance(activation_func[0], gated_act_ops) |
| 1741 | + ): |
| 1742 | + hidden_size = 2 * config.hidden_size |
| 1743 | + else: |
| 1744 | + hidden_size = config.hidden_size |
| 1745 | + |
1724 | 1746 | inp_hidden_states = torch.randn(
|
1725 | 1747 | (config.max_seqlen_q, bs, hidden_size),
|
1726 | 1748 | dtype=dtype,
|
@@ -1993,13 +2015,15 @@ def test_grouped_linear_accuracy_single_gemm(recipe):
|
1993 | 2015 | @pytest.mark.parametrize("model", ["126m"])
|
1994 | 2016 | @pytest.mark.parametrize("recipe", [recipe.MXFP8BlockScaling()])
|
1995 | 2017 | @pytest.mark.parametrize("fp8_model_params", all_boolean)
|
| 2018 | +@pytest.mark.parametrize("activation", all_activations) |
1996 | 2019 | def test_grouped_linear_fp8_input(
|
1997 | 2020 | dtype,
|
1998 | 2021 | num_gemms,
|
1999 | 2022 | bs,
|
2000 | 2023 | model,
|
2001 | 2024 | recipe,
|
2002 | 2025 | fp8_model_params,
|
| 2026 | + activation, |
2003 | 2027 | ):
|
2004 | 2028 | config = model_configs[model]
|
2005 | 2029 | if config.max_seqlen_q % 32 != 0:
|
@@ -2039,9 +2063,9 @@ def test_grouped_linear_fp8_input(
|
2039 | 2063 | weight_i_copy = getattr(grouped_linear_fp8_input, f"weight{i}")
|
2040 | 2064 | weight_i_copy.main_grad = weight_i.main_grad.clone()
|
2041 | 2065 |
|
2042 |
| - bf16_activation = te.ops.SwiGLU() |
| 2066 | + bf16_activation = act_ops[activation]() |
2043 | 2067 | fp8_activation = te.ops.Sequential(
|
2044 |
| - te.ops.SwiGLU(), |
| 2068 | + bf16_activation, |
2045 | 2069 | te.ops.Quantize(forward=True, backward=False), # Output QuantizedTensor in forward
|
2046 | 2070 | )
|
2047 | 2071 |
|
|
0 commit comments