From cfef1fb06382fd104f024f83eb1e01ac2b913926 Mon Sep 17 00:00:00 2001 From: Daniel Vega-Myhre Date: Wed, 2 Jul 2025 14:50:40 -0700 Subject: [PATCH] add tests for ep --- .../prototype/moe_training/test_everything.sh | 1 + test/prototype/moe_training/test_fsdp_tp.py | 2 + .../prototype/moe_training/test_fsdp_tp_ep.py | 263 ++++++++++++++++++ .../prototype/moe_training/test_fsdp_tp_ep.sh | 1 + 4 files changed, 267 insertions(+) create mode 100644 test/prototype/moe_training/test_fsdp_tp_ep.py create mode 100755 test/prototype/moe_training/test_fsdp_tp_ep.sh diff --git a/test/prototype/moe_training/test_everything.sh b/test/prototype/moe_training/test_everything.sh index 1a036cb7ea..b6004cfc13 100755 --- a/test/prototype/moe_training/test_everything.sh +++ b/test/prototype/moe_training/test_everything.sh @@ -15,6 +15,7 @@ then ./test/prototype/moe_training/test_fsdp.sh ./test/prototype/moe_training/test_tp.sh ./test/prototype/moe_training/test_fsdp_tp.sh +./test/prototype/moe_training/test_fsdp_tp_ep.sh fi echo "all tests successful" diff --git a/test/prototype/moe_training/test_fsdp_tp.py b/test/prototype/moe_training/test_fsdp_tp.py index 3720a3525d..709ba69574 100644 --- a/test/prototype/moe_training/test_fsdp_tp.py +++ b/test/prototype/moe_training/test_fsdp_tp.py @@ -191,6 +191,8 @@ def moe_module_filter_fn(mod: nn.Module, cur_fqn: str) -> bool: def setup_distributed(): rank = int(os.environ["RANK"]) world_size = int(os.environ["WORLD_SIZE"]) + assert world_size >= 4, "world size must be >= 4 for 2D parallel" + dist.init_process_group("nccl", rank=rank, world_size=world_size) # https://pytorch.org/tutorials/recipes/distributed_device_mesh.html diff --git a/test/prototype/moe_training/test_fsdp_tp_ep.py b/test/prototype/moe_training/test_fsdp_tp_ep.py new file mode 100644 index 0000000000..5d05310c30 --- /dev/null +++ b/test/prototype/moe_training/test_fsdp_tp_ep.py @@ -0,0 +1,263 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD 3-Clause license found in the +# LICENSE file in the root directory of this source tree. +###################################################################### +# +# To run these unit tests, use the following command: +# +# torchrun --nproc_per_node=${NUM_GPUS} -m pytest test_fsdp_tp_ep.py +# +####################################################################### + +import copy +import os + +import pytest +import torch +from torch import distributed as dist +from torch import nn +from torch.distributed._composable.fsdp import fully_shard +from torch.distributed._tensor import DTensor +from torch.distributed.device_mesh import DeviceMesh, init_device_mesh +from torch.distributed.tensor import Partial, Replicate, Shard +from torch.nn import functional as F + +try: + from torch.distributed.tensor.parallel import ( + PrepareModuleInputOutput, + parallelize_module, + ) +except ImportError: + import warnings + + warnings.warn( + "torch version is too old, these tests require nightly build. Skipping MoE training tests." + ) + pytest.skip(allow_module_level=True) + +# this feature requires CUDA and SM89+ +if not torch.cuda.is_available() or torch.cuda.get_device_capability() < (8, 9): + pytest.skip( + "CUDA not available or compute capability < 8.9", allow_module_level=True + ) + +from torchao.float8.float8_utils import compute_error +from torchao.prototype.moe_training.conversion_utils import MoETrainingConfig +from torchao.quantization.quant_api import quantize_ + +from .testing_utils import _validate_model_conversion + +# this test requires torchtitan +try: + from torchtitan.experiments.llama4.infra.expert_parallel import ( + ExpertParallel, + ExpertTensorParallel, + NoParallel, + TensorParallel, + ) + from torchtitan.experiments.llama4.model.args import TransformerModelArgs + from torchtitan.experiments.llama4.model.moe import MoE +except ImportError: + import warnings + + warnings.warn("torchtitan not installed, skipping MoE tests.") + pytest.skip(allow_module_level=True) + + +@pytest.mark.parametrize( + "target_fqns", + [ + ["experts"], + # TODO: investigate hang when shared_expert is converted + # ["experts,shared_expert"], + ], +) +def test_moe_float8_training_fsdp_tp_ep(target_fqns: list[str]): + assert torch.cuda.is_available() + + # setup distributed for tp + mesh = setup_distributed() + + # define model args + model_args = TransformerModelArgs( + moe_enabled=True, + num_experts=8, + dim=256, + vocab_size=1024, + ) + init_std = 0.02 + device = torch.device("cuda") + + # reference bf16 MoE + ref_model = MoE(model_args).to(torch.bfloat16).cuda() + torch.manual_seed(1) + ref_model.init_weights(init_std, device) + + # target MoE for testing conversion + model = copy.deepcopy(ref_model) + + # assert starting params are identical for both models + for param1, param2 in zip(model.parameters(), ref_model.parameters()): + assert torch.equal(param1, param2) + + # convert MoE to float8 training + def moe_module_filter_fn(mod: nn.Module, cur_fqn: str) -> bool: + for target_fqn in target_fqns: + if target_fqn in cur_fqn: + return True + return False + + # quantize test model + config = MoETrainingConfig() + quantize_(model, config=config, filter_fn=moe_module_filter_fn) + + # validate that only the experts were converted + _validate_model_conversion( + model, + target_fqns=target_fqns, + ) + + # apply TP and EP + apply_moe_ep_tp( + model, tp_mesh=mesh["tp"], ep_mesh=mesh["ep"], ep_tp_mesh=mesh["ep", "tp"] + ) + apply_moe_ep_tp( + ref_model, tp_mesh=mesh["tp"], ep_mesh=mesh["ep"], ep_tp_mesh=mesh["ep", "tp"] + ) + + # apply FSDP2 + fsdp_config = {"mesh": mesh["dp"]} + fully_shard(model, **fsdp_config) + fully_shard(ref_model, **fsdp_config) + + # Rough validation that parallelization was applied properly. + assert isinstance(model.experts.w1.data, DTensor), ( + "test model experts.w1 is not a DTensor" + ) + assert isinstance(model.experts.w2.data, DTensor), ( + "test model experts.w2 is not a DTensor" + ) + assert isinstance(model.experts.w3.data, DTensor), ( + "test model experts.w3 is not a DTensor" + ) + assert isinstance(ref_model.experts.w1.data, DTensor), ( + "ref model experts.w1 is not a DTensor" + ) + assert isinstance(ref_model.experts.w2.data, DTensor), ( + "ref model experts.w2 is not a DTensor" + ) + assert isinstance(ref_model.experts.w3.data, DTensor), ( + "ref model experts.w3 is not a DTensor" + ) + + # inputs + batch, seq, dim = 8, 2048, 256 + ref_x = torch.randn( + batch, seq, dim, dtype=torch.bfloat16, requires_grad=True, device=device + ) + x = ref_x.detach().clone().requires_grad_(True) + + # forward pass + ref_out = ref_model(ref_x) + out = model(x) + + # validate output + out_sqnr = compute_error(out, ref_out) + assert out_sqnr.item() >= 30.0, f"SQNR must be >= 30.0, got {out_sqnr.item()}." + + # compute loss + labels = torch.ones_like(ref_out) + ref_loss = F.mse_loss(ref_out, labels) + out_loss = F.mse_loss(out, labels) + + # backward pass + ref_loss.backward() + out_loss.backward() + + # validate input gradient + input_grad_sqnr = compute_error(x.grad, ref_x.grad) + assert input_grad_sqnr.item() >= 28.0, ( + f"SQNR must be >= 28.0, got {input_grad_sqnr.item()}." + ) + + # validate param gradients + for param1, param2 in zip(model.parameters(), ref_model.parameters()): + param_grad_sqnr = compute_error(param1.grad, param2.grad) + assert param_grad_sqnr.item() >= 25.0, ( + f"SQNR must be >= 25.0, got {param_grad_sqnr.item()}." + ) + + dist.destroy_process_group() + + +def setup_distributed(): + rank = int(os.environ["RANK"]) + world_size = int(os.environ["WORLD_SIZE"]) + assert world_size == 8, "world size must be == 8 for 3D parallel test" + + dist.init_process_group("nccl", rank=rank, world_size=world_size) + + # https://pytorch.org/tutorials/recipes/distributed_device_mesh.html + device_mesh = init_device_mesh( + "cuda", + (2, 2, 2), + mesh_dim_names=("dp", "ep", "tp"), + ) + + # seed must be the same in all processes + torch.manual_seed(1) + torch.cuda.set_device(rank) + return device_mesh + + +def apply_moe_ep_tp( + model: nn.Module, + tp_mesh: DeviceMesh | None, + ep_mesh: DeviceMesh | None, + ep_tp_mesh: DeviceMesh | None, +): + # Modified version of moe parallelization from https://github.com/pytorch/torchtitan/pull/1324/ + # that supports single MoE layer independent of a transformer. + if tp_mesh is not None: + moe_layer_plan = { + # input / output sharding on the seqlen dim + # all-gather for input, reduce-scatter for output + "moe": PrepareModuleInputOutput( + input_layouts=(Shard(1),), + desired_input_layouts=(Replicate(),), + use_local_input=True, + output_layouts=(Partial(),), + desired_output_layouts=(Shard(1),), + ), + # replicate computation for the router + "moe.router.gate": NoParallel(), + # input Replicate, output Partial + "moe.shared_expert": TensorParallel(), + } + parallelize_module( + module=model, + device_mesh=tp_mesh, + parallelize_plan=moe_layer_plan, + ) + + # if ep_mesh is not None: + experts_mesh, experts_plan = None, None + if ep_mesh is None: + experts_mesh = tp_mesh + # input Replicate, output Partial + experts_plan = TensorParallel() + elif tp_mesh is None: + experts_mesh = ep_mesh + # input / output sharding on the batch / tokens dim + experts_plan = ExpertParallel() + else: + experts_mesh = ep_tp_mesh + experts_plan = ExpertTensorParallel(tp_mesh=tp_mesh, ep_mesh=ep_mesh) + + parallelize_module( + module=model.experts, + device_mesh=experts_mesh, + parallelize_plan=experts_plan, + ) diff --git a/test/prototype/moe_training/test_fsdp_tp_ep.sh b/test/prototype/moe_training/test_fsdp_tp_ep.sh new file mode 100755 index 0000000000..dba02b2c0a --- /dev/null +++ b/test/prototype/moe_training/test_fsdp_tp_ep.sh @@ -0,0 +1 @@ +torchrun --nproc_per_node=8 --local-ranks-filter=0 -m pytest test/prototype/moe_training/test_fsdp_tp_ep.py -s