Skip to content

Commit 8537883

Browse files
authored
[mxfp8 moe training] Add mxfp8 to FSDP tests (#2849)
1 parent f03a737 commit 8537883

File tree

1 file changed

+22
-12
lines changed

1 file changed

+22
-12
lines changed

test/prototype/moe_training/test_fsdp.py

Lines changed: 22 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -27,32 +27,44 @@
2727
"CUDA not available or compute capability < 8.9", allow_module_level=True
2828
)
2929

30+
from testing_utils import _validate_model_conversion
31+
3032
from torchao.float8.float8_utils import compute_error
31-
from torchao.prototype.moe_training.conversion_utils import MoETrainingConfig
33+
from torchao.prototype.moe_training.conversion_utils import (
34+
MoEScalingType,
35+
MoETrainingConfig,
36+
)
3237
from torchao.quantization.quant_api import quantize_
3338

34-
from .testing_utils import _validate_model_conversion
35-
3639
# this test requires torchtitan
3740
try:
38-
from torchtitan.distributed.expert_parallel import (
39-
set_token_group_alignment_size_m,
40-
)
41+
from torchtitan.distributed.expert_parallel import set_token_group_alignment_size_m
4142
from torchtitan.models.moe import MoE, MoEArgs
4243
except ImportError:
4344
pytest.skip(
4445
"torchtitan not installed, skipping MoE tests.", allow_module_level=True
4546
)
4647

4748

48-
def test_moe_float8_training_fsdp():
49+
@pytest.mark.parametrize(
50+
"recipe, min_out_sqnr, alignment_size, min_param_grad_sqnr",
51+
[
52+
(MoEScalingType.FP8_ROWWISE, 29.0, 16, 23.0),
53+
(MoEScalingType.MXFP8, 28.0, 32, 21.0),
54+
],
55+
)
56+
def test_moe_float8_training_fsdp(
57+
recipe: MoEScalingType,
58+
min_out_sqnr: float,
59+
alignment_size: int,
60+
min_param_grad_sqnr: float,
61+
):
4962
assert torch.cuda.is_available()
5063

5164
# setup distributed for fsdp
5265
setup_distributed()
5366

54-
# token group aligment size must be 16 for fp8
55-
set_token_group_alignment_size_m(16)
67+
set_token_group_alignment_size_m(alignment_size)
5668

5769
# define model args
5870
target_fqns = ["experts"]
@@ -83,7 +95,7 @@ def moe_module_filter_fn(mod: nn.Module, cur_fqn: str) -> bool:
8395
return False
8496

8597
# quantize test model
86-
config = MoETrainingConfig()
98+
config = MoETrainingConfig(recipe)
8799
quantize_(model, config=config, filter_fn=moe_module_filter_fn)
88100

89101
# validate that only the experts were converted
@@ -109,7 +121,6 @@ def moe_module_filter_fn(mod: nn.Module, cur_fqn: str) -> bool:
109121

110122
# validate output
111123
out_sqnr = compute_error(out, ref_out)
112-
min_out_sqnr = 29.0
113124
assert out_sqnr.item() >= min_out_sqnr, (
114125
f"SQNR must be >= {min_out_sqnr}, got {out_sqnr.item()}."
115126
)
@@ -131,7 +142,6 @@ def moe_module_filter_fn(mod: nn.Module, cur_fqn: str) -> bool:
131142
)
132143

133144
# validate param gradients
134-
min_param_grad_sqnr = 23.0
135145
for param1, param2 in zip(model.parameters(), ref_model.parameters()):
136146
param_grad_sqnr = compute_error(param1.grad, param2.grad)
137147
assert param_grad_sqnr.item() >= min_param_grad_sqnr, (

0 commit comments

Comments
 (0)