Skip to content

Commit 245759b

Browse files
committed
Add muon optimizer
1 parent 72979a3 commit 245759b

File tree

7 files changed

+249
-5
lines changed

7 files changed

+249
-5
lines changed

src/MaxText/configs/base.yml

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -622,7 +622,7 @@ gradient_clipping_threshold: 1.0
622622
# batch by accumulating the gradient over a set of steps.
623623
gradient_accumulation_steps: 1
624624

625-
opt_type: "adamw" # one of "adamw", "adam_pax" or "sgd"
625+
opt_type: "adamw" # one of "adamw", "adam_pax", "sgd", or "muon"
626626

627627
# AdamW optimizer parameters
628628
# We use AdamW following Llama2's training details, see https://arxiv.org/pdf/2307.09288.pdf section 2.2
@@ -635,6 +635,14 @@ mu_dtype: "" # data type to store "mu" of AdamW tracking the first moment. Inher
635635
# Setting nu_dtype is not yet supported by optax, instead nu_dtype is always inherited from weights.
636636
# See b/399961932 for more.
637637

638+
# Muon optimizer parameters
639+
# https://github.com/google-deepmind/optax/blob/main/optax/contrib/_muon.py
640+
# "mu_dtype", "adam_eps" are shared by AdamW
641+
# "nesterov", "ns_coeffs", "ns_steps", "weight_decay_mask", "adaptive" use default
642+
muon_beta: 0.95 # Decay rate for the exponentially weighted average of grads.
643+
muon_weight_decay: 0 # Strength of the weight decay regularization. This is multiplied with the learning rate.
644+
muon_consistent_rms: None # If None, apply width scaling to updates. If float, apply consistent rms scaling (recommend 0.2).
645+
638646
# Stack trace parameters
639647
collect_stack_trace: False
640648
stack_trace_to_cloud: False # Uploads to cloud logging if True, else to the console if False.
Lines changed: 195 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,195 @@
1+
import jax
2+
import jax.numpy as jnp
3+
from optax.contrib import MuonDimensionNumbers as mdn
4+
5+
# deepseek2-16b, scanned, q_lora_rank=0
6+
# NOTE: not compatible with deepseek2-236b (q_lora_rank: 1536)
7+
DEEPSEEK2_DIMENSION_NUMBER = {
8+
"params": {
9+
"decoder": {
10+
"dense_layers": {
11+
"mlp": {
12+
"wi_0": {"kernel": mdn((0,), (-1,))},
13+
"wi_1": {"kernel": mdn((0,), (-1,))},
14+
"wo": {"kernel": mdn((0,), (-1,))},
15+
},
16+
"self_attention": {
17+
"kv_norm": {"scale": None},
18+
"wkv_a": {"kernel": mdn((0,), (-1,))},
19+
"wkv_b": {"kernel": mdn((0,), (-2, -1))},
20+
"out": {"kernel": mdn((0, -2), (-1,))},
21+
"query": {"kernel": mdn((0,), (-2, -1))}, # ds2
22+
},
23+
"pre_self_attention_layer_norm": {"scale": None},
24+
"post_self_attention_layer_norm": {"scale": None},
25+
},
26+
"moe_layers": {
27+
"DeepSeekMoeBlock_0": {
28+
"MoeBlock_0": {
29+
"wi_0": mdn((-2,), (-1,)),
30+
"wi_1": mdn((-2,), (-1,)),
31+
"wo": mdn((-2,), (-1,)),
32+
"gate": {"kernel": mdn((0,), (-1,))}, # ds2
33+
},
34+
"shared_experts": {
35+
"wi_0": {"kernel": mdn((0,), (-1,))},
36+
"wi_1": {"kernel": mdn((0,), (-1,))},
37+
"wo": {"kernel": mdn((0,), (-1,))},
38+
},
39+
},
40+
"self_attention": {
41+
"kv_norm": {"scale": None},
42+
"wkv_a": {"kernel": mdn((0,), (-1,))},
43+
"wkv_b": {"kernel": mdn((0,), (-2, -1))},
44+
"out": {"kernel": mdn((0, -2), (-1,))},
45+
"query": {"kernel": mdn((0,), (-2, -1))}, # ds2
46+
},
47+
"pre_self_attention_layer_norm": {"scale": None},
48+
"post_self_attention_layer_norm": {"scale": None},
49+
},
50+
"decoder_norm": {"scale": None},
51+
"logits_dense": {"kernel": None},
52+
},
53+
"token_embedder": {"embedding": None},
54+
}
55+
}
56+
57+
58+
# deepseek3, scanned
59+
DEEPSEEK3_DIMENSION_NUMBER = {
60+
"params": {
61+
"decoder": {
62+
"dense_layers": {
63+
"mlp": {
64+
"wi_0": {"kernel": mdn((0,), (-1,))},
65+
"wi_1": {"kernel": mdn((0,), (-1,))},
66+
"wo": {"kernel": mdn((0,), (-1,))},
67+
},
68+
"self_attention": {
69+
"kv_norm": {"scale": None},
70+
"wkv_a": {"kernel": mdn((0,), (-1,))},
71+
"wkv_b": {"kernel": mdn((0,), (-2, -1))},
72+
"out": {"kernel": mdn((0, -2), (-1,))},
73+
"q_norm": {"scale": None}, # ds3
74+
"wq_a": {"kernel": mdn((0,), (-1,))}, # ds3
75+
"wq_b": {"kernel": mdn((0,), (-2, -1))}, # ds3
76+
},
77+
"pre_self_attention_layer_norm": {"scale": None},
78+
"post_self_attention_layer_norm": {"scale": None},
79+
},
80+
"moe_layers": {
81+
"DeepSeekMoeBlock_0": {
82+
"MoeBlock_0": {
83+
"wi_0": mdn((-2,), (-1,)),
84+
"wi_1": mdn((-2,), (-1,)),
85+
"wo": mdn((-2,), (-1,)),
86+
"gate": {"kernel": mdn((0,), (-1,)), "bias": None}, # ds3
87+
},
88+
"shared_experts": {
89+
"wi_0": {"kernel": mdn((0,), (-1,))},
90+
"wi_1": {"kernel": mdn((0,), (-1,))},
91+
"wo": {"kernel": mdn((0,), (-1,))},
92+
},
93+
},
94+
"self_attention": {
95+
"kv_norm": {"scale": None},
96+
"wkv_a": {"kernel": mdn((0,), (-1,))},
97+
"wkv_b": {"kernel": mdn((0,), (-2, -1))},
98+
"out": {"kernel": mdn((0, -2), (-1,))},
99+
"q_norm": {"scale": None}, # ds3
100+
"wq_a": {"kernel": mdn((0,), (-1,))}, # ds3
101+
"wq_b": {"kernel": mdn((0,), (-2, -1))}, # ds3
102+
},
103+
"pre_self_attention_layer_norm": {"scale": None},
104+
"post_self_attention_layer_norm": {"scale": None},
105+
},
106+
"decoder_norm": {"scale": None},
107+
"logits_dense": {"kernel": None},
108+
},
109+
"token_embedder": {"embedding": None},
110+
}
111+
}
112+
113+
114+
def transform_logic(path):
115+
"""
116+
assume scan (i.e., dim 1 is layer num L), should work with unscan (without L)
117+
works for deepseek, llama2, gemma3
118+
"""
119+
# moe: [0, L, -2, -1]
120+
if "MoeBlock_0" in path and ("wo" in path or "wi_0" in path or "wi_1" in path):
121+
return mdn((-2,), (-1,))
122+
# attention out proj: [0, L, -2, -1]
123+
elif "self_attention" in path and "out" in path:
124+
return mdn((0, -2), (-1,))
125+
# attention qkv proj: [0, L, -2, -1]
126+
elif "self_attention" in path and (
127+
"query" in path or "key" in path or "value" in path or "wq_b" in path or "wkv_b" in path
128+
):
129+
return mdn((0,), (-2, -1))
130+
# do not apply muon: scalar, embedding, unembedding
131+
elif "scale" in path or "bias" in path or "embedding" in path or "logits_dense" in path:
132+
return None
133+
else:
134+
# all other: [0, L, -1]
135+
return mdn((0,), (-1,))
136+
137+
138+
def get_transform_tree(tree, path=()):
139+
if isinstance(tree, dict):
140+
return {k: get_transform_tree(v, path + (k,)) for k, v in tree.items()}
141+
else:
142+
return transform_logic(path)
143+
144+
145+
def get_abstract_param(model, config):
146+
key = jax.random.PRNGKey(0)
147+
input_shape = (config.micro_batch_size_to_train_on, config.max_target_length)
148+
abstract_vars = jax.eval_shape(
149+
model.init,
150+
{"params": key, "dropout": key, "aqt": key},
151+
jnp.ones(input_shape, dtype=jnp.int32),
152+
jnp.ones(input_shape, dtype=jnp.int32),
153+
encoder_images=None,
154+
)
155+
return abstract_vars
156+
157+
158+
def test1():
159+
assert get_transform_tree(DEEPSEEK2_DIMENSION_NUMBER) == DEEPSEEK2_DIMENSION_NUMBER
160+
assert get_transform_tree(DEEPSEEK3_DIMENSION_NUMBER) == DEEPSEEK3_DIMENSION_NUMBER
161+
162+
163+
def test2():
164+
from MaxText import pyconfig, maxtext_utils
165+
from MaxText.globals import MAXTEXT_PKG_DIR
166+
from MaxText.layers import models, quantizations
167+
import os
168+
169+
Transformer = models.transformer_as_linen
170+
171+
def _test2(model_name):
172+
# init model
173+
argv = [None, os.path.join(MAXTEXT_PKG_DIR, "configs", "base.yml"), f"model_name={model_name}"]
174+
config = pyconfig.initialize(argv)
175+
rng = jax.random.PRNGKey(0)
176+
devices_array = maxtext_utils.create_device_mesh(config)
177+
mesh = jax.sharding.Mesh(devices_array, config.mesh_axes)
178+
quant = quantizations.configure_quantization(config)
179+
model = Transformer(config, mesh=mesh, quant=quant)
180+
# quickly get param structure without materialization
181+
abstract_param = get_abstract_param(model, config)
182+
print(abstract_param)
183+
# get muon dimension number
184+
transform_tree = get_transform_tree(abstract_param)
185+
return transform_tree
186+
187+
assert _test2("deepseek2-16b") == DEEPSEEK2_DIMENSION_NUMBER
188+
assert _test2("deepseek3-test") == DEEPSEEK3_DIMENSION_NUMBER
189+
assert _test2("deepseek3-671b") == DEEPSEEK3_DIMENSION_NUMBER
190+
191+
192+
if __name__ == "__main__":
193+
# python -m MaxText.muon_dimension_number
194+
test1()
195+
test2()

