Skip to content

Conversation

@shuningjin
Copy link
Collaborator

@shuningjin shuningjin commented Oct 25, 2025

Description

Prerequisite: Need latest changes in optax.contrib.muon

# Install the specific commit
pip install git+https://github.com/google-deepmind/optax@9858013795e22958fc2b318fb59f254bf700b10e

or

# uninstall the old version first
pip uninstall optax
# Install the latest 'main' branch 
pip install git+https://github.com/google-deepmind/optax

What this PR does

3D and 4D parameters are logically 2D. Use the MuonDimensionNumber (mdn) for reshaping specification.

  • Note: reduction_dim - in feature, output_dim - out feature, the rest dims are batch over. dims can be negative number, e.g., 0 is 0th dim, -1 is the last dim, -2 is second to last. dims grouped together are flatten.
  • e.g., decoder.moe_layers.DeepSeekMoeBlock_0.MoeBlock_0.wo, (num_experts, num_layer, base_moe_mlp_dim, base_emb_dim), reduction_dim = (-2,), output_dim = (-1,)
  • e.g, decoder.moe_layers.self_attention.out.kernel, (base_num_query_heads, num_layer, v_head_dim, base_emb_dim), reduction = (0, -2), outputdim = (-1,)
  • As muon is designed for 2D we do not apply muon to scalar (norm, bias). Additionally, we do not apply it for embedding and unembedding, as previous work suggests this is empirically better

To generalize the reshaping

  • Given the abstract model param, we extract a static tree of mdn, using name-based rules.
  • (Alternatively, we can pass a callable mdn to muon. However, this can introduce more overhead for update.)

Example use

  • Should work for deepseek2, deepseek3, llama2, gemma3 (e.g., deepseek2-16b, gemma3-4b, llama2-7b) scan
  • Works with different sharding (e.g., tested fsdp, dp, tp)
  • opt_type=muon, optionally muon_beta=0.95, muon_weight_decay=0.1, muon_consistent_rms=0.2
  • muon is used together with adamw.

Pretrain

BASE_OUTPUT_PATH=gs://runner-maxtext-logs
RUN_NAME=muon-$(date +%Y-%m-%d-%H-%M-%S)

python3 -m MaxText.train MaxText/configs/base.yml \
base_output_directory=$BASE_OUTPUT_PATH run_name=$RUN_NAME \
model_name=gemma3-4b \
tokenizer_type=sentencepiece tokenizer_path=src/MaxText/assets/tokenizer.gemma3 \
dataset_type=tfds dataset_path='gs://mlperf-llm-public2' dataset_name='c4/en:3.0.4' train_split='train2' \
enable_checkpointing=false dtype=bfloat16 weight_dtype=bfloat16 \
opt_type=muon learning_rate=5e-4 adam_weight_decay=0.1 muon_weight_decay=0.1 muon_consistent_rms=0.2 \
per_device_batch_size=16 max_target_length=2048 steps=20 \
ici_fsdp_parallelism=4 ici_data_parallelism=1 ici_tensor_parallelism=1 \
cosine_learning_rate_final_fraction=0.1 warmup_steps_fraction=0.1 learning_rate_schedule_steps=-1 \
override_model_config=true enable_dropout=false \
profiler=xplane skip_first_n_steps_for_profiler=5 profiler_steps=3

Train compile

BASE_OUTPUT_PATH=gs://runner-maxtext-logs
RUN_NAME=muon-$(date +%Y-%m-%d-%H-%M-%S)

python3 -m MaxText.train_compile MaxText/configs/base.yml \
base_output_directory=$BASE_OUTPUT_PATH run_name=$RUN_NAME \
model_name=gemma3-4b \
tokenizer_type=sentencepiece tokenizer_path=src/MaxText/assets/tokenizer.gemma3 \
dataset_type=tfds dataset_path='gs://mlperf-llm-public2' dataset_name='c4/en:3.0.4' train_split='train2' \
enable_checkpointing=false dtype=bfloat16 weight_dtype=bfloat16 \
opt_type=muon learning_rate=5e-4 adam_weight_decay=0.1 muon_weight_decay=0.1 muon_consistent_rms=0.2 \
per_device_batch_size=16 max_target_length=2048 steps=20 \
ici_fsdp_parallelism=2 ici_data_parallelism=1 ici_tensor_parallelism=2 \
cosine_learning_rate_final_fraction=0.1 warmup_steps_fraction=0.1 learning_rate_schedule_steps=-1 \
override_model_config=true enable_dropout=false \
compile_topology=v5p-8 compile_topology_num_slices=1

Tests

unit test for reshape: python -m MaxText.muon_dimension_number

end-to-end test: b/437908829

Checklist

Before submitting this PR, please make sure (put X in square brackets):

  • I have performed a self-review of my code. For an optional AI review, add the gemini-review label.
  • I have necessary comments in my code, particularly in hard-to-understand areas.
  • I have run end-to-end tests tests and provided workload links above if applicable.
  • I have made or will make corresponding changes to the doc if needed, including adding new documentation pages to the relevant Table of Contents (toctree directive) as explained in our documentation.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants