|
| 1 | +# Copyright (c) Meta Platforms, Inc. and affiliates. |
| 2 | +# All rights reserved. |
| 3 | +# |
| 4 | +# This source code is licensed under the BSD 3-Clause license found in the |
| 5 | +# LICENSE file in the root directory of this source tree. |
| 6 | +###################################################################### |
| 7 | +# |
| 8 | +# To run these unit tests, use the following command: |
| 9 | +# |
| 10 | +# torchrun --nproc_per_node=${NUM_GPUS} -m pytest test_fsdp_tp.py |
| 11 | +# |
| 12 | +####################################################################### |
| 13 | + |
| 14 | +import copy |
| 15 | +import os |
| 16 | + |
| 17 | +import pytest |
| 18 | +import torch |
| 19 | +from torch import distributed as dist |
| 20 | +from torch import nn |
| 21 | +from torch.distributed._composable.fsdp import fully_shard |
| 22 | +from torch.distributed._tensor import DTensor |
| 23 | +from torch.distributed.device_mesh import DeviceMesh, init_device_mesh |
| 24 | +from torch.distributed.tensor import Partial, Replicate, Shard |
| 25 | +from torch.nn import functional as F |
| 26 | + |
| 27 | +try: |
| 28 | + from torch.distributed.tensor.parallel import ( |
| 29 | + PrepareModuleInputOutput, |
| 30 | + parallelize_module, |
| 31 | + ) |
| 32 | +except ImportError: |
| 33 | + import warnings |
| 34 | + |
| 35 | + warnings.warn( |
| 36 | + "torch version is too old, these tests require nightly build. Skipping MoE training tests." |
| 37 | + ) |
| 38 | + pytest.skip(allow_module_level=True) |
| 39 | + |
| 40 | +# this feature requires CUDA and SM89+ |
| 41 | +if not torch.cuda.is_available() or torch.cuda.get_device_capability() < (8, 9): |
| 42 | + pytest.skip( |
| 43 | + "CUDA not available or compute capability < 8.9", allow_module_level=True |
| 44 | + ) |
| 45 | + |
| 46 | +from torchao.float8.float8_utils import compute_error |
| 47 | +from torchao.prototype.moe_training.conversion_utils import MoETrainingConfig |
| 48 | +from torchao.quantization.quant_api import quantize_ |
| 49 | + |
| 50 | +from .testing_utils import _validate_model_conversion |
| 51 | + |
| 52 | +# this test requires torchtitan |
| 53 | +try: |
| 54 | + from torchtitan.experiments.llama4.infra.expert_parallel import ( |
| 55 | + ExpertParallel, |
| 56 | + ExpertTensorParallel, |
| 57 | + NoParallel, |
| 58 | + TensorParallel, |
| 59 | + ) |
| 60 | + from torchtitan.experiments.llama4.model.args import TransformerModelArgs |
| 61 | + from torchtitan.experiments.llama4.model.moe import MoE |
| 62 | +except ImportError: |
| 63 | + import warnings |
| 64 | + |
| 65 | + warnings.warn("torchtitan not installed, skipping MoE tests.") |
| 66 | + pytest.skip(allow_module_level=True) |
| 67 | + |
| 68 | + |
| 69 | +@pytest.mark.parametrize( |
| 70 | + "target_fqns", |
| 71 | + [ |
| 72 | + ["experts"], |
| 73 | + # TODO: investigate hang when shared_expert is converted |
| 74 | + # ["experts,shared_expert"], |
| 75 | + ], |
| 76 | +) |
| 77 | +def test_moe_float8_training_fsdp_tp(target_fqns: list[str]): |
| 78 | + assert torch.cuda.is_available() |
| 79 | + |
| 80 | + # setup distributed for tp |
| 81 | + mesh = setup_distributed() |
| 82 | + |
| 83 | + # define model args |
| 84 | + model_args = TransformerModelArgs( |
| 85 | + moe_enabled=True, |
| 86 | + num_experts=8, |
| 87 | + dim=256, |
| 88 | + vocab_size=1024, |
| 89 | + ) |
| 90 | + init_std = 0.02 |
| 91 | + device = torch.device("cuda") |
| 92 | + |
| 93 | + # reference bf16 MoE |
| 94 | + ref_model = MoE(model_args).to(torch.bfloat16).cuda() |
| 95 | + torch.manual_seed(1) |
| 96 | + ref_model.init_weights(init_std, device) |
| 97 | + |
| 98 | + # target MoE for testing conversion |
| 99 | + model = copy.deepcopy(ref_model) |
| 100 | + |
| 101 | + # assert starting params are identical for both models |
| 102 | + for param1, param2 in zip(model.parameters(), ref_model.parameters()): |
| 103 | + assert torch.equal(param1, param2) |
| 104 | + |
| 105 | + # convert MoE to float8 training |
| 106 | + def moe_module_filter_fn(mod: nn.Module, cur_fqn: str) -> bool: |
| 107 | + for target_fqn in target_fqns: |
| 108 | + if target_fqn in cur_fqn: |
| 109 | + return True |
| 110 | + return False |
| 111 | + |
| 112 | + # quantize test model |
| 113 | + config = MoETrainingConfig() |
| 114 | + quantize_(model, config=config, filter_fn=moe_module_filter_fn) |
| 115 | + |
| 116 | + # validate that only the experts were converted |
| 117 | + _validate_model_conversion( |
| 118 | + model, |
| 119 | + target_fqns=target_fqns, |
| 120 | + ) |
| 121 | + |
| 122 | + # apply TP |
| 123 | + apply_moe_ep_tp(model, tp_mesh=mesh["tp"], ep_mesh=None, ep_tp_mesh=None) |
| 124 | + apply_moe_ep_tp(ref_model, tp_mesh=mesh["tp"], ep_mesh=None, ep_tp_mesh=None) |
| 125 | + |
| 126 | + # apply FSDP2 |
| 127 | + fsdp_config = {"mesh": mesh["dp"]} |
| 128 | + fully_shard(model, **fsdp_config) |
| 129 | + fully_shard(ref_model, **fsdp_config) |
| 130 | + |
| 131 | + # Rough validation that parallelization was applied properly. |
| 132 | + assert isinstance(model.experts.w1.data, DTensor), ( |
| 133 | + "test model experts.w1 is not a DTensor" |
| 134 | + ) |
| 135 | + assert isinstance(model.experts.w2.data, DTensor), ( |
| 136 | + "test model experts.w2 is not a DTensor" |
| 137 | + ) |
| 138 | + assert isinstance(model.experts.w3.data, DTensor), ( |
| 139 | + "test model experts.w3 is not a DTensor" |
| 140 | + ) |
| 141 | + assert isinstance(ref_model.experts.w1.data, DTensor), ( |
| 142 | + "ref model experts.w1 is not a DTensor" |
| 143 | + ) |
| 144 | + assert isinstance(ref_model.experts.w2.data, DTensor), ( |
| 145 | + "ref model experts.w2 is not a DTensor" |
| 146 | + ) |
| 147 | + assert isinstance(ref_model.experts.w3.data, DTensor), ( |
| 148 | + "ref model experts.w3 is not a DTensor" |
| 149 | + ) |
| 150 | + |
| 151 | + # inputs |
| 152 | + batch, seq, dim = 8, 2048, 256 |
| 153 | + ref_x = torch.randn( |
| 154 | + batch, seq, dim, dtype=torch.bfloat16, requires_grad=True, device=device |
| 155 | + ) |
| 156 | + x = ref_x.detach().clone().requires_grad_(True) |
| 157 | + |
| 158 | + # forward pass |
| 159 | + ref_out = ref_model(ref_x) |
| 160 | + out = model(x) |
| 161 | + |
| 162 | + # validate output |
| 163 | + out_sqnr = compute_error(out, ref_out) |
| 164 | + assert out_sqnr.item() >= 30.0, f"SQNR must be >= 30.0, got {out_sqnr.item()}." |
| 165 | + |
| 166 | + # compute loss |
| 167 | + labels = torch.ones_like(ref_out) |
| 168 | + ref_loss = F.mse_loss(ref_out, labels) |
| 169 | + out_loss = F.mse_loss(out, labels) |
| 170 | + |
| 171 | + # backward pass |
| 172 | + ref_loss.backward() |
| 173 | + out_loss.backward() |
| 174 | + |
| 175 | + # validate input gradient |
| 176 | + input_grad_sqnr = compute_error(x.grad, ref_x.grad) |
| 177 | + assert input_grad_sqnr.item() >= 28.0, ( |
| 178 | + f"SQNR must be >= 28.0, got {input_grad_sqnr.item()}." |
| 179 | + ) |
| 180 | + |
| 181 | + # validate param gradients |
| 182 | + for param1, param2 in zip(model.parameters(), ref_model.parameters()): |
| 183 | + param_grad_sqnr = compute_error(param1.grad, param2.grad) |
| 184 | + assert param_grad_sqnr.item() >= 25.0, ( |
| 185 | + f"SQNR must be >= 25.0, got {param_grad_sqnr.item()}." |
| 186 | + ) |
| 187 | + |
| 188 | + dist.destroy_process_group() |
| 189 | + |
| 190 | + |
| 191 | +def setup_distributed(): |
| 192 | + rank = int(os.environ["RANK"]) |
| 193 | + world_size = int(os.environ["WORLD_SIZE"]) |
| 194 | + dist.init_process_group("nccl", rank=rank, world_size=world_size) |
| 195 | + |
| 196 | + # https://pytorch.org/tutorials/recipes/distributed_device_mesh.html |
| 197 | + device_mesh = init_device_mesh( |
| 198 | + "cuda", |
| 199 | + (world_size // 2, 2), |
| 200 | + mesh_dim_names=("dp", "tp"), |
| 201 | + ) |
| 202 | + |
| 203 | + # seed must be the same in all processes |
| 204 | + torch.manual_seed(1) |
| 205 | + torch.cuda.set_device(rank) |
| 206 | + return device_mesh |
| 207 | + |
| 208 | + |
| 209 | +def apply_moe_ep_tp( |
| 210 | + model: nn.Module, |
| 211 | + tp_mesh: DeviceMesh | None, |
| 212 | + ep_mesh: DeviceMesh | None, |
| 213 | + ep_tp_mesh: DeviceMesh | None, |
| 214 | +): |
| 215 | + # Modified version of moe parallelization from https://github.com/pytorch/torchtitan/pull/1324/ |
| 216 | + # that supports single MoE layer independent of a transformer. |
| 217 | + if tp_mesh is not None: |
| 218 | + moe_layer_plan = { |
| 219 | + # input / output sharding on the seqlen dim |
| 220 | + # all-gather for input, reduce-scatter for output |
| 221 | + "moe": PrepareModuleInputOutput( |
| 222 | + input_layouts=(Shard(1),), |
| 223 | + desired_input_layouts=(Replicate(),), |
| 224 | + use_local_input=True, |
| 225 | + output_layouts=(Partial(),), |
| 226 | + desired_output_layouts=(Shard(1),), |
| 227 | + ), |
| 228 | + # replicate computation for the router |
| 229 | + "moe.router.gate": NoParallel(), |
| 230 | + # input Replicate, output Partial |
| 231 | + "moe.shared_expert": TensorParallel(), |
| 232 | + } |
| 233 | + parallelize_module( |
| 234 | + module=model, |
| 235 | + device_mesh=tp_mesh, |
| 236 | + parallelize_plan=moe_layer_plan, |
| 237 | + ) |
| 238 | + |
| 239 | + # if ep_mesh is not None: |
| 240 | + experts_mesh, experts_plan = None, None |
| 241 | + if ep_mesh is None: |
| 242 | + experts_mesh = tp_mesh |
| 243 | + # input Replicate, output Partial |
| 244 | + experts_plan = TensorParallel() |
| 245 | + elif tp_mesh is None: |
| 246 | + experts_mesh = ep_mesh |
| 247 | + # input / output sharding on the batch / tokens dim |
| 248 | + experts_plan = ExpertParallel() |
| 249 | + else: |
| 250 | + experts_mesh = ep_tp_mesh |
| 251 | + experts_plan = ExpertTensorParallel(tp_mesh=tp_mesh, ep_mesh=ep_mesh) |
| 252 | + |
| 253 | + parallelize_module( |
| 254 | + module=model.experts, |
| 255 | + device_mesh=experts_mesh, |
| 256 | + parallelize_plan=experts_plan, |
| 257 | + ) |
0 commit comments