src/MaxText/optimizers.py

Lines changed: 29 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,11 +17,14 @@
1717

1818
import jax
1919
import jax.numpy as jnp
20+
from flax.linen import partitioning as nn_partitioning
2021

2122
import optax
23+
from optax.contrib import muon
24+
from MaxText.muon_dimension_number import get_abstract_param, get_transform_tree
2225

2326

24-
def get_optimizer(config, learning_rate_schedule):
27+
def get_optimizer(config, learning_rate_schedule, model=None):
2528
"""Create optimizer."""
2629
if config.opt_type == "adamw":
2730
# Create AdamW Optimizer following Llama2's training details, see https://arxiv.org/pdf/2307.09288.pdf section 2.2
@@ -45,6 +48,31 @@ def get_optimizer(config, learning_rate_schedule):
4548
)
4649
elif config.opt_type == "sgd":
4750
return optax.sgd(learning_rate_schedule)
51+
elif config.opt_type == "muon":
52+
# extract muon dimension number from model structure
53+
assert model is not None
54+
with model.mesh, nn_partitioning.axis_rules(config.logical_axis_rules):
55+
abstract_param = get_abstract_param(model, config)
56+
print(abstract_param)
57+
muon_weight_dimension_numbers = get_transform_tree(abstract_param)
58+
print("dimension number:", muon_weight_dimension_numbers)
59+
muon_kwargs = {
60+
# Shared parameters: "nesterov" uses default
61+
"learning_rate": learning_rate_schedule,
62+
"eps": config.adam_eps,
63+
"mu_dtype": config.mu_dtype,
64+
# Muon-specific parameters: "ns_coeffs", "ns_steps", "weight_decay_mask", "adaptive" uses default
65+
"beta": config.muon_beta,
66+
"weight_decay": config.muon_weight_decay,
67+
"muon_weight_dimension_numbers": muon_weight_dimension_numbers,
68+
"consistent_rms": config.muon_consistent_rms,
69+
# AdamW-specific parameters
70+
"adam_b1": config.adam_b1,
71+
"adam_b2": config.adam_b2,
72+
"adam_eps_root": config.adam_eps_root,
73+
"adam_weight_decay": config.adam_weight_decay,
74+
}
75+
return muon(**muon_kwargs)
4876
else:
4977
raise ValueError(f"{config.opt_type=} is not a supported.")
5078

