Skip to content

V1 muon optimizer#10618

Open
HelloWorldBeginner wants to merge 2 commits into
hiyouga:mainfrom
HelloWorldBeginner:v1-muon-optimizer
Open

V1 muon optimizer#10618
HelloWorldBeginner wants to merge 2 commits into
hiyouga:mainfrom
HelloWorldBeginner:v1-muon-optimizer

Conversation

@HelloWorldBeginner

Copy link
Copy Markdown
Contributor

What does this PR do?

support for muon

Before submitting

mhh111 added 2 commits June 30, 2026 18:30
- Register a `muon` OptimizerPlugin in v1, vendored into
  v1/plugins/trainer_plugins/muon_optimizer.py (no dependency on v0
  `llamafactory.third_party.muon`).
- Muon for 2D weight matrices (excluding embed/lm_head); built-in AdamW
  for the rest. Warns when run under FSDP2 (DTensor shards) that NS is
  approximate until a DTensor-aware variant is added.
- Add a one-time, rank0, env-gated (LLAMAFACTORY_MUON_DIAG=1) diagnostic
  in Muon.step to gather param/grad/data types needed for the DTensor-aware v2.
- Add examples/v1/train_full/train_full_muon.yaml.
- Add run_ulysses.sh and update train_full_ulysses_cp.yaml for v1 Ulysses
  sequence-parallel SFT launch.
Muon v1 ran Newton-Schulz on FSDP2's DTensor shards, which computes a partial
Gram matrix and makes the NS iteration diverge -> NaN at step 2.

v2 changes only the Muon branch of the step:
- all-gather the full gradient via g.full_tensor()
- run Newton-Schulz on the full 2D matrix
- scatter the update back to the local shard via distribute_tensor, then add
The AdamW branch is unchanged (elementwise on shards is already correct).

Trade-off: +1 all-gather +1 scatter per Muon param per step, and the momentum
buffer is stored full (doubled optimizer-state memory for Muon params).

Also:
- optimizer.py: drop the now-obsolete FSDP2 "approximate" warning and the
  unused _is_dtensor helper.
- Add tests_v1/.../test_ulysses_cp_precision.py: CP-on vs CP-off loss/grad
  agreement test.
- train_full_muon.yaml: re-enable fsdp2 (v2 supports it).

@gemini-code-assist gemini-code-assist Bot left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces the Muon optimizer plugin, which is designed to be DTensor-aware and compatible with FSDP2. It includes the implementation of the Muon optimizer with Newton-Schulz iteration, its registration as a trainer plugin, and an example configuration file. The review feedback highlights two important improvements: first, the zeropower_via_newtonschulz5 function should cast its output back to the original gradient dtype to prevent runtime errors during full-precision training; second, the parameter filtering logic should be expanded to exclude alternative embedding names (like 'wte' and 'wpe') and LoRA adapter weights from Muon optimization.

Important

The consumer version of Gemini Code Assist on GitHub is being sunset. Starting June 18, 2026, new organization installations will be blocked, and all code review activity will officially cease on July 17, 2026.
For more details on the timeline and next steps, please review the Help Documentation.

Comment on lines +85 to +87
if G.size(0) > G.size(1):
X = X.T
return X

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The zeropower_via_newtonschulz5 function casts the input tensor to bfloat16 for computation but returns the result without casting it back to the original dtype (G.dtype). If the model parameters (p.data) are in float32 (e.g., during full-precision training or when using master weights), performing the in-place addition p.data.add_(...) with a bfloat16 update tensor will raise a RuntimeError due to dtype mismatch.

Casting the returned tensor back to G.dtype ensures compatibility with the original parameter precision.

Suggested change
if G.size(0) > G.size(1):
X = X.T
return X
if G.size(0) > G.size(1):
X = X.T
return X.to(G.dtype)

Comment on lines +50 to +53
if param.ndim == 2 and "embed" not in name and "lm_head" not in name:
muon_params.append(param)
else:
adamw_params.append(param)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The current parameter filtering logic only excludes parameters containing "embed" or "lm_head" from Muon optimization. However, some models use alternative names for embeddings (such as "wte" and "wpe" in GPT-2). Additionally, if LoRA is used, the 2D adapter weights (containing "lora") will be incorrectly optimized by Muon instead of AdamW.

We should explicitly exclude "wte", "wpe", and "lora" from Muon optimization to ensure they are correctly optimized by AdamW.

            if (
                param.ndim == 2
                and "embed" not in name
                and "lm_head" not in name
                and "wte" not in name
                and "wpe" not in name
                and "lora" not in name
            ):
                muon_params.append(param)
            else:
                adamw_params.append(param)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant