Skip to content

Commit 80c88fc

Browse files
add fsdp+tp tests for moe training
1 parent 6821971 commit 80c88fc

File tree

4 files changed

+291
-0
lines changed

4 files changed

+291
-0
lines changed
Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD 3-Clause license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
#!/bin/bash
7+
8+
# terminate script on first error
9+
set -e
10+
IS_ROCM=$(rocm-smi --version || true)
11+
12+
# These tests do not work on ROCm yet
13+
if [ -z "$IS_ROCM" ]
14+
then
15+
./test/prototype/moe_training/test_fsdp.sh
16+
./test/prototype/moe_training/test_tp.sh
17+
./test/prototype/moe_training/test_fsdp_tp.sh
18+
fi
19+
20+
echo "all tests successful"

test/prototype/moe_training/test_fsdp.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,16 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD 3-Clause license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
######################################################################
7+
#
8+
# To run these unit tests, use the following command:
9+
#
10+
# torchrun --nproc_per_node=${NUM_GPUS} -m pytest test_fsdp.py
11+
#
12+
#######################################################################
13+
114
import copy
215
import os
316

Lines changed: 257 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,257 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD 3-Clause license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
######################################################################
7+
#
8+
# To run these unit tests, use the following command:
9+
#
10+
# torchrun --nproc_per_node=${NUM_GPUS} -m pytest test_fsdp_tp.py
11+
#
12+
#######################################################################
13+
14+
import copy
15+
import os
16+
17+
import pytest
18+
import torch
19+
from torch import distributed as dist
20+
from torch import nn
21+
from torch.distributed._composable.fsdp import fully_shard
22+
from torch.distributed._tensor import DTensor
23+
from torch.distributed.device_mesh import DeviceMesh, init_device_mesh
24+
from torch.distributed.tensor import Partial, Replicate, Shard
25+
from torch.nn import functional as F
26+
27+
try:
28+
from torch.distributed.tensor.parallel import (
29+
PrepareModuleInputOutput,
30+
parallelize_module,
31+
)
32+
except ImportError:
33+
import warnings
34+
35+
warnings.warn(
36+
"torch version is too old, these tests require nightly build. Skipping MoE training tests."
37+
)
38+
pytest.skip(allow_module_level=True)
39+
40+
# this feature requires CUDA and SM89+
41+
if not torch.cuda.is_available() or torch.cuda.get_device_capability() < (8, 9):
42+
pytest.skip(
43+
"CUDA not available or compute capability < 8.9", allow_module_level=True
44+
)
45+
46+
from torchao.float8.float8_utils import compute_error
47+
from torchao.prototype.moe_training.conversion_utils import MoETrainingConfig
48+
from torchao.quantization.quant_api import quantize_
49+
50+
from .testing_utils import _validate_model_conversion
51+
52+
# this test requires torchtitan
53+
try:
54+
from torchtitan.experiments.llama4.infra.expert_parallel import (
55+
ExpertParallel,
56+
ExpertTensorParallel,
57+
NoParallel,
58+
TensorParallel,
59+
)
60+
from torchtitan.experiments.llama4.model.args import TransformerModelArgs
61+
from torchtitan.experiments.llama4.model.moe import MoE
62+
except ImportError:
63+
import warnings
64+
65+
warnings.warn("torchtitan not installed, skipping MoE tests.")
66+
pytest.skip(allow_module_level=True)
67+
68+
69+
@pytest.mark.parametrize(
70+
"target_fqns",
71+
[
72+
["experts"],
73+
# TODO: investigate hang when shared_expert is converted
74+
# ["experts,shared_expert"],
75+
],
76+
)
77+
def test_moe_float8_training_fsdp_tp(target_fqns: list[str]):
78+
assert torch.cuda.is_available()
79+
80+
# setup distributed for tp
81+
mesh = setup_distributed()
82+
83+
# define model args
84+
model_args = TransformerModelArgs(
85+
moe_enabled=True,
86+
num_experts=8,
87+
dim=256,
88+
vocab_size=1024,
89+
)
90+
init_std = 0.02
91+
device = torch.device("cuda")
92+
93+
# reference bf16 MoE
94+
ref_model = MoE(model_args).to(torch.bfloat16).cuda()
95+
torch.manual_seed(1)
96+
ref_model.init_weights(init_std, device)
97+
98+
# target MoE for testing conversion
99+
model = copy.deepcopy(ref_model)
100+
101+
# assert starting params are identical for both models
102+
for param1, param2 in zip(model.parameters(), ref_model.parameters()):
103+
assert torch.equal(param1, param2)
104+
105+
# convert MoE to float8 training
106+
def moe_module_filter_fn(mod: nn.Module, cur_fqn: str) -> bool:
107+
for target_fqn in target_fqns:
108+
if target_fqn in cur_fqn:
109+
return True
110+
return False
111+
112+
# quantize test model
113+
config = MoETrainingConfig()
114+
quantize_(model, config=config, filter_fn=moe_module_filter_fn)
115+
116+
# validate that only the experts were converted
117+
_validate_model_conversion(
118+
model,
119+
target_fqns=target_fqns,
120+
)
121+
122+
# apply TP
123+
apply_moe_ep_tp(model, tp_mesh=mesh["tp"], ep_mesh=None, ep_tp_mesh=None)
124+
apply_moe_ep_tp(ref_model, tp_mesh=mesh["tp"], ep_mesh=None, ep_tp_mesh=None)
125+
126+
# apply FSDP2
127+
fsdp_config = {"mesh": mesh["dp"]}
128+
fully_shard(model, **fsdp_config)
129+
fully_shard(ref_model, **fsdp_config)
130+
131+
# Rough validation that parallelization was applied properly.
132+
assert isinstance(model.experts.w1.data, DTensor), (
133+
"test model experts.w1 is not a DTensor"
134+
)
135+
assert isinstance(model.experts.w2.data, DTensor), (
136+
"test model experts.w2 is not a DTensor"
137+
)
138+
assert isinstance(model.experts.w3.data, DTensor), (
139+
"test model experts.w3 is not a DTensor"
140+
)
141+
assert isinstance(ref_model.experts.w1.data, DTensor), (
142+
"ref model experts.w1 is not a DTensor"
143+
)
144+
assert isinstance(ref_model.experts.w2.data, DTensor), (
145+
"ref model experts.w2 is not a DTensor"
146+
)
147+
assert isinstance(ref_model.experts.w3.data, DTensor), (
148+
"ref model experts.w3 is not a DTensor"
149+
)
150+
151+
# inputs
152+
batch, seq, dim = 8, 2048, 256
153+
ref_x = torch.randn(
154+
batch, seq, dim, dtype=torch.bfloat16, requires_grad=True, device=device
155+
)
156+
x = ref_x.detach().clone().requires_grad_(True)
157+
158+
# forward pass
159+
ref_out = ref_model(ref_x)
160+
out = model(x)
161+
162+
# validate output
163+
out_sqnr = compute_error(out, ref_out)
164+
assert out_sqnr.item() >= 30.0, f"SQNR must be >= 30.0, got {out_sqnr.item()}."
165+
166+
# compute loss
167+
labels = torch.ones_like(ref_out)
168+
ref_loss = F.mse_loss(ref_out, labels)
169+
out_loss = F.mse_loss(out, labels)
170+
171+
# backward pass
172+
ref_loss.backward()
173+
out_loss.backward()
174+
175+
# validate input gradient
176+
input_grad_sqnr = compute_error(x.grad, ref_x.grad)
177+
assert input_grad_sqnr.item() >= 28.0, (
178+
f"SQNR must be >= 28.0, got {input_grad_sqnr.item()}."
179+
)
180+
181+
# validate param gradients
182+
for param1, param2 in zip(model.parameters(), ref_model.parameters()):
183+
param_grad_sqnr = compute_error(param1.grad, param2.grad)
184+
assert param_grad_sqnr.item() >= 25.0, (
185+
f"SQNR must be >= 25.0, got {param_grad_sqnr.item()}."
186+
)
187+
188+
dist.destroy_process_group()
189+
190+
191+
def setup_distributed():
192+
rank = int(os.environ["RANK"])
193+
world_size = int(os.environ["WORLD_SIZE"])
194+
dist.init_process_group("nccl", rank=rank, world_size=world_size)
195+
196+
# https://pytorch.org/tutorials/recipes/distributed_device_mesh.html
197+
device_mesh = init_device_mesh(
198+
"cuda",
199+
(world_size // 2, 2),
200+
mesh_dim_names=("dp", "tp"),
201+
)
202+
203+
# seed must be the same in all processes
204+
torch.manual_seed(1)
205+
torch.cuda.set_device(rank)
206+
return device_mesh
207+
208+
209+
def apply_moe_ep_tp(
210+
model: nn.Module,
211+
tp_mesh: DeviceMesh | None,
212+
ep_mesh: DeviceMesh | None,
213+
ep_tp_mesh: DeviceMesh | None,
214+
):
215+
# Modified version of moe parallelization from https://github.com/pytorch/torchtitan/pull/1324/
216+
# that supports single MoE layer independent of a transformer.
217+
if tp_mesh is not None:
218+
moe_layer_plan = {
219+
# input / output sharding on the seqlen dim
220+
# all-gather for input, reduce-scatter for output
221+
"moe": PrepareModuleInputOutput(
222+
input_layouts=(Shard(1),),
223+
desired_input_layouts=(Replicate(),),
224+
use_local_input=True,
225+
output_layouts=(Partial(),),
226+
desired_output_layouts=(Shard(1),),
227+
),
228+
# replicate computation for the router
229+
"moe.router.gate": NoParallel(),
230+
# input Replicate, output Partial
231+
"moe.shared_expert": TensorParallel(),
232+
}
233+
parallelize_module(
234+
module=model,
235+
device_mesh=tp_mesh,
236+
parallelize_plan=moe_layer_plan,
237+
)
238+
239+
# if ep_mesh is not None:
240+
experts_mesh, experts_plan = None, None
241+
if ep_mesh is None:
242+
experts_mesh = tp_mesh
243+
# input Replicate, output Partial
244+
experts_plan = TensorParallel()
245+
elif tp_mesh is None:
246+
experts_mesh = ep_mesh
247+
# input / output sharding on the batch / tokens dim
248+
experts_plan = ExpertParallel()
249+
else:
250+
experts_mesh = ep_tp_mesh
251+
experts_plan = ExpertTensorParallel(tp_mesh=tp_mesh, ep_mesh=ep_mesh)
252+
253+
parallelize_module(
254+
module=model.experts,
255+
device_mesh=experts_mesh,
256+
parallelize_plan=experts_plan,
257+
)
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
torchrun --nproc_per_node=4 --local-ranks-filter=0 -m pytest test/prototype/moe_training/test_fsdp_tp.py -s

0 commit comments

Comments
 (0)