src/MaxText/pyconfig.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -483,6 +483,16 @@ def resolve_config_path(param: str) -> str:
483483
return param if os.path.isfile(param) else os.path.join("src", param)
484484

485485

486+
def set_muon_config(raw_keys):
487+
if raw_keys["muon_consistent_rms"] in ["None", "none"]:
488+
raw_keys["muon_consistent_rms"] = None
489+
else:
490+
try:
491+
raw_keys["muon_consistent_rms"] = float(raw_keys["muon_consistent_rms"])
492+
except ValueError as e:
493+
raise ValueError(f"muon_consistent_rms should be None or float") from e
494+
495+
486496
class _HyperParameters:
487497
# pylint: disable=missing-class-docstring
488498
# This class is responsible for loading, merging, and overriding the configuration.
@@ -735,6 +745,7 @@ def user_init(raw_keys):
735745
raw_keys["mu_dtype"] = set_mu_dtype(raw_keys)
736746
raw_keys["logical_axis_rules"] = _lists_to_tuples(raw_keys["logical_axis_rules"])
737747
raw_keys["data_sharding"] = _lists_to_tuples(raw_keys["data_sharding"])
748+
set_muon_config(raw_keys)
738749

739750
if raw_keys["remat_policy"] == "custom":
740751
raw_keys = validate_and_assign_remat_tensors(raw_keys)

