-
Notifications
You must be signed in to change notification settings - Fork 3.4k
Cautious Muon + SP4096 + Depth Recurrence — val_bpb 1.1604 (non-record) #1381
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,74 @@ | ||
| # Cautious Muon + SP4096 + Depth Recurrence + Parallel Residuals | ||
|
|
||
| **val_bpb = 1.1604** (3-seed mean, std = 0.0033) | ||
|
|
||
| This is a non-record submission exploring the effect of Cautious Muon (arXiv:2411.16085) on the PR #1334 architecture stack. | ||
|
|
||
| ## Results | ||
|
|
||
| | Seed | val_bpb | val_loss | Artifact Size | | ||
| |------|---------|----------|---------------| | ||
| | 42 | 1.1568 | 2.6619 | 15,179,504 B | | ||
| | 314 | 1.1611 | 2.6717 | 15,173,470 B | | ||
| | 999 | 1.1634 | 2.6770 | 15,159,223 B | | ||
|
Comment on lines
+9
to
+13
|
||
| | **Mean** | **1.1604** | **2.6702** | **15,170,732 B** | | ||
|
|
||
| ## Key Technique: Cautious Muon (arXiv:2411.16085) | ||
|
|
||
| The primary modification is applying the Cautious optimizer principle to the Muon optimizer. The optimizer pipeline is: | ||
|
|
||
| 1. Compute momentum-corrected gradient (Nesterov) | ||
| 2. **MuonEq-R row normalization** — normalize rows before Newton-Schulz | ||
| 3. **Newton-Schulz orthogonalization** (5 steps) — project onto orthogonal manifold | ||
| 4. **Cautious masking** — only apply update where the orthogonalized direction agrees with the raw gradient sign: | ||
|
|
||
| ```python | ||
| caution_mask = (g * raw_grad > 0).to(g.dtype) | ||
| g = g * caution_mask / caution_mask.mean().clamp_min(1e-3) | ||
| ``` | ||
|
|
||
| This filters out "stale" momentum directions that disagree with the current gradient, providing ~1.47x effective convergence per step with zero parameter overhead and no impact on artifact size. | ||
|
|
||
| ## Full Architecture Stack | ||
|
|
||
| Built on PR #1334 (@aryanbhosale) with: | ||
| - **SP4096 BPE tokenizer** (from PR #1218, @clarkkev) | ||
| - **Depth recurrence** layers 4,5 (13 virtual layers from 11 physical, activated at step 3000) | ||
| - **Parallel residuals** from layer 7 (separate attn/MLP lanes with learnable merge) | ||
| - **MuonEq-R** row normalization before Newton-Schulz orthogonalization (arXiv:2603.28254) | ||
| - **Cautious masking** applied after Newton-Schulz orthogonalization | ||
| - **QK-Gain 5.0** per-head query-key scaling | ||
| - **EMA 0.997** weight averaging | ||
| - **Full GPTQ INT6** quantization with selective +-1 pruning | ||
| - **Brotli compression** | ||
|
|
||
| ## Non-matrix parameters | ||
|
|
||
| Token embeddings, scalar parameters, and head use standard `torch.optim.AdamW`. Cautious masking is applied only inside Muon for matrix parameters. | ||
|
|
||
| ## Compliance | ||
|
|
||
| - Track A fixed predictor -- no TTT, no SLOT, no eval-time adaptation | ||
| - All predictions are causal and normalized via softmax (F.cross_entropy) | ||
| - Artifact under 16MB limit (max 15,179,504 bytes) | ||
| - Training completes within 600s wallclock on 8xH100 SXM | ||
|
|
||
| ## Reproduction | ||
|
|
||
| ```bash | ||
| cd /workspace/parameter-golf | ||
| # Download SP4096 data | ||
| MATCHED_FINEWEB_REPO_ID=kevclark/parameter-golf python3 data/cached_challenge_fineweb.py --variant sp4096 | ||
| # Run | ||
| SEED=42 torchrun --standalone --nproc_per_node=8 train_gpt.py | ||
| ``` | ||
|
|
||
| ## Decompressed Source | ||
|
|
||
| The `train_gpt.py` is a self-extracting compressed script (`exec(lzma.decompress(base64.b85decode(...)))`). The decompressed source is ~1991 lines of Python. The only modification to the PR #1334 base is the two-line Cautious Muon mask inside `Muon.step()` (lines ~869-872 of the decompressed source). | ||
|
|
||
| ## Credits | ||
|
|
||
| - PR #1334 (@aryanbhosale) -- base architecture (SP4096, depth recurrence, parallel residuals, MuonEq-R) | ||
| - PR #1218 (@clarkkev) -- SP4096 tokenizer | ||
| - Liang et al. (arXiv:2411.16085) -- Cautious Optimizers | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,39 @@ | ||
| { | ||
| "author": "X-Abhishek-X", | ||
| "github_id": "X-Abhishek-X", | ||
| "name": "Cautious Muon + SP4096 + Depth Recurrence + Parallel Residuals", | ||
| "blurb": "Non-record submission. Applies Cautious Muon (arXiv:2411.16085) to the Muon optimizer — masks Newton-Schulz updates where the orthogonalized direction disagrees with the raw gradient sign, providing ~1.47x effective convergence per step with zero parameter overhead. Built on PR #1334 (aryanbhosale) base with SP4096 vocabulary, depth recurrence (layers 4,5), parallel residuals (from layer 7), MuonEq-R, QK-Gain 5.0, and full GPTQ INT6 + Brotli compression. Mean val_bpb = 1.1604 (3 seeds, std = 0.0033).", | ||
| "date": "2026-04-05T00:00:00Z", | ||
| "track": "10min_16mb", | ||
| "val_loss": 2.67020395, | ||
| "val_bpb": 1.16043988, | ||
| "val_loss_std": 0.00764948, | ||
| "val_bpb_std": 0.00332438, | ||
| "seeds": [42, 314, 999], | ||
|
Comment on lines
+8
to
+12
|
||
| "seed_results": { | ||
| "42": { | ||
| "val_loss": 2.66190192, | ||
| "val_bpb": 1.15683191, | ||
| "artifact_bytes": 15179504 | ||
| }, | ||
| "314": { | ||
| "val_loss": 2.67174312, | ||
| "val_bpb": 1.16110878, | ||
| "artifact_bytes": 15173470 | ||
| }, | ||
| "999": { | ||
| "val_loss": 2.67696681, | ||
| "val_bpb": 1.16337894, | ||
| "artifact_bytes": 15159223 | ||
| } | ||
| }, | ||
| "artifact_bytes_mean": 15170732, | ||
| "artifact_bytes_max": 15179504, | ||
| "bytes_total": 15179504, | ||
| "bytes_code": 24659, | ||
| "hardware": "8x H100 SXM (RunPod On-Demand)", | ||
| "pytorch_version": "2.9.1", | ||
| "cuda_version": "12.8", | ||
| "technique_summary": "Cautious Muon optimizer (arXiv:2411.16085), SP4096 BPE tokenizer, depth recurrence layers 4-5 (start step 3000), parallel residuals from layer 7, MuonEq-R row normalization, QK-Gain 5.0, EMA 0.997, full GPTQ INT6 quantization with selective pruning, Brotli compression", | ||
| "comparison_baseline_pr": 1334 | ||
| } | ||
Large diffs are not rendered by default.
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,134 @@ | ||
| W0405 12:33:07.793000 58424 torch/distributed/run.py:803] | ||
| W0405 12:33:07.793000 58424 torch/distributed/run.py:803] ***************************************** | ||
| W0405 12:33:07.793000 58424 torch/distributed/run.py:803] Setting OMP_NUM_THREADS environment variable for each process to be 1 in default, to avoid your system being overloaded, please further tune the variable for optimal performance in your application as needed. | ||
| W0405 12:33:07.793000 58424 torch/distributed/run.py:803] ***************************************** | ||
| Hyperparameters: | ||
| adam_eps: 1e-08 | ||
| adam_wd: 0.02 | ||
| beta1: 0.9 | ||
| beta2: 0.95 | ||
| cautious_muon: True | ||
| compressor: brotli | ||
| data_dir: ./data/ | ||
| datasets_dir: ./data/datasets/fineweb10B_sp4096 | ||
| distributed: True | ||
| ema_decay: 0.997 | ||
| embed_lr: 0.6 | ||
| embed_wd: 0.09 | ||
| embedding_dim: 512 | ||
| eval_seq_len: 2048 | ||
| eval_stride: 64 | ||
| gptq_calibration_batches: 64 | ||
| gptq_enabled: True | ||
| gptq_reserve_seconds: 10.0 | ||
| grad_accum_steps: 1 | ||
| grad_clip_norm: 0.3 | ||
| head_lr: 0.008 | ||
| is_main_process: True | ||
| iterations: 20000 | ||
| ln_scale: True | ||
| local_rank: 0 | ||
| logfile: logs/523d357e-1519-45f9-bc20-69bbfb520c1b.txt | ||
| logit_softcap: 30.0 | ||
| matrix_lr: 0.02 | ||
| max_wallclock_seconds: 600.0 | ||
| min_lr: 0.0 | ||
| mlp_mult: 4.0 | ||
| model_dim: 512 | ||
| model_path: final_model.pt | ||
| muon_backend_steps: 5 | ||
| muon_beta2: 0.95 | ||
| muon_momentum: 0.99 | ||
| muon_momentum_warmup_start: 0.92 | ||
| muon_momentum_warmup_steps: 1500 | ||
| muon_wd: 0.09 | ||
| num_heads: 8 | ||
| num_kv_heads: 4 | ||
| num_layers: 11 | ||
| parallel_start_layer: 7 | ||
| qk_gain_init: 5.0 | ||
| quantized_model_path: final_model.int6.ptz | ||
| rank: 0 | ||
| recur_layers: 4,5 | ||
| recur_start_step: 3000 | ||
| rope_base: 10000.0 | ||
| rope_dims: 16 | ||
| rope_train_seq_len: 2048 | ||
| run_id: 523d357e-1519-45f9-bc20-69bbfb520c1b | ||
| scalar_lr: 0.02 | ||
| seed: 314 | ||
| skip_gates_enabled: True | ||
| sliding_window_enabled: True | ||
| tie_embeddings: True | ||
| tied_embed_init_std: 0.005 | ||
| tied_embed_lr: 0.03 | ||
| tokenizer_path: ./data/tokenizers/fineweb_4096_bpe.model | ||
| train_batch_tokens: 786432 | ||
| train_files: ./data/datasets/fineweb10B_sp4096/fineweb_train_*.bin | ||
| train_log_every: 500 | ||
| train_seq_len: 2048 | ||
| ttt_batch_seqs: 32 | ||
| ttt_chunk_tokens: 32768 | ||
| ttt_enabled: False | ||
| ttt_epochs: 3 | ||
| ttt_freeze_blocks: 0 | ||
| ttt_grad_clip: 1.0 | ||
| ttt_lr: 0.002 | ||
| ttt_momentum: 0.9 | ||
| val_batch_tokens: 524288 | ||
| val_files: ./data/datasets/fineweb10B_sp4096/fineweb_val_*.bin | ||
| val_loss_every: 4000 | ||
| ve_dim: 128 | ||
| ve_enabled: True | ||
| ve_layers: 9,10 | ||
| vocab_size: 4096 | ||
| warmdown_frac: 0.667 | ||
| warmup_steps: 20 | ||
| world_size: 8 | ||
| xsa_last_n: 11 | ||
| train_shards: 80 | ||
| val_tokens: 45508608 | ||
| model_params:34401372 | ||
| gptq:reserving 10s, effective=590000ms | ||
| warmup_step: 1/20 | ||
| warmup_step: 2/20 | ||
| warmup_step: 3/20 | ||
| warmup_step: 4/20 | ||
| warmup_step: 5/20 | ||
| warmup_step: 6/20 | ||
| warmup_step: 10/20 | ||
| warmup_step: 20/20 | ||
| 0/20000 val_loss: 8.3172 val_bpb: 3.6146 | ||
| 1/20000 train_loss: 8.3192 train_time: 0.0m tok/s: 8417806 | ||
| 2/20000 train_loss: 11.4353 train_time: 0.0m tok/s: 8294705 | ||
| 3/20000 train_loss: 8.8870 train_time: 0.0m tok/s: 8200165 | ||
| 4/20000 train_loss: 8.0296 train_time: 0.0m tok/s: 8159612 | ||
| 5/20000 train_loss: 8.6996 train_time: 0.0m tok/s: 8141617 | ||
| 500/20000 train_loss: 3.0633 train_time: 0.8m tok/s: 7939661 | ||
| 1000/20000 train_loss: 2.9869 train_time: 1.7m tok/s: 7911553 | ||
| 1500/20000 train_loss: 2.9710 train_time: 2.5m tok/s: 7908791 | ||
| 2000/20000 train_loss: 2.6991 train_time: 3.3m tok/s: 7909908 | ||
| 2500/20000 train_loss: 2.7380 train_time: 4.1m tok/s: 7911617 | ||
| 3000/20000 train_loss: 2.7805 train_time: 5.0m tok/s: 7913690 | ||
| recurrence:activated at step 3000, virtual_layers=[0, 1, 2, 3, 4, 5, 4, 5, 6, 7, 8, 9, 10] | ||
| 3500/20000 train_loss: 2.6930 train_time: 6.1m tok/s: 7516407 | ||
| 4000/20000 train_loss: 2.6237 train_time: 7.1m tok/s: 7424228 | ||
| 4000/20000 val_loss: 2.6465 val_bpb: 1.1501 | ||
| 4500/20000 train_loss: 2.5745 train_time: 8.0m tok/s: 7355829 | ||
| 5000/20000 train_loss: 2.5200 train_time: 9.0m tok/s: 7301867 | ||
| 5449/20000 val_loss: 2.5373 val_bpb: 1.1027 | ||
| stopping_early: wallclock_cap train_time: 590030ms step: 5449/20000 | ||
| peak memory allocated: 30120 MiB reserved: 30154 MiB | ||
| ema:applying EMA weights | ||
| pre-quantization post-ema val_loss:2.53590710 val_bpb:1.10207601 eval_time:2003ms | ||
| Serialized model: 132406149 bytes | ||
| Code size: 24659 bytes | ||
| GPTQ:collecting Hessians from calibration data... | ||
| GPTQ:collected 66 Hessians in 9.7s | ||
| GPTQ quantization: 66 layers with full GPTQ, 0 fallback to clip-search | ||
| selective_prune: unpruned=16.59MB target=16.0MB | ||
| selective_prune: pruning 4714416/8989294 lowest-error ±1 values (excess=589302B) | ||
| Serialized model int6+brotli: 15148811 bytes | ||
| Total submission size int6+brotli: 15173470 bytes | ||
| final_int6_roundtrip val_loss:2.71751254 val_bpb:1.18099965 eval_time:8188ms | ||
| final_int6_sliding_window val_loss:2.67174312 val_bpb:1.16110878 eval_time:76629ms |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,134 @@ | ||
| W0405 12:16:26.757000 47726 torch/distributed/run.py:803] | ||
| W0405 12:16:26.757000 47726 torch/distributed/run.py:803] ***************************************** | ||
| W0405 12:16:26.757000 47726 torch/distributed/run.py:803] Setting OMP_NUM_THREADS environment variable for each process to be 1 in default, to avoid your system being overloaded, please further tune the variable for optimal performance in your application as needed. | ||
| W0405 12:16:26.757000 47726 torch/distributed/run.py:803] ***************************************** | ||
| Hyperparameters: | ||
| adam_eps: 1e-08 | ||
| adam_wd: 0.02 | ||
| beta1: 0.9 | ||
| beta2: 0.95 | ||
| cautious_muon: True | ||
| compressor: brotli | ||
| data_dir: ./data/ | ||
| datasets_dir: ./data/datasets/fineweb10B_sp4096 | ||
| distributed: True | ||
| ema_decay: 0.997 | ||
| embed_lr: 0.6 | ||
| embed_wd: 0.09 | ||
| embedding_dim: 512 | ||
| eval_seq_len: 2048 | ||
| eval_stride: 64 | ||
| gptq_calibration_batches: 64 | ||
| gptq_enabled: True | ||
| gptq_reserve_seconds: 10.0 | ||
| grad_accum_steps: 1 | ||
| grad_clip_norm: 0.3 | ||
| head_lr: 0.008 | ||
| is_main_process: True | ||
| iterations: 20000 | ||
| ln_scale: True | ||
| local_rank: 0 | ||
| logfile: logs/6e35fe72-b1cf-49e5-95dc-2b7d967c8075.txt | ||
| logit_softcap: 30.0 | ||
| matrix_lr: 0.02 | ||
| max_wallclock_seconds: 600.0 | ||
| min_lr: 0.0 | ||
| mlp_mult: 4.0 | ||
| model_dim: 512 | ||
| model_path: final_model.pt | ||
| muon_backend_steps: 5 | ||
| muon_beta2: 0.95 | ||
| muon_momentum: 0.99 | ||
| muon_momentum_warmup_start: 0.92 | ||
| muon_momentum_warmup_steps: 1500 | ||
| muon_wd: 0.09 | ||
| num_heads: 8 | ||
| num_kv_heads: 4 | ||
| num_layers: 11 | ||
| parallel_start_layer: 7 | ||
| qk_gain_init: 5.0 | ||
| quantized_model_path: final_model.int6.ptz | ||
| rank: 0 | ||
| recur_layers: 4,5 | ||
| recur_start_step: 3000 | ||
| rope_base: 10000.0 | ||
| rope_dims: 16 | ||
| rope_train_seq_len: 2048 | ||
| run_id: 6e35fe72-b1cf-49e5-95dc-2b7d967c8075 | ||
| scalar_lr: 0.02 | ||
| seed: 42 | ||
| skip_gates_enabled: True | ||
| sliding_window_enabled: True | ||
| tie_embeddings: True | ||
| tied_embed_init_std: 0.005 | ||
| tied_embed_lr: 0.03 | ||
| tokenizer_path: ./data/tokenizers/fineweb_4096_bpe.model | ||
| train_batch_tokens: 786432 | ||
| train_files: ./data/datasets/fineweb10B_sp4096/fineweb_train_*.bin | ||
| train_log_every: 500 | ||
| train_seq_len: 2048 | ||
| ttt_batch_seqs: 32 | ||
| ttt_chunk_tokens: 32768 | ||
| ttt_enabled: False | ||
| ttt_epochs: 3 | ||
| ttt_freeze_blocks: 0 | ||
| ttt_grad_clip: 1.0 | ||
| ttt_lr: 0.002 | ||
| ttt_momentum: 0.9 | ||
| val_batch_tokens: 524288 | ||
| val_files: ./data/datasets/fineweb10B_sp4096/fineweb_val_*.bin | ||
| val_loss_every: 4000 | ||
| ve_dim: 128 | ||
| ve_enabled: True | ||
| ve_layers: 9,10 | ||
| vocab_size: 4096 | ||
| warmdown_frac: 0.667 | ||
| warmup_steps: 20 | ||
| world_size: 8 | ||
| xsa_last_n: 11 | ||
| train_shards: 80 | ||
| val_tokens: 45508608 | ||
| model_params:34401372 | ||
| gptq:reserving 10s, effective=590000ms | ||
| warmup_step: 1/20 | ||
| warmup_step: 2/20 | ||
| warmup_step: 3/20 | ||
| warmup_step: 4/20 | ||
| warmup_step: 5/20 | ||
| warmup_step: 6/20 | ||
| warmup_step: 10/20 | ||
| warmup_step: 20/20 | ||
| 0/20000 val_loss: 8.3187 val_bpb: 3.6152 | ||
| 1/20000 train_loss: 8.3201 train_time: 0.0m tok/s: 8409702 | ||
| 2/20000 train_loss: 11.3749 train_time: 0.0m tok/s: 8301223 | ||
| 3/20000 train_loss: 8.9270 train_time: 0.0m tok/s: 8232435 | ||
| 4/20000 train_loss: 8.0370 train_time: 0.0m tok/s: 8197540 | ||
| 5/20000 train_loss: 8.6257 train_time: 0.0m tok/s: 8158832 | ||
| 500/20000 train_loss: 3.0617 train_time: 0.8m tok/s: 7952184 | ||
| 1000/20000 train_loss: 2.9819 train_time: 1.7m tok/s: 7929746 | ||
| 1500/20000 train_loss: 2.9701 train_time: 2.5m tok/s: 7920772 | ||
| 2000/20000 train_loss: 2.6938 train_time: 3.3m tok/s: 7916245 | ||
| 2500/20000 train_loss: 2.7396 train_time: 4.1m tok/s: 7918140 | ||
| 3000/20000 train_loss: 2.7789 train_time: 5.0m tok/s: 7919917 | ||
| recurrence:activated at step 3000, virtual_layers=[0, 1, 2, 3, 4, 5, 4, 5, 6, 7, 8, 9, 10] | ||
| 3500/20000 train_loss: 2.6938 train_time: 6.1m tok/s: 7516573 | ||
| 4000/20000 train_loss: 2.6257 train_time: 7.1m tok/s: 7425494 | ||
| 4000/20000 val_loss: 2.6472 val_bpb: 1.1505 | ||
| 4500/20000 train_loss: 2.5745 train_time: 8.0m tok/s: 7356557 | ||
| 5000/20000 train_loss: 2.5209 train_time: 9.0m tok/s: 7302902 | ||
| 5450/20000 val_loss: 2.5381 val_bpb: 1.1030 | ||
| stopping_early: wallclock_cap train_time: 590058ms step: 5450/20000 | ||
| peak memory allocated: 30120 MiB reserved: 30154 MiB | ||
| ema:applying EMA weights | ||
| pre-quantization post-ema val_loss:2.53669680 val_bpb:1.10241920 eval_time:2002ms | ||
| Serialized model: 132406149 bytes | ||
| Code size: 24659 bytes | ||
| GPTQ:collecting Hessians from calibration data... | ||
| GPTQ:collected 66 Hessians in 9.8s | ||
| GPTQ quantization: 66 layers with full GPTQ, 0 fallback to clip-search | ||
| selective_prune: unpruned=16.59MB target=16.0MB | ||
| selective_prune: pruning 4683904/8992183 lowest-error ±1 values (excess=585488B) | ||
| Serialized model int6+brotli: 15154845 bytes | ||
| Total submission size int6+brotli: 15179504 bytes | ||
| final_int6_roundtrip val_loss:2.70991443 val_bpb:1.17769759 eval_time:23408ms | ||
| final_int6_sliding_window val_loss:2.66190192 val_bpb:1.15683191 eval_time:99175ms |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This submission is labeled as a “Record” (folder/PR title), but the reported mean
val_bpb = 1.1604is substantially worse than the current 10min_16mb leaderboard entries (e.g. 1.1228 in the repo README). Consider renaming the PR/folder/READMEnameto avoid implying it’s a new SOTA record if it’s intended as a non-record/ablation submission.