V1 muon optimizer#10618
Conversation
- Register a `muon` OptimizerPlugin in v1, vendored into v1/plugins/trainer_plugins/muon_optimizer.py (no dependency on v0 `llamafactory.third_party.muon`). - Muon for 2D weight matrices (excluding embed/lm_head); built-in AdamW for the rest. Warns when run under FSDP2 (DTensor shards) that NS is approximate until a DTensor-aware variant is added. - Add a one-time, rank0, env-gated (LLAMAFACTORY_MUON_DIAG=1) diagnostic in Muon.step to gather param/grad/data types needed for the DTensor-aware v2. - Add examples/v1/train_full/train_full_muon.yaml. - Add run_ulysses.sh and update train_full_ulysses_cp.yaml for v1 Ulysses sequence-parallel SFT launch.
Muon v1 ran Newton-Schulz on FSDP2's DTensor shards, which computes a partial Gram matrix and makes the NS iteration diverge -> NaN at step 2. v2 changes only the Muon branch of the step: - all-gather the full gradient via g.full_tensor() - run Newton-Schulz on the full 2D matrix - scatter the update back to the local shard via distribute_tensor, then add The AdamW branch is unchanged (elementwise on shards is already correct). Trade-off: +1 all-gather +1 scatter per Muon param per step, and the momentum buffer is stored full (doubled optimizer-state memory for Muon params). Also: - optimizer.py: drop the now-obsolete FSDP2 "approximate" warning and the unused _is_dtensor helper. - Add tests_v1/.../test_ulysses_cp_precision.py: CP-on vs CP-off loss/grad agreement test. - train_full_muon.yaml: re-enable fsdp2 (v2 supports it).
There was a problem hiding this comment.
Code Review
This pull request introduces the Muon optimizer plugin, which is designed to be DTensor-aware and compatible with FSDP2. It includes the implementation of the Muon optimizer with Newton-Schulz iteration, its registration as a trainer plugin, and an example configuration file. The review feedback highlights two important improvements: first, the zeropower_via_newtonschulz5 function should cast its output back to the original gradient dtype to prevent runtime errors during full-precision training; second, the parameter filtering logic should be expanded to exclude alternative embedding names (like 'wte' and 'wpe') and LoRA adapter weights from Muon optimization.
Important
The consumer version of Gemini Code Assist on GitHub is being sunset. Starting June 18, 2026, new organization installations will be blocked, and all code review activity will officially cease on July 17, 2026.
For more details on the timeline and next steps, please review the Help Documentation.
| if G.size(0) > G.size(1): | ||
| X = X.T | ||
| return X |
There was a problem hiding this comment.
The zeropower_via_newtonschulz5 function casts the input tensor to bfloat16 for computation but returns the result without casting it back to the original dtype (G.dtype). If the model parameters (p.data) are in float32 (e.g., during full-precision training or when using master weights), performing the in-place addition p.data.add_(...) with a bfloat16 update tensor will raise a RuntimeError due to dtype mismatch.
Casting the returned tensor back to G.dtype ensures compatibility with the original parameter precision.
| if G.size(0) > G.size(1): | |
| X = X.T | |
| return X | |
| if G.size(0) > G.size(1): | |
| X = X.T | |
| return X.to(G.dtype) |
| if param.ndim == 2 and "embed" not in name and "lm_head" not in name: | ||
| muon_params.append(param) | ||
| else: | ||
| adamw_params.append(param) |
There was a problem hiding this comment.
The current parameter filtering logic only excludes parameters containing "embed" or "lm_head" from Muon optimization. However, some models use alternative names for embeddings (such as "wte" and "wpe" in GPT-2). Additionally, if LoRA is used, the 2D adapter weights (containing "lora") will be incorrectly optimized by Muon instead of AdamW.
We should explicitly exclude "wte", "wpe", and "lora" from Muon optimization to ensure they are correctly optimized by AdamW.
if (
param.ndim == 2
and "embed" not in name
and "lm_head" not in name
and "wte" not in name
and "wpe" not in name
and "lora" not in name
):
muon_params.append(param)
else:
adamw_params.append(param)
What does this PR do?
support for muon
Before submitting