src/MaxText/sft/sft_trainer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -148,7 +148,7 @@ def train(mt_config, goodput_recorder=None):
148148
with maybe_record_goodput(goodput_recorder, GoodputEvent.TPU_INIT):
149149
model, mesh = model_creation_utils.create_nnx_model(mt_config)
150150
learning_rate_schedule = maxtext_utils.create_learning_rate_schedule(mt_config)
151-
optimizer = optimizers.get_optimizer(mt_config, learning_rate_schedule)
151+
optimizer = optimizers.get_optimizer(mt_config, learning_rate_schedule, model)
152152

153153
with maybe_record_goodput(goodput_recorder, GoodputEvent.TRAINING_PREPARATION):
154154
training_hooks = hooks.SFTTrainingHooks(mt_config, mesh, learning_rate_schedule, goodput_recorder)

src/MaxText/train_compile.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -88,7 +88,8 @@ def get_shaped_inputs(topology_mesh, config):
8888
model = Transformer(config, topology_mesh, quant=quant, model_mode=MODEL_MODE_TRAIN)
8989
# The learning_rate_schedule is baked into the compiled object.
9090
learning_rate_schedule = maxtext_utils.create_learning_rate_schedule(config)
91-
tx = optimizers.get_optimizer(config, learning_rate_schedule)
91+
# pass in model for muon
92+
tx = optimizers.get_optimizer(config, learning_rate_schedule, model)
9293

9394
# Shaped RNG keys
9495
_, example_rng = jax.random.split(jax.random.PRNGKey(0), 2)

src/MaxText/train_utils.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,8 @@ def create_training_tools(config, model, mesh):
3333
"""Creates the init_rng, optimizer, learning rate schedule, and checkpoint manager."""
3434
init_rng = jax.random.PRNGKey(config.init_weights_seed)
3535
learning_rate_schedule = maxtext_utils.create_learning_rate_schedule(config)
36-
tx = optimizers.get_optimizer(config, learning_rate_schedule)
36+
# pass in model for muon
37+
tx = optimizers.get_optimizer(config, learning_rate_schedule, model)
3738
logger = checkpointing.setup_checkpoint_logger(config)
3839
if config.enable_multi_tier_checkpointing:
3940
checkpoint_manager = checkpointing.create_orbax_emergency_replicator_checkpoint_manager(

0 commit comments

Comments
 (0)