Hi, thanks for releasing the Muon implementation and examples.
I noticed that in the provided attention example, the optimizer applies Muon to the concatenated Q/K/V projection matrices (e.g., Wq of shape [d_model, n_heads * d_head]) rather than performing orthogonalization per attention head (i.e., per [d_model, d_head] block).
I would like to ask about the rationale for this design choice.
From my understanding, Muon is defined as a matrix-level optimizer, so treating each projection matrix as a single 2-D parameter is consistent with the theory. However, concatenating all heads together also prevents per-head decoupling and implicitly assumes that cross-head correlations should be preserved during the orthogonalization step.
Could you clarify the reasoning behind optimizing the concatenated matrices instead of head-wise blocks? Is this primarily a mathematical consideration, a stability constraint, or an engineering/performance decision?
Thank you!