2727 "CUDA not available or compute capability < 8.9" , allow_module_level = True
2828 )
2929
30+ from testing_utils import _validate_model_conversion
31+
3032from torchao .float8 .float8_utils import compute_error
31- from torchao .prototype .moe_training .conversion_utils import MoETrainingConfig
33+ from torchao .prototype .moe_training .conversion_utils import (
34+ MoEScalingType ,
35+ MoETrainingConfig ,
36+ )
3237from torchao .quantization .quant_api import quantize_
3338
34- from .testing_utils import _validate_model_conversion
35-
3639# this test requires torchtitan
3740try :
38- from torchtitan .distributed .expert_parallel import (
39- set_token_group_alignment_size_m ,
40- )
41+ from torchtitan .distributed .expert_parallel import set_token_group_alignment_size_m
4142 from torchtitan .models .moe import MoE , MoEArgs
4243except ImportError :
4344 pytest .skip (
4445 "torchtitan not installed, skipping MoE tests." , allow_module_level = True
4546 )
4647
4748
48- def test_moe_float8_training_fsdp ():
49+ @pytest .mark .parametrize (
50+ "recipe, min_out_sqnr, alignment_size, min_param_grad_sqnr" ,
51+ [
52+ (MoEScalingType .FP8_ROWWISE , 29.0 , 16 , 23.0 ),
53+ (MoEScalingType .MXFP8 , 28.0 , 32 , 21.0 ),
54+ ],
55+ )
56+ def test_moe_float8_training_fsdp (
57+ recipe : MoEScalingType ,
58+ min_out_sqnr : float ,
59+ alignment_size : int ,
60+ min_param_grad_sqnr : float ,
61+ ):
4962 assert torch .cuda .is_available ()
5063
5164 # setup distributed for fsdp
5265 setup_distributed ()
5366
54- # token group aligment size must be 16 for fp8
55- set_token_group_alignment_size_m (16 )
67+ set_token_group_alignment_size_m (alignment_size )
5668
5769 # define model args
5870 target_fqns = ["experts" ]
@@ -83,7 +95,7 @@ def moe_module_filter_fn(mod: nn.Module, cur_fqn: str) -> bool:
8395 return False
8496
8597 # quantize test model
86- config = MoETrainingConfig ()
98+ config = MoETrainingConfig (recipe )
8799 quantize_ (model , config = config , filter_fn = moe_module_filter_fn )
88100
89101 # validate that only the experts were converted
@@ -109,7 +121,6 @@ def moe_module_filter_fn(mod: nn.Module, cur_fqn: str) -> bool:
109121
110122 # validate output
111123 out_sqnr = compute_error (out , ref_out )
112- min_out_sqnr = 29.0
113124 assert out_sqnr .item () >= min_out_sqnr , (
114125 f"SQNR must be >= { min_out_sqnr } , got { out_sqnr .item ()} ."
115126 )
@@ -131,7 +142,6 @@ def moe_module_filter_fn(mod: nn.Module, cur_fqn: str) -> bool:
131142 )
132143
133144 # validate param gradients
134- min_param_grad_sqnr = 23.0
135145 for param1 , param2 in zip (model .parameters (), ref_model .parameters ()):
136146 param_grad_sqnr = compute_error (param1 .grad , param2 .grad )
137147 assert param_grad_sqnr .item () >= min_param_grad_sqnr , (
0 commit comments