Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Description
Prerequisite: Need latest changes in optax.contrib.muon
or
What this PR does
3D and 4D parameters are logically 2D. Use the MuonDimensionNumber (mdn) for reshaping specification.
reduction_dim- in feature,output_dim- out feature, the rest dims are batch over. dims can be negative number, e.g., 0 is 0th dim, -1 is the last dim, -2 is second to last. dims grouped together are flatten.decoder.moe_layers.DeepSeekMoeBlock_0.MoeBlock_0.wo, (num_experts, num_layer, base_moe_mlp_dim, base_emb_dim), reduction_dim = (-2,), output_dim = (-1,)decoder.moe_layers.self_attention.out.kernel, (base_num_query_heads, num_layer, v_head_dim, base_emb_dim), reduction = (0, -2), outputdim = (-1,)To generalize the reshaping
Example use
opt_type=muon, optionallymuon_beta=0.95,muon_weight_decay=0.1,muon_consistent_rms=0.2Pretrain
Train compile
Tests
unit test for reshape:
python -m MaxText.muon_dimension_numberend-to-end test: b/437908829
Checklist
Before submitting this PR, please make sure (put X in square brackets):
gemini-reviewlabel.