Enable MoE (Qwen3-30B-A3B) with Expert Parallelism support#87
Enable MoE (Qwen3-30B-A3B) with Expert Parallelism support#87shivam-MBZUAI wants to merge 1 commit into
Conversation
- Switch model to Qwen3-30B-A3B (30.5B total, 3.3B active params) - Add Expert Parallelism (EP) detection and validation in env.py - Add train_ep.py reference implementation with EPMoeBlock, gradient sync, and full state dict gathering via broadcast - Fix FSDP reference run: transformer_auto_wrap_policy, disable no_sync to prevent full-gradient OOM, disable gradient checkpointing - Use SGD optimizer (lr=0.1) across all strategies for bf16 compatibility - MFU calculation uses active params (3.3B) for MoE models - Update all local train files for new model and optimizer Made-with: Cursor
WalkthroughThe PR introduces expert-parallel (EP) training support for Mixture-of-Experts (MoE) models by extending parallelism configuration with Changes
Sequence DiagramsequenceDiagram
participant Trainer
participant Model as Model (Sharded<br/>Experts)
participant RouterLogits as Router & Top-K
participant LocalExperts as Local Expert<br/>Processors
participant EPAllReduce as EP AllReduce<br/>Group
participant Optimizer
Trainer->>Model: Forward Pass
Model->>RouterLogits: Compute router logits
RouterLogits->>RouterLogits: Compute routing weights<br/>(softmax + top-k)
RouterLogits->>LocalExperts: Select top-k experts<br/>& token indices
LocalExperts->>LocalExperts: Execute only local<br/>expert shards
LocalExperts->>EPAllReduce: Sum expert outputs<br/>across EP ranks
EPAllReduce->>Model: Aggregated expert output
Model->>Trainer: Logits + Router Logits
Trainer->>Trainer: Compute next-token<br/>cross-entropy loss
Trainer->>Trainer: Backward propagation
Trainer->>EPAllReduce: All-reduce replicated<br/>gradients (excl. experts)
EPAllReduce->>Optimizer: Synchronized gradients
Optimizer->>Model: SGD update with<br/>lr=0.1, momentum=0.0
Estimated code review effort🎯 4 (Complex) | ⏱️ ~75 minutes Possibly related PRs
Suggested reviewers
🚥 Pre-merge checks | ✅ 2 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (2 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches📝 Generate docstrings
🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
🧹 Nitpick comments (3)
local_test/train_ep.py (2)
171-171: Optional: Rename unused loop variablestepto_step.Per Python conventions, prefix unused variables with underscore.
Suggested fix
- for step in range(num_steps): + for _step in range(num_steps):🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@local_test/train_ep.py` at line 171, The loop variable `step` in the training loop "for step in range(num_steps):" is unused; rename it to `_step` to follow Python conventions for unused variables and silence linters — update the loop header in train_ep.py accordingly (e.g., change `for step in range(num_steps):` to use `_step`) and ensure no other references to `step` exist in the surrounding function.
123-123: Addstrict=Truetozip()for safety.The
expert_itemsandshapeslists should always have the same length since they're derived from the same source, but addingstrict=Truewould catch any unexpected mismatch.Suggested fix
- for (name, _), shape in zip(expert_items, shapes): + for (name, _), shape in zip(expert_items, shapes, strict=True):🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@local_test/train_ep.py` at line 123, The loop using zip(expert_items, shapes) should enforce length equality: update the zip call in the for loop that iterates "for (name, _), shape in zip(expert_items, shapes):" to use zip(..., strict=True) so any mismatch between expert_items and shapes raises an immediate error; modify that specific zip invocation accordingly.local_test/train.py (1)
99-99: Optional: Rename unused loop variablestepto_step.The loop variable is unused within the loop body. Renaming to
_stepfollows Python conventions for intentionally unused variables and silences the linter warning.Suggested fix
- for step in range(num_steps): + for _step in range(num_steps):🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@local_test/train.py` at line 99, The loop variable "step" in the for loop "for step in range(num_steps):" is unused and should be renamed to "_step" to follow Python conventions and silence linters; update the loop header in train.py to use "for _step in range(num_steps):" and ensure no other references to "step" exist elsewhere in that scope (adjust if necessary).
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Nitpick comments:
In `@local_test/train_ep.py`:
- Line 171: The loop variable `step` in the training loop "for step in
range(num_steps):" is unused; rename it to `_step` to follow Python conventions
for unused variables and silence linters — update the loop header in train_ep.py
accordingly (e.g., change `for step in range(num_steps):` to use `_step`) and
ensure no other references to `step` exist in the surrounding function.
- Line 123: The loop using zip(expert_items, shapes) should enforce length
equality: update the zip call in the for loop that iterates "for (name, _),
shape in zip(expert_items, shapes):" to use zip(..., strict=True) so any
mismatch between expert_items and shapes raises an immediate error; modify that
specific zip invocation accordingly.
In `@local_test/train.py`:
- Line 99: The loop variable "step" in the for loop "for step in
range(num_steps):" is unused and should be renamed to "_step" to follow Python
conventions and silence linters; update the loop header in train.py to use "for
_step in range(num_steps):" and ensure no other references to "step" exist
elsewhere in that scope (adjust if necessary).
ℹ️ Review info
⚙️ Run configuration
Configuration used: Organization UI
Review profile: CHILL
Plan: Pro
Run ID: 1c468eea-635c-4851-944c-50b656f788c8
📒 Files selected for processing (9)
environments/templar/Dockerfileenvironments/templar/env.pylocal_test/simulate_validator.pylocal_test/train.pylocal_test/train_ddp.pylocal_test/train_ep.pylocal_test/train_fsdp.pylocal_test/train_mixed.pylocal_test/train_tp.py
Made-with: Cursor
Summary by CodeRabbit
New Features
Improvements