From f47b26c5714ebd06534c0682628769dde83a155c Mon Sep 17 00:00:00 2001 From: Alex Date: Fri, 24 Apr 2026 11:35:32 -0400 Subject: [PATCH 1/4] Update leaderboard with recent record submissions --- README.md | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/README.md b/README.md index 0393a3b7f2..0aaa2e5a4b 100644 --- a/README.md +++ b/README.md @@ -30,6 +30,8 @@ Happy training! | Run | Score | Author | Summary | Date | Info | |-----|------:|--------|---------|------|------| +| Scylla Tokenizer + Full GPTQ + XSA-all + FA3 | 0.9485 | icryo | On PR #1184: Scylla TokenMonster tokenizer from PR #1143 + the PR #1060 Full GPTQ/XSA-all/coprime-loader stack + FA3 | 2026-03-31 | [info](records/track_10min_16mb/2026-03-31_Scylla_FullGPTQ_XSA11_FA3_0.9485/README.md) | +| Scylla Tokenizer + Legal Score-First TTT | 1.0806 | simon-marcus | On PR #1143: TokenMonster-derived Scylla tokenizer, retokenized FineWeb, metadata byte accounting, and legal score-first TTT | 2026-03-30 | [info](https://github.com/openai/parameter-golf/pull/1143) | | SP8192 + 3-Layer Recurrence + Parallel Residuals + Legal TTT | 1.0810 | bigbag | On PR #1493: 3-layer recurrence, parallel residuals, QK-Gain 5.25, and legal score-first TTT on the PR #1394 stack | 2026-04-09 | [info](records/track_10min_16mb/2026-04-09_SP8192_3LayerRecur_ParResid_QK525_LegalTTT/README.md) | | SP8192 + Parallel Residuals + Score-First TTT | 1.0822 | aryanbhosale | On PR #1477: parallel residuals on the PR #1413 SP8192 + legal score-first TTT stack | 2026-04-08 | [info](records/track_10min_16mb/2026-04-08_SP8192_ParallelResid_ScoreFirstTTT/README.md) | | SP8192 + QK-Gain 5 + Legal Score-First TTT | 1.0828 | dexhunter | On PR #1413: QK-Gain 5.0 + legal score-first TTT on the PR #1394 SP8192 stack | 2026-04-06 | [info](records/track_10min_16mb/2026-04-06_SP8192_QK5_LegalTTT_1.0828/README.md) | @@ -39,7 +41,11 @@ Happy training! | MuonEq-R + Depth Recurrence + WD=0.090 + All-Int6 GPTQ | 1.0912 | dexhunter | On PR #1285: MuonEq-R + layers 4-5 recurrence + higher weight decay + all-int6 GPTQ | 2026-04-03 | [info](records/track_10min_16mb/2026-04-03_MuonEqR_DepthRecurrence_WD090_AllInt6/README.md) | | 4096-Vocab + Larger Model + High WD + Simplifications | 1.0979 | Kevin Clark | On PR #1218: SP4096 + 4x MLP + high weight decay, with TTT, hash embeddings, SmearGate, and value residuals removed | 2026-04-01 | [info](records/track_10min_16mb/2026-04-01_Vocab4096_MLPMult4_WD085/README.md) | | Parallel Residuals + Mini Depth Recurrence | 1.1063 | Marko Sisovic | On PR #1204: mini recurrence on layers 4-5 + parallel attention/MLP residual lanes + AR self-generated GPTQ calibration | 2026-03-31 | [info](records/track_10min_16mb/2026-03-31_ParallelResiduals_MiniDepthRecurrence/README.md) | +| Rascal | 1.1099 | newjordan | On PR #1120: XSA-all + Parallel Muon + coprime loader + Bigram2048/RoPE16 + SWA/late QAT without GPTQ | 2026-03-30 | [info](https://github.com/openai/parameter-golf/pull/1120) | +| Coprime-Stride Loader + Full GPTQ + XSA-all | 1.1122 | dexhunter | On PR #1060: coprime multi-shard loader + Full Hessian GPTQ + XSA on all layers + BigramHash(2816x112) | 2026-03-29 | [info](records/track_10min_16mb/2026-03-29_Loader_FullGPTQ_XSA11_BigramHash2816/README.md) | | 11L AR Self-Gen GPTQ + XSA | 1.1147 | abaybektursun | On PR #1019: Self-Generated GPTQ Calibration Data + all-layer XSA on the PR #549 stack | 2026-03-25 | [info](records/track_10min_16mb/2026-03-25_ValCalib_GPTQ_XSA_BigramHash3072/README.md) | +| 11L Muon TTT + Entropy-Adaptive Epochs | 1.1179 | aamodbhatt | On PR #1148: Muon-style Newton-Schulz TTT updates + entropy-adaptive 2/3/4 epoch selection on the PR #549 stack | 2026-03-28 | [info](records/track_10min_16mb/2026-03-28_MuonTTT_EntropyAdaptive_11L_8xH100/README.md) | +| LeakyReLU(0.75)^2 + Legal TTT + Parallel Muon | 1.1185 | michaelwinczuk | On PR #1031: LeakyReLU negative_slope 0.75 with tuned matrix LR and warmdown on the PR #549 stack | 2026-03-27 | [info](https://github.com/openai/parameter-golf/pull/1031) | | LeakyReLU² + Legal Score-First TTT + Parallel Muon | 1.1194 | abaybektursun | On PR #549: LeakyReLU(0.5)^2 + TTT + Parallel Muon on the PR #414 stack | 2026-03-23 | [info](records/track_10min_16mb/2026-03-23_LeakyReLU_LegalTTT_ParallelMuon/README.md) | | 11L EMA + GPTQ-lite + warmdown3500 | 1.1228 | signalrush | On PR #374: GPTQ-lite clip search + EMA, plus warmdown3500 and QAT@0.15 | 2026-03-22 | [info](records/track_10min_16mb/2026-03-22_11L_EMA_GPTQ-lite_warmdown3500_QAT015_1.1233/README.md) | | 11L Partial RoPE + LN Scale + EMA + XSA4 | 1.1248 | jfprincz | On PR #287: Partial RoPE (16/64) + layerwise LN scale | 2026-03-21 | [info](records/track_10min_16mb/2026-03-21_11L_XSA4_EMA_PartialRoPE_LateQAT_1.1248/README.md) | From f13ae15415e9e66a464e0293b21121d25ce19102 Mon Sep 17 00:00:00 2001 From: Alex Date: Sun, 26 Apr 2026 00:40:33 -0400 Subject: [PATCH 2/4] Keep only valid recent leaderboard rows --- README.md | 4 ---- 1 file changed, 4 deletions(-) diff --git a/README.md b/README.md index 0aaa2e5a4b..1445065db7 100644 --- a/README.md +++ b/README.md @@ -30,8 +30,6 @@ Happy training! | Run | Score | Author | Summary | Date | Info | |-----|------:|--------|---------|------|------| -| Scylla Tokenizer + Full GPTQ + XSA-all + FA3 | 0.9485 | icryo | On PR #1184: Scylla TokenMonster tokenizer from PR #1143 + the PR #1060 Full GPTQ/XSA-all/coprime-loader stack + FA3 | 2026-03-31 | [info](records/track_10min_16mb/2026-03-31_Scylla_FullGPTQ_XSA11_FA3_0.9485/README.md) | -| Scylla Tokenizer + Legal Score-First TTT | 1.0806 | simon-marcus | On PR #1143: TokenMonster-derived Scylla tokenizer, retokenized FineWeb, metadata byte accounting, and legal score-first TTT | 2026-03-30 | [info](https://github.com/openai/parameter-golf/pull/1143) | | SP8192 + 3-Layer Recurrence + Parallel Residuals + Legal TTT | 1.0810 | bigbag | On PR #1493: 3-layer recurrence, parallel residuals, QK-Gain 5.25, and legal score-first TTT on the PR #1394 stack | 2026-04-09 | [info](records/track_10min_16mb/2026-04-09_SP8192_3LayerRecur_ParResid_QK525_LegalTTT/README.md) | | SP8192 + Parallel Residuals + Score-First TTT | 1.0822 | aryanbhosale | On PR #1477: parallel residuals on the PR #1413 SP8192 + legal score-first TTT stack | 2026-04-08 | [info](records/track_10min_16mb/2026-04-08_SP8192_ParallelResid_ScoreFirstTTT/README.md) | | SP8192 + QK-Gain 5 + Legal Score-First TTT | 1.0828 | dexhunter | On PR #1413: QK-Gain 5.0 + legal score-first TTT on the PR #1394 SP8192 stack | 2026-04-06 | [info](records/track_10min_16mb/2026-04-06_SP8192_QK5_LegalTTT_1.0828/README.md) | @@ -44,8 +42,6 @@ Happy training! | Rascal | 1.1099 | newjordan | On PR #1120: XSA-all + Parallel Muon + coprime loader + Bigram2048/RoPE16 + SWA/late QAT without GPTQ | 2026-03-30 | [info](https://github.com/openai/parameter-golf/pull/1120) | | Coprime-Stride Loader + Full GPTQ + XSA-all | 1.1122 | dexhunter | On PR #1060: coprime multi-shard loader + Full Hessian GPTQ + XSA on all layers + BigramHash(2816x112) | 2026-03-29 | [info](records/track_10min_16mb/2026-03-29_Loader_FullGPTQ_XSA11_BigramHash2816/README.md) | | 11L AR Self-Gen GPTQ + XSA | 1.1147 | abaybektursun | On PR #1019: Self-Generated GPTQ Calibration Data + all-layer XSA on the PR #549 stack | 2026-03-25 | [info](records/track_10min_16mb/2026-03-25_ValCalib_GPTQ_XSA_BigramHash3072/README.md) | -| 11L Muon TTT + Entropy-Adaptive Epochs | 1.1179 | aamodbhatt | On PR #1148: Muon-style Newton-Schulz TTT updates + entropy-adaptive 2/3/4 epoch selection on the PR #549 stack | 2026-03-28 | [info](records/track_10min_16mb/2026-03-28_MuonTTT_EntropyAdaptive_11L_8xH100/README.md) | -| LeakyReLU(0.75)^2 + Legal TTT + Parallel Muon | 1.1185 | michaelwinczuk | On PR #1031: LeakyReLU negative_slope 0.75 with tuned matrix LR and warmdown on the PR #549 stack | 2026-03-27 | [info](https://github.com/openai/parameter-golf/pull/1031) | | LeakyReLU² + Legal Score-First TTT + Parallel Muon | 1.1194 | abaybektursun | On PR #549: LeakyReLU(0.5)^2 + TTT + Parallel Muon on the PR #414 stack | 2026-03-23 | [info](records/track_10min_16mb/2026-03-23_LeakyReLU_LegalTTT_ParallelMuon/README.md) | | 11L EMA + GPTQ-lite + warmdown3500 | 1.1228 | signalrush | On PR #374: GPTQ-lite clip search + EMA, plus warmdown3500 and QAT@0.15 | 2026-03-22 | [info](records/track_10min_16mb/2026-03-22_11L_EMA_GPTQ-lite_warmdown3500_QAT015_1.1233/README.md) | | 11L Partial RoPE + LN Scale + EMA + XSA4 | 1.1248 | jfprincz | On PR #287: Partial RoPE (16/64) + layerwise LN scale | 2026-03-21 | [info](records/track_10min_16mb/2026-03-21_11L_XSA4_EMA_PartialRoPE_LateQAT_1.1248/README.md) | From 6c808fcf1f2cac2d887a907283ac667d22715a41 Mon Sep 17 00:00:00 2001 From: Alex Date: Sun, 26 Apr 2026 00:41:21 -0400 Subject: [PATCH 3/4] Remove invalid Scylla record --- .../README.md | 95 - .../candidate.meta.npz | Bin 1770 -> 0 bytes .../candidate.vocab | Bin 25678 -> 0 bytes .../submission.json | 9 - .../train_gpt.py | 2173 -------------- .../train_seed1337.log | 2505 ----------------- .../train_seed2025.log | 2308 --------------- .../train_seed42.log | 2308 --------------- 8 files changed, 9398 deletions(-) delete mode 100644 records/track_10min_16mb/2026-03-31_Scylla_FullGPTQ_XSA11_FA3_0.9485/README.md delete mode 100644 records/track_10min_16mb/2026-03-31_Scylla_FullGPTQ_XSA11_FA3_0.9485/candidate.meta.npz delete mode 100644 records/track_10min_16mb/2026-03-31_Scylla_FullGPTQ_XSA11_FA3_0.9485/candidate.vocab delete mode 100644 records/track_10min_16mb/2026-03-31_Scylla_FullGPTQ_XSA11_FA3_0.9485/submission.json delete mode 100644 records/track_10min_16mb/2026-03-31_Scylla_FullGPTQ_XSA11_FA3_0.9485/train_gpt.py delete mode 100644 records/track_10min_16mb/2026-03-31_Scylla_FullGPTQ_XSA11_FA3_0.9485/train_seed1337.log delete mode 100644 records/track_10min_16mb/2026-03-31_Scylla_FullGPTQ_XSA11_FA3_0.9485/train_seed2025.log delete mode 100644 records/track_10min_16mb/2026-03-31_Scylla_FullGPTQ_XSA11_FA3_0.9485/train_seed42.log diff --git a/records/track_10min_16mb/2026-03-31_Scylla_FullGPTQ_XSA11_FA3_0.9485/README.md b/records/track_10min_16mb/2026-03-31_Scylla_FullGPTQ_XSA11_FA3_0.9485/README.md deleted file mode 100644 index ac177659af..0000000000 --- a/records/track_10min_16mb/2026-03-31_Scylla_FullGPTQ_XSA11_FA3_0.9485/README.md +++ /dev/null @@ -1,95 +0,0 @@ -# Record: Scylla Tokenizer + Full GPTQ + XSA-all + FA3 - -**val_bpb: 0.9485** (3-seed mean, std 0.0008) | **~15.6 MB** | 8×H100 SXM | No TTT - -## Results (8×H100 SXM) - -| Seed | Steps | ms/step | Sliding BPB (s64) | -|------|-------|---------|--------------------| -| 1337 | 6,716 | 87.9 | 0.9491 | -| 42 | — | ~88 | **0.9476** | -| 2025 | — | ~88 | 0.9489 | -| **Mean ± Std** | | | **0.9485 ± 0.0008** | - -## vs Competition - -| Submission | BPB | Improvement | -|-----------|-----|-------------| -| PR #1143 (Scylla, old stack) | 1.0806 | **-0.1321 (-12.2%)** | -| PR #1089 (Turbo-Muon) | 1.1086 | -0.1601 | -| Merged SOTA (PR #549) | 1.1194 | -0.1709 | - -## What's New - -This submission combines the Scylla tokenizer (PR #1143) with the modern training stack (PR #1060), achieving a result far better than either alone. - -### Tokenizer: Scylla (998 tokens) -- TokenMonster-derived vocabulary, pruned from `english-1024-clean-v1` -- Created by @simon-marcus through iterative autoresearch (PR #1143) -- 998 active tokens (vs 1024 for SentencePiece) -- Better byte-per-token efficiency via ungreedy multi-branch tokenization -- Retokenized FineWeb: 194 train shards (~19.4B tokens) + 1 val shard - -### Training Stack (PR #1060 base) -- **Full Hessian GPTQ** — Cholesky error compensation, 64-batch calibration in 6.7s -- **XSA on all 11 layers** — exclusive self-attention everywhere -- **Coprime-stride multi-shard loader** — diverse batches across 194 shards -- **FlashAttention 3** — Hopper native kernels (pre-built wheel) -- **Parallel Muon** + Parameter Banking — 3-phase overlapped optimizer - -### Why It's Better Than #1143 Alone -PR #1143 used the old SOTA stack (PR #549 base) which lacks: -- Full GPTQ (used GPTQ-lite → worse quantization) -- XSA on all layers (used last 4 → less cross-position mixing) -- Coprime data loader (sequential loading → less batch diversity) -- More training data (we used 194 shards vs their 79) - -### No TTT Needed -TTT was tested and found neutral (0.9491 with TTT, 0.9491 without). Full GPTQ eliminates the need for test-time adaptation. - -## Architecture - -- 11L, 512d, 8H/4KV (GQA), MLP 3× LeakyReLU(0.5)² -- XSA on all 11 layers, BigramHash(2816×112), SmearGate -- Partial RoPE (16d), LN Scale 1/√(l+1) -- Shared ValueEmbedding (dim=128, layers 9-10) -- EMA (decay=0.997) + Tight SWA (every 50 steps) -- Full Hessian GPTQ int6 + LZMA compression - -## Timing - -| Phase | Time | -|-------|------| -| Training (6,716 steps @ 88ms) | 591s | -| GPTQ calibration | 6.7s | -| Sliding window eval (stride=64) | 92s | - -## Reproduction - -```bash -# Install -pip install torch==2.9.1 --index-url https://download.pytorch.org/whl/cu128 -pip install "https://download.pytorch.org/whl/cu128/flash_attn_3-3.0.0-cp39-abi3-manylinux_2_28_x86_64.whl" -pip install sentencepiece huggingface-hub datasets numpy tqdm tokenmonster - -# Retokenize FineWeb with Scylla vocab (once, ~90 min) -python3 retokenize.py --vocab candidate.vocab --output-dir data/datasets/fineweb10B_scylla - -# Train -SEED=1337 DATA_PATH=./data/datasets/fineweb10B_scylla \ -TOKENIZER_PATH=./candidate.vocab TOKENIZER_META_PATH=./candidate.meta.npz \ -VOCAB_SIZE=998 XSA_LAST_N=11 USE_GPTQ=1 GPTQ_RESERVE_MS=9000 TTT_ENABLED=0 \ -BIGRAM_VOCAB_SIZE=2816 BIGRAM_DIM=112 \ -torchrun --standalone --nproc_per_node=8 train_gpt.py -``` - -## Included Files -- `train_gpt.py` — Training script (PR #1060 + tokenizer metadata loading) -- `candidate.vocab` — Scylla tokenizer (998 tokens) -- `candidate.meta.npz` — Per-token byte accounting metadata -- `train_seed{1337,42,2025}.log` — Training logs for all 3 seeds - -## Credits -- **Scylla tokenizer**: @simon-marcus (PR #1143) -- **Training stack**: @resouer (PR #1060), @abaybektursun (PR #549) -- **Retokenization pipeline**: Built for this submission diff --git a/records/track_10min_16mb/2026-03-31_Scylla_FullGPTQ_XSA11_FA3_0.9485/candidate.meta.npz b/records/track_10min_16mb/2026-03-31_Scylla_FullGPTQ_XSA11_FA3_0.9485/candidate.meta.npz deleted file mode 100644 index 580a1e39e538effabf4bebdc78998ad14fdbb388..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 1770 zcmWIWW@gc4U|`??Vnv1}8|Sb34+TODA`EHyMY)M3@nxw+#hLkedU*wvj0^${9YD1p z;0B=>X1@`CB`t9BM8L6tweu!L&52o(w{$_?HnVTf+2HYL*q(RxS}nJvTO>^#Ld@DZHATh3t}&{Or`c%&OF)`0UKQ6qrjk zKwS`r*CoPP8B5TeGN}NIQ>HMi6Jld)t5G<|_9#T*9P^`?6%rk>JV#Umb}A{bp}BCs zdSD;Sg`&W)D$Xx0N=}W>%}+_qiO);SO@%p9gbCuh^H4`Fov@MXkbywk<5$5Y-A-F| zUnu1@&3<(0;L0OrTaVtFvP8_ahAaLC_n*cKR}Z*XmH)3VKl^=#oZoX!=^L}fFFl!+ z&$Q2Xvc<#11Kvxkb#M8+bh>pa&crR~UdW=!BIlnae&GCPak;~Ano!$>mSg2DyVc&c zC`COf@8GV;S*OhMheBpPZN!Ukr>nSa85nk2~HtBR=)4OFzStfS!Enq>n8@c1BWSacX>0Wl3r= z%o#qQ6b}Rsz$63CM2wMg&=c{rX<}C3gq*?*OvtYsI2KIHJd@(+#-_H)!A+xcf#gDy zBmF)hn>%++hzK`aw1TZwMtG*8-4z8DwV4kKy``fkoK0={tnkgyaKi-NnJ*W#&&-*v zQOKxMlJL=ap6?n54#ktrQ(i1xH}lD^7D?4N1*?+}b90`Y%%7;X`9TS{wZx%}D94i9 z76ndaIy#T8I-Pl@B3Wcmkk~pW;d{|xMSh7-Oh;RP`qCgr38W?qeF2Gvspc?_om7x3v@&X@BKag%I_MHs5nUotabPjti^Q3o^>YZ3|YW@Q6f%D@hU2Z2E(%nITG E0M&j?J^%m! diff --git a/records/track_10min_16mb/2026-03-31_Scylla_FullGPTQ_XSA11_FA3_0.9485/candidate.vocab b/records/track_10min_16mb/2026-03-31_Scylla_FullGPTQ_XSA11_FA3_0.9485/candidate.vocab deleted file mode 100644 index a2b7009ad33ecf35136044f8ace43e7e2e1ad4f2..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 25678 zcmcJ&cX*V=8t{G3JfVdqf(i;4=_0*mpMX+Dv2jGjf?&!fSxB-W*(4MTRS^M2K`A0C z_TD>U!wT3%u!0>y>R0j|0iCllxqEL18puWgbl^s)^JZK>D`m%MQ>g+*7kvDJMmaDpW&`6|JxFc6} z^&ng1*(FOu)y;#(A}=l*6sjXVXd-g=EiZ?vy9Z50E;wj&OdaJxGm)P+Ul~(9JZLWR z=)QaNR8J2M5P9{CoKW@hpoPen_dgJ--X0t%(&X5iW2%n_Ek&L@{pL{h^`MnV&4X`- zs-Fi3iCp~d3!&=o!NDS(I-C@$0UjJ8vi_Uyp&IBxYmuR2dxmO|2W>^NMNZn%AWs!|&`acx-E96s4|V0Kxo{D?Wm(`y;Gfx$J&`;#0ne#(c;z56rAL6@1RqDY2k#%!E4pqW~fgOfJLsjX)5RrmS3qv){gQ2X;2Rc+$9vm&Q=D~q^s@j8N=-4|e@>I%$VIpr`U7DwA zJQyyr^@s!VRILXiM9yvdVNBI|FjC~G#zmo;?!hRLYc8w~)eH|tiyZy<>3M3V2giz> zwxDmQX8ENYBhu@}H)HC24~`T0`mcv#>H-grXT}cB%2OA5FqWg<8cjFbgA>?uH(VI1 zIUbxSa@(B@f3627i5%Rnb*ScfaI(mhHSDqZ9-Jc5Id4U-TIj(zksfW=<*G#<{71Np57tD8MIoB0^`+8}j{2j{SZj-HgOZuMZI zNW8huRkwR^uE;Y_+%ZVq<-sJ5X}d26se3$_EHZxjD}&U19-JrgyPOrO=RC+4`R&Xz zW9oSi3XmAEbFkXvL7_;~9m8_fW)F%m1-n^>T^_`lDx;?AbsX5CzD`aYtN=sUr(lRc-wD<^z&q@$ki*Kg7o*K zOyunMhC>E;QZ901hX#;=o>YiT-0~-F4fU-jMSdFB8JYZ>qt+`$?i#-(Pv;bPG7Yl# zW@HLIsS?@r$McXPPpU=cO#FEQI&OH2$nC_7?Tktm6tBcWHHSa=3E{w<$ zk$UUz4|PxKIl#&+6Y@$tbQ7*>Lr$JUwcHCoA#C9%3yfQmUJkS;fw5+$B#>LT`7sWVOiu9{et* zCnIx|C2K^!8t_c0CnMv?wVbbc%|e}z(d}tv)?z{q$%bJ59JvmM@}zcoIv;cA$U5e^ z`5DMyyd1e+5Q6z~%C1{i0ii8NZev^ztO<1y{zo5M^X(#c zcKIcyi)hP{J47bmxda(%cH~ZxOac0bsdqq~C)i%^|JO)SZ6M6c{JP32>$o(SAyRL&UhmJfTGGNb8>Y*)19uy4g zzsN8yM;;QH-M%u^#f+k#ZT(^T{KOy#qj2OAkx%}4Ii^b(dw(nQs7UjEV?td*DMudT z6l{G6LMcZc7nwV%VW<<78fZ&BA(Fcy4#7inZ!Co*!KAak=whWSIZbht|gm9-fwpb1Q|zO#0U7SMV>BW?2f#|DjT~~ zMxPyd84ZxP1;S2nhD&atj+|^MFpiCd0k|Cn|C1iwT`^OD*v$wLOqVWDYC6{K7<}R z@)pt4zhB051vBQz+ajN@GHo$qj=UrC)aI?BPO^e|w)J;$E1p5uC0Pqc-lMIj>2s2` zaAX_$`Mu{M)a=OnB8M$M07A`tBeReEVw1$n1!G zNvX&2ZK~;ylle+yPVe=|&>u&>W?Va0Kxo~OZ$wUfn)6jn>yGRenOU(m)G03Ftm(qR z5Ju|A4+K{S;-%EE%EN8zKjP{A*$6@@M}DF|t0+}V4@X#;pV4*mPk_)vM}A>e-{ut5 z(jP~D#ZRbQ9_l(|M%q%pVTnrEEi;(a(U$x!@<hx#`GUgN!Ckket*VhL=ye~jOh z0Ch^Ee?<;G_X1G2G@|h9A2iPka{4-_bqduf`JqJvr259Jw+Yfp?)>UX(8P4XK*{1C z8d5Od`I40)SxQnr|BY0>b0$IcloY*oi9ybv1JzeDGAaIw^T}!oQ$m*VP(gVuVU9#P*)}A=gl$5wFv5_M16m%L9X|pBb97CV}L=ftDx>mhRinWSmIi& zC9|NU*YZ+>T#KL{O7frSYmlo8)Kkf4lb<~Ik|o=_8RV)5^;UAt>GMFPuKGHR zD@(inGlN|9puS38+cl0^Ep-C|^;5F!{*R25YZ266$wzB`GRRdA8la@&>JXH0f!lPE zft-%djGjoiXbLpQoQ{-1Ztg*Yl_WMiPm2k6(m=UNo{*CbawiQGQ_|wup`aXd4m{sxvL8_M#-Ry+ZicW7w9-8t#kGmwYlE zRSzoQAiv#$PE@-o1Qk;KM`EPhfdLgMxqg6I+G;mOP#l%^(z~FP>(^YIJSBOph8W~J z0V+|_Xmvk>oK*moVjpuqV4_p5U!a7NF7F*=q+Gv1Qz*FR8~T-U6AhZGq+T~;0&3g_ zm`CWMTu{lz;t`ne2C1>|2 zG01fSG)>9-M;J$>&J|oBQpG_&+*r{%H%1USkB4&eb6#VH?P4MV^7VWmDW5KALJ_o4Iw#AmhpV=EheLZTG zM*TcGD2@7ibZ{CC@aT{<8t75$G#cbln=~5i(V=OS>(OCp6!WNU8s)c$dQYk%>V1Jn zz0#=Aquyy$gX!2`pBhY*$dw)L#WD-9&;R=zD#DzVz7pRniH(PWuUdM3ENnuU8;3Y~{ zJpXzuC^A{4J8+mp7Gjz@pB{qfghQ87u;APO^C0}$r z8Yz?CdE81}q2!w?Ca2UCWP2mE4A=T~)KJ1?|DMDj6S)#Ax|iihn3>pM(N*;8tY#pS z9eKv0OagE-y%!ECv-({nfYyo1Q?*gsKi*ga7pilT&7LcQ4|| ziLAt5ZTvG=(@R!rHP-6mQJ^F(k_e*U8dm$U24PTcvZCaDI8)aW z*1X;wWX4G12eeknNmn!mnVHyP(RE60y}y1KRCG-RQBr1|lIIs53NkanNu>AJEBR{l zH6T;h>v)qQH{iJ4k4;OOnRpLxl8X>xt2We?G$A|xrr>%dV>@C3Dow#1IG7?gDfzA4 zPqCoVq;L2)Qa97DL=XnkYKZJP!k}A}y!jo5VpG$^JhL zN^-@x)1upPVF&#a3#xmiGbDG=uPTC!lqpCGV;H$tB|o29V-Pdp&|RFNKiKyv(^ZG= zR#G^18^~lOC`G^S!Cjfh6xJH4U+`baH?U||4Ge=>#t!|7&na@Bl101NtFuf(oPSgB zekCoZ#=>CN?B%m1ZqWmnxoscEK&w1@P|40jo{?JX(L)&NwTFaRBoIa}mVdYzAw{ob zK>T=Er%@9Y%fFUAqU6IXMjMt%1w6{?e!l?L&6E|QT5cef^jm!ttY06OSeQy;jB{Fi zux0K(4EBVQo0pvp%P(|(QWkNal7A;pGpx+TYxR{prR1T@j?9Bqxa2|u;#MUyPR|-* znDeM%PqW=`zuJ_obk0meqB;W3iYpAOa$Z9tCC?C5ztq;$SLND;J*(vR4@VfeYUlN2 zD|wD)_H{BW<>VT(Za)C zE^36mL9{X4_Ad$~U~ek<`i{=bX<-yLz}{lN1U0s-b9!NKD>?F1ZCJD(hoan=zsHV& z73I6NY)crgq`_qu8x|kgM76W~>OH!%=tcTm6lD*<+en%`-OI|kb%VW+;`r(U!xGNx zY_DXyl0KhjTjurmFN+7z={)Y?z<}4#Z_RV zPZ)1;6rRIABepzxzLAUO`A8+7E1A@EC39aAcl)3R_DIRyt;|^{Njlb3$(LA&+#jj0 zB#NqfF#$@#@I)gQtq1IDre;oOSZS2Z>w_^?PfW@y4Q z+kNmo!C}WGmbtYA|G@Trs*7c=f3P2wjQqfiBT?bz0QM7c#Y^WKmU412^t6)O)=Y+# zMXAK$N`6stM!DI4WzpD15DO4d^}XLRS0C&*w!&O9e`QHm-zXFb!)x>>^IR6qJ?sy< zc*0vS?5Dl+9nE1wN1PR-YDu{H^ZW|*|AFg zQPO7Qv4%x^Xbj4fJ#=k<%UruK(X#uBVV1c*!<3c=g+~|`?eXK-kXo+)d4^?fEn$I{ zZXf*0I4a_9jZRdOrRC|LCm~l6?a7mfGquF_wzbTS16E(l+~baAeJi5%g*DJJr{g)u zC5u_#adv$hYPsn1XDoBM3veSX_nkP~urjyiq$RSooYmf(g{0XV|FdNqYgymG=v}-9 zJI~-IT88~}w~?!Hy;d37 z1=vBzeY+oqd*a^+9IT~FI0yqko8&_uJKg-o4Ok~FTlUvy-KwLMD6F%Vqk1(qa*-`7AVZ?%@NveTS4Vps z)>TW-d^6taD3MjDq??wRoz47JN12%-jFy%|X8&xNYZum?vGp=~s5;8Jz>d;#>G?U9 zxv{~x1X_KJtIyd%SWhi8cJwiFH7<)6C%3HS%ZKdzMf()io4L=fMJ^R3g-bDHTI&Dr zaD1Osl6^nL?)$!4S|4ZZYAWdxCg6Trj{DoRk&5*FR3-hjO#avG-Bje4zy@eJr}0&m zx&FZhYUw<|*ws{&fSF6SSj!ddjPFwu*|IVvgSB*hN;nv#yOmZmH^6~iZ*w1y**{UM_r9~G8$kjzRJ+L8KIvi|#i@Ip9!G?0)-ko6N zqBUQDZ?0ut{TmF6>}5*HFj@2^%v2|eua%*`{5ghAAoxSMAf83WH;yRZi{%#9cJP?{CHu`MRonPF~hut(CY#LdeRC68v9 zD+?nB5w%<5%EHJvL@eRTUT!XI(#$3CVB{RUoSPTeQyJ#kfNjVyH&?KyGtA8uY-5JG zxq>~DVQ#Kq&t{mLE7)@x=H?3ae40&leOX3qm1bpbY*!LnWtfwLk;;h1R_5jf_ELtq zalm-z;^o{Pf|20x%%#&{ucleKYxgQ*s|<7P!Zv4^YZpe|BdV{$m0gZ=nqjUij5LRr zbIDQI8yV*6gT0wyuFtTyGR(~>?ClJ5a|(MW&5~}-R}foem|Hg(`IV?INw;pWZ5ig; zfW4n#t_|4sG^=!DyP5zZ!`$4%KFBavAMC>nbM?VKO0z0Ae=ErjW|(UOMusNpbCsJb z*ryrh<_fkm!(6+tT^Z)ug?*M`ZXdutPqXSCHjBQB*eb(Dwl-;W*q0gR`Um?e!(9Jh zypxIgTobg`Ry5vP*FV^=X;$at))HH#*$g-L z>xiw=Y^IaDp4ckQHqU7hU2E*mu!Wxe1Zoi@LHik2q931JrDKCcN~ zQhJ=0gIeJuk^CjW`2@}y7r9!3r||!#+Ok}W;H#XfWyLV^Amn3_^1e>7c z)Cr8Y%Cb*wg7i!+VJEhHo=; z6H~T-oC4T+TE>^%Xqd^KJC?7d@kLGZvPi^|4t`4+xdJU`bR@p4wPo3%Be>e7Y@>t1 zEL>Jny(8_|inJVfM$<5hENivRNX5C_B@VqL71lGkP3IVs2`Xm8q#77l7Tq7$=PikLUSW$#KPa-JbidW0d)r$nUfmY-%VMzSD3a>0RRggQ5WW+dZL&Y&4rv0C2xr@oQ&$tYljmNvI~S9JcPi4-r)k6uXtrx9yzt!H5r_>$SF(sI(^w{7vLOGs98GDhz= zkgaFKf-0_l32C=gnc}5UTtLF7M$7Z-atusH4gIX7R?FFI;|{v)B(P4)nhVD|=+bb& z>4a}bPqnbZ#jwASKql^4-O54NI&h|zqmNn&EXt3fzqVX5YZ-ZIcL!Za13X{Lzcc=z z^F=;iaJZ5Sv^=zA1d{Rm=owdvcj|1y8+#g9n|b7P5fSf42?t$Bgyd{eBrT5z7Uvhb zhC1`Ag0pzLStFq9CvdKoH@kgJ>&5XXQtzTe8mQR3cD^yB%I zmOEZ;YoOmt{kcxlvT~j|n~4f@;z$)v=j&)_&}C;@7`=xhFSDG`bM5C0leD6XwMoLP zV2R0_0PTUzwZ*U2vTJc015;hA>#3MM+SgFS=mUXMYNGV-Q0@f?x1W9lIMwIIX5v=R z&y5cQ%ls^nm|RU#uGQfdM#X_^v^@T73j^cDZikSXyq06QptpgE=!M)!VoNPWypT39 zdRqm&j$rl0k1dR{X25k?#(eTO!z@cg4FRv$a>5H)Rx(;6;0?siM~i{U3iqX;QP?vr zsX0?DjC2Tay_Va@8m(BCjEVzqBCS={5?G!eP4Z}BOeXmK!0&7L7hCM(KZ!j~lG10;QB)freB>5x;2 zIJML}Y>z4K(<&5yNXw)>?;=Uc(rtzTS{~MN!DU8OR{9IB(+E1XBo6(O;+5<$dtfi( z>nNO~ivuHB7NuK&k7+4gcPNmQq8kwjn8(SI`0{5-0%D@X)5)w6&`APcB~jh!={y6?P1tF6+q}kdG~H+=<4z~TH{)GAx5hzN zCvc0FKNrn&(CsDQR?Sx}zp*g7eW9V(@iZq4Hw`5s3!5Mu)pF-ETa0AV?9VCK9lngj z+1&QK8PSX=X$pKxOUuhY0h07{>wd15x3!%AywOXw`BAztNl40}8~8es>{-{PNhIX7 z+&XuhkxWDv3zY<=T5?Jo1MBj=dw_q%#tRNJFdlcyfz$Oqt#AK`I_vy$0Jo!)lc!tB zs5o$kmiDc8TFFQW06);u{M46>q@Vc(A|F!xHKP^lN~048{0I&9*5d{yqIFLZoNBrI zg!?RvUg-fpVU|`*Wj^bYRc=CR2v5-=^Nbe+be#w8)bhxBvzjyfWnwJ>DpNPY>?IPH zQ6^c-XIj>^d@HQi!#ImRP5HT&V;|1V19FG#;TPOJ+;W$ZWZ!uBCCSiccZBtNW&bBr z`4!H<)LlSx{qNt$eXZs7{rGjgMq*=?T4tZaeS^{8#W?~dXvo3c1W>yf5zu5#9sCyi z{tul;uxt52s*`n-_ppWJiAq50z?BVT-E}Qban=& zoFsncoHXPWQVSFN;b+dxz%nPf5Z5RT@n5YZ*~s}B7CGqq~Ri>c}{DI^QJ*;wF!u=yb7(w19+2Rl~e-Y$$ zl7*Mx8f9Rqs}nJDYhE(p`neF-C=CghY;nqxaf>8*rDP%g8|g3)qgh(4<*E#G>daHrFd z`$#K^V{>B$)5&1uMMCM(Ij-SQvk8Uo&$hP-I9l6WS! zWnj|1F~kYEJq>x)-)z#JkGSzsXni zZ)9jik+Ix6xoYQy03KU|7SMt?ISpGwGu{Z`!!8&MEiCiSVQW?-koOD5TAGX^cBl&E z&_{ovXNAd{=qCO*UJ3+q-2OF2HdzyyT4)f+A?L)QyqoEos;89oU9tj+P5z7;Ofc-+ z90cnHa^MpIve3vYK(KxwpS*J{v^Y@}dFNdjNFd)Q&oVSwWyaQ!RaMe3kYiirTI%m7 zp^XBmTV+OUqF4Nfbvjcf!_fGo2tCz^tYC-}ikOsYSkS(c7M$h4U@p>STOUw2m zTgt0L7uNRYSzRDc4K(vh5E!k+AuMko3sx2wnj#Qfpo#d1BAW-&_{!6uW%*UnH;0Dv ziXf1tPkczdWrU_~9*Ct{1Tt`Zdt}SgbPO*E0{ME(+sKwD{oAMGcuNpSqhrh{D$JzW z)g;bn#R!k@P7M|58ZO|=Wq}k-t_Q^#iDbtsIXIAY?V1|OHgz=+KpYatZ~L}FO#tDN zSFu1^2eSFs#)g*p+ysF^n?Sm38)|6Mzp|OY%Ys1Wv>9zEPX=7RWw4gR0y*;cHqfec z&nEJ^AdvYD?uJ$;(57}2`MfU(G zp(%E|t>Gj*$3S|IdEe0}4uT#LNd7zZO}!~nuU@u8AfIiTYN>bYij;I@)|d2#rc8#{ z%HsKV3go>P`x;srPq;b5rS2R^`!26KTIFP+T>@D>{YOWmy#wtU$be-NE%m3clvf9V zd{_H1{Y{l7Dx7Z!Ju;C0jVQIU{uDyH2Xf{9EK94Rw-(T&0(rH_%#QKk-K?J$NRL3? zUSReEE_^}U`|v%P$(d^%ElPUYE0A-SntjHLfhbQlfyWwwbU)IZ3(i%vOU~hHJdm1~ z&DpGB;hZk1;A>l)w@GcGwfSYy%N1Pyeu4aOqgi=w_05SrkF+6g5d!%;wT;$ru%kJ{ z!ydqSt6O4ZOA}F?4ILQBjB_4$G+GPjAZF(_bLQ(3sZvgNp*`J$1HO9rmZ>*c73nCv z?A$<(xhmVprZpFojeTL)jZouzyE!Z7nKHZYxqi@Hv#&MpT2UF~)AIr;{dHfgKCZ6y zw4JXzB#o80eR7(b8wFd#Os*Kx)MWWA#hD}` zj%qN*$lSX2I1E$oLH9ZoyJ~GPZkEuQ9KduF>XpA)@Af!BU*1 zGcwe-cD*(?43Vt4W3aOKk_b;zbIZ^|;}CY{Yh(YR$zo5>{trzhJ!K0=^_sf|Tkj*- z;tXXXEPa#=dxoOzEq$CUOPU((Y3WluugOr~-woQ_IYg})>u+VD`5DT+gQYj1sMFNk zJXpGkCueDD?j9_~V=am(*H=0D4p7V2c4dfZ<4kH#MV zxFc5II6L_%@9iARmKf1GtY2+|)j)k=?`XyZ^5XChV-4)RNc0Sg5b`)o@2ibMa2avA)50F{LOk=f;c;q{EswO(hjY47^^1PvF$vbFPuben*C&k&+VwSzD1~ zu=U3LZOTc3Y&fO?xFF6!a4H*T=;T25zFK8*%3F5uDS>S3@-9sl%rMimK)2$agRfL# zdUgy&ny1=UyYA)bo|69sa$V;?kmd@bA~N~kb8rG_UeMD>8@IzX`MZ+wfqeA!2S^ts zimao=Ec2EV-Tl#Dzgf(qJ?9~`S8_Vmc}iOIcsfsg1Nrs3 z1r}HNM~MTqObKMqWdjVZDJV<0_nK#Ms~5;0{Juy7lZ3@2*nSV?tz95Jr%rI#R{}0) zhhYe_cgu(H` zWV9i8nK=z#`#iIKNZtA{faGmeAlJS#-(X%BMjHadQjG^P!0eOi8C=)8d?n5t7V5hO z2O@18o)UNFKfw2t%#mSoJ(gx?I7B}(wJ_aYvJ9qD~)CKa*>)WUsY?r~MD5nR~ zVpLm}!8itH$Fk<+n* zJ}zWybQLBzt;%WbC|d_ zs>En%Oqjz&t`WAetF8Orz{IW*w#wTZJfhbO#@a+SllV1_QG_1<_8-iocuWi%N!uva zHrY7KM6nUJacm1siiF`1$69Qx6B^Q+rwan^o!^T!Fc=Nt@rA7Xo5RBfjL|HcH$@i( z(xf-i*60`u<*d$Z&gD(B!v-m~p#3hMyYQbA$bEDFiZ$ftO8C``RJ@)F?6uom&g+k# z!EKlv1(Od1GB1$q4X+zsR_1olr@T}rHHgWawBK+{kKI|$_AJ0Y#BL7_&jd{k_%n>hjEn}$i8>phsUSqmwVOHiQ81%-8EmrO^|H+@51CIxYx(M zO}kF>yJOcaOO^)m`qCrdC2_uk7yV>Q4|0bX)hqs_9;cs2HZfPDy|UzrKo;`@s1Y}M z(JAg_fn2=xTPt5!RukRf@X?il^sB9hytxx&4=>QT|Aebo1=6?S8_1jK^h%>Usl-6^ zuI+HtF5RbH6zrfZS%FT!shWCBXzAx+u$HR>8T*Wp=Mg<#w9_&%?sEd^|MtOFKFJqO z-FV?E1G)UPJx;!;CNi(^Re=mCF#V*SlB)c4J*xxB-L}=qr_BTMYgm_^5wG=kR`6?4 zOjiZeQx)gAeq@kQv}*}YUfM!EXjF5AtO7=*UKhxkwWgouij3r|<@j;y0;w3=(r|B1 zk0b3Za(y6=?9hf+#tReCdJ!GqKOVnyx2Zox^d4QsyvJq^fI+^nzTHxtt5P5w=uR| zj6VF%K=x1CVK|Yyb-}RusQSACIk>l(H={~?5)CVGH}$M*4aWyDI%$E%g}y}O9_H<_ zc45OvtGl4O`98x;m;+oB7+?j7-8&Ia{-WQIrDcdpg!#&ve}SUrXgM+2ZDah@aY z+-I+7ai2Z64s;9skgW#?+H#{ijXUO*=e_KQTZV@XQG3Nz`GvDYt>Tw%_@XE&v{{40 zhUVdwIY@#2n0rDlnF$7d>J~Qg`i+S|!R2QOI`R1He~Fn7W@Jv}E-st#)b{?$2L&eL z_eZHEZ=8wE$9~5Lyv;!aMUE|b$%oMemwsn`#lom$d3XTus%yW4sH z@+jfK#>4o)r0Yy~7w9Iq1YAH9BR_kG`JkpEb8fox_!-Asoy-R%B%n-GvOxF5Tj1NJ z9JI=Ye89KZ*cWjI*p#dVm(hXTF!4`%RZ2kMK4@k?cq)+i%%98$$vTr!>yf3H?3P?g z2XflCimGg~&p(uU9s80BN6CXTB8g$hTO7Kmf>U4>a_fTktw_8)fo=`hi?kK(x z$m5S*WbIbk^ycWolKeP|CUEbWyI|qP# z_-1lg``(IW8?#!JnfPCLtsThGvtQ(sqC{~qd8jBSO{BFokUz(tNNqOkaXZCDKChDl z+}VLo66Hph`rGM2T$l&))75WK8_7(}pEqb2z&8WA^`8xVVqTw^HDg2e3d|Rf-U{T| zMIZ1rGE`QV1`|G+%NLGgfRoYp-if3hCdNg z?Y%u#VDLeVbR?OAcvZqUWXfI5Zo_Fzyb#MCmE+Y!iy$k@ylX~Hm^H{otrW!heS=w} z8b*rc{XqVzyD>DMl;s!1%QByA52WMX0byg8%g=P1gSUgg@!6kZd^DYFE<(@9(BZJK z6PvWlq7*k}k*-=0^wH)giasQ+C~3<_HX+RqLPYNZSvmrSvi_T>(B>vH3E(d{nJoOD z4Z|kp>)=FD89nnEm-Bpf)@!jQER)HCWj?GUf|k^^A;Z`>2z*^2{Up$F2i; zFKAy~@QHnGQEk2~#y@^O>(v+`qyMLw)`qqUSIc?xdxQ Tensor: - """Batched Newton-Schulz orthogonalization. G: (B,M,N) or (M,N).""" - a, b, c = (3.4445, -4.7750, 2.0315) - was_2d = G.ndim == 2 - if was_2d: - G = G.unsqueeze(0) - X = G.bfloat16() - transposed = X.size(-2) > X.size(-1) - if transposed: - X = X.mT - X = X / (X.norm(dim=(-2, -1), keepdim=True) + eps) - for _ in range(steps): - A = X @ X.mT - B = b * A + c * (A @ A) - X = a * X + B @ X - if transposed: - X = X.mT - if was_2d: - X = X.squeeze(0) - return X - -# --- Parallel Muon optimizer --- - -class Muon(torch.optim.Optimizer): - """Parallel Muon: post-backward reduce-scatter -> local NS5 -> all-gather. - - No DDP for bank params. After backward, this optimizer: - 1. Launches async reduce-scatter for all banks (biggest first) - 2. Returns control so Adam can step on small params while RS is in-flight - 3. Waits for each RS, runs local NS5 on the shard, launches async all-gather - 4. Each all-gather overlaps with next bank's NS5 - """ - def __init__(self, params, lr: float, momentum: float, backend_steps: int, - nesterov: bool = True, weight_decay: float = 0.0): - super().__init__( - params, - dict(lr=lr, momentum=momentum, backend_steps=backend_steps, - nesterov=nesterov, weight_decay=weight_decay), - ) - self._built = False - - def _build(self): - self._distributed = dist.is_available() and dist.is_initialized() - self._world_size = dist.get_world_size() if self._distributed else 1 - self._rank = dist.get_rank() if self._distributed else 0 - ws = self._world_size - - self._bank_meta = [] - for group in self.param_groups: - for p in group["params"]: - B = p.shape[0] - padded_B = ((B + ws - 1) // ws) * ws - shard_B = padded_B // ws - tail = p.shape[1:] - dev = p.device - self._bank_meta.append({ - 'p': p, - 'B': B, - 'padded_grad': torch.zeros(padded_B, *tail, device=dev, dtype=torch.bfloat16), - 'shard': torch.zeros(shard_B, *tail, device=dev, dtype=torch.bfloat16), - 'shard_mom': torch.zeros(shard_B, *tail, device=dev, dtype=torch.bfloat16), - 'full_update': torch.zeros(padded_B, *tail, device=dev, dtype=torch.bfloat16), - 'scale': max(1, p.shape[-2] / p.shape[-1]) ** 0.5, - }) - # Sort by size descending -- launch biggest reduce-scatters first - self._bank_meta.sort(key=lambda m: -m['p'].numel()) - self._built = True - - def launch_reduce_scatters(self): - """Phase 1: launch async reduce-scatter for all banks. Call right after backward.""" - if not self._built: - self._build() - if not self._distributed: - return - self._rs_futures = [] - for m in self._bank_meta: - p = m['p'] - if p.grad is None: - self._rs_futures.append(None) - continue - pg = m['padded_grad'] - pg[:m['B']].copy_(p.grad.bfloat16()) - if pg.shape[0] > m['B']: - pg[m['B']:].zero_() - fut = dist.reduce_scatter_tensor(m['shard'], pg, op=dist.ReduceOp.AVG, async_op=True) - self._rs_futures.append(fut) - - @torch.no_grad() - def step(self, closure=None): - """Phase 3: wait for RS, local NS5, all-gather. Call AFTER Adam steps.""" - loss = None - if closure is not None: - with torch.enable_grad(): - loss = closure() - - if not self._built: - self._build() - - for group in self.param_groups: - lr = group["lr"] - momentum = group["momentum"] - backend_steps = group["backend_steps"] - nesterov = group["nesterov"] - wd = group.get("weight_decay", 0.0) - - prev_ag_handle = None - prev_m = None - - sharded = self._distributed and hasattr(self, '_rs_futures') - - for i, m in enumerate(self._bank_meta): - p = m['p'] - if p.grad is None: - continue - - if prev_ag_handle is not None: - prev_ag_handle.wait() - pp = prev_m['p'] - upd = prev_m['full_update'][:prev_m['B']] - if wd > 0.0: - pp.data.mul_(1.0 - lr * wd) - pp.add_(upd.to(dtype=pp.dtype), alpha=-lr * prev_m['scale']) - - if sharded and self._rs_futures[i] is not None: - self._rs_futures[i].wait() - g = m['shard'] - buf = m['shard_mom'] - else: - g = p.grad.bfloat16() - state = self.state[p] - if "momentum_buffer" not in state: - state["momentum_buffer"] = torch.zeros_like(g) - buf = state["momentum_buffer"] - - buf.mul_(momentum).add_(g) - if nesterov: - update = g.add(buf, alpha=momentum) - else: - update = buf - - update = zeropower_via_newtonschulz5(update, steps=backend_steps) - - if sharded: - prev_ag_handle = dist.all_gather_into_tensor( - m['full_update'], update, async_op=True) - prev_m = m - else: - if wd > 0.0: - p.data.mul_(1.0 - lr * wd) - p.add_(update.to(dtype=p.dtype), alpha=-lr * m['scale']) - - if prev_ag_handle is not None: - prev_ag_handle.wait() - pp = prev_m['p'] - upd = prev_m['full_update'][:prev_m['B']] - if wd > 0.0: - pp.data.mul_(1.0 - lr * wd) - pp.add_(upd.to(dtype=pp.dtype), alpha=-lr * prev_m['scale']) - - if hasattr(self, '_rs_futures'): - del self._rs_futures - - return loss - -# --- Tokenizer evaluation helpers --- - -def build_sentencepiece_luts( - sp: spm.SentencePieceProcessor, vocab_size: int, device: torch.device -) -> tuple[Tensor, Tensor, Tensor]: - sp_vocab_size = int(sp.vocab_size()) - table_size = max(sp_vocab_size, vocab_size) - base_bytes_np = np.zeros((table_size,), dtype=np.int16) - has_leading_space_np = np.zeros((table_size,), dtype=np.bool_) - is_boundary_token_np = np.ones((table_size,), dtype=np.bool_) - for token_id in range(sp_vocab_size): - if sp.is_control(token_id) or sp.is_unknown(token_id) or sp.is_unused(token_id): - continue - is_boundary_token_np[token_id] = False - if sp.is_byte(token_id): - base_bytes_np[token_id] = 1 - continue - piece = sp.id_to_piece(token_id) - if piece.startswith("\u2581"): - has_leading_space_np[token_id] = True - piece = piece[1:] - base_bytes_np[token_id] = len(piece.encode("utf-8")) - return ( - torch.tensor(base_bytes_np, dtype=torch.int16, device=device), - torch.tensor(has_leading_space_np, dtype=torch.bool, device=device), - torch.tensor(is_boundary_token_np, dtype=torch.bool, device=device), - ) - -TOKENIZER_META_FORMAT_VERSION = 1 -TOKENIZER_META_SUFFIX = ".meta.npz" - - -def _derive_tokenizer_meta_path(tokenizer_path: str) -> Path: - tokenizer = Path(tokenizer_path) - if tokenizer.suffix == ".model": - return tokenizer.with_suffix(TOKENIZER_META_SUFFIX) - return tokenizer.with_name(tokenizer.name + TOKENIZER_META_SUFFIX) - - -def build_sentencepiece_luts_np( - sp: spm.SentencePieceProcessor, vocab_size: int -) -> tuple[np.ndarray, np.ndarray, np.ndarray]: - sp_vocab_size = int(sp.vocab_size()) - table_size = max(sp_vocab_size, vocab_size) - base_bytes_np = np.zeros((table_size,), dtype=np.int16) - has_leading_space_np = np.zeros((table_size,), dtype=np.bool_) - is_boundary_token_np = np.ones((table_size,), dtype=np.bool_) - for token_id in range(sp_vocab_size): - if sp.is_control(token_id) or sp.is_unknown(token_id) or sp.is_unused(token_id): - continue - is_boundary_token_np[token_id] = False - if sp.is_byte(token_id): - base_bytes_np[token_id] = 1 - continue - piece = sp.id_to_piece(token_id) - if piece.startswith("\u2581"): - has_leading_space_np[token_id] = True - piece = piece[1:] - base_bytes_np[token_id] = len(piece.encode("utf-8")) - return base_bytes_np, has_leading_space_np, is_boundary_token_np - - -def load_tokenizer_meta_luts_np( - meta_path: Path, vocab_size: int -) -> tuple[np.ndarray, np.ndarray, np.ndarray, dict[str, object]]: - def _scalar(value): - arr = np.asarray(value) - if arr.ndim == 0: - return arr.item() - first = arr.reshape(-1)[0] - return first.item() if hasattr(first, "item") else first - - with np.load(meta_path, allow_pickle=False) as data: - format_version = int(_scalar(data["format_version"])) - if format_version != TOKENIZER_META_FORMAT_VERSION: - raise ValueError( - f"Unsupported tokenizer meta format_version={format_version} " - f"expected={TOKENIZER_META_FORMAT_VERSION}" - ) - meta_vocab_size = int(_scalar(data["vocab_size"])) - tokenizer_kind = str(_scalar(data["tokenizer_kind"])) - source_model_name = str(_scalar(data["source_model_name"])) - base_bytes_np = np.asarray(data["base_bytes"], dtype=np.int16) - has_leading_space_np = np.asarray(data["has_leading_space"], dtype=np.bool_) - is_boundary_token_np = np.asarray(data["is_boundary_token"], dtype=np.bool_) - table_size = max(meta_vocab_size, vocab_size) - if base_bytes_np.shape[0] < table_size: - padded_base_bytes = np.zeros((table_size,), dtype=np.int16) - padded_has_leading_space = np.zeros((table_size,), dtype=np.bool_) - padded_is_boundary = np.ones((table_size,), dtype=np.bool_) - padded_base_bytes[: base_bytes_np.shape[0]] = base_bytes_np - padded_has_leading_space[: has_leading_space_np.shape[0]] = has_leading_space_np - padded_is_boundary[: is_boundary_token_np.shape[0]] = is_boundary_token_np - base_bytes_np = padded_base_bytes - has_leading_space_np = padded_has_leading_space - is_boundary_token_np = padded_is_boundary - metadata = { - "format_version": format_version, - "tokenizer_kind": tokenizer_kind, - "source_model_name": source_model_name, - "vocab_size": meta_vocab_size, - "meta_path": str(meta_path), - } - return base_bytes_np, has_leading_space_np, is_boundary_token_np, metadata - - -def load_tokenizer_luts( - tokenizer_path: str, - tokenizer_meta_path: str, - vocab_size: int, - device: torch.device, - *, - validate_meta: bool = False, -) -> tuple[tuple[Tensor, Tensor, Tensor], dict[str, object]]: - meta_path = ( - Path(tokenizer_meta_path) if tokenizer_meta_path - else _derive_tokenizer_meta_path(tokenizer_path) - ) - if meta_path.exists(): - base_bytes_np, has_leading_space_np, is_boundary_token_np, metadata = ( - load_tokenizer_meta_luts_np(meta_path, vocab_size) - ) - if validate_meta and str(tokenizer_path).endswith(".model"): - sp = spm.SentencePieceProcessor(model_file=tokenizer_path) - sp_luts = build_sentencepiece_luts_np(sp, vocab_size) - if not ( - np.array_equal(base_bytes_np, sp_luts[0]) - and np.array_equal(has_leading_space_np, sp_luts[1]) - and np.array_equal(is_boundary_token_np, sp_luts[2]) - ): - raise ValueError(f"Tokenizer metadata mismatch for {meta_path}") - return ( - torch.tensor(base_bytes_np, dtype=torch.int16, device=device), - torch.tensor(has_leading_space_np, dtype=torch.bool, device=device), - torch.tensor(is_boundary_token_np, dtype=torch.bool, device=device), - ), metadata - if not str(tokenizer_path).endswith(".model"): - raise FileNotFoundError(f"TOKENIZER_META_PATH does not exist: {meta_path}") - sp = spm.SentencePieceProcessor(model_file=tokenizer_path) - return build_sentencepiece_luts(sp, vocab_size, device), { - "tokenizer_kind": "sentencepiece", - "source_model_name": str(tokenizer_path), - "vocab_size": int(sp.vocab_size()), - "meta_path": None, - "fallback": True, - } - -def load_validation_tokens(pattern: str, seq_len: int) -> Tensor: - files = [Path(p) for p in sorted(glob.glob(pattern))] - if not files: - raise FileNotFoundError(f"No files found for pattern: {pattern}") - tokens = torch.cat([load_data_shard(file) for file in files]).contiguous() - usable = ((tokens.numel() - 1) // seq_len) * seq_len - if usable <= 0: - raise ValueError(f"Validation split is too short for TRAIN_SEQ_LEN={seq_len}") - return tokens[: usable + 1] -def eval_val( - args: Hyperparameters, - model: nn.Module, - rank: int, - world_size: int, - device: torch.device, - grad_accum_steps: int, - val_tokens: Tensor, - base_bytes_lut: Tensor, - has_leading_space_lut: Tensor, - is_boundary_token_lut: Tensor, - eval_seq_len: int | None = None, -) -> tuple[float, float]: - seq_len = eval_seq_len or args.train_seq_len - local_batch_tokens = args.val_batch_size // (world_size * grad_accum_steps) - if local_batch_tokens < seq_len: - raise ValueError( - "VAL_BATCH_SIZE must provide at least one sequence per rank; " - f"got VAL_BATCH_SIZE={args.val_batch_size}, WORLD_SIZE={world_size}, " - f"GRAD_ACCUM_STEPS={grad_accum_steps}, seq_len={seq_len}" - ) - local_batch_seqs = local_batch_tokens // seq_len - total_seqs = (val_tokens.numel() - 1) // seq_len - seq_start = (total_seqs * rank) // world_size - seq_end = (total_seqs * (rank + 1)) // world_size - val_loss_sum = torch.zeros((), device=device, dtype=torch.float64) - val_token_count = torch.zeros((), device=device, dtype=torch.float64) - val_byte_count = torch.zeros((), device=device, dtype=torch.float64) - model.eval() - with torch.inference_mode(): - for batch_seq_start in range(seq_start, seq_end, local_batch_seqs): - batch_seq_end = min(batch_seq_start + local_batch_seqs, seq_end) - raw_start = batch_seq_start * seq_len - raw_end = batch_seq_end * seq_len + 1 - local = val_tokens[raw_start:raw_end].to(device=device, dtype=torch.int64, non_blocking=True) - x = local[:-1].reshape(-1, seq_len) - y = local[1:].reshape(-1, seq_len) - with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): - batch_loss = model(x, y).detach() - batch_token_count = float(y.numel()) - val_loss_sum += batch_loss.to(torch.float64) * batch_token_count - val_token_count += batch_token_count - prev_ids = x.reshape(-1) - tgt_ids = y.reshape(-1) - token_bytes = base_bytes_lut[tgt_ids].to(dtype=torch.int16) - token_bytes += (has_leading_space_lut[tgt_ids] & ~is_boundary_token_lut[prev_ids]).to(dtype=torch.int16) - val_byte_count += token_bytes.to(torch.float64).sum() - if dist.is_available() and dist.is_initialized(): - dist.all_reduce(val_loss_sum, op=dist.ReduceOp.SUM) - dist.all_reduce(val_token_count, op=dist.ReduceOp.SUM) - dist.all_reduce(val_byte_count, op=dist.ReduceOp.SUM) - val_loss = val_loss_sum / val_token_count - bits_per_token = val_loss.item() / math.log(2.0) - tokens_per_byte = val_token_count.item() / val_byte_count.item() - model.train() - return float(val_loss.item()), float(bits_per_token * tokens_per_byte) - -# --- Quantization helpers --- - -CONTROL_TENSOR_NAME_PATTERNS = tuple( - pattern - for pattern in os.environ.get( - "CONTROL_TENSOR_NAME_PATTERNS", - "attn_scale,attn_scales,mlp_scale,mlp_scales,resid_mix,resid_mixes,q_gain,skip_weight,skip_weights,smear,dtg_gate,ve_layer_scales,ve_shared.scale,attn_gate,vr_lambda", - ).split(",") - if pattern -) -INT8_KEEP_FLOAT_FP32_NAME_PATTERNS = tuple( - pattern - for pattern in os.environ.get( - "INT8_KEEP_FLOAT_FP32_NAME_PATTERNS", - ",".join(CONTROL_TENSOR_NAME_PATTERNS), - ).split(",") - if pattern -) -INT8_KEEP_FLOAT_MAX_NUMEL = 65_536 -INT8_KEEP_FLOAT_STORE_DTYPE = torch.float16 -INT8_PER_ROW_SCALE_DTYPE = torch.float16 -INT8_CLIP_PERCENTILE = 99.99984 -INT8_CLIP_Q = INT8_CLIP_PERCENTILE / 100.0 -def tensor_nbytes(t: Tensor) -> int: - return int(t.numel()) * int(t.element_size()) -def keep_float_tensor(name: str, t: Tensor, passthrough_orig_dtypes: dict[str, str]) -> Tensor: - if any(pattern in name for pattern in INT8_KEEP_FLOAT_FP32_NAME_PATTERNS): - return t.float().contiguous() - if t.dtype in {torch.float32, torch.bfloat16}: - passthrough_orig_dtypes[name] = str(t.dtype).removeprefix("torch.") - return t.to(dtype=INT8_KEEP_FLOAT_STORE_DTYPE).contiguous() - return t -def quantize_float_tensor(t: Tensor) -> tuple[Tensor, Tensor]: - t32 = t.float() - if t32.ndim == 2: - clip_abs = ( - torch.quantile(t32.abs(), INT8_CLIP_Q, dim=1) - if t32.numel() - else torch.empty((t32.shape[0],), dtype=torch.float32) - ) - clipped = torch.maximum(torch.minimum(t32, clip_abs[:, None]), -clip_abs[:, None]) - scale = (clip_abs / 127.0).clamp_min(1.0 / 127.0) - q = torch.clamp(torch.round(clipped / scale[:, None]), -127, 127).to(torch.int8).contiguous() - return q, scale.to(dtype=INT8_PER_ROW_SCALE_DTYPE).contiguous() - clip_abs = float(torch.quantile(t32.abs().flatten(), INT8_CLIP_Q).item()) if t32.numel() else 0.0 - scale = torch.tensor(clip_abs / 127.0 if clip_abs > 0 else 1.0, dtype=torch.float32) - q = torch.clamp(torch.round(torch.clamp(t32, -clip_abs, clip_abs) / scale), -127, 127).to(torch.int8).contiguous() - return q, scale -def quantize_state_dict_int8(state_dict: dict[str, Tensor]): - quantized: dict[str, Tensor] = {} - scales: dict[str, Tensor] = {} - dtypes: dict[str, str] = {} - passthrough: dict[str, Tensor] = {} - passthrough_orig_dtypes: dict[str, str] = {} - qmeta: dict[str, dict[str, object]] = {} - stats = dict.fromkeys( - ("param_count", "num_tensors", "num_float_tensors", "num_nonfloat_tensors", "baseline_tensor_bytes", "int8_payload_bytes"), - 0, - ) - for name, tensor in state_dict.items(): - t = tensor.detach().to("cpu").contiguous() - stats["param_count"] += int(t.numel()) - stats["num_tensors"] += 1 - stats["baseline_tensor_bytes"] += tensor_nbytes(t) - if not t.is_floating_point(): - stats["num_nonfloat_tensors"] += 1 - passthrough[name] = t - stats["int8_payload_bytes"] += tensor_nbytes(t) - continue - if t.numel() <= INT8_KEEP_FLOAT_MAX_NUMEL: - kept = keep_float_tensor(name, t, passthrough_orig_dtypes) - passthrough[name] = kept - stats["int8_payload_bytes"] += tensor_nbytes(kept) - continue - stats["num_float_tensors"] += 1 - q, s = quantize_float_tensor(t) - if s.ndim > 0: - qmeta[name] = {"scheme": "per_row", "axis": 0} - quantized[name] = q - scales[name] = s - dtypes[name] = str(t.dtype).removeprefix("torch.") - stats["int8_payload_bytes"] += tensor_nbytes(q) + tensor_nbytes(s) - obj: dict[str, object] = { - "__quant_format__": "int8_clean_per_row_v1", - "quantized": quantized, - "scales": scales, - "dtypes": dtypes, - "passthrough": passthrough, - } - if qmeta: - obj["qmeta"] = qmeta - if passthrough_orig_dtypes: - obj["passthrough_orig_dtypes"] = passthrough_orig_dtypes - return obj, stats -def dequantize_state_dict_int8(obj: dict[str, object]) -> dict[str, Tensor]: - out: dict[str, Tensor] = {} - qmeta = obj.get("qmeta", {}) - passthrough_orig_dtypes = obj.get("passthrough_orig_dtypes", {}) - for name, q in obj["quantized"].items(): - dtype = getattr(torch, obj["dtypes"][name]) - s = obj["scales"][name] - if qmeta.get(name, {}).get("scheme") == "per_row" or s.ndim > 0: - s = s.to(dtype=torch.float32) - out[name] = (q.float() * s.view(q.shape[0], *([1] * (q.ndim - 1)))).to(dtype=dtype).contiguous() - else: - scale = float(s.item()) - out[name] = (q.float() * scale).to(dtype=dtype).contiguous() - for name, t in obj["passthrough"].items(): - out_t = t.detach().to("cpu").contiguous() - orig_dtype = passthrough_orig_dtypes.get(name) - if isinstance(orig_dtype, str): - out_t = out_t.to(dtype=getattr(torch, orig_dtype)).contiguous() - out[name] = out_t - return out - -# --- Data loading --- - -def load_data_shard(file: Path) -> Tensor: - header_bytes = 256 * np.dtype(" int: - key = str(file) - cached = _SHARD_NTOKENS_CACHE.get(key) - if cached is not None: - return cached - header = np.fromfile(file, dtype=" np.memmap: - key = str(file) - mm = _MMAP_CACHE.get(key) - if mm is not None: - return mm - n = _read_num_tokens(file) - mm = np.memmap(file, mode="r", dtype=" int: - if n <= 1: - return 1 - while True: - s = int(self._rng.integers(1, n)) - if math.gcd(s, n) == 1: - return s - def _reset_cursor(self, si: int, seq_len: int) -> None: - nt = int(self._num_tokens[si]) - max_phase = min(seq_len - 1, max(0, nt - seq_len - 1)) - phase = int(self._rng.integers(max_phase + 1)) if max_phase > 0 else 0 - bc = (nt - 1 - phase) // seq_len - self._cursor_phase[si] = phase - self._cursor_block_count[si] = bc - self._cursor_next[si] = 0 - self._cursor_start[si] = int(self._rng.integers(bc)) if bc > 1 else 0 - self._cursor_stride[si] = self._pick_coprime_stride(bc) - self._cursor_init[si] = True - def _ensure_cursor(self, si: int, seq_len: int) -> None: - if not self._cursor_init[si] or self._cursor_next[si] >= self._cursor_block_count[si]: - self._reset_cursor(si, seq_len) - def _take_from_shard(self, si: int, seq_len: int, count: int, out: list[tuple[int, int]]) -> None: - rem = count - while rem > 0: - self._ensure_cursor(si, seq_len) - bc = int(self._cursor_block_count[si]) - ni = int(self._cursor_next[si]) - take = min(rem, bc - ni) - phase = int(self._cursor_phase[si]) - start = int(self._cursor_start[si]) - stride = int(self._cursor_stride[si]) - for j in range(take): - bi = (start + (ni + j) * stride) % bc - out.append((si, phase + bi * seq_len)) - self._cursor_next[si] = ni + take - rem -= take - def _init_pipeline(self, global_tokens: int, seq_len: int, grad_accum_steps: int) -> None: - local_tokens = global_tokens // (self.world_size * grad_accum_steps) - num_seqs = local_tokens // seq_len - global_num_seqs = num_seqs * self.world_size - self._cfg = (local_tokens, seq_len, num_seqs, global_num_seqs) - bbc = (self._num_tokens - 1) // seq_len - eligible = bbc > 0 - self._eligible_shards = np.nonzero(eligible)[0].astype(np.int64) - self._base_block_counts = bbc[self._eligible_shards].astype(np.int64) - def _sample_global_windows(self) -> list[tuple[int, int]]: - assert self._cfg is not None and self._eligible_shards is not None - _, seq_len, _, gns = self._cfg - ec = int(self._eligible_shards.size) - progress = min(self._batches_built / 1800.0, 1.0) - remaining = np.empty(ec, dtype=np.float64) - for i, si in enumerate(self._eligible_shards.tolist()): - if self._cursor_init[si]: - r = int(self._cursor_block_count[si]) - int(self._cursor_next[si]) - remaining[i] = float(max(r, 1)) - else: - remaining[i] = float(self._base_block_counts[i]) - alpha = 0.90 - 0.40 * progress - weights = np.power(remaining, alpha) - ws = float(weights.sum()) - if not np.isfinite(ws) or ws <= 0.0: - weights = np.ones(ec, dtype=np.float64) - ws = float(weights.sum()) - probs = weights / ws - low = min(max(8, self.world_size), ec, gns) - high = min(max(32, self.world_size * 8), ec, gns) - mix = max(1, min(int(round(low + progress * (high - low))), ec, gns)) - cp = self._rng.choice(ec, size=mix, replace=False, p=probs) - cs = self._eligible_shards[cp] - cpr = probs[cp].copy() - cpr /= cpr.sum() - counts = np.ones(mix, dtype=np.int64) - extra = gns - mix - if extra > 0: - counts += self._rng.multinomial(extra, cpr).astype(np.int64) - perm = self._rng.permutation(mix) - cs, counts = cs[perm], counts[perm] - buckets: list[list[tuple[int, int]]] = [] - for si, cnt in zip(cs.tolist(), counts.tolist()): - b: list[tuple[int, int]] = [] - self._take_from_shard(int(si), seq_len, int(cnt), b) - if b: - if len(b) > 1: - bp = self._rng.permutation(len(b)) - b = [b[int(k)] for k in bp.tolist()] - buckets.append(b) - windows: list[tuple[int, int]] = [] - active = [i for i, bk in enumerate(buckets) if bk] - while active: - order = self._rng.permutation(len(active)) - new_active: list[int] = [] - for oi in order.tolist(): - bi = active[oi] - if buckets[bi]: - windows.append(buckets[bi].pop()) - if buckets[bi]: - new_active.append(bi) - active = new_active - return windows - def next_batch(self, global_tokens: int, seq_len: int, grad_accum_steps: int) -> tuple[Tensor, Tensor]: - if self._cfg is None: - self._init_pipeline(global_tokens, seq_len, grad_accum_steps) - _, _, num_seqs, gns = self._cfg - gw = self._sample_global_windows() - local_w = gw[self.rank::self.world_size] - x = torch.empty((num_seqs, seq_len), dtype=torch.int64) - y = torch.empty((num_seqs, seq_len), dtype=torch.int64) - for slot, (si, pos) in enumerate(local_w): - mm = _get_shard_memmap(self.files[si]) - window = torch.as_tensor(np.array(mm[pos:pos + seq_len + 1], dtype=np.int64)) - x[slot] = window[:-1] - y[slot] = window[1:] - self._batches_built += 1 - return x.to(self.device, non_blocking=True), y.to(self.device, non_blocking=True) - -# --- Transformer modules --- - -class RMSNorm(nn.Module): - def __init__(self, eps: float | None = None): - super().__init__() - self.eps = eps - def forward(self, x: Tensor) -> Tensor: - return F.rms_norm(x, (x.size(-1),), eps=self.eps) -class CastedLinear(nn.Linear): - def forward(self, x: Tensor) -> Tensor: - w = self.weight.to(x.dtype) - bias = self.bias.to(x.dtype) if self.bias is not None else None - return F.linear(x, w, bias) -def restore_low_dim_params_to_fp32(module: nn.Module) -> None: - with torch.no_grad(): - for name, param in module.named_parameters(): - if (param.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS)) and param.dtype != torch.float32: - param.data = param.data.float() -class Rotary(nn.Module): - def __init__(self, dim: int, base: float = 10000.0, train_seq_len: int = 1024, rope_dims: int = 0): - super().__init__() - self.dim = dim - self.base = base - self.train_seq_len = train_seq_len - self.rope_dims = rope_dims if rope_dims > 0 else dim - inv_freq = 1.0 / (base ** (torch.arange(0, self.rope_dims, 2, dtype=torch.float32) / self.rope_dims)) - self.register_buffer("inv_freq", inv_freq, persistent=False) - self._seq_len_cached = 0 - self._cos_cached: Tensor | None = None - self._sin_cached: Tensor | None = None - def forward(self, seq_len: int, device: torch.device, dtype: torch.dtype) -> tuple[Tensor, Tensor]: - if ( - self._cos_cached is None - or self._sin_cached is None - or self._seq_len_cached != seq_len - or self._cos_cached.device != device - ): - rd = self.rope_dims - if seq_len > self.train_seq_len: - scale = seq_len / self.train_seq_len - new_base = self.base * (scale ** (rd / (rd - 2))) - inv_freq = 1.0 / (new_base ** (torch.arange(0, rd, 2, dtype=torch.float32, device=device) / rd)) - else: - inv_freq = self.inv_freq.to(device) - t = torch.arange(seq_len, device=device, dtype=inv_freq.dtype) - freqs = torch.outer(t, inv_freq) - self._cos_cached = freqs.cos()[None, :, None, :] - self._sin_cached = freqs.sin()[None, :, None, :] - self._seq_len_cached = seq_len - return self._cos_cached.to(dtype=dtype), self._sin_cached.to(dtype=dtype) -def apply_rotary_emb(x: Tensor, cos: Tensor, sin: Tensor, rope_dims: int = 0) -> Tensor: - if rope_dims > 0 and rope_dims < x.size(-1): - x_rope, x_pass = x[..., :rope_dims], x[..., rope_dims:] - half = rope_dims // 2 - x1, x2 = x_rope[..., :half], x_rope[..., half:] - x_rope = torch.cat((x1 * cos + x2 * sin, x1 * (-sin) + x2 * cos), dim=-1) - return torch.cat((x_rope, x_pass), dim=-1) - half = x.size(-1) // 2 - x1, x2 = x[..., :half], x[..., half:] - return torch.cat((x1 * cos + x2 * sin, x1 * (-sin) + x2 * cos), dim=-1) - -class CausalSelfAttention(nn.Module): - def __init__( - self, - dim: int, - num_heads: int, - num_kv_heads: int, - rope_base: float, - qk_gain_init: float, - ): - super().__init__() - if dim % num_heads != 0: - raise ValueError("model_dim must be divisible by num_heads") - if num_heads % num_kv_heads != 0: - raise ValueError("num_heads must be divisible by num_kv_heads") - self.num_heads = num_heads - self.num_kv_heads = num_kv_heads - self.head_dim = dim // num_heads - if self.head_dim % 2 != 0: - raise ValueError("head_dim must be even for RoPE") - # No CastedLinear -- weights come from banks - self.q_gain = nn.Parameter(torch.full((num_heads,), qk_gain_init, dtype=torch.float32)) - self.rope_dims = 0 # set by GPT.__init__ for partial RoPE - self.rotary = Rotary(self.head_dim, base=rope_base, train_seq_len=1024) - self.use_xsa = False # set by GPT.__init__ for deep layers only - def _xsa_efficient(self, y: Tensor, v: Tensor) -> Tensor: - """Efficient XSA: subtract self-value projection via GQA-aware reshape (no repeat_interleave). - y: [B, T, H, D], v: [B, T, Hkv, D]. H must be divisible by Hkv.""" - B, T, H, D = y.shape - Hkv = v.size(-2) - group = H // Hkv - y_g = y.reshape(B, T, Hkv, group, D) # [B, T, Hkv, group, D] - vn = F.normalize(v, dim=-1).unsqueeze(-2) # [B, T, Hkv, 1, D] -- broadcast ready - proj = (y_g * vn).sum(dim=-1, keepdim=True) * vn - return (y_g - proj).reshape(B, T, H, D) - def forward(self, x: Tensor, q_w: Tensor, k_w: Tensor, v_w: Tensor, out_w: Tensor, v_embed: Tensor | None = None) -> Tensor: - if getattr(self, '_save_gptq', False): - self._gptq_qkv_in = x.detach() - bsz, seqlen, dim = x.shape - q = F.linear(x, q_w.to(x.dtype)).reshape(bsz, seqlen, self.num_heads, self.head_dim) - k = F.linear(x, k_w.to(x.dtype)).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim) - v = F.linear(x, v_w.to(x.dtype)) - if v_embed is not None: - v = v + v_embed - v = v.reshape(bsz, seqlen, self.num_kv_heads, self.head_dim) - q = F.rms_norm(q, (q.size(-1),)) - k = F.rms_norm(k, (k.size(-1),)) - cos, sin = self.rotary(seqlen, x.device, q.dtype) - q = apply_rotary_emb(q, cos, sin, self.rope_dims) - k = apply_rotary_emb(k, cos, sin, self.rope_dims) - q = q * self.q_gain.to(dtype=q.dtype)[None, None, :, None] - y = flash_attn_3_func(q, k, v, causal=True) - if self.use_xsa: - y = self._xsa_efficient(y, v) - y = y.reshape(bsz, seqlen, dim) - if getattr(self, '_save_gptq', False): - self._gptq_o_in = y.detach() - return F.linear(y, out_w.to(x.dtype)) - -class SmearGate(nn.Module): - def __init__(self, dim: int): - super().__init__() - self.gate = nn.Parameter(torch.zeros(dim, dtype=torch.float32)) - def forward(self, x: Tensor) -> Tensor: - g = torch.sigmoid(self.gate.to(dtype=x.dtype))[None, None, :] - x_prev = torch.cat([torch.zeros_like(x[:, :1]), x[:, :-1]], dim=1) - return (1 - g) * x + g * x_prev - -class BigramHashEmbedding(nn.Module): - def __init__(self, bigram_vocab_size: int, bigram_dim: int, model_dim: int): - super().__init__() - self.bigram_vocab_size = bigram_vocab_size - self.embed = nn.Embedding(bigram_vocab_size, bigram_dim) - nn.init.zeros_(self.embed.weight) - self.proj = CastedLinear(bigram_dim, model_dim, bias=False) if bigram_dim != model_dim else None - if self.proj is not None: - nn.init.zeros_(self.proj.weight) - self.scale = nn.Parameter(torch.tensor(0.05, dtype=torch.float32)) - def bigram_hash(self, tokens: Tensor) -> Tensor: - t = tokens.to(torch.int32) - mod = self.bigram_vocab_size - 1 - out = torch.empty_like(t) - out[..., 0] = mod - out[..., 1:] = torch.bitwise_xor(36313 * t[..., 1:], 27191 * t[..., :-1]) % mod - return out.long() - def forward(self, token_ids: Tensor) -> Tensor: - h = self.embed(self.bigram_hash(token_ids)) - if self.proj is not None: - h = self.proj(h) - return h * self.scale.to(dtype=h.dtype) - -class ValueEmbedding(nn.Module): - """Reinject token identity into attention values at specific layers. - Each table maps vocab tokens to a low-dim embedding, projected to model_dim.""" - def __init__(self, vocab_size: int, ve_dim: int, model_dim: int): - super().__init__() - self.embed = nn.Embedding(vocab_size, ve_dim) - nn.init.normal_(self.embed.weight, std=0.01) - self.proj = CastedLinear(ve_dim, model_dim, bias=False) if ve_dim != model_dim else None - if self.proj is not None: - nn.init.zeros_(self.proj.weight) - self.scale = nn.Parameter(torch.tensor(0.1, dtype=torch.float32)) - def forward(self, token_ids: Tensor) -> Tensor: - h = self.embed(token_ids) - if self.proj is not None: - h = self.proj(h) - return h * self.scale.to(dtype=h.dtype) - -class MLP(nn.Module): - def __init__(self, dim: int, mlp_mult: int, neg_slope: float = 0.5): - super().__init__() - self.neg_slope = neg_slope - # No CastedLinear -- weights come from banks - def forward(self, x: Tensor, up_w: Tensor, down_w: Tensor) -> Tensor: - if getattr(self, '_save_gptq', False): - self._gptq_up_in = x.detach() - x = F.leaky_relu(F.linear(x, up_w.to(x.dtype)), negative_slope=self.neg_slope) - x2 = x.square() - if getattr(self, '_save_gptq', False): - self._gptq_down_in = x2.detach() - return F.linear(x2, down_w.to(x.dtype)) - -class Block(nn.Module): - def __init__( - self, - dim: int, - num_heads: int, - num_kv_heads: int, - mlp_mult: int, - rope_base: float, - qk_gain_init: float, - layer_idx: int = 0, - ln_scale: bool = False, - neg_slope: float = 0.5, - ): - super().__init__() - self.layer_idx = layer_idx - self.attn_norm = RMSNorm() - self.mlp_norm = RMSNorm() - self.attn = CausalSelfAttention(dim, num_heads, num_kv_heads, rope_base, qk_gain_init) - self.mlp = MLP(dim, mlp_mult, neg_slope=neg_slope) - self.attn_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) - self.mlp_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) - self.resid_mix = nn.Parameter(torch.stack((torch.ones(dim), torch.zeros(dim))).float()) - self.ln_scale_factor = 1.0 / math.sqrt(layer_idx + 1) if ln_scale else 1.0 - def forward(self, x: Tensor, x0: Tensor, q_w: Tensor, k_w: Tensor, v_w: Tensor, out_w: Tensor, up_w: Tensor, down_w: Tensor, v_embed: Tensor | None = None) -> Tensor: - mix = self.resid_mix.to(dtype=x.dtype) - x_in = mix[0][None, None, :] * x + mix[1][None, None, :] * x0 - attn_out = self.attn(self.attn_norm(x_in) * self.ln_scale_factor, q_w, k_w, v_w, out_w, v_embed=v_embed) - x_out = x_in + self.attn_scale.to(dtype=x_in.dtype)[None, None, :] * attn_out - mlp_out = self.mlp_scale.to(dtype=x_out.dtype)[None, None, :] * self.mlp(self.mlp_norm(x_out) * self.ln_scale_factor, up_w, down_w) - return x_out + mlp_out - -class GPT(nn.Module): - def __init__( - self, - vocab_size: int, - num_layers: int, - model_dim: int, - num_heads: int, - num_kv_heads: int, - mlp_mult: int, - tie_embeddings: bool, - tied_embed_init_std: float, - logit_softcap: float, - rope_base: float, - qk_gain_init: float, - bigram_vocab_size: int = 0, - bigram_dim: int = 128, - xsa_last_n: int = 0, - rope_dims: int = 0, - ln_scale: bool = False, - ve_enabled: bool = False, - ve_dim: int = 128, - ve_layers: str = "9,10", - neg_slope: float = 0.5, - ): - super().__init__() - self._ve_target_dim = num_kv_heads * (model_dim // num_heads) # kv_dim for value projection - if logit_softcap <= 0.0: - raise ValueError(f"logit_softcap must be positive, got {logit_softcap}") - self.tie_embeddings = tie_embeddings - self.tied_embed_init_std = tied_embed_init_std - self.logit_softcap = logit_softcap - self.tok_emb = nn.Embedding(vocab_size, model_dim) - self.bigram = BigramHashEmbedding(bigram_vocab_size, bigram_dim, model_dim) if bigram_vocab_size > 0 else None - self.smear = SmearGate(model_dim) - self.num_encoder_layers = num_layers // 2 - self.num_decoder_layers = num_layers - self.num_encoder_layers - self.num_skip_weights = min(self.num_encoder_layers, self.num_decoder_layers) - self.skip_weights = nn.Parameter(torch.ones(self.num_skip_weights, model_dim, dtype=torch.float32)) - # Parameter banks: contiguous 3D tensors for batched optimizer - head_dim = model_dim // num_heads - kv_dim = num_kv_heads * head_dim - mlp_dim = int(mlp_mult * model_dim) - self.num_layers = num_layers - self.qo_bank = nn.Parameter(torch.empty(2 * num_layers, model_dim, model_dim)) - self.kv_bank = nn.Parameter(torch.empty(2 * num_layers, kv_dim, model_dim)) - self.mlp_up_bank = nn.Parameter(torch.empty(num_layers, mlp_dim, model_dim)) - self.mlp_down_bank = nn.Parameter(torch.empty(num_layers, model_dim, mlp_dim)) - self.blocks = nn.ModuleList( - [ - Block( - model_dim, - num_heads, - num_kv_heads, - mlp_mult, - rope_base, - qk_gain_init, - layer_idx=i, - ln_scale=ln_scale, - neg_slope=neg_slope, - ) - for i in range(num_layers) - ] - ) - if rope_dims > 0: - head_dim = model_dim // num_heads - for block in self.blocks: - block.attn.rope_dims = rope_dims - block.attn.rotary = Rotary(head_dim, base=rope_base, train_seq_len=1024, rope_dims=rope_dims) - self.ve_layer_indices = [int(x) for x in ve_layers.split(",") if x.strip()] if ve_enabled else [] - kv_dim_ve = self._ve_target_dim - if self.ve_layer_indices: - self.ve_shared = ValueEmbedding(vocab_size, ve_dim, kv_dim_ve) - self.ve_layer_scales = nn.ParameterList( - [nn.Parameter(torch.ones(1, dtype=torch.float32)) for _ in self.ve_layer_indices] - ) - else: - self.ve_shared = None - self.ve_layer_scales = nn.ParameterList() - self.value_embeds = nn.ModuleList() # keep empty for compat - self.final_norm = RMSNorm() - self.lm_head = None if tie_embeddings else CastedLinear(model_dim, vocab_size, bias=False) - if self.lm_head is not None: - self.lm_head._zero_init = True - if xsa_last_n > 0: - for i in range(max(0, num_layers - xsa_last_n), num_layers): - self.blocks[i].attn.use_xsa = True - self._init_weights() - def _init_weights(self) -> None: - if self.tie_embeddings: - nn.init.normal_(self.tok_emb.weight, mean=0.0, std=self.tied_embed_init_std) - n = self.num_layers - proj_scale = 1.0 / math.sqrt(2 * n) - # Init banks: orthogonal, with proj layers scaled down and out/down zero-init - for i in range(n): - nn.init.orthogonal_(self.qo_bank.data[i], gain=1.0) # Q - nn.init.zeros_(self.qo_bank.data[n + i]) # Out (zero init) - nn.init.orthogonal_(self.kv_bank.data[i], gain=1.0) # K - nn.init.orthogonal_(self.kv_bank.data[n + i], gain=1.0) # V - nn.init.orthogonal_(self.mlp_up_bank.data[i], gain=1.0) # MLP up - nn.init.zeros_(self.mlp_down_bank.data[i]) # MLP down (zero init) - # Scale proj layers (out_proj and mlp_down are "proj" layers) - self.qo_bank.data[n + i].mul_(proj_scale) - self.mlp_down_bank.data[i].mul_(proj_scale) - # Init remaining nn.Linear modules (bigram proj, lm_head) - for name, module in self.named_modules(): - if isinstance(module, nn.Linear): - if getattr(module, "_zero_init", False): - nn.init.zeros_(module.weight) - elif module.weight.ndim == 2 and module.weight.shape[0] >= 64 and module.weight.shape[1] >= 64: - nn.init.orthogonal_(module.weight, gain=1.0) - def _get_ve(self, layer_idx: int, input_ids: Tensor, ve_cache: dict | None = None) -> Tensor | None: - """Get value embedding for a specific layer using shared table + per-layer scale.""" - if self.ve_shared is None or layer_idx not in self.ve_layer_indices: - return None - if ve_cache is not None and 've' not in ve_cache: - ve_cache['ve'] = self.ve_shared(input_ids) - ve_base = ve_cache['ve'] if ve_cache is not None else self.ve_shared(input_ids) - ve_idx = self.ve_layer_indices.index(layer_idx) - return ve_base * self.ve_layer_scales[ve_idx].to(dtype=ve_base.dtype) - def forward(self, input_ids: Tensor, target_ids: Tensor) -> Tensor: - n = self.num_layers - x = self.tok_emb(input_ids) - if self.bigram is not None: - x = x + self.bigram(input_ids) - x = F.rms_norm(x, (x.size(-1),)) - x = self.smear(x) - x0 = x - skips: list[Tensor] = [] - ve_cache: dict = {} - for i in range(self.num_encoder_layers): - ve = self._get_ve(i, input_ids, ve_cache) - x = self.blocks[i](x, x0, - self.qo_bank[i], self.kv_bank[i], self.kv_bank[n + i], - self.qo_bank[n + i], self.mlp_up_bank[i], self.mlp_down_bank[i], - v_embed=ve) - skips.append(x) - for i in range(self.num_decoder_layers): - bi = self.num_encoder_layers + i - if skips: - x = x + self.skip_weights[i].to(dtype=x.dtype)[None, None, :] * skips.pop() - ve = self._get_ve(bi, input_ids, ve_cache) - x = self.blocks[bi](x, x0, - self.qo_bank[bi], self.kv_bank[bi], self.kv_bank[n + bi], - self.qo_bank[n + bi], self.mlp_up_bank[bi], self.mlp_down_bank[bi], - v_embed=ve) - x = self.final_norm(x) - x_flat = x.reshape(-1, x.size(-1)) - targets = target_ids.reshape(-1) - if self.tie_embeddings: - logits_proj = F.linear(x_flat, self.tok_emb.weight) - else: - if self.lm_head is None: - raise RuntimeError("lm_head is required when tie_embeddings=False") - logits_proj = self.lm_head(x_flat) - logits = self.logit_softcap * torch.tanh(logits_proj / self.logit_softcap) - return F.cross_entropy(logits.float(), targets, reduction="mean") - def forward_logits(self, input_ids: Tensor) -> Tensor: - """Return logits (bsz, seq_len, vocab) without computing loss.""" - n = self.num_layers - x = self.tok_emb(input_ids) - if self.bigram is not None: - x = x + self.bigram(input_ids) - x = F.rms_norm(x, (x.size(-1),)) - x = self.smear(x) - x0 = x - skips: list[Tensor] = [] - ve_cache: dict = {} - for i in range(self.num_encoder_layers): - ve = self._get_ve(i, input_ids, ve_cache) - x = self.blocks[i](x, x0, - self.qo_bank[i], self.kv_bank[i], self.kv_bank[n + i], - self.qo_bank[n + i], self.mlp_up_bank[i], self.mlp_down_bank[i], - v_embed=ve) - skips.append(x) - for i in range(self.num_decoder_layers): - bi = self.num_encoder_layers + i - if skips: - x = x + self.skip_weights[i].to(dtype=x.dtype)[None, None, :] * skips.pop() - ve = self._get_ve(bi, input_ids, ve_cache) - x = self.blocks[bi](x, x0, - self.qo_bank[bi], self.kv_bank[bi], self.kv_bank[n + bi], - self.qo_bank[n + bi], self.mlp_up_bank[bi], self.mlp_down_bank[bi], - v_embed=ve) - x = self.final_norm(x) - if self.tie_embeddings: - logits_proj = F.linear(x, self.tok_emb.weight) - else: - logits_proj = self.lm_head(x) - return self.logit_softcap * torch.tanh(logits_proj / self.logit_softcap) - -# --- Sliding window evaluation --- - -def eval_val_sliding( - args: Hyperparameters, - base_model: nn.Module, - rank: int, - world_size: int, - device: torch.device, - val_tokens: Tensor, - base_bytes_lut: Tensor, - has_leading_space_lut: Tensor, - is_boundary_token_lut: Tensor, - stride: int, - batch_seqs: int = 32, - eval_seq_len: int | None = None, -) -> tuple[float, float]: - """Sliding window evaluation: each token scored with maximum context.""" - seq_len = eval_seq_len or args.train_seq_len - total_tokens = val_tokens.numel() - 1 - window_starts = [ws for ws in range(0, total_tokens, stride) - if min(ws + seq_len, total_tokens) - ws >= 1] - total_windows = len(window_starts) - my_s = (total_windows * rank) // world_size - my_e = (total_windows * (rank + 1)) // world_size - my_windows = window_starts[my_s:my_e] - loss_sum = torch.zeros((), device=device, dtype=torch.float64) - token_count = torch.zeros((), device=device, dtype=torch.float64) - byte_count = torch.zeros((), device=device, dtype=torch.float64) - base_model.eval() - compiled_logits = torch.compile(base_model.forward_logits, dynamic=False, fullgraph=True) - with torch.inference_mode(): - for bi in range(0, len(my_windows), batch_seqs): - batch_ws = my_windows[bi:bi + batch_seqs] - bsz = len(batch_ws) - x_batch = torch.zeros(bsz, seq_len, dtype=torch.int64, device=device) - y_batch = torch.zeros(bsz, seq_len, dtype=torch.int64, device=device) - wlens: list[int] = [] - for i, ws in enumerate(batch_ws): - end = min(ws + seq_len, total_tokens) - wlen = end - ws - wlens.append(wlen) - chunk = val_tokens[ws:end + 1].to(dtype=torch.int64, device=device) - x_batch[i, :wlen] = chunk[:-1] - y_batch[i, :wlen] = chunk[1:] - with torch.autocast(device_type="cuda", dtype=torch.bfloat16): - logits = compiled_logits(x_batch) - nll = F.cross_entropy( - logits.reshape(-1, logits.size(-1)).float(), - y_batch.reshape(-1), - reduction="none", - ).reshape(bsz, seq_len) - for i, ws in enumerate(batch_ws): - wlen = wlens[i] - s = 0 if ws == 0 else max(wlen - stride, 0) - scored_nll = nll[i, s:wlen].to(torch.float64) - loss_sum += scored_nll.sum() - token_count += float(wlen - s) - tgt = y_batch[i, s:wlen] - prev = x_batch[i, s:wlen] - tb = base_bytes_lut[tgt].to(torch.float64) - tb += (has_leading_space_lut[tgt] & ~is_boundary_token_lut[prev]).to(torch.float64) - byte_count += tb.sum() - if dist.is_available() and dist.is_initialized(): - dist.all_reduce(loss_sum, op=dist.ReduceOp.SUM) - dist.all_reduce(token_count, op=dist.ReduceOp.SUM) - dist.all_reduce(byte_count, op=dist.ReduceOp.SUM) - val_loss = (loss_sum / token_count).item() - bits_per_token = val_loss / math.log(2.0) - tokens_per_byte = token_count.item() / byte_count.item() - base_model.train() - return val_loss, bits_per_token * tokens_per_byte - - -def eval_val_sliding_ttt( - args: Hyperparameters, base_model: nn.Module, rank: int, world_size: int, - device: torch.device, val_tokens: Tensor, base_bytes_lut: Tensor, - has_leading_space_lut: Tensor, is_boundary_token_lut: Tensor, - stride: int, batch_seqs: int = 32, log0=print, -) -> tuple[float, float]: - """Legal score-first TTT (PR #461 recipe): score each chunk with sliding windows, - then train on it. Every token scored BEFORE any update that could use it.""" - seq_len = args.train_seq_len - total_tokens = val_tokens.numel() - 1 - ttt_chunk = args.ttt_chunk_tokens - - # Pre-compute all window starts - window_starts = [ws for ws in range(0, total_tokens, stride) - if min(ws + seq_len, total_tokens) - ws >= stride or ws == 0] - - # Assign each window to a chunk based on the first token it scores - num_chunks = (total_tokens + ttt_chunk - 1) // ttt_chunk - chunk_windows: list[list[int]] = [[] for _ in range(num_chunks)] - for ws in window_starts: - end = min(ws + seq_len, total_tokens) - wlen = end - ws - s = 0 if ws == 0 else max(wlen - stride, 0) - scored_start = ws + s - ci = min(scored_start // ttt_chunk, num_chunks - 1) - chunk_windows[ci].append(ws) - - log0(f"ttt_sliding:start chunks={num_chunks} chunk_tokens={ttt_chunk} " - f"total_windows={len(window_starts)} stride={stride} " - f"ttt_lr={args.ttt_lr} ttt_epochs={args.ttt_epochs} " - f"freeze_blocks={args.ttt_freeze_blocks}") - - loss_sum = torch.zeros((), device=device, dtype=torch.float64) - token_count = torch.zeros((), device=device, dtype=torch.float64) - byte_count = torch.zeros((), device=device, dtype=torch.float64) - - # Freeze first N blocks - frozen_block_ids = set(range(min(args.ttt_freeze_blocks, len(base_model.blocks)))) - ttt_params = [] - for name, p in base_model.named_parameters(): - freeze = False - for bi in frozen_block_ids: - if f"blocks.{bi}." in name: - freeze = True - break - if freeze: - p.requires_grad_(False) - else: - p.requires_grad_(True) - ttt_params.append(p) - - log0(f"ttt_sliding:params unfrozen={sum(p.numel() for p in ttt_params)} " - f"frozen={sum(p.numel() for p in base_model.parameters() if not p.requires_grad)}") - - optimizer = torch.optim.SGD(ttt_params, lr=args.ttt_lr, momentum=args.ttt_momentum) - t0 = time.perf_counter() - - for ci in range(num_chunks): - windows = chunk_windows[ci] - if not windows: - continue - chunk_start = ci * ttt_chunk - chunk_end = min((ci + 1) * ttt_chunk, total_tokens) - - # --- Phase 1: SCORE this chunk's windows (inference_mode) --- - my_s = (len(windows) * rank) // world_size - my_e = (len(windows) * (rank + 1)) // world_size - my_windows = windows[my_s:my_e] - - base_model.eval() - with torch.inference_mode(): - for bi in range(0, len(my_windows), batch_seqs): - batch_ws = my_windows[bi:bi + batch_seqs] - bsz = len(batch_ws) - x_batch = torch.zeros(bsz, seq_len, dtype=torch.int64, device=device) - y_batch = torch.zeros(bsz, seq_len, dtype=torch.int64, device=device) - wlens: list[int] = [] - for i, ws in enumerate(batch_ws): - end = min(ws + seq_len, total_tokens) - wlen = end - ws - wlens.append(wlen) - chunk_tok = val_tokens[ws:end + 1].to(dtype=torch.int64, device=device) - x_batch[i, :wlen] = chunk_tok[:-1] - y_batch[i, :wlen] = chunk_tok[1:] - with torch.autocast(device_type="cuda", dtype=torch.bfloat16): - logits = base_model.forward_logits(x_batch) - nll = F.cross_entropy( - logits.reshape(-1, logits.size(-1)).float(), - y_batch.reshape(-1), reduction="none", - ).reshape(bsz, seq_len) - for i, ws in enumerate(batch_ws): - wlen = wlens[i] - s = 0 if ws == 0 else max(wlen - stride, 0) - scored_nll = nll[i, s:wlen].to(torch.float64) - loss_sum += scored_nll.sum() - token_count += float(wlen - s) - tgt, prev = y_batch[i, s:wlen], x_batch[i, s:wlen] - tb = base_bytes_lut[tgt].to(torch.float64) - tb += (has_leading_space_lut[tgt] & ~is_boundary_token_lut[prev]).to(torch.float64) - byte_count += tb.sum() - - # --- Phase 2: TRAIN on this chunk (already scored = legal) --- - is_last_chunk = (ci == num_chunks - 1) - if not is_last_chunk and args.ttt_epochs > 0: - base_model.train() - chunk_seqs = (chunk_end - chunk_start) // seq_len - if chunk_seqs > 0: - cos_lr = args.ttt_lr * 0.5 * (1.0 + math.cos(math.pi * ci / max(num_chunks - 1, 1))) - for pg in optimizer.param_groups: - pg['lr'] = cos_lr - my_seq_s = (chunk_seqs * rank) // world_size - my_seq_e = (chunk_seqs * (rank + 1)) // world_size - my_chunk_seqs = my_seq_e - my_seq_s - for _ep in range(args.ttt_epochs): - for bs in range(0, my_chunk_seqs, args.ttt_batch_seqs): - be = min(bs + args.ttt_batch_seqs, my_chunk_seqs) - actual_bs = my_seq_s + bs - start_tok = chunk_start + actual_bs * seq_len - end_tok = chunk_start + (my_seq_s + be) * seq_len + 1 - if end_tok > val_tokens.numel(): - continue - local = val_tokens[start_tok:end_tok].to(device=device, dtype=torch.int64) - x = local[:-1].reshape(-1, seq_len) - y = local[1:].reshape(-1, seq_len) - optimizer.zero_grad(set_to_none=True) - with torch.autocast(device_type="cuda", dtype=torch.bfloat16): - loss = base_model(x, y) - loss.backward() - if world_size > 1: - for p in ttt_params: - if p.grad is not None: - dist.all_reduce(p.grad, op=dist.ReduceOp.AVG) - torch.nn.utils.clip_grad_norm_(ttt_params, args.ttt_grad_clip) - optimizer.step() - - if rank == 0 and (ci % 10 == 0 or ci == num_chunks - 1): - elapsed = time.perf_counter() - t0 - rl = loss_sum.item() / max(token_count.item(), 1) - rbpb = rl / math.log(2.0) * (token_count.item() / max(byte_count.item(), 1)) if token_count.item() > 0 else 0.0 - log0(f" ttt_chunk [{ci+1}/{num_chunks}] bpb={rbpb:.6f} time={elapsed:.1f}s") - - if dist.is_available() and dist.is_initialized(): - dist.all_reduce(loss_sum, op=dist.ReduceOp.SUM) - dist.all_reduce(token_count, op=dist.ReduceOp.SUM) - dist.all_reduce(byte_count, op=dist.ReduceOp.SUM) - - val_loss = (loss_sum / token_count).item() - val_bpb = val_loss / math.log(2.0) * (token_count.item() / byte_count.item()) - - for p in base_model.parameters(): - p.requires_grad_(True) - base_model.eval() - - log0(f"ttt_sliding:done val_loss={val_loss:.6f} val_bpb={val_bpb:.6f} " - f"elapsed={time.perf_counter() - t0:.1f}s") - return val_loss, val_bpb - - -# --- GPTQ-lite int6 quantization --- - -def _classify_param(name: str) -> str: - if "tok_emb" in name or "lm_head" in name: - return "embed" - if ".mlp." in name: - return "mlp" - if ".attn." in name or (".proj." in name and ".mlp." not in name): - return "attn" - return "other" -def quantize_int6_per_row(t: Tensor, clip_range: int = 31) -> tuple[Tensor, Tensor]: - t32 = t.float() - if t32.ndim == 2: - best_q, best_s, best_err = None, None, float('inf') - for pct in [0.9990, 0.9995, 0.9999, 0.99999, 1.0]: - if pct < 1.0: - row_clip = torch.quantile(t32.abs(), pct, dim=1) - else: - row_clip = t32.abs().amax(dim=1) - s = (row_clip / clip_range).clamp_min(1.0 / clip_range).to(torch.float16) - q = torch.clamp(torch.round(t32 / s.float()[:, None]), -clip_range, clip_range).to(torch.int8) - recon = q.float() * s.float()[:, None] - err = (t32 - recon).pow(2).mean().item() - if err < best_err: - best_q, best_s, best_err = q, s, err - return best_q, best_s - amax = t32.abs().max().item() - scale = torch.tensor(amax / clip_range if amax > 0 else 1.0, dtype=torch.float16) - q = torch.clamp(torch.round(t32 / scale.float()), -clip_range, clip_range).to(torch.int8) - return q, scale - -def _unbank_state_dict(sd: dict[str, Tensor], num_layers: int) -> dict[str, Tensor]: - """Convert 3D bank tensors into individual 2D tensors with standard names.""" - out: dict[str, Tensor] = {} - n = num_layers - for name, tensor in sd.items(): - if name == "qo_bank": - for i in range(n): - out[f"blocks.{i}.attn.c_q.weight"] = tensor[i] - out[f"blocks.{i}.attn.proj.weight"] = tensor[n + i] - elif name == "kv_bank": - for i in range(n): - out[f"blocks.{i}.attn.c_k.weight"] = tensor[i] - out[f"blocks.{i}.attn.c_v.weight"] = tensor[n + i] - elif name == "mlp_up_bank": - for i in range(n): - out[f"blocks.{i}.mlp.fc.weight"] = tensor[i] - elif name == "mlp_down_bank": - for i in range(n): - out[f"blocks.{i}.mlp.proj.weight"] = tensor[i] - else: - out[name] = tensor - return out - -def _rebank_state_dict(sd: dict[str, Tensor], num_layers: int, template_sd: dict[str, Tensor]) -> dict[str, Tensor]: - """Convert individual 2D tensors back into 3D bank tensors.""" - out: dict[str, Tensor] = {} - n = num_layers - # Reconstruct banks from individual weight keys - qo_slices = [None] * (2 * n) - kv_slices = [None] * (2 * n) - up_slices = [None] * n - down_slices = [None] * n - consumed = set() - for i in range(n): - qk = f"blocks.{i}.attn.c_q.weight" - if qk in sd: - qo_slices[i] = sd[qk] - consumed.add(qk) - ok = f"blocks.{i}.attn.proj.weight" - if ok in sd: - qo_slices[n + i] = sd[ok] - consumed.add(ok) - kk = f"blocks.{i}.attn.c_k.weight" - if kk in sd: - kv_slices[i] = sd[kk] - consumed.add(kk) - vk = f"blocks.{i}.attn.c_v.weight" - if vk in sd: - kv_slices[n + i] = sd[vk] - consumed.add(vk) - fk = f"blocks.{i}.mlp.fc.weight" - if fk in sd: - up_slices[i] = sd[fk] - consumed.add(fk) - dk = f"blocks.{i}.mlp.proj.weight" - if dk in sd: - down_slices[i] = sd[dk] - consumed.add(dk) - out["qo_bank"] = torch.stack(qo_slices).to(dtype=template_sd["qo_bank"].dtype) - out["kv_bank"] = torch.stack(kv_slices).to(dtype=template_sd["kv_bank"].dtype) - out["mlp_up_bank"] = torch.stack(up_slices).to(dtype=template_sd["mlp_up_bank"].dtype) - out["mlp_down_bank"] = torch.stack(down_slices).to(dtype=template_sd["mlp_down_bank"].dtype) - for name, tensor in sd.items(): - if name not in consumed: - out[name] = tensor - return out - -def mixed_quantize_int6(state_dict: dict[str, Tensor], int6_cats: set[str], clip_range: int = 31, - hessians: dict[str, Tensor] | None = None): - num_layers_total = max( - (int(k.split(".")[1]) for k in state_dict if k.startswith("blocks.")), - default=0, - ) + 1 - late_k_layers = set(range(num_layers_total - 2, num_layers_total)) - result: dict[str, Tensor] = {} - meta: dict[str, object] = {} - gptq_count, naive_count = 0, 0 - for name, tensor in state_dict.items(): - t = tensor.detach().cpu().contiguous() - cat = _classify_param(name) - if not t.is_floating_point() or t.numel() <= 65536: - result[name] = t.to(torch.float16) if t.is_floating_point() else t - meta[name] = "passthrough" - continue - if any(p in name for p in CONTROL_TENSOR_NAME_PATTERNS): - result[name] = t.float() - meta[name] = "passthrough_ctrl" - continue - if cat in int6_cats and t.ndim >= 1: - H = hessians.get(name) if hessians else None - if H is not None and t.ndim == 2: - q, s = gptq_quantize_weight(t, H.cpu(), clip_range=clip_range) - gptq_count += 1 - else: - q, s = quantize_int6_per_row(t, clip_range=clip_range) - naive_count += 1 - result[name + ".q"] = q - result[name + ".scale"] = s - meta[name] = {"type": "int6"} - else: - q, s = quantize_float_tensor(t) - result[name + ".q"] = q - result[name + ".scale"] = s - meta[name] = {"type": "int8"} - if hessians: - print(f"gptq_quantize: {gptq_count} GPTQ layers, {naive_count} naive layers", flush=True) - return result, meta -def dequantize_mixed_int6(result: dict[str, Tensor], meta: dict[str, object], - template_sd: dict[str, Tensor]) -> dict[str, Tensor]: - out: dict[str, Tensor] = {} - for name, orig in template_sd.items(): - info = meta.get(name) - if info is None: - continue - orig_dtype = orig.dtype - if info in ("passthrough", "passthrough_ctrl", "passthrough_fp16"): - t = result[name] - if t.dtype == torch.float16 and orig_dtype in (torch.float32, torch.bfloat16): - t = t.to(orig_dtype) - out[name] = t - continue - q, s = result[name + ".q"], result[name + ".scale"] - if s.ndim > 0: - out[name] = (q.float() * s.float().view(q.shape[0], *([1] * (q.ndim - 1)))).to(orig_dtype) - else: - out[name] = (q.float() * float(s.item())).to(orig_dtype) - return out - -# --- Full Hessian GPTQ --- - -def gptq_quantize_weight(W: Tensor, H: Tensor, clip_range: int = 31, - block_size: int = 128, percdamp: float = 0.01) -> tuple[Tensor, Tensor]: - """GPTQ with Cholesky error compensation and actorder (Frantar et al., ICLR 2023).""" - W_orig = W.float().clone() - rows, cols = W_orig.shape - H = H.float().clone() - dead = torch.diag(H) == 0 - H[dead, dead] = 1 - damp = percdamp * H.diag().mean() - H.diagonal().add_(damp) - perm = torch.argsort(H.diag(), descending=True) - invperm = torch.argsort(perm) - W_perm = W_orig[:, perm].clone() - W_perm[:, dead[perm]] = 0 - H = H[perm][:, perm] - try: - Hinv = torch.cholesky_inverse(torch.linalg.cholesky(H)) - Hinv = torch.linalg.cholesky(Hinv, upper=True) - except torch.linalg.LinAlgError: - return quantize_int6_per_row(W_orig, clip_range) - best_q, best_scale, best_err = None, None, float('inf') - for pct in [0.9990, 0.9995, 0.9999, 0.99999, 1.0]: - if pct < 1.0: - row_clip = torch.quantile(W_orig.abs(), pct, dim=1) - else: - row_clip = W_orig.abs().amax(dim=1) - s = (row_clip / clip_range).clamp_min(1.0 / clip_range).to(torch.float16) - sf = s.float() - Q = torch.zeros(rows, cols, dtype=torch.int8) - W_work = W_perm.clone() - for i1 in range(0, cols, block_size): - i2 = min(i1 + block_size, cols) - W_block = W_work[:, i1:i2].clone() - Hinv_block = Hinv[i1:i2, i1:i2] - Err = torch.zeros(rows, i2 - i1) - for j in range(i2 - i1): - w_col = W_block[:, j] - d = Hinv_block[j, j] - q_col = torch.clamp(torch.round(w_col / sf), -clip_range, clip_range) - Q[:, i1 + j] = q_col.to(torch.int8) - err = (w_col - q_col.float() * sf) / d - Err[:, j] = err - W_block[:, j:] -= err.unsqueeze(1) * Hinv_block[j, j:].unsqueeze(0) - if i2 < cols: - W_work[:, i2:] -= Err @ Hinv[i1:i2, i2:] - recon = Q.float() * sf[:, None] - mse = (W_perm - recon).pow(2).mean().item() - if mse < best_err: - best_q, best_scale, best_err = Q, s, mse - best_q = best_q[:, invperm] - return best_q, best_scale - -def _init_hessians(nl: int, dim: int, mlp_dim: int, device: torch.device) -> dict[str, Tensor]: - h: dict[str, Tensor] = {} - for i in range(nl): - for k in ['c_q', 'c_k', 'c_v']: - h[f'blocks.{i}.attn.{k}.weight'] = torch.zeros(dim, dim, dtype=torch.float32, device=device) - h[f'blocks.{i}.attn.proj.weight'] = torch.zeros(dim, dim, dtype=torch.float32, device=device) - h[f'blocks.{i}.mlp.fc.weight'] = torch.zeros(dim, dim, dtype=torch.float32, device=device) - h[f'blocks.{i}.mlp.proj.weight'] = torch.zeros(mlp_dim, mlp_dim, dtype=torch.float32, device=device) - return h - -def _accum_hessians(hessians: dict[str, Tensor], blocks: nn.ModuleList, dim: int, mlp_dim: int) -> None: - for i, block in enumerate(blocks): - qkv_in = block.attn._gptq_qkv_in.float().reshape(-1, dim) - h_qkv = qkv_in.t() @ qkv_in - hessians[f'blocks.{i}.attn.c_q.weight'] += h_qkv - hessians[f'blocks.{i}.attn.c_k.weight'] += h_qkv - hessians[f'blocks.{i}.attn.c_v.weight'] += h_qkv - o_in = block.attn._gptq_o_in.float().reshape(-1, dim) - hessians[f'blocks.{i}.attn.proj.weight'] += o_in.t() @ o_in - up_in = block.mlp._gptq_up_in.float().reshape(-1, dim) - hessians[f'blocks.{i}.mlp.fc.weight'] += up_in.t() @ up_in - down_in = block.mlp._gptq_down_in.float().reshape(-1, mlp_dim) - hessians[f'blocks.{i}.mlp.proj.weight'] += down_in.t() @ down_in - -def _finalize_hessians(hessians: dict[str, Tensor], num_batches: int) -> None: - for name in hessians: - hessians[name] = hessians[name].cpu() / num_batches - damp = 0.01 * torch.diag(hessians[name]).mean().clamp_min(1e-6) - hessians[name] += damp * torch.eye(hessians[name].shape[0]) - -def gptq_collect_hessians(base_model: nn.Module, train_loader, device: torch.device, - num_batches: int, batch_tokens: int, seq_len: int, - grad_accum_steps: int) -> dict[str, Tensor]: - """Collect Hessians H = X^T X from training data.""" - nl = base_model.num_layers - dim = base_model.tok_emb.weight.shape[1] - mlp_dim = base_model.mlp_up_bank.shape[1] - hessians = _init_hessians(nl, dim, mlp_dim, device) - for block in base_model.blocks: - block.attn._save_gptq = True - block.mlp._save_gptq = True - base_model.eval() - with torch.inference_mode(), torch.autocast(device_type='cuda', dtype=torch.bfloat16): - for _ in range(num_batches): - x, y = train_loader.next_batch(batch_tokens, seq_len, grad_accum_steps) - base_model(x, y) - _accum_hessians(hessians, base_model.blocks, dim, mlp_dim) - for block in base_model.blocks: - block.attn._save_gptq = False - block.mlp._save_gptq = False - _finalize_hessians(hessians, num_batches) - base_model.train() - return hessians - -# --- Training --- - -def main() -> None: - code = Path(__file__).read_text(encoding="utf-8") - args = Hyperparameters() - # zeropower_via_newtonschulz5 runs eagerly with bmm -- do NOT compile - distributed = "RANK" in os.environ and "WORLD_SIZE" in os.environ - rank = int(os.environ.get("RANK", "0")) - world_size = int(os.environ.get("WORLD_SIZE", "1")) - local_rank = int(os.environ.get("LOCAL_RANK", "0")) - if world_size <= 0: - raise ValueError(f"WORLD_SIZE must be positive, got {world_size}") - if 8 % world_size != 0: - raise ValueError(f"WORLD_SIZE={world_size} must divide 8 so grad_accum_steps stays integral") - grad_accum_steps = 8 // world_size - grad_scale = 1.0 / grad_accum_steps - if not torch.cuda.is_available(): - raise RuntimeError("CUDA is required") - device = torch.device("cuda", local_rank) - torch.cuda.set_device(device) - if distributed: - dist.init_process_group(backend="nccl", device_id=device) - dist.barrier() - master_process = rank == 0 - torch.backends.cuda.matmul.allow_tf32 = True - torch.backends.cudnn.allow_tf32 = True - from torch.backends.cuda import enable_cudnn_sdp, enable_flash_sdp, enable_math_sdp, enable_mem_efficient_sdp - enable_cudnn_sdp(False) - enable_flash_sdp(True) - enable_mem_efficient_sdp(False) - enable_math_sdp(False) - logfile = None - if master_process: - os.makedirs("logs", exist_ok=True) - logfile = f"logs/{args.run_id}.txt" - print(logfile) - def log0(msg: str, console: bool = True) -> None: - if not master_process: - return - if console: - print(msg) - if logfile is not None: - with open(logfile, "a", encoding="utf-8") as f: - print(msg, file=f) - log0(code, console=False) - log0("=" * 100, console=False) - log0(f"Running Python {sys.version}", console=False) - log0(f"Running PyTorch {torch.__version__}", console=False) - log0( - subprocess.run(["nvidia-smi"], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True, check=False).stdout, - console=False, - ) - log0("=" * 100, console=False) - random.seed(args.seed) - np.random.seed(args.seed) - torch.manual_seed(args.seed) - torch.cuda.manual_seed_all(args.seed) - (base_bytes_lut, has_leading_space_lut, is_boundary_token_lut), tokenizer_metadata = load_tokenizer_luts( - args.tokenizer_path, args.tokenizer_meta_path, args.vocab_size, device, - validate_meta=args.tokenizer_meta_validate, - ) - log0(f"tokenizer: kind={tokenizer_metadata.get('tokenizer_kind', 'unknown')} vocab={tokenizer_metadata.get('vocab_size', '?')}") - if tokenizer_metadata.get('tokenizer_kind') == 'sentencepiece': - sp = spm.SentencePieceProcessor(model_file=args.tokenizer_path) - if int(sp.vocab_size()) != args.vocab_size: - raise ValueError( - f"VOCAB_SIZE={args.vocab_size} does not match tokenizer vocab_size={int(sp.vocab_size())}" - ) - dataset_dir = Path(args.data_path).resolve() - actual_train_files = len(list(dataset_dir.glob("fineweb_train_*.bin"))) - effective_eval_seq_len = args.eval_seq_len if args.eval_seq_len > 0 else args.train_seq_len - val_seq_len = max(args.train_seq_len, effective_eval_seq_len) - val_tokens = load_validation_tokens(args.val_files, val_seq_len) - log0(f"train_loader:dataset:{dataset_dir.name} train_shards:{actual_train_files}") - log0(f"val_loader:shards pattern={args.val_files} tokens:{val_tokens.numel() - 1}") - base_model = GPT( - vocab_size=args.vocab_size, - num_layers=args.num_layers, - model_dim=args.model_dim, - num_heads=args.num_heads, - num_kv_heads=args.num_kv_heads, - mlp_mult=args.mlp_mult, - tie_embeddings=args.tie_embeddings, - tied_embed_init_std=args.tied_embed_init_std, - logit_softcap=args.logit_softcap, - rope_base=args.rope_base, - qk_gain_init=args.qk_gain_init, - bigram_vocab_size=args.bigram_vocab_size, - bigram_dim=args.bigram_dim, - xsa_last_n=args.xsa_last_n, - rope_dims=args.rope_dims, - ln_scale=args.ln_scale, - ve_enabled=args.ve_enabled, - ve_dim=args.ve_dim, - ve_layers=args.ve_layers, - neg_slope=args.negative_slope, - ).to(device).bfloat16() - # Banks stay FP32 (like CastedLinear weights), cast to BF16 in forward - base_model.qo_bank.data = base_model.qo_bank.data.float() - base_model.kv_bank.data = base_model.kv_bank.data.float() - base_model.mlp_up_bank.data = base_model.mlp_up_bank.data.float() - base_model.mlp_down_bank.data = base_model.mlp_down_bank.data.float() - for module in base_model.modules(): - if isinstance(module, CastedLinear): - module.float() - restore_low_dim_params_to_fp32(base_model) - # No DDP -- Parallel Muon handles bank grad communication via reduce-scatter, - # and non-bank grads are manually all-reduced before Adam steps. - compiled_model = torch.compile(base_model, dynamic=False, fullgraph=True) - model = compiled_model - - # Optimizer split: - # - 4 parameter banks -> Muon (batched Newton-Schulz) - # - token embedding -> Adam - # - scalars/control tensors -> Adam - # - bigram proj, VE proj -> Adam (small matrix params not worth banking) - matrix_params = [ - base_model.qo_bank, base_model.kv_bank, - base_model.mlp_up_bank, base_model.mlp_down_bank, - ] - block_named_params = list(base_model.blocks.named_parameters()) - scalar_params = [ - p - for name, p in block_named_params - if p.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) - ] - if base_model.skip_weights.numel() > 0: - scalar_params.append(base_model.skip_weights) - scalar_params.append(base_model.smear.gate) - if base_model.bigram is not None: - scalar_params.append(base_model.bigram.scale) - token_lr = args.tied_embed_lr if args.tie_embeddings else args.embed_lr - tok_params = [{"params": [base_model.tok_emb.weight], "lr": token_lr, "base_lr": token_lr}] - if base_model.bigram is not None: - tok_params.append({"params": [base_model.bigram.embed.weight], "lr": token_lr, "base_lr": token_lr}) - if base_model.bigram.proj is not None: - scalar_params.append(base_model.bigram.proj.weight) - if base_model.ve_shared is not None: - tok_params.append({"params": [base_model.ve_shared.embed.weight], "lr": token_lr, "base_lr": token_lr}) - if base_model.ve_shared.proj is not None: - scalar_params.append(base_model.ve_shared.proj.weight) - scalar_params.append(base_model.ve_shared.scale) - for s in base_model.ve_layer_scales: - scalar_params.append(s) - optimizer_tok = torch.optim.AdamW( - tok_params, - betas=(args.beta1, args.beta2), - eps=args.adam_eps, - weight_decay=args.adam_wd, - fused=True, - ) - optimizer_muon = Muon( - matrix_params, - lr=args.matrix_lr, - momentum=args.muon_momentum, - backend_steps=args.muon_backend_steps, - weight_decay=args.muon_wd, - ) - for group in optimizer_muon.param_groups: - group["base_lr"] = args.matrix_lr - optimizer_scalar = torch.optim.AdamW( - [{"params": scalar_params, "lr": args.scalar_lr, "base_lr": args.scalar_lr}], - betas=(args.beta1, args.beta2), - eps=args.adam_eps, - weight_decay=args.adam_wd, - fused=True, - ) - # Non-bank params that need manual all-reduce (replicated across GPUs) - replicated_params = list(optimizer_tok.param_groups[0]["params"]) - for pg in optimizer_tok.param_groups[1:]: - replicated_params.extend(pg["params"]) - replicated_params.extend(scalar_params) - - optimizer_head = None - if base_model.lm_head is not None: - optimizer_head = torch.optim.Adam( - [{"params": [base_model.lm_head.weight], "lr": args.head_lr, "base_lr": args.head_lr}], - betas=(args.beta1, args.beta2), - eps=args.adam_eps, - fused=True, - ) - replicated_params.append(base_model.lm_head.weight) - optimizers: list[torch.optim.Optimizer] = [optimizer_tok, optimizer_muon, optimizer_scalar] - if optimizer_head is not None: - optimizers.append(optimizer_head) - log0(f"model_params:{sum(p.numel() for p in base_model.parameters())}") - xsa_layers = [i for i, b in enumerate(base_model.blocks) if b.attn.use_xsa] - log0(f"XSA:last_{args.xsa_last_n} active_layers:{xsa_layers}") - log0(f"world_size:{world_size} grad_accum_steps:{grad_accum_steps}") - log0("sdp_backends:cudnn=False flash=True mem_efficient=False math=False") - log0(f"attention_mode:gqa num_heads:{args.num_heads} num_kv_heads:{args.num_kv_heads}") - log0( - f"tie_embeddings:{args.tie_embeddings} embed_lr:{token_lr} " - f"head_lr:{args.head_lr if base_model.lm_head is not None else 0.0} " - f"matrix_lr:{args.matrix_lr} scalar_lr:{args.scalar_lr}" - ) - log0( - f"train_batch_tokens:{args.train_batch_tokens} train_seq_len:{args.train_seq_len} " - f"iterations:{args.iterations} warmup_steps:{args.warmup_steps} " - f"max_wallclock_seconds:{args.max_wallclock_seconds:.3f}" - ) - log0(f"seed:{args.seed}") - train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) - def zero_grad_all() -> None: - for opt in optimizers: - opt.zero_grad(set_to_none=True) - max_wallclock_ms = 1000.0 * args.max_wallclock_seconds if args.max_wallclock_seconds > 0 else None - if args.use_gptq and max_wallclock_ms is not None: - max_wallclock_ms -= args.gptq_reserve_ms - log0(f"gptq:reserving {args.gptq_reserve_ms:.0f}ms from training budget, effective={max_wallclock_ms:.0f}ms") - def lr_mul(step: int, elapsed_ms: float) -> float: - if args.warmdown_iters <= 0: - return 1.0 - if max_wallclock_ms is None: - warmdown_start = max(args.iterations - args.warmdown_iters, 0) - return max((args.iterations - step) / max(args.warmdown_iters, 1), 0.0) if warmdown_start <= step < args.iterations else 1.0 - step_ms = elapsed_ms / max(step, 1) - warmdown_ms = args.warmdown_iters * step_ms - remaining_ms = max(max_wallclock_ms - elapsed_ms, 0.0) - return remaining_ms / max(warmdown_ms, 1e-9) if remaining_ms <= warmdown_ms else 1.0 - if args.warmup_steps > 0: - initial_model_state = {name: tensor.detach().cpu().clone() for name, tensor in base_model.state_dict().items()} - initial_optimizer_states = [copy.deepcopy(opt.state_dict()) for opt in optimizers] - model.train() - for warmup_step in range(args.warmup_steps): - zero_grad_all() - for micro_step in range(grad_accum_steps): - x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) - with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): - warmup_loss = model(x, y) - (warmup_loss * grad_scale).backward() - # All-reduce all grads for warmup (simple, not optimized) - if distributed: - for p in base_model.parameters(): - if p.grad is not None: - dist.all_reduce(p.grad, op=dist.ReduceOp.AVG) - for opt in optimizers: - opt.step() - zero_grad_all() - if args.warmup_steps <= 20 or (warmup_step + 1) % 10 == 0 or warmup_step + 1 == args.warmup_steps: - log0(f"warmup_step:{warmup_step + 1}/{args.warmup_steps}") - base_model.load_state_dict(initial_model_state, strict=True) - for opt, state in zip(optimizers, initial_optimizer_states, strict=True): - opt.load_state_dict(state) - zero_grad_all() - train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) - swa_state: dict[str, Tensor] | None = None - swa_count = 0 - ema_state = {name: t.detach().float().clone() for name, t in base_model.state_dict().items()} - ema_decay = 0.997 - training_time_ms = 0.0 - stop_after_step: int | None = None - torch.cuda.synchronize() - t0 = time.perf_counter() - step = 0 - while True: - last_step = step == args.iterations or (stop_after_step is not None and step >= stop_after_step) - should_validate = last_step or (args.val_loss_every > 0 and step % args.val_loss_every == 0) - if should_validate: - torch.cuda.synchronize() - training_time_ms += 1000.0 * (time.perf_counter() - t0) - val_loss, val_bpb = eval_val( - args, - model, - rank, - world_size, - device, - grad_accum_steps, - val_tokens, - base_bytes_lut, - has_leading_space_lut, - is_boundary_token_lut, - ) - log0( - f"step:{step}/{args.iterations} val_loss:{val_loss:.4f} val_bpb:{val_bpb:.4f} " - f"train_time:{training_time_ms:.0f}ms step_avg:{training_time_ms / max(step, 1):.2f}ms" - ) - torch.cuda.synchronize() - t0 = time.perf_counter() - if last_step: - if stop_after_step is not None and step < args.iterations: - log0( - f"stopping_early: wallclock_cap train_time:{training_time_ms:.0f}ms " - f"step:{step}/{args.iterations}" - ) - break - elapsed_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) - scale = lr_mul(step, elapsed_ms) - zero_grad_all() - train_loss = torch.zeros((), device=device) - for micro_step in range(grad_accum_steps): - x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) - with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): - loss = model(x, y) - train_loss += loss.detach() - (loss * grad_scale).backward() - train_loss /= grad_accum_steps - frac = min(step / args.muon_momentum_warmup_steps, 1.0) if args.muon_momentum_warmup_steps > 0 else 1.0 - muon_momentum = (1 - frac) * args.muon_momentum_warmup_start + frac * args.muon_momentum - for group in optimizer_muon.param_groups: - group["momentum"] = muon_momentum - for opt in optimizers: - for group in opt.param_groups: - group["lr"] = group["base_lr"] * scale - if args.grad_clip_norm > 0: - torch.nn.utils.clip_grad_norm_(base_model.parameters(), args.grad_clip_norm) - # === 3-phase overlapped optimizer step === - # Phase 1: Launch async reduce-scatter for banks (biggest first) - optimizer_muon.launch_reduce_scatters() - # Phase 2: All-reduce non-bank grads + step Adam (while bank RS is in-flight) - if distributed: - for p in replicated_params: - if p.grad is not None: - dist.all_reduce(p.grad, op=dist.ReduceOp.AVG) - optimizer_tok.step() - optimizer_scalar.step() - if optimizer_head is not None: - optimizer_head.step() - # Phase 3: Wait for RS, local NS5, all-gather (banks processed last) - optimizer_muon.step() - zero_grad_all() - # EMA update - with torch.no_grad(): - for name, t in base_model.state_dict().items(): - ema_state[name].mul_(ema_decay).add_(t.detach().float(), alpha=1.0 - ema_decay) - step += 1 - approx_training_time_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) - if args.swa_enabled and scale < 0.2 and step % args.swa_every == 0: - if swa_state is None: - swa_state = {name: t.detach().cpu().clone() for name, t in base_model.state_dict().items()} - swa_count = 1 - log0(f"swa:start step:{step}") - else: - for name, t in base_model.state_dict().items(): - swa_state[name] += t.detach().cpu() - swa_count += 1 - should_log_train = ( - args.train_log_every > 0 - and (step <= 10 or step % args.train_log_every == 0 or stop_after_step is not None) - ) - if should_log_train: - log0( - f"step:{step}/{args.iterations} train_loss:{train_loss.item():.4f} " - f"train_time:{approx_training_time_ms:.0f}ms step_avg:{approx_training_time_ms / step:.2f}ms" - ) - reached_cap = max_wallclock_ms is not None and approx_training_time_ms >= max_wallclock_ms - if distributed and max_wallclock_ms is not None: - reached_cap_tensor = torch.tensor(int(reached_cap), device=device) - dist.all_reduce(reached_cap_tensor, op=dist.ReduceOp.MAX) - reached_cap = bool(reached_cap_tensor.item()) - if stop_after_step is None and reached_cap: - stop_after_step = step - log0( - f"peak memory allocated: {torch.cuda.max_memory_allocated() // 1024 // 1024} MiB " - f"reserved: {torch.cuda.max_memory_reserved() // 1024 // 1024} MiB" - ) - # Apply EMA weights - log0("ema:applying EMA weights") - current_state = base_model.state_dict() - avg_state = {name: t.to(dtype=current_state[name].dtype) for name, t in ema_state.items()} - base_model.load_state_dict(avg_state, strict=True) - torch.cuda.synchronize() - t_diag = time.perf_counter() - diag_val_loss, diag_val_bpb = eval_val( - args, compiled_model, rank, world_size, device, grad_accum_steps, - val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, - ) - torch.cuda.synchronize() - log0( - f"DIAGNOSTIC post_ema val_loss:{diag_val_loss:.4f} val_bpb:{diag_val_bpb:.4f} " - f"eval_time:{1000.0 * (time.perf_counter() - t_diag):.0f}ms" - ) - export_sd = base_model.state_dict() - if master_process: - torch.save(export_sd, "final_model.pt") - model_bytes = os.path.getsize("final_model.pt") - code_bytes = len(code.encode("utf-8")) - log0(f"Serialized model: {model_bytes} bytes") - log0(f"Code size: {code_bytes} bytes") - # Unbank 3D tensors into individual 2D tensors for quantization - sd_cpu = {k: v.detach().cpu() for k, v in export_sd.items()} - unbanked_sd = _unbank_state_dict(sd_cpu, args.num_layers) - # GPTQ calibration: collect Hessians from training data - gptq_hessians = None - if args.use_gptq: - t_gptq = time.perf_counter() - log0(f"gptq:calibrating with {args.gptq_calib_samples} batches (training data)...") - calib_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) - gptq_hessians = gptq_collect_hessians( - base_model, calib_loader, device, num_batches=args.gptq_calib_samples, - batch_tokens=args.train_batch_tokens, seq_len=args.train_seq_len, - grad_accum_steps=grad_accum_steps) - del calib_loader - gptq_elapsed = time.perf_counter() - t_gptq - log0(f"gptq:calibrated {len(gptq_hessians)} layers in {gptq_elapsed:.1f}s") - torch.cuda.empty_cache() - quant_result, quant_meta = mixed_quantize_int6(unbanked_sd, {"mlp", "attn"}, clip_range=args.quant_clip_range, hessians=gptq_hessians) - quant_buf = io.BytesIO() - torch.save({"w": quant_result, "m": quant_meta}, quant_buf) - quant_raw = quant_buf.getvalue() - quant_blob = lzma.compress(quant_raw, preset=6) - if master_process: - with open("final_model.int6.ptz", "wb") as f: - f.write(quant_blob) - quant_file_bytes = len(quant_blob) - code_bytes = len(code.encode("utf-8")) - log0(f"Serialized model int6+lzma: {quant_file_bytes} bytes") - log0(f"Total submission size int6+lzma: {quant_file_bytes + code_bytes} bytes") - if distributed: - dist.barrier() - with open("final_model.int6.ptz", "rb") as f: - quant_blob_disk = f.read() - quant_state = torch.load( - io.BytesIO(lzma.decompress(quant_blob_disk)), - map_location="cpu", - ) - deq_unbanked = dequantize_mixed_int6(quant_state["w"], quant_state["m"], unbanked_sd) - # Re-bank the dequantized tensors - deq_state = _rebank_state_dict(deq_unbanked, args.num_layers, sd_cpu) - eval_model = GPT( - vocab_size=args.vocab_size, num_layers=args.num_layers, model_dim=args.model_dim, - num_heads=args.num_heads, num_kv_heads=args.num_kv_heads, mlp_mult=args.mlp_mult, - tie_embeddings=args.tie_embeddings, tied_embed_init_std=args.tied_embed_init_std, - logit_softcap=args.logit_softcap, rope_base=args.rope_base, qk_gain_init=args.qk_gain_init, - bigram_vocab_size=args.bigram_vocab_size, bigram_dim=args.bigram_dim, - xsa_last_n=args.xsa_last_n, - rope_dims=args.rope_dims, ln_scale=args.ln_scale, - ve_enabled=args.ve_enabled, ve_dim=args.ve_dim, ve_layers=args.ve_layers, - neg_slope=args.negative_slope, - ).to(device).bfloat16() - eval_model.qo_bank.data = eval_model.qo_bank.data.float() - eval_model.kv_bank.data = eval_model.kv_bank.data.float() - eval_model.mlp_up_bank.data = eval_model.mlp_up_bank.data.float() - eval_model.mlp_down_bank.data = eval_model.mlp_down_bank.data.float() - for m in eval_model.modules(): - if isinstance(m, CastedLinear): - m.float() - restore_low_dim_params_to_fp32(eval_model) - eval_model.load_state_dict(deq_state, strict=True) - compiled_eval = torch.compile(eval_model, dynamic=False, fullgraph=True) - torch.cuda.synchronize() - t_qeval = time.perf_counter() - q_val_loss, q_val_bpb = eval_val( - args, compiled_eval, rank, world_size, device, grad_accum_steps, - val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, - eval_seq_len=effective_eval_seq_len, - ) - torch.cuda.synchronize() - log0( - f"final_int6_roundtrip val_loss:{q_val_loss:.4f} val_bpb:{q_val_bpb:.4f} " - f"eval_time:{1000.0 * (time.perf_counter() - t_qeval):.0f}ms" - ) - log0(f"final_int6_roundtrip_exact val_loss:{q_val_loss:.8f} val_bpb:{q_val_bpb:.8f}") - sw_seq_len = effective_eval_seq_len - if args.eval_stride > 0 and args.eval_stride < sw_seq_len: - torch.cuda.synchronize() - t_slide = time.perf_counter() - sw_val_loss, sw_val_bpb = eval_val_sliding( - args, eval_model, rank, world_size, device, - val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, - stride=args.eval_stride, - eval_seq_len=sw_seq_len, - ) - torch.cuda.synchronize() - log0( - f"final_int6_sliding_window val_loss:{sw_val_loss:.4f} val_bpb:{sw_val_bpb:.4f} " - f"stride:{args.eval_stride} eval_time:{1000.0 * (time.perf_counter() - t_slide):.0f}ms" - ) - log0(f"final_int6_sliding_window_exact val_loss:{sw_val_loss:.8f} val_bpb:{sw_val_bpb:.8f}") - log0(f"final_int8_zlib_roundtrip_exact val_loss:{sw_val_loss:.8f} val_bpb:{sw_val_bpb:.8f}") - if args.eval_stride != 64 and 64 < sw_seq_len: - torch.cuda.synchronize() - t_slide64 = time.perf_counter() - sw64_val_loss, sw64_val_bpb = eval_val_sliding( - args, eval_model, rank, world_size, device, - val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, - stride=64, - eval_seq_len=sw_seq_len, - ) - torch.cuda.synchronize() - log0( - f"final_int6_sliding_window_s64 val_loss:{sw64_val_loss:.4f} val_bpb:{sw64_val_bpb:.4f} " - f"stride:64 eval_time:{1000.0 * (time.perf_counter() - t_slide64):.0f}ms" - ) - log0(f"final_int6_sliding_window_s64_exact val_loss:{sw64_val_loss:.8f} val_bpb:{sw64_val_bpb:.8f}") - log0(f"final_int8_zlib_roundtrip_exact val_loss:{sw64_val_loss:.8f} val_bpb:{sw64_val_bpb:.8f}") - # Legal score-first TTT (PR #461 recipe) - if args.ttt_enabled: - torch.cuda.synchronize() - t_ttt = time.perf_counter() - ttt_loss, ttt_bpb = eval_val_sliding_ttt( - args, eval_model, rank, world_size, device, - val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, - stride=args.eval_stride, log0=log0, - ) - torch.cuda.synchronize() - log0(f"legal_ttt val_loss:{ttt_loss:.4f} val_bpb:{ttt_bpb:.4f} " - f"eval_time:{1000.0 * (time.perf_counter() - t_ttt):.0f}ms") - log0(f"legal_ttt_exact val_loss:{ttt_loss:.8f} val_bpb:{ttt_bpb:.8f}") - if distributed: - dist.destroy_process_group() -if __name__ == "__main__": - main() diff --git a/records/track_10min_16mb/2026-03-31_Scylla_FullGPTQ_XSA11_FA3_0.9485/train_seed1337.log b/records/track_10min_16mb/2026-03-31_Scylla_FullGPTQ_XSA11_FA3_0.9485/train_seed1337.log deleted file mode 100644 index 0e6403215a..0000000000 --- a/records/track_10min_16mb/2026-03-31_Scylla_FullGPTQ_XSA11_FA3_0.9485/train_seed1337.log +++ /dev/null @@ -1,2505 +0,0 @@ -from __future__ import annotations -import copy -import glob -import io -import lzma -import math -import os -import random -import subprocess -import sys -import time -import uuid -from pathlib import Path -import numpy as np -import sentencepiece as spm -import torch -import torch.distributed as dist -import torch.nn.functional as F -from torch import Tensor, nn -from torch.nn.parallel import DistributedDataParallel as DDP -from flash_attn_interface import flash_attn_func as flash_attn_3_func -class Hyperparameters: - data_path = os.environ.get("DATA_PATH", "./data/datasets/fineweb10B_sp1024") - train_files = os.path.join(data_path, "fineweb_train_*.bin") - val_files = os.path.join(data_path, "fineweb_val_*.bin") - tokenizer_path = os.environ.get("TOKENIZER_PATH", "./data/tokenizers/fineweb_1024_bpe.model") - run_id = os.environ.get("RUN_ID", str(uuid.uuid4())) - seed = int(os.environ.get("SEED", 1337)) - val_batch_size = int(os.environ.get("VAL_BATCH_SIZE", 524_288)) - val_loss_every = int(os.environ.get("VAL_LOSS_EVERY", 4000)) - train_log_every = int(os.environ.get("TRAIN_LOG_EVERY", 500)) - iterations = int(os.environ.get("ITERATIONS", 20000)) - warmdown_iters = int(os.environ.get("WARMDOWN_ITERS", 3500)) - warmup_steps = int(os.environ.get("WARMUP_STEPS", 20)) - train_batch_tokens = int(os.environ.get("TRAIN_BATCH_TOKENS", 786_432)) - train_seq_len = int(os.environ.get("TRAIN_SEQ_LEN", 2048)) - eval_seq_len = int(os.environ.get("EVAL_SEQ_LEN", 2048)) - max_wallclock_seconds = float(os.environ.get("MAX_WALLCLOCK_SECONDS", 600.0)) - qk_gain_init = float(os.environ.get("QK_GAIN_INIT", 1.5)) - vocab_size = int(os.environ.get("VOCAB_SIZE", 1024)) - num_layers = int(os.environ.get("NUM_LAYERS", 11)) - num_kv_heads = int(os.environ.get("NUM_KV_HEADS", 4)) - model_dim = int(os.environ.get("MODEL_DIM", 512)) - num_heads = int(os.environ.get("NUM_HEADS", 8)) - mlp_mult = float(os.environ.get("MLP_MULT", 3.0)) - tie_embeddings = bool(int(os.environ.get("TIE_EMBEDDINGS", "1"))) - rope_base = float(os.environ.get("ROPE_BASE", 10000.0)) - logit_softcap = float(os.environ.get("LOGIT_SOFTCAP", 30.0)) - embed_lr = float(os.environ.get("EMBED_LR", 0.6)) - head_lr = float(os.environ.get("HEAD_LR", 0.008)) - tied_embed_lr = float(os.environ.get("TIED_EMBED_LR", 0.035)) - tied_embed_init_std = float(os.environ.get("TIED_EMBED_INIT_STD", 0.005)) - matrix_lr = float(os.environ.get("MATRIX_LR", 0.025)) - scalar_lr = float(os.environ.get("SCALAR_LR", 0.025)) - muon_momentum = float(os.environ.get("MUON_MOMENTUM", 0.99)) - muon_backend_steps = int(os.environ.get("MUON_BACKEND_STEPS", 5)) - muon_momentum_warmup_start = float(os.environ.get("MUON_MOMENTUM_WARMUP_START", 0.92)) - muon_momentum_warmup_steps = int(os.environ.get("MUON_MOMENTUM_WARMUP_STEPS", 1500)) - beta1 = float(os.environ.get("BETA1", 0.9)) - beta2 = float(os.environ.get("BETA2", 0.95)) - adam_eps = float(os.environ.get("ADAM_EPS", 1e-8)) - grad_clip_norm = float(os.environ.get("GRAD_CLIP_NORM", 0.3)) - eval_stride = int(os.environ.get("EVAL_STRIDE", 64)) - muon_beta2 = float(os.environ.get("MUON_BETA2", 0.95)) - swa_enabled = bool(int(os.environ.get("SWA_ENABLED", "1"))) - swa_every = int(os.environ.get("SWA_EVERY", 50)) - muon_wd = float(os.environ.get("MUON_WD", 0.04)) - adam_wd = float(os.environ.get("ADAM_WD", 0.04)) - bigram_vocab_size = int(os.environ.get("BIGRAM_VOCAB_SIZE", 2048)) - bigram_dim = int(os.environ.get("BIGRAM_DIM", 128)) - xsa_last_n = int(os.environ.get("XSA_LAST_N", 4)) - rope_dims = int(os.environ.get("ROPE_DIMS", 16)) - ln_scale = bool(int(os.environ.get("LN_SCALE", "1"))) - ve_enabled = bool(int(os.environ.get("VE_ENABLED", "1"))) - ve_dim = int(os.environ.get("VE_DIM", 128)) - ve_layers = os.environ.get("VE_LAYERS", "9,10") - ttt_enabled = bool(int(os.environ.get("TTT_ENABLED", "0"))) - ttt_lr = float(os.environ.get("TTT_LR", 0.002)) - ttt_epochs = int(os.environ.get("TTT_EPOCHS", 3)) - ttt_chunk_tokens = int(os.environ.get("TTT_CHUNK_TOKENS", 32768)) - ttt_freeze_blocks = int(os.environ.get("TTT_FREEZE_BLOCKS", 2)) - ttt_momentum = float(os.environ.get("TTT_MOMENTUM", 0.9)) - ttt_batch_seqs = int(os.environ.get("TTT_BATCH_SEQS", 32)) - ttt_grad_clip = float(os.environ.get("TTT_GRAD_CLIP", 1.0)) - negative_slope = float(os.environ.get("NEGATIVE_SLOPE", 0.5)) - use_gptq = bool(int(os.environ.get("USE_GPTQ", "0"))) - gptq_calib_samples = int(os.environ.get("GPTQ_CALIB_SAMPLES", "64")) - gptq_reserve_ms = float(os.environ.get("GPTQ_RESERVE_MS", "14000")) - quant_clip_range = int(os.environ.get("QUANT_CLIP_RANGE", 31)) - tokenizer_meta_path = os.environ.get("TOKENIZER_META_PATH", "") - tokenizer_meta_validate = bool(int(os.environ.get("TOKENIZER_META_VALIDATE", "0"))) - -# --- Batched Newton-Schulz orthogonalization --- - -def zeropower_via_newtonschulz5(G: Tensor, steps: int = 5, eps: float = 1e-7) -> Tensor: - """Batched Newton-Schulz orthogonalization. G: (B,M,N) or (M,N).""" - a, b, c = (3.4445, -4.7750, 2.0315) - was_2d = G.ndim == 2 - if was_2d: - G = G.unsqueeze(0) - X = G.bfloat16() - transposed = X.size(-2) > X.size(-1) - if transposed: - X = X.mT - X = X / (X.norm(dim=(-2, -1), keepdim=True) + eps) - for _ in range(steps): - A = X @ X.mT - B = b * A + c * (A @ A) - X = a * X + B @ X - if transposed: - X = X.mT - if was_2d: - X = X.squeeze(0) - return X - -# --- Parallel Muon optimizer --- - -class Muon(torch.optim.Optimizer): - """Parallel Muon: post-backward reduce-scatter -> local NS5 -> all-gather. - - No DDP for bank params. After backward, this optimizer: - 1. Launches async reduce-scatter for all banks (biggest first) - 2. Returns control so Adam can step on small params while RS is in-flight - 3. Waits for each RS, runs local NS5 on the shard, launches async all-gather - 4. Each all-gather overlaps with next bank's NS5 - """ - def __init__(self, params, lr: float, momentum: float, backend_steps: int, - nesterov: bool = True, weight_decay: float = 0.0): - super().__init__( - params, - dict(lr=lr, momentum=momentum, backend_steps=backend_steps, - nesterov=nesterov, weight_decay=weight_decay), - ) - self._built = False - - def _build(self): - self._distributed = dist.is_available() and dist.is_initialized() - self._world_size = dist.get_world_size() if self._distributed else 1 - self._rank = dist.get_rank() if self._distributed else 0 - ws = self._world_size - - self._bank_meta = [] - for group in self.param_groups: - for p in group["params"]: - B = p.shape[0] - padded_B = ((B + ws - 1) // ws) * ws - shard_B = padded_B // ws - tail = p.shape[1:] - dev = p.device - self._bank_meta.append({ - 'p': p, - 'B': B, - 'padded_grad': torch.zeros(padded_B, *tail, device=dev, dtype=torch.bfloat16), - 'shard': torch.zeros(shard_B, *tail, device=dev, dtype=torch.bfloat16), - 'shard_mom': torch.zeros(shard_B, *tail, device=dev, dtype=torch.bfloat16), - 'full_update': torch.zeros(padded_B, *tail, device=dev, dtype=torch.bfloat16), - 'scale': max(1, p.shape[-2] / p.shape[-1]) ** 0.5, - }) - # Sort by size descending -- launch biggest reduce-scatters first - self._bank_meta.sort(key=lambda m: -m['p'].numel()) - self._built = True - - def launch_reduce_scatters(self): - """Phase 1: launch async reduce-scatter for all banks. Call right after backward.""" - if not self._built: - self._build() - if not self._distributed: - return - self._rs_futures = [] - for m in self._bank_meta: - p = m['p'] - if p.grad is None: - self._rs_futures.append(None) - continue - pg = m['padded_grad'] - pg[:m['B']].copy_(p.grad.bfloat16()) - if pg.shape[0] > m['B']: - pg[m['B']:].zero_() - fut = dist.reduce_scatter_tensor(m['shard'], pg, op=dist.ReduceOp.AVG, async_op=True) - self._rs_futures.append(fut) - - @torch.no_grad() - def step(self, closure=None): - """Phase 3: wait for RS, local NS5, all-gather. Call AFTER Adam steps.""" - loss = None - if closure is not None: - with torch.enable_grad(): - loss = closure() - - if not self._built: - self._build() - - for group in self.param_groups: - lr = group["lr"] - momentum = group["momentum"] - backend_steps = group["backend_steps"] - nesterov = group["nesterov"] - wd = group.get("weight_decay", 0.0) - - prev_ag_handle = None - prev_m = None - - sharded = self._distributed and hasattr(self, '_rs_futures') - - for i, m in enumerate(self._bank_meta): - p = m['p'] - if p.grad is None: - continue - - if prev_ag_handle is not None: - prev_ag_handle.wait() - pp = prev_m['p'] - upd = prev_m['full_update'][:prev_m['B']] - if wd > 0.0: - pp.data.mul_(1.0 - lr * wd) - pp.add_(upd.to(dtype=pp.dtype), alpha=-lr * prev_m['scale']) - - if sharded and self._rs_futures[i] is not None: - self._rs_futures[i].wait() - g = m['shard'] - buf = m['shard_mom'] - else: - g = p.grad.bfloat16() - state = self.state[p] - if "momentum_buffer" not in state: - state["momentum_buffer"] = torch.zeros_like(g) - buf = state["momentum_buffer"] - - buf.mul_(momentum).add_(g) - if nesterov: - update = g.add(buf, alpha=momentum) - else: - update = buf - - update = zeropower_via_newtonschulz5(update, steps=backend_steps) - - if sharded: - prev_ag_handle = dist.all_gather_into_tensor( - m['full_update'], update, async_op=True) - prev_m = m - else: - if wd > 0.0: - p.data.mul_(1.0 - lr * wd) - p.add_(update.to(dtype=p.dtype), alpha=-lr * m['scale']) - - if prev_ag_handle is not None: - prev_ag_handle.wait() - pp = prev_m['p'] - upd = prev_m['full_update'][:prev_m['B']] - if wd > 0.0: - pp.data.mul_(1.0 - lr * wd) - pp.add_(upd.to(dtype=pp.dtype), alpha=-lr * prev_m['scale']) - - if hasattr(self, '_rs_futures'): - del self._rs_futures - - return loss - -# --- Tokenizer evaluation helpers --- - -def build_sentencepiece_luts( - sp: spm.SentencePieceProcessor, vocab_size: int, device: torch.device -) -> tuple[Tensor, Tensor, Tensor]: - sp_vocab_size = int(sp.vocab_size()) - table_size = max(sp_vocab_size, vocab_size) - base_bytes_np = np.zeros((table_size,), dtype=np.int16) - has_leading_space_np = np.zeros((table_size,), dtype=np.bool_) - is_boundary_token_np = np.ones((table_size,), dtype=np.bool_) - for token_id in range(sp_vocab_size): - if sp.is_control(token_id) or sp.is_unknown(token_id) or sp.is_unused(token_id): - continue - is_boundary_token_np[token_id] = False - if sp.is_byte(token_id): - base_bytes_np[token_id] = 1 - continue - piece = sp.id_to_piece(token_id) - if piece.startswith("\u2581"): - has_leading_space_np[token_id] = True - piece = piece[1:] - base_bytes_np[token_id] = len(piece.encode("utf-8")) - return ( - torch.tensor(base_bytes_np, dtype=torch.int16, device=device), - torch.tensor(has_leading_space_np, dtype=torch.bool, device=device), - torch.tensor(is_boundary_token_np, dtype=torch.bool, device=device), - ) - -TOKENIZER_META_FORMAT_VERSION = 1 -TOKENIZER_META_SUFFIX = ".meta.npz" - - -def _derive_tokenizer_meta_path(tokenizer_path: str) -> Path: - tokenizer = Path(tokenizer_path) - if tokenizer.suffix == ".model": - return tokenizer.with_suffix(TOKENIZER_META_SUFFIX) - return tokenizer.with_name(tokenizer.name + TOKENIZER_META_SUFFIX) - - -def build_sentencepiece_luts_np( - sp: spm.SentencePieceProcessor, vocab_size: int -) -> tuple[np.ndarray, np.ndarray, np.ndarray]: - sp_vocab_size = int(sp.vocab_size()) - table_size = max(sp_vocab_size, vocab_size) - base_bytes_np = np.zeros((table_size,), dtype=np.int16) - has_leading_space_np = np.zeros((table_size,), dtype=np.bool_) - is_boundary_token_np = np.ones((table_size,), dtype=np.bool_) - for token_id in range(sp_vocab_size): - if sp.is_control(token_id) or sp.is_unknown(token_id) or sp.is_unused(token_id): - continue - is_boundary_token_np[token_id] = False - if sp.is_byte(token_id): - base_bytes_np[token_id] = 1 - continue - piece = sp.id_to_piece(token_id) - if piece.startswith("\u2581"): - has_leading_space_np[token_id] = True - piece = piece[1:] - base_bytes_np[token_id] = len(piece.encode("utf-8")) - return base_bytes_np, has_leading_space_np, is_boundary_token_np - - -def load_tokenizer_meta_luts_np( - meta_path: Path, vocab_size: int -) -> tuple[np.ndarray, np.ndarray, np.ndarray, dict[str, object]]: - def _scalar(value): - arr = np.asarray(value) - if arr.ndim == 0: - return arr.item() - first = arr.reshape(-1)[0] - return first.item() if hasattr(first, "item") else first - - with np.load(meta_path, allow_pickle=False) as data: - format_version = int(_scalar(data["format_version"])) - if format_version != TOKENIZER_META_FORMAT_VERSION: - raise ValueError( - f"Unsupported tokenizer meta format_version={format_version} " - f"expected={TOKENIZER_META_FORMAT_VERSION}" - ) - meta_vocab_size = int(_scalar(data["vocab_size"])) - tokenizer_kind = str(_scalar(data["tokenizer_kind"])) - source_model_name = str(_scalar(data["source_model_name"])) - base_bytes_np = np.asarray(data["base_bytes"], dtype=np.int16) - has_leading_space_np = np.asarray(data["has_leading_space"], dtype=np.bool_) - is_boundary_token_np = np.asarray(data["is_boundary_token"], dtype=np.bool_) - table_size = max(meta_vocab_size, vocab_size) - if base_bytes_np.shape[0] < table_size: - padded_base_bytes = np.zeros((table_size,), dtype=np.int16) - padded_has_leading_space = np.zeros((table_size,), dtype=np.bool_) - padded_is_boundary = np.ones((table_size,), dtype=np.bool_) - padded_base_bytes[: base_bytes_np.shape[0]] = base_bytes_np - padded_has_leading_space[: has_leading_space_np.shape[0]] = has_leading_space_np - padded_is_boundary[: is_boundary_token_np.shape[0]] = is_boundary_token_np - base_bytes_np = padded_base_bytes - has_leading_space_np = padded_has_leading_space - is_boundary_token_np = padded_is_boundary - metadata = { - "format_version": format_version, - "tokenizer_kind": tokenizer_kind, - "source_model_name": source_model_name, - "vocab_size": meta_vocab_size, - "meta_path": str(meta_path), - } - return base_bytes_np, has_leading_space_np, is_boundary_token_np, metadata - - -def load_tokenizer_luts( - tokenizer_path: str, - tokenizer_meta_path: str, - vocab_size: int, - device: torch.device, - *, - validate_meta: bool = False, -) -> tuple[tuple[Tensor, Tensor, Tensor], dict[str, object]]: - meta_path = ( - Path(tokenizer_meta_path) if tokenizer_meta_path - else _derive_tokenizer_meta_path(tokenizer_path) - ) - if meta_path.exists(): - base_bytes_np, has_leading_space_np, is_boundary_token_np, metadata = ( - load_tokenizer_meta_luts_np(meta_path, vocab_size) - ) - if validate_meta and str(tokenizer_path).endswith(".model"): - sp = spm.SentencePieceProcessor(model_file=tokenizer_path) - sp_luts = build_sentencepiece_luts_np(sp, vocab_size) - if not ( - np.array_equal(base_bytes_np, sp_luts[0]) - and np.array_equal(has_leading_space_np, sp_luts[1]) - and np.array_equal(is_boundary_token_np, sp_luts[2]) - ): - raise ValueError(f"Tokenizer metadata mismatch for {meta_path}") - return ( - torch.tensor(base_bytes_np, dtype=torch.int16, device=device), - torch.tensor(has_leading_space_np, dtype=torch.bool, device=device), - torch.tensor(is_boundary_token_np, dtype=torch.bool, device=device), - ), metadata - if not str(tokenizer_path).endswith(".model"): - raise FileNotFoundError(f"TOKENIZER_META_PATH does not exist: {meta_path}") - sp = spm.SentencePieceProcessor(model_file=tokenizer_path) - return build_sentencepiece_luts(sp, vocab_size, device), { - "tokenizer_kind": "sentencepiece", - "source_model_name": str(tokenizer_path), - "vocab_size": int(sp.vocab_size()), - "meta_path": None, - "fallback": True, - } - -def load_validation_tokens(pattern: str, seq_len: int) -> Tensor: - files = [Path(p) for p in sorted(glob.glob(pattern))] - if not files: - raise FileNotFoundError(f"No files found for pattern: {pattern}") - tokens = torch.cat([load_data_shard(file) for file in files]).contiguous() - usable = ((tokens.numel() - 1) // seq_len) * seq_len - if usable <= 0: - raise ValueError(f"Validation split is too short for TRAIN_SEQ_LEN={seq_len}") - return tokens[: usable + 1] -def eval_val( - args: Hyperparameters, - model: nn.Module, - rank: int, - world_size: int, - device: torch.device, - grad_accum_steps: int, - val_tokens: Tensor, - base_bytes_lut: Tensor, - has_leading_space_lut: Tensor, - is_boundary_token_lut: Tensor, - eval_seq_len: int | None = None, -) -> tuple[float, float]: - seq_len = eval_seq_len or args.train_seq_len - local_batch_tokens = args.val_batch_size // (world_size * grad_accum_steps) - if local_batch_tokens < seq_len: - raise ValueError( - "VAL_BATCH_SIZE must provide at least one sequence per rank; " - f"got VAL_BATCH_SIZE={args.val_batch_size}, WORLD_SIZE={world_size}, " - f"GRAD_ACCUM_STEPS={grad_accum_steps}, seq_len={seq_len}" - ) - local_batch_seqs = local_batch_tokens // seq_len - total_seqs = (val_tokens.numel() - 1) // seq_len - seq_start = (total_seqs * rank) // world_size - seq_end = (total_seqs * (rank + 1)) // world_size - val_loss_sum = torch.zeros((), device=device, dtype=torch.float64) - val_token_count = torch.zeros((), device=device, dtype=torch.float64) - val_byte_count = torch.zeros((), device=device, dtype=torch.float64) - model.eval() - with torch.inference_mode(): - for batch_seq_start in range(seq_start, seq_end, local_batch_seqs): - batch_seq_end = min(batch_seq_start + local_batch_seqs, seq_end) - raw_start = batch_seq_start * seq_len - raw_end = batch_seq_end * seq_len + 1 - local = val_tokens[raw_start:raw_end].to(device=device, dtype=torch.int64, non_blocking=True) - x = local[:-1].reshape(-1, seq_len) - y = local[1:].reshape(-1, seq_len) - with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): - batch_loss = model(x, y).detach() - batch_token_count = float(y.numel()) - val_loss_sum += batch_loss.to(torch.float64) * batch_token_count - val_token_count += batch_token_count - prev_ids = x.reshape(-1) - tgt_ids = y.reshape(-1) - token_bytes = base_bytes_lut[tgt_ids].to(dtype=torch.int16) - token_bytes += (has_leading_space_lut[tgt_ids] & ~is_boundary_token_lut[prev_ids]).to(dtype=torch.int16) - val_byte_count += token_bytes.to(torch.float64).sum() - if dist.is_available() and dist.is_initialized(): - dist.all_reduce(val_loss_sum, op=dist.ReduceOp.SUM) - dist.all_reduce(val_token_count, op=dist.ReduceOp.SUM) - dist.all_reduce(val_byte_count, op=dist.ReduceOp.SUM) - val_loss = val_loss_sum / val_token_count - bits_per_token = val_loss.item() / math.log(2.0) - tokens_per_byte = val_token_count.item() / val_byte_count.item() - model.train() - return float(val_loss.item()), float(bits_per_token * tokens_per_byte) - -# --- Quantization helpers --- - -CONTROL_TENSOR_NAME_PATTERNS = tuple( - pattern - for pattern in os.environ.get( - "CONTROL_TENSOR_NAME_PATTERNS", - "attn_scale,attn_scales,mlp_scale,mlp_scales,resid_mix,resid_mixes,q_gain,skip_weight,skip_weights,smear,dtg_gate,ve_layer_scales,ve_shared.scale,attn_gate,vr_lambda", - ).split(",") - if pattern -) -INT8_KEEP_FLOAT_FP32_NAME_PATTERNS = tuple( - pattern - for pattern in os.environ.get( - "INT8_KEEP_FLOAT_FP32_NAME_PATTERNS", - ",".join(CONTROL_TENSOR_NAME_PATTERNS), - ).split(",") - if pattern -) -INT8_KEEP_FLOAT_MAX_NUMEL = 65_536 -INT8_KEEP_FLOAT_STORE_DTYPE = torch.float16 -INT8_PER_ROW_SCALE_DTYPE = torch.float16 -INT8_CLIP_PERCENTILE = 99.99984 -INT8_CLIP_Q = INT8_CLIP_PERCENTILE / 100.0 -def tensor_nbytes(t: Tensor) -> int: - return int(t.numel()) * int(t.element_size()) -def keep_float_tensor(name: str, t: Tensor, passthrough_orig_dtypes: dict[str, str]) -> Tensor: - if any(pattern in name for pattern in INT8_KEEP_FLOAT_FP32_NAME_PATTERNS): - return t.float().contiguous() - if t.dtype in {torch.float32, torch.bfloat16}: - passthrough_orig_dtypes[name] = str(t.dtype).removeprefix("torch.") - return t.to(dtype=INT8_KEEP_FLOAT_STORE_DTYPE).contiguous() - return t -def quantize_float_tensor(t: Tensor) -> tuple[Tensor, Tensor]: - t32 = t.float() - if t32.ndim == 2: - clip_abs = ( - torch.quantile(t32.abs(), INT8_CLIP_Q, dim=1) - if t32.numel() - else torch.empty((t32.shape[0],), dtype=torch.float32) - ) - clipped = torch.maximum(torch.minimum(t32, clip_abs[:, None]), -clip_abs[:, None]) - scale = (clip_abs / 127.0).clamp_min(1.0 / 127.0) - q = torch.clamp(torch.round(clipped / scale[:, None]), -127, 127).to(torch.int8).contiguous() - return q, scale.to(dtype=INT8_PER_ROW_SCALE_DTYPE).contiguous() - clip_abs = float(torch.quantile(t32.abs().flatten(), INT8_CLIP_Q).item()) if t32.numel() else 0.0 - scale = torch.tensor(clip_abs / 127.0 if clip_abs > 0 else 1.0, dtype=torch.float32) - q = torch.clamp(torch.round(torch.clamp(t32, -clip_abs, clip_abs) / scale), -127, 127).to(torch.int8).contiguous() - return q, scale -def quantize_state_dict_int8(state_dict: dict[str, Tensor]): - quantized: dict[str, Tensor] = {} - scales: dict[str, Tensor] = {} - dtypes: dict[str, str] = {} - passthrough: dict[str, Tensor] = {} - passthrough_orig_dtypes: dict[str, str] = {} - qmeta: dict[str, dict[str, object]] = {} - stats = dict.fromkeys( - ("param_count", "num_tensors", "num_float_tensors", "num_nonfloat_tensors", "baseline_tensor_bytes", "int8_payload_bytes"), - 0, - ) - for name, tensor in state_dict.items(): - t = tensor.detach().to("cpu").contiguous() - stats["param_count"] += int(t.numel()) - stats["num_tensors"] += 1 - stats["baseline_tensor_bytes"] += tensor_nbytes(t) - if not t.is_floating_point(): - stats["num_nonfloat_tensors"] += 1 - passthrough[name] = t - stats["int8_payload_bytes"] += tensor_nbytes(t) - continue - if t.numel() <= INT8_KEEP_FLOAT_MAX_NUMEL: - kept = keep_float_tensor(name, t, passthrough_orig_dtypes) - passthrough[name] = kept - stats["int8_payload_bytes"] += tensor_nbytes(kept) - continue - stats["num_float_tensors"] += 1 - q, s = quantize_float_tensor(t) - if s.ndim > 0: - qmeta[name] = {"scheme": "per_row", "axis": 0} - quantized[name] = q - scales[name] = s - dtypes[name] = str(t.dtype).removeprefix("torch.") - stats["int8_payload_bytes"] += tensor_nbytes(q) + tensor_nbytes(s) - obj: dict[str, object] = { - "__quant_format__": "int8_clean_per_row_v1", - "quantized": quantized, - "scales": scales, - "dtypes": dtypes, - "passthrough": passthrough, - } - if qmeta: - obj["qmeta"] = qmeta - if passthrough_orig_dtypes: - obj["passthrough_orig_dtypes"] = passthrough_orig_dtypes - return obj, stats -def dequantize_state_dict_int8(obj: dict[str, object]) -> dict[str, Tensor]: - out: dict[str, Tensor] = {} - qmeta = obj.get("qmeta", {}) - passthrough_orig_dtypes = obj.get("passthrough_orig_dtypes", {}) - for name, q in obj["quantized"].items(): - dtype = getattr(torch, obj["dtypes"][name]) - s = obj["scales"][name] - if qmeta.get(name, {}).get("scheme") == "per_row" or s.ndim > 0: - s = s.to(dtype=torch.float32) - out[name] = (q.float() * s.view(q.shape[0], *([1] * (q.ndim - 1)))).to(dtype=dtype).contiguous() - else: - scale = float(s.item()) - out[name] = (q.float() * scale).to(dtype=dtype).contiguous() - for name, t in obj["passthrough"].items(): - out_t = t.detach().to("cpu").contiguous() - orig_dtype = passthrough_orig_dtypes.get(name) - if isinstance(orig_dtype, str): - out_t = out_t.to(dtype=getattr(torch, orig_dtype)).contiguous() - out[name] = out_t - return out - -# --- Data loading --- - -def load_data_shard(file: Path) -> Tensor: - header_bytes = 256 * np.dtype(" int: - key = str(file) - cached = _SHARD_NTOKENS_CACHE.get(key) - if cached is not None: - return cached - header = np.fromfile(file, dtype=" np.memmap: - key = str(file) - mm = _MMAP_CACHE.get(key) - if mm is not None: - return mm - n = _read_num_tokens(file) - mm = np.memmap(file, mode="r", dtype=" int: - if n <= 1: - return 1 - while True: - s = int(self._rng.integers(1, n)) - if math.gcd(s, n) == 1: - return s - def _reset_cursor(self, si: int, seq_len: int) -> None: - nt = int(self._num_tokens[si]) - max_phase = min(seq_len - 1, max(0, nt - seq_len - 1)) - phase = int(self._rng.integers(max_phase + 1)) if max_phase > 0 else 0 - bc = (nt - 1 - phase) // seq_len - self._cursor_phase[si] = phase - self._cursor_block_count[si] = bc - self._cursor_next[si] = 0 - self._cursor_start[si] = int(self._rng.integers(bc)) if bc > 1 else 0 - self._cursor_stride[si] = self._pick_coprime_stride(bc) - self._cursor_init[si] = True - def _ensure_cursor(self, si: int, seq_len: int) -> None: - if not self._cursor_init[si] or self._cursor_next[si] >= self._cursor_block_count[si]: - self._reset_cursor(si, seq_len) - def _take_from_shard(self, si: int, seq_len: int, count: int, out: list[tuple[int, int]]) -> None: - rem = count - while rem > 0: - self._ensure_cursor(si, seq_len) - bc = int(self._cursor_block_count[si]) - ni = int(self._cursor_next[si]) - take = min(rem, bc - ni) - phase = int(self._cursor_phase[si]) - start = int(self._cursor_start[si]) - stride = int(self._cursor_stride[si]) - for j in range(take): - bi = (start + (ni + j) * stride) % bc - out.append((si, phase + bi * seq_len)) - self._cursor_next[si] = ni + take - rem -= take - def _init_pipeline(self, global_tokens: int, seq_len: int, grad_accum_steps: int) -> None: - local_tokens = global_tokens // (self.world_size * grad_accum_steps) - num_seqs = local_tokens // seq_len - global_num_seqs = num_seqs * self.world_size - self._cfg = (local_tokens, seq_len, num_seqs, global_num_seqs) - bbc = (self._num_tokens - 1) // seq_len - eligible = bbc > 0 - self._eligible_shards = np.nonzero(eligible)[0].astype(np.int64) - self._base_block_counts = bbc[self._eligible_shards].astype(np.int64) - def _sample_global_windows(self) -> list[tuple[int, int]]: - assert self._cfg is not None and self._eligible_shards is not None - _, seq_len, _, gns = self._cfg - ec = int(self._eligible_shards.size) - progress = min(self._batches_built / 1800.0, 1.0) - remaining = np.empty(ec, dtype=np.float64) - for i, si in enumerate(self._eligible_shards.tolist()): - if self._cursor_init[si]: - r = int(self._cursor_block_count[si]) - int(self._cursor_next[si]) - remaining[i] = float(max(r, 1)) - else: - remaining[i] = float(self._base_block_counts[i]) - alpha = 0.90 - 0.40 * progress - weights = np.power(remaining, alpha) - ws = float(weights.sum()) - if not np.isfinite(ws) or ws <= 0.0: - weights = np.ones(ec, dtype=np.float64) - ws = float(weights.sum()) - probs = weights / ws - low = min(max(8, self.world_size), ec, gns) - high = min(max(32, self.world_size * 8), ec, gns) - mix = max(1, min(int(round(low + progress * (high - low))), ec, gns)) - cp = self._rng.choice(ec, size=mix, replace=False, p=probs) - cs = self._eligible_shards[cp] - cpr = probs[cp].copy() - cpr /= cpr.sum() - counts = np.ones(mix, dtype=np.int64) - extra = gns - mix - if extra > 0: - counts += self._rng.multinomial(extra, cpr).astype(np.int64) - perm = self._rng.permutation(mix) - cs, counts = cs[perm], counts[perm] - buckets: list[list[tuple[int, int]]] = [] - for si, cnt in zip(cs.tolist(), counts.tolist()): - b: list[tuple[int, int]] = [] - self._take_from_shard(int(si), seq_len, int(cnt), b) - if b: - if len(b) > 1: - bp = self._rng.permutation(len(b)) - b = [b[int(k)] for k in bp.tolist()] - buckets.append(b) - windows: list[tuple[int, int]] = [] - active = [i for i, bk in enumerate(buckets) if bk] - while active: - order = self._rng.permutation(len(active)) - new_active: list[int] = [] - for oi in order.tolist(): - bi = active[oi] - if buckets[bi]: - windows.append(buckets[bi].pop()) - if buckets[bi]: - new_active.append(bi) - active = new_active - return windows - def next_batch(self, global_tokens: int, seq_len: int, grad_accum_steps: int) -> tuple[Tensor, Tensor]: - if self._cfg is None: - self._init_pipeline(global_tokens, seq_len, grad_accum_steps) - _, _, num_seqs, gns = self._cfg - gw = self._sample_global_windows() - local_w = gw[self.rank::self.world_size] - x = torch.empty((num_seqs, seq_len), dtype=torch.int64) - y = torch.empty((num_seqs, seq_len), dtype=torch.int64) - for slot, (si, pos) in enumerate(local_w): - mm = _get_shard_memmap(self.files[si]) - window = torch.as_tensor(np.array(mm[pos:pos + seq_len + 1], dtype=np.int64)) - x[slot] = window[:-1] - y[slot] = window[1:] - self._batches_built += 1 - return x.to(self.device, non_blocking=True), y.to(self.device, non_blocking=True) - -# --- Transformer modules --- - -class RMSNorm(nn.Module): - def __init__(self, eps: float | None = None): - super().__init__() - self.eps = eps - def forward(self, x: Tensor) -> Tensor: - return F.rms_norm(x, (x.size(-1),), eps=self.eps) -class CastedLinear(nn.Linear): - def forward(self, x: Tensor) -> Tensor: - w = self.weight.to(x.dtype) - bias = self.bias.to(x.dtype) if self.bias is not None else None - return F.linear(x, w, bias) -def restore_low_dim_params_to_fp32(module: nn.Module) -> None: - with torch.no_grad(): - for name, param in module.named_parameters(): - if (param.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS)) and param.dtype != torch.float32: - param.data = param.data.float() -class Rotary(nn.Module): - def __init__(self, dim: int, base: float = 10000.0, train_seq_len: int = 1024, rope_dims: int = 0): - super().__init__() - self.dim = dim - self.base = base - self.train_seq_len = train_seq_len - self.rope_dims = rope_dims if rope_dims > 0 else dim - inv_freq = 1.0 / (base ** (torch.arange(0, self.rope_dims, 2, dtype=torch.float32) / self.rope_dims)) - self.register_buffer("inv_freq", inv_freq, persistent=False) - self._seq_len_cached = 0 - self._cos_cached: Tensor | None = None - self._sin_cached: Tensor | None = None - def forward(self, seq_len: int, device: torch.device, dtype: torch.dtype) -> tuple[Tensor, Tensor]: - if ( - self._cos_cached is None - or self._sin_cached is None - or self._seq_len_cached != seq_len - or self._cos_cached.device != device - ): - rd = self.rope_dims - if seq_len > self.train_seq_len: - scale = seq_len / self.train_seq_len - new_base = self.base * (scale ** (rd / (rd - 2))) - inv_freq = 1.0 / (new_base ** (torch.arange(0, rd, 2, dtype=torch.float32, device=device) / rd)) - else: - inv_freq = self.inv_freq.to(device) - t = torch.arange(seq_len, device=device, dtype=inv_freq.dtype) - freqs = torch.outer(t, inv_freq) - self._cos_cached = freqs.cos()[None, :, None, :] - self._sin_cached = freqs.sin()[None, :, None, :] - self._seq_len_cached = seq_len - return self._cos_cached.to(dtype=dtype), self._sin_cached.to(dtype=dtype) -def apply_rotary_emb(x: Tensor, cos: Tensor, sin: Tensor, rope_dims: int = 0) -> Tensor: - if rope_dims > 0 and rope_dims < x.size(-1): - x_rope, x_pass = x[..., :rope_dims], x[..., rope_dims:] - half = rope_dims // 2 - x1, x2 = x_rope[..., :half], x_rope[..., half:] - x_rope = torch.cat((x1 * cos + x2 * sin, x1 * (-sin) + x2 * cos), dim=-1) - return torch.cat((x_rope, x_pass), dim=-1) - half = x.size(-1) // 2 - x1, x2 = x[..., :half], x[..., half:] - return torch.cat((x1 * cos + x2 * sin, x1 * (-sin) + x2 * cos), dim=-1) - -class CausalSelfAttention(nn.Module): - def __init__( - self, - dim: int, - num_heads: int, - num_kv_heads: int, - rope_base: float, - qk_gain_init: float, - ): - super().__init__() - if dim % num_heads != 0: - raise ValueError("model_dim must be divisible by num_heads") - if num_heads % num_kv_heads != 0: - raise ValueError("num_heads must be divisible by num_kv_heads") - self.num_heads = num_heads - self.num_kv_heads = num_kv_heads - self.head_dim = dim // num_heads - if self.head_dim % 2 != 0: - raise ValueError("head_dim must be even for RoPE") - # No CastedLinear -- weights come from banks - self.q_gain = nn.Parameter(torch.full((num_heads,), qk_gain_init, dtype=torch.float32)) - self.rope_dims = 0 # set by GPT.__init__ for partial RoPE - self.rotary = Rotary(self.head_dim, base=rope_base, train_seq_len=1024) - self.use_xsa = False # set by GPT.__init__ for deep layers only - def _xsa_efficient(self, y: Tensor, v: Tensor) -> Tensor: - """Efficient XSA: subtract self-value projection via GQA-aware reshape (no repeat_interleave). - y: [B, T, H, D], v: [B, T, Hkv, D]. H must be divisible by Hkv.""" - B, T, H, D = y.shape - Hkv = v.size(-2) - group = H // Hkv - y_g = y.reshape(B, T, Hkv, group, D) # [B, T, Hkv, group, D] - vn = F.normalize(v, dim=-1).unsqueeze(-2) # [B, T, Hkv, 1, D] -- broadcast ready - proj = (y_g * vn).sum(dim=-1, keepdim=True) * vn - return (y_g - proj).reshape(B, T, H, D) - def forward(self, x: Tensor, q_w: Tensor, k_w: Tensor, v_w: Tensor, out_w: Tensor, v_embed: Tensor | None = None) -> Tensor: - if getattr(self, '_save_gptq', False): - self._gptq_qkv_in = x.detach() - bsz, seqlen, dim = x.shape - q = F.linear(x, q_w.to(x.dtype)).reshape(bsz, seqlen, self.num_heads, self.head_dim) - k = F.linear(x, k_w.to(x.dtype)).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim) - v = F.linear(x, v_w.to(x.dtype)) - if v_embed is not None: - v = v + v_embed - v = v.reshape(bsz, seqlen, self.num_kv_heads, self.head_dim) - q = F.rms_norm(q, (q.size(-1),)) - k = F.rms_norm(k, (k.size(-1),)) - cos, sin = self.rotary(seqlen, x.device, q.dtype) - q = apply_rotary_emb(q, cos, sin, self.rope_dims) - k = apply_rotary_emb(k, cos, sin, self.rope_dims) - q = q * self.q_gain.to(dtype=q.dtype)[None, None, :, None] - y = flash_attn_3_func(q, k, v, causal=True) - if self.use_xsa: - y = self._xsa_efficient(y, v) - y = y.reshape(bsz, seqlen, dim) - if getattr(self, '_save_gptq', False): - self._gptq_o_in = y.detach() - return F.linear(y, out_w.to(x.dtype)) - -class SmearGate(nn.Module): - def __init__(self, dim: int): - super().__init__() - self.gate = nn.Parameter(torch.zeros(dim, dtype=torch.float32)) - def forward(self, x: Tensor) -> Tensor: - g = torch.sigmoid(self.gate.to(dtype=x.dtype))[None, None, :] - x_prev = torch.cat([torch.zeros_like(x[:, :1]), x[:, :-1]], dim=1) - return (1 - g) * x + g * x_prev - -class BigramHashEmbedding(nn.Module): - def __init__(self, bigram_vocab_size: int, bigram_dim: int, model_dim: int): - super().__init__() - self.bigram_vocab_size = bigram_vocab_size - self.embed = nn.Embedding(bigram_vocab_size, bigram_dim) - nn.init.zeros_(self.embed.weight) - self.proj = CastedLinear(bigram_dim, model_dim, bias=False) if bigram_dim != model_dim else None - if self.proj is not None: - nn.init.zeros_(self.proj.weight) - self.scale = nn.Parameter(torch.tensor(0.05, dtype=torch.float32)) - def bigram_hash(self, tokens: Tensor) -> Tensor: - t = tokens.to(torch.int32) - mod = self.bigram_vocab_size - 1 - out = torch.empty_like(t) - out[..., 0] = mod - out[..., 1:] = torch.bitwise_xor(36313 * t[..., 1:], 27191 * t[..., :-1]) % mod - return out.long() - def forward(self, token_ids: Tensor) -> Tensor: - h = self.embed(self.bigram_hash(token_ids)) - if self.proj is not None: - h = self.proj(h) - return h * self.scale.to(dtype=h.dtype) - -class ValueEmbedding(nn.Module): - """Reinject token identity into attention values at specific layers. - Each table maps vocab tokens to a low-dim embedding, projected to model_dim.""" - def __init__(self, vocab_size: int, ve_dim: int, model_dim: int): - super().__init__() - self.embed = nn.Embedding(vocab_size, ve_dim) - nn.init.normal_(self.embed.weight, std=0.01) - self.proj = CastedLinear(ve_dim, model_dim, bias=False) if ve_dim != model_dim else None - if self.proj is not None: - nn.init.zeros_(self.proj.weight) - self.scale = nn.Parameter(torch.tensor(0.1, dtype=torch.float32)) - def forward(self, token_ids: Tensor) -> Tensor: - h = self.embed(token_ids) - if self.proj is not None: - h = self.proj(h) - return h * self.scale.to(dtype=h.dtype) - -class MLP(nn.Module): - def __init__(self, dim: int, mlp_mult: int, neg_slope: float = 0.5): - super().__init__() - self.neg_slope = neg_slope - # No CastedLinear -- weights come from banks - def forward(self, x: Tensor, up_w: Tensor, down_w: Tensor) -> Tensor: - if getattr(self, '_save_gptq', False): - self._gptq_up_in = x.detach() - x = F.leaky_relu(F.linear(x, up_w.to(x.dtype)), negative_slope=self.neg_slope) - x2 = x.square() - if getattr(self, '_save_gptq', False): - self._gptq_down_in = x2.detach() - return F.linear(x2, down_w.to(x.dtype)) - -class Block(nn.Module): - def __init__( - self, - dim: int, - num_heads: int, - num_kv_heads: int, - mlp_mult: int, - rope_base: float, - qk_gain_init: float, - layer_idx: int = 0, - ln_scale: bool = False, - neg_slope: float = 0.5, - ): - super().__init__() - self.layer_idx = layer_idx - self.attn_norm = RMSNorm() - self.mlp_norm = RMSNorm() - self.attn = CausalSelfAttention(dim, num_heads, num_kv_heads, rope_base, qk_gain_init) - self.mlp = MLP(dim, mlp_mult, neg_slope=neg_slope) - self.attn_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) - self.mlp_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) - self.resid_mix = nn.Parameter(torch.stack((torch.ones(dim), torch.zeros(dim))).float()) - self.ln_scale_factor = 1.0 / math.sqrt(layer_idx + 1) if ln_scale else 1.0 - def forward(self, x: Tensor, x0: Tensor, q_w: Tensor, k_w: Tensor, v_w: Tensor, out_w: Tensor, up_w: Tensor, down_w: Tensor, v_embed: Tensor | None = None) -> Tensor: - mix = self.resid_mix.to(dtype=x.dtype) - x_in = mix[0][None, None, :] * x + mix[1][None, None, :] * x0 - attn_out = self.attn(self.attn_norm(x_in) * self.ln_scale_factor, q_w, k_w, v_w, out_w, v_embed=v_embed) - x_out = x_in + self.attn_scale.to(dtype=x_in.dtype)[None, None, :] * attn_out - mlp_out = self.mlp_scale.to(dtype=x_out.dtype)[None, None, :] * self.mlp(self.mlp_norm(x_out) * self.ln_scale_factor, up_w, down_w) - return x_out + mlp_out - -class GPT(nn.Module): - def __init__( - self, - vocab_size: int, - num_layers: int, - model_dim: int, - num_heads: int, - num_kv_heads: int, - mlp_mult: int, - tie_embeddings: bool, - tied_embed_init_std: float, - logit_softcap: float, - rope_base: float, - qk_gain_init: float, - bigram_vocab_size: int = 0, - bigram_dim: int = 128, - xsa_last_n: int = 0, - rope_dims: int = 0, - ln_scale: bool = False, - ve_enabled: bool = False, - ve_dim: int = 128, - ve_layers: str = "9,10", - neg_slope: float = 0.5, - ): - super().__init__() - self._ve_target_dim = num_kv_heads * (model_dim // num_heads) # kv_dim for value projection - if logit_softcap <= 0.0: - raise ValueError(f"logit_softcap must be positive, got {logit_softcap}") - self.tie_embeddings = tie_embeddings - self.tied_embed_init_std = tied_embed_init_std - self.logit_softcap = logit_softcap - self.tok_emb = nn.Embedding(vocab_size, model_dim) - self.bigram = BigramHashEmbedding(bigram_vocab_size, bigram_dim, model_dim) if bigram_vocab_size > 0 else None - self.smear = SmearGate(model_dim) - self.num_encoder_layers = num_layers // 2 - self.num_decoder_layers = num_layers - self.num_encoder_layers - self.num_skip_weights = min(self.num_encoder_layers, self.num_decoder_layers) - self.skip_weights = nn.Parameter(torch.ones(self.num_skip_weights, model_dim, dtype=torch.float32)) - # Parameter banks: contiguous 3D tensors for batched optimizer - head_dim = model_dim // num_heads - kv_dim = num_kv_heads * head_dim - mlp_dim = int(mlp_mult * model_dim) - self.num_layers = num_layers - self.qo_bank = nn.Parameter(torch.empty(2 * num_layers, model_dim, model_dim)) - self.kv_bank = nn.Parameter(torch.empty(2 * num_layers, kv_dim, model_dim)) - self.mlp_up_bank = nn.Parameter(torch.empty(num_layers, mlp_dim, model_dim)) - self.mlp_down_bank = nn.Parameter(torch.empty(num_layers, model_dim, mlp_dim)) - self.blocks = nn.ModuleList( - [ - Block( - model_dim, - num_heads, - num_kv_heads, - mlp_mult, - rope_base, - qk_gain_init, - layer_idx=i, - ln_scale=ln_scale, - neg_slope=neg_slope, - ) - for i in range(num_layers) - ] - ) - if rope_dims > 0: - head_dim = model_dim // num_heads - for block in self.blocks: - block.attn.rope_dims = rope_dims - block.attn.rotary = Rotary(head_dim, base=rope_base, train_seq_len=1024, rope_dims=rope_dims) - self.ve_layer_indices = [int(x) for x in ve_layers.split(",") if x.strip()] if ve_enabled else [] - kv_dim_ve = self._ve_target_dim - if self.ve_layer_indices: - self.ve_shared = ValueEmbedding(vocab_size, ve_dim, kv_dim_ve) - self.ve_layer_scales = nn.ParameterList( - [nn.Parameter(torch.ones(1, dtype=torch.float32)) for _ in self.ve_layer_indices] - ) - else: - self.ve_shared = None - self.ve_layer_scales = nn.ParameterList() - self.value_embeds = nn.ModuleList() # keep empty for compat - self.final_norm = RMSNorm() - self.lm_head = None if tie_embeddings else CastedLinear(model_dim, vocab_size, bias=False) - if self.lm_head is not None: - self.lm_head._zero_init = True - if xsa_last_n > 0: - for i in range(max(0, num_layers - xsa_last_n), num_layers): - self.blocks[i].attn.use_xsa = True - self._init_weights() - def _init_weights(self) -> None: - if self.tie_embeddings: - nn.init.normal_(self.tok_emb.weight, mean=0.0, std=self.tied_embed_init_std) - n = self.num_layers - proj_scale = 1.0 / math.sqrt(2 * n) - # Init banks: orthogonal, with proj layers scaled down and out/down zero-init - for i in range(n): - nn.init.orthogonal_(self.qo_bank.data[i], gain=1.0) # Q - nn.init.zeros_(self.qo_bank.data[n + i]) # Out (zero init) - nn.init.orthogonal_(self.kv_bank.data[i], gain=1.0) # K - nn.init.orthogonal_(self.kv_bank.data[n + i], gain=1.0) # V - nn.init.orthogonal_(self.mlp_up_bank.data[i], gain=1.0) # MLP up - nn.init.zeros_(self.mlp_down_bank.data[i]) # MLP down (zero init) - # Scale proj layers (out_proj and mlp_down are "proj" layers) - self.qo_bank.data[n + i].mul_(proj_scale) - self.mlp_down_bank.data[i].mul_(proj_scale) - # Init remaining nn.Linear modules (bigram proj, lm_head) - for name, module in self.named_modules(): - if isinstance(module, nn.Linear): - if getattr(module, "_zero_init", False): - nn.init.zeros_(module.weight) - elif module.weight.ndim == 2 and module.weight.shape[0] >= 64 and module.weight.shape[1] >= 64: - nn.init.orthogonal_(module.weight, gain=1.0) - def _get_ve(self, layer_idx: int, input_ids: Tensor, ve_cache: dict | None = None) -> Tensor | None: - """Get value embedding for a specific layer using shared table + per-layer scale.""" - if self.ve_shared is None or layer_idx not in self.ve_layer_indices: - return None - if ve_cache is not None and 've' not in ve_cache: - ve_cache['ve'] = self.ve_shared(input_ids) - ve_base = ve_cache['ve'] if ve_cache is not None else self.ve_shared(input_ids) - ve_idx = self.ve_layer_indices.index(layer_idx) - return ve_base * self.ve_layer_scales[ve_idx].to(dtype=ve_base.dtype) - def forward(self, input_ids: Tensor, target_ids: Tensor) -> Tensor: - n = self.num_layers - x = self.tok_emb(input_ids) - if self.bigram is not None: - x = x + self.bigram(input_ids) - x = F.rms_norm(x, (x.size(-1),)) - x = self.smear(x) - x0 = x - skips: list[Tensor] = [] - ve_cache: dict = {} - for i in range(self.num_encoder_layers): - ve = self._get_ve(i, input_ids, ve_cache) - x = self.blocks[i](x, x0, - self.qo_bank[i], self.kv_bank[i], self.kv_bank[n + i], - self.qo_bank[n + i], self.mlp_up_bank[i], self.mlp_down_bank[i], - v_embed=ve) - skips.append(x) - for i in range(self.num_decoder_layers): - bi = self.num_encoder_layers + i - if skips: - x = x + self.skip_weights[i].to(dtype=x.dtype)[None, None, :] * skips.pop() - ve = self._get_ve(bi, input_ids, ve_cache) - x = self.blocks[bi](x, x0, - self.qo_bank[bi], self.kv_bank[bi], self.kv_bank[n + bi], - self.qo_bank[n + bi], self.mlp_up_bank[bi], self.mlp_down_bank[bi], - v_embed=ve) - x = self.final_norm(x) - x_flat = x.reshape(-1, x.size(-1)) - targets = target_ids.reshape(-1) - if self.tie_embeddings: - logits_proj = F.linear(x_flat, self.tok_emb.weight) - else: - if self.lm_head is None: - raise RuntimeError("lm_head is required when tie_embeddings=False") - logits_proj = self.lm_head(x_flat) - logits = self.logit_softcap * torch.tanh(logits_proj / self.logit_softcap) - return F.cross_entropy(logits.float(), targets, reduction="mean") - def forward_logits(self, input_ids: Tensor) -> Tensor: - """Return logits (bsz, seq_len, vocab) without computing loss.""" - n = self.num_layers - x = self.tok_emb(input_ids) - if self.bigram is not None: - x = x + self.bigram(input_ids) - x = F.rms_norm(x, (x.size(-1),)) - x = self.smear(x) - x0 = x - skips: list[Tensor] = [] - ve_cache: dict = {} - for i in range(self.num_encoder_layers): - ve = self._get_ve(i, input_ids, ve_cache) - x = self.blocks[i](x, x0, - self.qo_bank[i], self.kv_bank[i], self.kv_bank[n + i], - self.qo_bank[n + i], self.mlp_up_bank[i], self.mlp_down_bank[i], - v_embed=ve) - skips.append(x) - for i in range(self.num_decoder_layers): - bi = self.num_encoder_layers + i - if skips: - x = x + self.skip_weights[i].to(dtype=x.dtype)[None, None, :] * skips.pop() - ve = self._get_ve(bi, input_ids, ve_cache) - x = self.blocks[bi](x, x0, - self.qo_bank[bi], self.kv_bank[bi], self.kv_bank[n + bi], - self.qo_bank[n + bi], self.mlp_up_bank[bi], self.mlp_down_bank[bi], - v_embed=ve) - x = self.final_norm(x) - if self.tie_embeddings: - logits_proj = F.linear(x, self.tok_emb.weight) - else: - logits_proj = self.lm_head(x) - return self.logit_softcap * torch.tanh(logits_proj / self.logit_softcap) - -# --- Sliding window evaluation --- - -def eval_val_sliding( - args: Hyperparameters, - base_model: nn.Module, - rank: int, - world_size: int, - device: torch.device, - val_tokens: Tensor, - base_bytes_lut: Tensor, - has_leading_space_lut: Tensor, - is_boundary_token_lut: Tensor, - stride: int, - batch_seqs: int = 32, - eval_seq_len: int | None = None, -) -> tuple[float, float]: - """Sliding window evaluation: each token scored with maximum context.""" - seq_len = eval_seq_len or args.train_seq_len - total_tokens = val_tokens.numel() - 1 - window_starts = [ws for ws in range(0, total_tokens, stride) - if min(ws + seq_len, total_tokens) - ws >= 1] - total_windows = len(window_starts) - my_s = (total_windows * rank) // world_size - my_e = (total_windows * (rank + 1)) // world_size - my_windows = window_starts[my_s:my_e] - loss_sum = torch.zeros((), device=device, dtype=torch.float64) - token_count = torch.zeros((), device=device, dtype=torch.float64) - byte_count = torch.zeros((), device=device, dtype=torch.float64) - base_model.eval() - compiled_logits = torch.compile(base_model.forward_logits, dynamic=False, fullgraph=True) - with torch.inference_mode(): - for bi in range(0, len(my_windows), batch_seqs): - batch_ws = my_windows[bi:bi + batch_seqs] - bsz = len(batch_ws) - x_batch = torch.zeros(bsz, seq_len, dtype=torch.int64, device=device) - y_batch = torch.zeros(bsz, seq_len, dtype=torch.int64, device=device) - wlens: list[int] = [] - for i, ws in enumerate(batch_ws): - end = min(ws + seq_len, total_tokens) - wlen = end - ws - wlens.append(wlen) - chunk = val_tokens[ws:end + 1].to(dtype=torch.int64, device=device) - x_batch[i, :wlen] = chunk[:-1] - y_batch[i, :wlen] = chunk[1:] - with torch.autocast(device_type="cuda", dtype=torch.bfloat16): - logits = compiled_logits(x_batch) - nll = F.cross_entropy( - logits.reshape(-1, logits.size(-1)).float(), - y_batch.reshape(-1), - reduction="none", - ).reshape(bsz, seq_len) - for i, ws in enumerate(batch_ws): - wlen = wlens[i] - s = 0 if ws == 0 else max(wlen - stride, 0) - scored_nll = nll[i, s:wlen].to(torch.float64) - loss_sum += scored_nll.sum() - token_count += float(wlen - s) - tgt = y_batch[i, s:wlen] - prev = x_batch[i, s:wlen] - tb = base_bytes_lut[tgt].to(torch.float64) - tb += (has_leading_space_lut[tgt] & ~is_boundary_token_lut[prev]).to(torch.float64) - byte_count += tb.sum() - if dist.is_available() and dist.is_initialized(): - dist.all_reduce(loss_sum, op=dist.ReduceOp.SUM) - dist.all_reduce(token_count, op=dist.ReduceOp.SUM) - dist.all_reduce(byte_count, op=dist.ReduceOp.SUM) - val_loss = (loss_sum / token_count).item() - bits_per_token = val_loss / math.log(2.0) - tokens_per_byte = token_count.item() / byte_count.item() - base_model.train() - return val_loss, bits_per_token * tokens_per_byte - - -def eval_val_sliding_ttt( - args: Hyperparameters, base_model: nn.Module, rank: int, world_size: int, - device: torch.device, val_tokens: Tensor, base_bytes_lut: Tensor, - has_leading_space_lut: Tensor, is_boundary_token_lut: Tensor, - stride: int, batch_seqs: int = 32, log0=print, -) -> tuple[float, float]: - """Legal score-first TTT (PR #461 recipe): score each chunk with sliding windows, - then train on it. Every token scored BEFORE any update that could use it.""" - seq_len = args.train_seq_len - total_tokens = val_tokens.numel() - 1 - ttt_chunk = args.ttt_chunk_tokens - - # Pre-compute all window starts - window_starts = [ws for ws in range(0, total_tokens, stride) - if min(ws + seq_len, total_tokens) - ws >= stride or ws == 0] - - # Assign each window to a chunk based on the first token it scores - num_chunks = (total_tokens + ttt_chunk - 1) // ttt_chunk - chunk_windows: list[list[int]] = [[] for _ in range(num_chunks)] - for ws in window_starts: - end = min(ws + seq_len, total_tokens) - wlen = end - ws - s = 0 if ws == 0 else max(wlen - stride, 0) - scored_start = ws + s - ci = min(scored_start // ttt_chunk, num_chunks - 1) - chunk_windows[ci].append(ws) - - log0(f"ttt_sliding:start chunks={num_chunks} chunk_tokens={ttt_chunk} " - f"total_windows={len(window_starts)} stride={stride} " - f"ttt_lr={args.ttt_lr} ttt_epochs={args.ttt_epochs} " - f"freeze_blocks={args.ttt_freeze_blocks}") - - loss_sum = torch.zeros((), device=device, dtype=torch.float64) - token_count = torch.zeros((), device=device, dtype=torch.float64) - byte_count = torch.zeros((), device=device, dtype=torch.float64) - - # Freeze first N blocks - frozen_block_ids = set(range(min(args.ttt_freeze_blocks, len(base_model.blocks)))) - ttt_params = [] - for name, p in base_model.named_parameters(): - freeze = False - for bi in frozen_block_ids: - if f"blocks.{bi}." in name: - freeze = True - break - if freeze: - p.requires_grad_(False) - else: - p.requires_grad_(True) - ttt_params.append(p) - - log0(f"ttt_sliding:params unfrozen={sum(p.numel() for p in ttt_params)} " - f"frozen={sum(p.numel() for p in base_model.parameters() if not p.requires_grad)}") - - optimizer = torch.optim.SGD(ttt_params, lr=args.ttt_lr, momentum=args.ttt_momentum) - t0 = time.perf_counter() - - for ci in range(num_chunks): - windows = chunk_windows[ci] - if not windows: - continue - chunk_start = ci * ttt_chunk - chunk_end = min((ci + 1) * ttt_chunk, total_tokens) - - # --- Phase 1: SCORE this chunk's windows (inference_mode) --- - my_s = (len(windows) * rank) // world_size - my_e = (len(windows) * (rank + 1)) // world_size - my_windows = windows[my_s:my_e] - - base_model.eval() - with torch.inference_mode(): - for bi in range(0, len(my_windows), batch_seqs): - batch_ws = my_windows[bi:bi + batch_seqs] - bsz = len(batch_ws) - x_batch = torch.zeros(bsz, seq_len, dtype=torch.int64, device=device) - y_batch = torch.zeros(bsz, seq_len, dtype=torch.int64, device=device) - wlens: list[int] = [] - for i, ws in enumerate(batch_ws): - end = min(ws + seq_len, total_tokens) - wlen = end - ws - wlens.append(wlen) - chunk_tok = val_tokens[ws:end + 1].to(dtype=torch.int64, device=device) - x_batch[i, :wlen] = chunk_tok[:-1] - y_batch[i, :wlen] = chunk_tok[1:] - with torch.autocast(device_type="cuda", dtype=torch.bfloat16): - logits = base_model.forward_logits(x_batch) - nll = F.cross_entropy( - logits.reshape(-1, logits.size(-1)).float(), - y_batch.reshape(-1), reduction="none", - ).reshape(bsz, seq_len) - for i, ws in enumerate(batch_ws): - wlen = wlens[i] - s = 0 if ws == 0 else max(wlen - stride, 0) - scored_nll = nll[i, s:wlen].to(torch.float64) - loss_sum += scored_nll.sum() - token_count += float(wlen - s) - tgt, prev = y_batch[i, s:wlen], x_batch[i, s:wlen] - tb = base_bytes_lut[tgt].to(torch.float64) - tb += (has_leading_space_lut[tgt] & ~is_boundary_token_lut[prev]).to(torch.float64) - byte_count += tb.sum() - - # --- Phase 2: TRAIN on this chunk (already scored = legal) --- - is_last_chunk = (ci == num_chunks - 1) - if not is_last_chunk and args.ttt_epochs > 0: - base_model.train() - chunk_seqs = (chunk_end - chunk_start) // seq_len - if chunk_seqs > 0: - cos_lr = args.ttt_lr * 0.5 * (1.0 + math.cos(math.pi * ci / max(num_chunks - 1, 1))) - for pg in optimizer.param_groups: - pg['lr'] = cos_lr - my_seq_s = (chunk_seqs * rank) // world_size - my_seq_e = (chunk_seqs * (rank + 1)) // world_size - my_chunk_seqs = my_seq_e - my_seq_s - for _ep in range(args.ttt_epochs): - for bs in range(0, my_chunk_seqs, args.ttt_batch_seqs): - be = min(bs + args.ttt_batch_seqs, my_chunk_seqs) - actual_bs = my_seq_s + bs - start_tok = chunk_start + actual_bs * seq_len - end_tok = chunk_start + (my_seq_s + be) * seq_len + 1 - if end_tok > val_tokens.numel(): - continue - local = val_tokens[start_tok:end_tok].to(device=device, dtype=torch.int64) - x = local[:-1].reshape(-1, seq_len) - y = local[1:].reshape(-1, seq_len) - optimizer.zero_grad(set_to_none=True) - with torch.autocast(device_type="cuda", dtype=torch.bfloat16): - loss = base_model(x, y) - loss.backward() - if world_size > 1: - for p in ttt_params: - if p.grad is not None: - dist.all_reduce(p.grad, op=dist.ReduceOp.AVG) - torch.nn.utils.clip_grad_norm_(ttt_params, args.ttt_grad_clip) - optimizer.step() - - if rank == 0 and (ci % 10 == 0 or ci == num_chunks - 1): - elapsed = time.perf_counter() - t0 - rl = loss_sum.item() / max(token_count.item(), 1) - rbpb = rl / math.log(2.0) * (token_count.item() / max(byte_count.item(), 1)) if token_count.item() > 0 else 0.0 - log0(f" ttt_chunk [{ci+1}/{num_chunks}] bpb={rbpb:.6f} time={elapsed:.1f}s") - - if dist.is_available() and dist.is_initialized(): - dist.all_reduce(loss_sum, op=dist.ReduceOp.SUM) - dist.all_reduce(token_count, op=dist.ReduceOp.SUM) - dist.all_reduce(byte_count, op=dist.ReduceOp.SUM) - - val_loss = (loss_sum / token_count).item() - val_bpb = val_loss / math.log(2.0) * (token_count.item() / byte_count.item()) - - for p in base_model.parameters(): - p.requires_grad_(True) - base_model.eval() - - log0(f"ttt_sliding:done val_loss={val_loss:.6f} val_bpb={val_bpb:.6f} " - f"elapsed={time.perf_counter() - t0:.1f}s") - return val_loss, val_bpb - - -# --- GPTQ-lite int6 quantization --- - -def _classify_param(name: str) -> str: - if "tok_emb" in name or "lm_head" in name: - return "embed" - if ".mlp." in name: - return "mlp" - if ".attn." in name or (".proj." in name and ".mlp." not in name): - return "attn" - return "other" -def quantize_int6_per_row(t: Tensor, clip_range: int = 31) -> tuple[Tensor, Tensor]: - t32 = t.float() - if t32.ndim == 2: - best_q, best_s, best_err = None, None, float('inf') - for pct in [0.9990, 0.9995, 0.9999, 0.99999, 1.0]: - if pct < 1.0: - row_clip = torch.quantile(t32.abs(), pct, dim=1) - else: - row_clip = t32.abs().amax(dim=1) - s = (row_clip / clip_range).clamp_min(1.0 / clip_range).to(torch.float16) - q = torch.clamp(torch.round(t32 / s.float()[:, None]), -clip_range, clip_range).to(torch.int8) - recon = q.float() * s.float()[:, None] - err = (t32 - recon).pow(2).mean().item() - if err < best_err: - best_q, best_s, best_err = q, s, err - return best_q, best_s - amax = t32.abs().max().item() - scale = torch.tensor(amax / clip_range if amax > 0 else 1.0, dtype=torch.float16) - q = torch.clamp(torch.round(t32 / scale.float()), -clip_range, clip_range).to(torch.int8) - return q, scale - -def _unbank_state_dict(sd: dict[str, Tensor], num_layers: int) -> dict[str, Tensor]: - """Convert 3D bank tensors into individual 2D tensors with standard names.""" - out: dict[str, Tensor] = {} - n = num_layers - for name, tensor in sd.items(): - if name == "qo_bank": - for i in range(n): - out[f"blocks.{i}.attn.c_q.weight"] = tensor[i] - out[f"blocks.{i}.attn.proj.weight"] = tensor[n + i] - elif name == "kv_bank": - for i in range(n): - out[f"blocks.{i}.attn.c_k.weight"] = tensor[i] - out[f"blocks.{i}.attn.c_v.weight"] = tensor[n + i] - elif name == "mlp_up_bank": - for i in range(n): - out[f"blocks.{i}.mlp.fc.weight"] = tensor[i] - elif name == "mlp_down_bank": - for i in range(n): - out[f"blocks.{i}.mlp.proj.weight"] = tensor[i] - else: - out[name] = tensor - return out - -def _rebank_state_dict(sd: dict[str, Tensor], num_layers: int, template_sd: dict[str, Tensor]) -> dict[str, Tensor]: - """Convert individual 2D tensors back into 3D bank tensors.""" - out: dict[str, Tensor] = {} - n = num_layers - # Reconstruct banks from individual weight keys - qo_slices = [None] * (2 * n) - kv_slices = [None] * (2 * n) - up_slices = [None] * n - down_slices = [None] * n - consumed = set() - for i in range(n): - qk = f"blocks.{i}.attn.c_q.weight" - if qk in sd: - qo_slices[i] = sd[qk] - consumed.add(qk) - ok = f"blocks.{i}.attn.proj.weight" - if ok in sd: - qo_slices[n + i] = sd[ok] - consumed.add(ok) - kk = f"blocks.{i}.attn.c_k.weight" - if kk in sd: - kv_slices[i] = sd[kk] - consumed.add(kk) - vk = f"blocks.{i}.attn.c_v.weight" - if vk in sd: - kv_slices[n + i] = sd[vk] - consumed.add(vk) - fk = f"blocks.{i}.mlp.fc.weight" - if fk in sd: - up_slices[i] = sd[fk] - consumed.add(fk) - dk = f"blocks.{i}.mlp.proj.weight" - if dk in sd: - down_slices[i] = sd[dk] - consumed.add(dk) - out["qo_bank"] = torch.stack(qo_slices).to(dtype=template_sd["qo_bank"].dtype) - out["kv_bank"] = torch.stack(kv_slices).to(dtype=template_sd["kv_bank"].dtype) - out["mlp_up_bank"] = torch.stack(up_slices).to(dtype=template_sd["mlp_up_bank"].dtype) - out["mlp_down_bank"] = torch.stack(down_slices).to(dtype=template_sd["mlp_down_bank"].dtype) - for name, tensor in sd.items(): - if name not in consumed: - out[name] = tensor - return out - -def mixed_quantize_int6(state_dict: dict[str, Tensor], int6_cats: set[str], clip_range: int = 31, - hessians: dict[str, Tensor] | None = None): - num_layers_total = max( - (int(k.split(".")[1]) for k in state_dict if k.startswith("blocks.")), - default=0, - ) + 1 - late_k_layers = set(range(num_layers_total - 2, num_layers_total)) - result: dict[str, Tensor] = {} - meta: dict[str, object] = {} - gptq_count, naive_count = 0, 0 - for name, tensor in state_dict.items(): - t = tensor.detach().cpu().contiguous() - cat = _classify_param(name) - if not t.is_floating_point() or t.numel() <= 65536: - result[name] = t.to(torch.float16) if t.is_floating_point() else t - meta[name] = "passthrough" - continue - if any(p in name for p in CONTROL_TENSOR_NAME_PATTERNS): - result[name] = t.float() - meta[name] = "passthrough_ctrl" - continue - if cat in int6_cats and t.ndim >= 1: - H = hessians.get(name) if hessians else None - if H is not None and t.ndim == 2: - q, s = gptq_quantize_weight(t, H.cpu(), clip_range=clip_range) - gptq_count += 1 - else: - q, s = quantize_int6_per_row(t, clip_range=clip_range) - naive_count += 1 - result[name + ".q"] = q - result[name + ".scale"] = s - meta[name] = {"type": "int6"} - else: - q, s = quantize_float_tensor(t) - result[name + ".q"] = q - result[name + ".scale"] = s - meta[name] = {"type": "int8"} - if hessians: - print(f"gptq_quantize: {gptq_count} GPTQ layers, {naive_count} naive layers", flush=True) - return result, meta -def dequantize_mixed_int6(result: dict[str, Tensor], meta: dict[str, object], - template_sd: dict[str, Tensor]) -> dict[str, Tensor]: - out: dict[str, Tensor] = {} - for name, orig in template_sd.items(): - info = meta.get(name) - if info is None: - continue - orig_dtype = orig.dtype - if info in ("passthrough", "passthrough_ctrl", "passthrough_fp16"): - t = result[name] - if t.dtype == torch.float16 and orig_dtype in (torch.float32, torch.bfloat16): - t = t.to(orig_dtype) - out[name] = t - continue - q, s = result[name + ".q"], result[name + ".scale"] - if s.ndim > 0: - out[name] = (q.float() * s.float().view(q.shape[0], *([1] * (q.ndim - 1)))).to(orig_dtype) - else: - out[name] = (q.float() * float(s.item())).to(orig_dtype) - return out - -# --- Full Hessian GPTQ --- - -def gptq_quantize_weight(W: Tensor, H: Tensor, clip_range: int = 31, - block_size: int = 128, percdamp: float = 0.01) -> tuple[Tensor, Tensor]: - """GPTQ with Cholesky error compensation and actorder (Frantar et al., ICLR 2023).""" - W_orig = W.float().clone() - rows, cols = W_orig.shape - H = H.float().clone() - dead = torch.diag(H) == 0 - H[dead, dead] = 1 - damp = percdamp * H.diag().mean() - H.diagonal().add_(damp) - perm = torch.argsort(H.diag(), descending=True) - invperm = torch.argsort(perm) - W_perm = W_orig[:, perm].clone() - W_perm[:, dead[perm]] = 0 - H = H[perm][:, perm] - try: - Hinv = torch.cholesky_inverse(torch.linalg.cholesky(H)) - Hinv = torch.linalg.cholesky(Hinv, upper=True) - except torch.linalg.LinAlgError: - return quantize_int6_per_row(W_orig, clip_range) - best_q, best_scale, best_err = None, None, float('inf') - for pct in [0.9990, 0.9995, 0.9999, 0.99999, 1.0]: - if pct < 1.0: - row_clip = torch.quantile(W_orig.abs(), pct, dim=1) - else: - row_clip = W_orig.abs().amax(dim=1) - s = (row_clip / clip_range).clamp_min(1.0 / clip_range).to(torch.float16) - sf = s.float() - Q = torch.zeros(rows, cols, dtype=torch.int8) - W_work = W_perm.clone() - for i1 in range(0, cols, block_size): - i2 = min(i1 + block_size, cols) - W_block = W_work[:, i1:i2].clone() - Hinv_block = Hinv[i1:i2, i1:i2] - Err = torch.zeros(rows, i2 - i1) - for j in range(i2 - i1): - w_col = W_block[:, j] - d = Hinv_block[j, j] - q_col = torch.clamp(torch.round(w_col / sf), -clip_range, clip_range) - Q[:, i1 + j] = q_col.to(torch.int8) - err = (w_col - q_col.float() * sf) / d - Err[:, j] = err - W_block[:, j:] -= err.unsqueeze(1) * Hinv_block[j, j:].unsqueeze(0) - if i2 < cols: - W_work[:, i2:] -= Err @ Hinv[i1:i2, i2:] - recon = Q.float() * sf[:, None] - mse = (W_perm - recon).pow(2).mean().item() - if mse < best_err: - best_q, best_scale, best_err = Q, s, mse - best_q = best_q[:, invperm] - return best_q, best_scale - -def _init_hessians(nl: int, dim: int, mlp_dim: int, device: torch.device) -> dict[str, Tensor]: - h: dict[str, Tensor] = {} - for i in range(nl): - for k in ['c_q', 'c_k', 'c_v']: - h[f'blocks.{i}.attn.{k}.weight'] = torch.zeros(dim, dim, dtype=torch.float32, device=device) - h[f'blocks.{i}.attn.proj.weight'] = torch.zeros(dim, dim, dtype=torch.float32, device=device) - h[f'blocks.{i}.mlp.fc.weight'] = torch.zeros(dim, dim, dtype=torch.float32, device=device) - h[f'blocks.{i}.mlp.proj.weight'] = torch.zeros(mlp_dim, mlp_dim, dtype=torch.float32, device=device) - return h - -def _accum_hessians(hessians: dict[str, Tensor], blocks: nn.ModuleList, dim: int, mlp_dim: int) -> None: - for i, block in enumerate(blocks): - qkv_in = block.attn._gptq_qkv_in.float().reshape(-1, dim) - h_qkv = qkv_in.t() @ qkv_in - hessians[f'blocks.{i}.attn.c_q.weight'] += h_qkv - hessians[f'blocks.{i}.attn.c_k.weight'] += h_qkv - hessians[f'blocks.{i}.attn.c_v.weight'] += h_qkv - o_in = block.attn._gptq_o_in.float().reshape(-1, dim) - hessians[f'blocks.{i}.attn.proj.weight'] += o_in.t() @ o_in - up_in = block.mlp._gptq_up_in.float().reshape(-1, dim) - hessians[f'blocks.{i}.mlp.fc.weight'] += up_in.t() @ up_in - down_in = block.mlp._gptq_down_in.float().reshape(-1, mlp_dim) - hessians[f'blocks.{i}.mlp.proj.weight'] += down_in.t() @ down_in - -def _finalize_hessians(hessians: dict[str, Tensor], num_batches: int) -> None: - for name in hessians: - hessians[name] = hessians[name].cpu() / num_batches - damp = 0.01 * torch.diag(hessians[name]).mean().clamp_min(1e-6) - hessians[name] += damp * torch.eye(hessians[name].shape[0]) - -def gptq_collect_hessians(base_model: nn.Module, train_loader, device: torch.device, - num_batches: int, batch_tokens: int, seq_len: int, - grad_accum_steps: int) -> dict[str, Tensor]: - """Collect Hessians H = X^T X from training data.""" - nl = base_model.num_layers - dim = base_model.tok_emb.weight.shape[1] - mlp_dim = base_model.mlp_up_bank.shape[1] - hessians = _init_hessians(nl, dim, mlp_dim, device) - for block in base_model.blocks: - block.attn._save_gptq = True - block.mlp._save_gptq = True - base_model.eval() - with torch.inference_mode(), torch.autocast(device_type='cuda', dtype=torch.bfloat16): - for _ in range(num_batches): - x, y = train_loader.next_batch(batch_tokens, seq_len, grad_accum_steps) - base_model(x, y) - _accum_hessians(hessians, base_model.blocks, dim, mlp_dim) - for block in base_model.blocks: - block.attn._save_gptq = False - block.mlp._save_gptq = False - _finalize_hessians(hessians, num_batches) - base_model.train() - return hessians - -# --- Training --- - -def main() -> None: - code = Path(__file__).read_text(encoding="utf-8") - args = Hyperparameters() - # zeropower_via_newtonschulz5 runs eagerly with bmm -- do NOT compile - distributed = "RANK" in os.environ and "WORLD_SIZE" in os.environ - rank = int(os.environ.get("RANK", "0")) - world_size = int(os.environ.get("WORLD_SIZE", "1")) - local_rank = int(os.environ.get("LOCAL_RANK", "0")) - if world_size <= 0: - raise ValueError(f"WORLD_SIZE must be positive, got {world_size}") - if 8 % world_size != 0: - raise ValueError(f"WORLD_SIZE={world_size} must divide 8 so grad_accum_steps stays integral") - grad_accum_steps = 8 // world_size - grad_scale = 1.0 / grad_accum_steps - if not torch.cuda.is_available(): - raise RuntimeError("CUDA is required") - device = torch.device("cuda", local_rank) - torch.cuda.set_device(device) - if distributed: - dist.init_process_group(backend="nccl", device_id=device) - dist.barrier() - master_process = rank == 0 - torch.backends.cuda.matmul.allow_tf32 = True - torch.backends.cudnn.allow_tf32 = True - from torch.backends.cuda import enable_cudnn_sdp, enable_flash_sdp, enable_math_sdp, enable_mem_efficient_sdp - enable_cudnn_sdp(False) - enable_flash_sdp(True) - enable_mem_efficient_sdp(False) - enable_math_sdp(False) - logfile = None - if master_process: - os.makedirs("logs", exist_ok=True) - logfile = f"logs/{args.run_id}.txt" - print(logfile) - def log0(msg: str, console: bool = True) -> None: - if not master_process: - return - if console: - print(msg) - if logfile is not None: - with open(logfile, "a", encoding="utf-8") as f: - print(msg, file=f) - log0(code, console=False) - log0("=" * 100, console=False) - log0(f"Running Python {sys.version}", console=False) - log0(f"Running PyTorch {torch.__version__}", console=False) - log0( - subprocess.run(["nvidia-smi"], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True, check=False).stdout, - console=False, - ) - log0("=" * 100, console=False) - random.seed(args.seed) - np.random.seed(args.seed) - torch.manual_seed(args.seed) - torch.cuda.manual_seed_all(args.seed) - (base_bytes_lut, has_leading_space_lut, is_boundary_token_lut), tokenizer_metadata = load_tokenizer_luts( - args.tokenizer_path, args.tokenizer_meta_path, args.vocab_size, device, - validate_meta=args.tokenizer_meta_validate, - ) - log0(f"tokenizer: kind={tokenizer_metadata.get('tokenizer_kind', 'unknown')} vocab={tokenizer_metadata.get('vocab_size', '?')}") - if tokenizer_metadata.get('tokenizer_kind') == 'sentencepiece': - sp = spm.SentencePieceProcessor(model_file=args.tokenizer_path) - if int(sp.vocab_size()) != args.vocab_size: - raise ValueError( - f"VOCAB_SIZE={args.vocab_size} does not match tokenizer vocab_size={int(sp.vocab_size())}" - ) - dataset_dir = Path(args.data_path).resolve() - actual_train_files = len(list(dataset_dir.glob("fineweb_train_*.bin"))) - effective_eval_seq_len = args.eval_seq_len if args.eval_seq_len > 0 else args.train_seq_len - val_seq_len = max(args.train_seq_len, effective_eval_seq_len) - val_tokens = load_validation_tokens(args.val_files, val_seq_len) - log0(f"train_loader:dataset:{dataset_dir.name} train_shards:{actual_train_files}") - log0(f"val_loader:shards pattern={args.val_files} tokens:{val_tokens.numel() - 1}") - base_model = GPT( - vocab_size=args.vocab_size, - num_layers=args.num_layers, - model_dim=args.model_dim, - num_heads=args.num_heads, - num_kv_heads=args.num_kv_heads, - mlp_mult=args.mlp_mult, - tie_embeddings=args.tie_embeddings, - tied_embed_init_std=args.tied_embed_init_std, - logit_softcap=args.logit_softcap, - rope_base=args.rope_base, - qk_gain_init=args.qk_gain_init, - bigram_vocab_size=args.bigram_vocab_size, - bigram_dim=args.bigram_dim, - xsa_last_n=args.xsa_last_n, - rope_dims=args.rope_dims, - ln_scale=args.ln_scale, - ve_enabled=args.ve_enabled, - ve_dim=args.ve_dim, - ve_layers=args.ve_layers, - neg_slope=args.negative_slope, - ).to(device).bfloat16() - # Banks stay FP32 (like CastedLinear weights), cast to BF16 in forward - base_model.qo_bank.data = base_model.qo_bank.data.float() - base_model.kv_bank.data = base_model.kv_bank.data.float() - base_model.mlp_up_bank.data = base_model.mlp_up_bank.data.float() - base_model.mlp_down_bank.data = base_model.mlp_down_bank.data.float() - for module in base_model.modules(): - if isinstance(module, CastedLinear): - module.float() - restore_low_dim_params_to_fp32(base_model) - # No DDP -- Parallel Muon handles bank grad communication via reduce-scatter, - # and non-bank grads are manually all-reduced before Adam steps. - compiled_model = torch.compile(base_model, dynamic=False, fullgraph=True) - model = compiled_model - - # Optimizer split: - # - 4 parameter banks -> Muon (batched Newton-Schulz) - # - token embedding -> Adam - # - scalars/control tensors -> Adam - # - bigram proj, VE proj -> Adam (small matrix params not worth banking) - matrix_params = [ - base_model.qo_bank, base_model.kv_bank, - base_model.mlp_up_bank, base_model.mlp_down_bank, - ] - block_named_params = list(base_model.blocks.named_parameters()) - scalar_params = [ - p - for name, p in block_named_params - if p.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) - ] - if base_model.skip_weights.numel() > 0: - scalar_params.append(base_model.skip_weights) - scalar_params.append(base_model.smear.gate) - if base_model.bigram is not None: - scalar_params.append(base_model.bigram.scale) - token_lr = args.tied_embed_lr if args.tie_embeddings else args.embed_lr - tok_params = [{"params": [base_model.tok_emb.weight], "lr": token_lr, "base_lr": token_lr}] - if base_model.bigram is not None: - tok_params.append({"params": [base_model.bigram.embed.weight], "lr": token_lr, "base_lr": token_lr}) - if base_model.bigram.proj is not None: - scalar_params.append(base_model.bigram.proj.weight) - if base_model.ve_shared is not None: - tok_params.append({"params": [base_model.ve_shared.embed.weight], "lr": token_lr, "base_lr": token_lr}) - if base_model.ve_shared.proj is not None: - scalar_params.append(base_model.ve_shared.proj.weight) - scalar_params.append(base_model.ve_shared.scale) - for s in base_model.ve_layer_scales: - scalar_params.append(s) - optimizer_tok = torch.optim.AdamW( - tok_params, - betas=(args.beta1, args.beta2), - eps=args.adam_eps, - weight_decay=args.adam_wd, - fused=True, - ) - optimizer_muon = Muon( - matrix_params, - lr=args.matrix_lr, - momentum=args.muon_momentum, - backend_steps=args.muon_backend_steps, - weight_decay=args.muon_wd, - ) - for group in optimizer_muon.param_groups: - group["base_lr"] = args.matrix_lr - optimizer_scalar = torch.optim.AdamW( - [{"params": scalar_params, "lr": args.scalar_lr, "base_lr": args.scalar_lr}], - betas=(args.beta1, args.beta2), - eps=args.adam_eps, - weight_decay=args.adam_wd, - fused=True, - ) - # Non-bank params that need manual all-reduce (replicated across GPUs) - replicated_params = list(optimizer_tok.param_groups[0]["params"]) - for pg in optimizer_tok.param_groups[1:]: - replicated_params.extend(pg["params"]) - replicated_params.extend(scalar_params) - - optimizer_head = None - if base_model.lm_head is not None: - optimizer_head = torch.optim.Adam( - [{"params": [base_model.lm_head.weight], "lr": args.head_lr, "base_lr": args.head_lr}], - betas=(args.beta1, args.beta2), - eps=args.adam_eps, - fused=True, - ) - replicated_params.append(base_model.lm_head.weight) - optimizers: list[torch.optim.Optimizer] = [optimizer_tok, optimizer_muon, optimizer_scalar] - if optimizer_head is not None: - optimizers.append(optimizer_head) - log0(f"model_params:{sum(p.numel() for p in base_model.parameters())}") - xsa_layers = [i for i, b in enumerate(base_model.blocks) if b.attn.use_xsa] - log0(f"XSA:last_{args.xsa_last_n} active_layers:{xsa_layers}") - log0(f"world_size:{world_size} grad_accum_steps:{grad_accum_steps}") - log0("sdp_backends:cudnn=False flash=True mem_efficient=False math=False") - log0(f"attention_mode:gqa num_heads:{args.num_heads} num_kv_heads:{args.num_kv_heads}") - log0( - f"tie_embeddings:{args.tie_embeddings} embed_lr:{token_lr} " - f"head_lr:{args.head_lr if base_model.lm_head is not None else 0.0} " - f"matrix_lr:{args.matrix_lr} scalar_lr:{args.scalar_lr}" - ) - log0( - f"train_batch_tokens:{args.train_batch_tokens} train_seq_len:{args.train_seq_len} " - f"iterations:{args.iterations} warmup_steps:{args.warmup_steps} " - f"max_wallclock_seconds:{args.max_wallclock_seconds:.3f}" - ) - log0(f"seed:{args.seed}") - train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) - def zero_grad_all() -> None: - for opt in optimizers: - opt.zero_grad(set_to_none=True) - max_wallclock_ms = 1000.0 * args.max_wallclock_seconds if args.max_wallclock_seconds > 0 else None - if args.use_gptq and max_wallclock_ms is not None: - max_wallclock_ms -= args.gptq_reserve_ms - log0(f"gptq:reserving {args.gptq_reserve_ms:.0f}ms from training budget, effective={max_wallclock_ms:.0f}ms") - def lr_mul(step: int, elapsed_ms: float) -> float: - if args.warmdown_iters <= 0: - return 1.0 - if max_wallclock_ms is None: - warmdown_start = max(args.iterations - args.warmdown_iters, 0) - return max((args.iterations - step) / max(args.warmdown_iters, 1), 0.0) if warmdown_start <= step < args.iterations else 1.0 - step_ms = elapsed_ms / max(step, 1) - warmdown_ms = args.warmdown_iters * step_ms - remaining_ms = max(max_wallclock_ms - elapsed_ms, 0.0) - return remaining_ms / max(warmdown_ms, 1e-9) if remaining_ms <= warmdown_ms else 1.0 - if args.warmup_steps > 0: - initial_model_state = {name: tensor.detach().cpu().clone() for name, tensor in base_model.state_dict().items()} - initial_optimizer_states = [copy.deepcopy(opt.state_dict()) for opt in optimizers] - model.train() - for warmup_step in range(args.warmup_steps): - zero_grad_all() - for micro_step in range(grad_accum_steps): - x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) - with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): - warmup_loss = model(x, y) - (warmup_loss * grad_scale).backward() - # All-reduce all grads for warmup (simple, not optimized) - if distributed: - for p in base_model.parameters(): - if p.grad is not None: - dist.all_reduce(p.grad, op=dist.ReduceOp.AVG) - for opt in optimizers: - opt.step() - zero_grad_all() - if args.warmup_steps <= 20 or (warmup_step + 1) % 10 == 0 or warmup_step + 1 == args.warmup_steps: - log0(f"warmup_step:{warmup_step + 1}/{args.warmup_steps}") - base_model.load_state_dict(initial_model_state, strict=True) - for opt, state in zip(optimizers, initial_optimizer_states, strict=True): - opt.load_state_dict(state) - zero_grad_all() - train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) - swa_state: dict[str, Tensor] | None = None - swa_count = 0 - ema_state = {name: t.detach().float().clone() for name, t in base_model.state_dict().items()} - ema_decay = 0.997 - training_time_ms = 0.0 - stop_after_step: int | None = None - torch.cuda.synchronize() - t0 = time.perf_counter() - step = 0 - while True: - last_step = step == args.iterations or (stop_after_step is not None and step >= stop_after_step) - should_validate = last_step or (args.val_loss_every > 0 and step % args.val_loss_every == 0) - if should_validate: - torch.cuda.synchronize() - training_time_ms += 1000.0 * (time.perf_counter() - t0) - val_loss, val_bpb = eval_val( - args, - model, - rank, - world_size, - device, - grad_accum_steps, - val_tokens, - base_bytes_lut, - has_leading_space_lut, - is_boundary_token_lut, - ) - log0( - f"step:{step}/{args.iterations} val_loss:{val_loss:.4f} val_bpb:{val_bpb:.4f} " - f"train_time:{training_time_ms:.0f}ms step_avg:{training_time_ms / max(step, 1):.2f}ms" - ) - torch.cuda.synchronize() - t0 = time.perf_counter() - if last_step: - if stop_after_step is not None and step < args.iterations: - log0( - f"stopping_early: wallclock_cap train_time:{training_time_ms:.0f}ms " - f"step:{step}/{args.iterations}" - ) - break - elapsed_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) - scale = lr_mul(step, elapsed_ms) - zero_grad_all() - train_loss = torch.zeros((), device=device) - for micro_step in range(grad_accum_steps): - x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) - with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): - loss = model(x, y) - train_loss += loss.detach() - (loss * grad_scale).backward() - train_loss /= grad_accum_steps - frac = min(step / args.muon_momentum_warmup_steps, 1.0) if args.muon_momentum_warmup_steps > 0 else 1.0 - muon_momentum = (1 - frac) * args.muon_momentum_warmup_start + frac * args.muon_momentum - for group in optimizer_muon.param_groups: - group["momentum"] = muon_momentum - for opt in optimizers: - for group in opt.param_groups: - group["lr"] = group["base_lr"] * scale - if args.grad_clip_norm > 0: - torch.nn.utils.clip_grad_norm_(base_model.parameters(), args.grad_clip_norm) - # === 3-phase overlapped optimizer step === - # Phase 1: Launch async reduce-scatter for banks (biggest first) - optimizer_muon.launch_reduce_scatters() - # Phase 2: All-reduce non-bank grads + step Adam (while bank RS is in-flight) - if distributed: - for p in replicated_params: - if p.grad is not None: - dist.all_reduce(p.grad, op=dist.ReduceOp.AVG) - optimizer_tok.step() - optimizer_scalar.step() - if optimizer_head is not None: - optimizer_head.step() - # Phase 3: Wait for RS, local NS5, all-gather (banks processed last) - optimizer_muon.step() - zero_grad_all() - # EMA update - with torch.no_grad(): - for name, t in base_model.state_dict().items(): - ema_state[name].mul_(ema_decay).add_(t.detach().float(), alpha=1.0 - ema_decay) - step += 1 - approx_training_time_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) - if args.swa_enabled and scale < 0.2 and step % args.swa_every == 0: - if swa_state is None: - swa_state = {name: t.detach().cpu().clone() for name, t in base_model.state_dict().items()} - swa_count = 1 - log0(f"swa:start step:{step}") - else: - for name, t in base_model.state_dict().items(): - swa_state[name] += t.detach().cpu() - swa_count += 1 - should_log_train = ( - args.train_log_every > 0 - and (step <= 10 or step % args.train_log_every == 0 or stop_after_step is not None) - ) - if should_log_train: - log0( - f"step:{step}/{args.iterations} train_loss:{train_loss.item():.4f} " - f"train_time:{approx_training_time_ms:.0f}ms step_avg:{approx_training_time_ms / step:.2f}ms" - ) - reached_cap = max_wallclock_ms is not None and approx_training_time_ms >= max_wallclock_ms - if distributed and max_wallclock_ms is not None: - reached_cap_tensor = torch.tensor(int(reached_cap), device=device) - dist.all_reduce(reached_cap_tensor, op=dist.ReduceOp.MAX) - reached_cap = bool(reached_cap_tensor.item()) - if stop_after_step is None and reached_cap: - stop_after_step = step - log0( - f"peak memory allocated: {torch.cuda.max_memory_allocated() // 1024 // 1024} MiB " - f"reserved: {torch.cuda.max_memory_reserved() // 1024 // 1024} MiB" - ) - # Apply EMA weights - log0("ema:applying EMA weights") - current_state = base_model.state_dict() - avg_state = {name: t.to(dtype=current_state[name].dtype) for name, t in ema_state.items()} - base_model.load_state_dict(avg_state, strict=True) - torch.cuda.synchronize() - t_diag = time.perf_counter() - diag_val_loss, diag_val_bpb = eval_val( - args, compiled_model, rank, world_size, device, grad_accum_steps, - val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, - ) - torch.cuda.synchronize() - log0( - f"DIAGNOSTIC post_ema val_loss:{diag_val_loss:.4f} val_bpb:{diag_val_bpb:.4f} " - f"eval_time:{1000.0 * (time.perf_counter() - t_diag):.0f}ms" - ) - export_sd = base_model.state_dict() - if master_process: - torch.save(export_sd, "final_model.pt") - model_bytes = os.path.getsize("final_model.pt") - code_bytes = len(code.encode("utf-8")) - log0(f"Serialized model: {model_bytes} bytes") - log0(f"Code size: {code_bytes} bytes") - # Unbank 3D tensors into individual 2D tensors for quantization - sd_cpu = {k: v.detach().cpu() for k, v in export_sd.items()} - unbanked_sd = _unbank_state_dict(sd_cpu, args.num_layers) - # GPTQ calibration: collect Hessians from training data - gptq_hessians = None - if args.use_gptq: - t_gptq = time.perf_counter() - log0(f"gptq:calibrating with {args.gptq_calib_samples} batches (training data)...") - calib_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) - gptq_hessians = gptq_collect_hessians( - base_model, calib_loader, device, num_batches=args.gptq_calib_samples, - batch_tokens=args.train_batch_tokens, seq_len=args.train_seq_len, - grad_accum_steps=grad_accum_steps) - del calib_loader - gptq_elapsed = time.perf_counter() - t_gptq - log0(f"gptq:calibrated {len(gptq_hessians)} layers in {gptq_elapsed:.1f}s") - torch.cuda.empty_cache() - quant_result, quant_meta = mixed_quantize_int6(unbanked_sd, {"mlp", "attn"}, clip_range=args.quant_clip_range, hessians=gptq_hessians) - quant_buf = io.BytesIO() - torch.save({"w": quant_result, "m": quant_meta}, quant_buf) - quant_raw = quant_buf.getvalue() - quant_blob = lzma.compress(quant_raw, preset=6) - if master_process: - with open("final_model.int6.ptz", "wb") as f: - f.write(quant_blob) - quant_file_bytes = len(quant_blob) - code_bytes = len(code.encode("utf-8")) - log0(f"Serialized model int6+lzma: {quant_file_bytes} bytes") - log0(f"Total submission size int6+lzma: {quant_file_bytes + code_bytes} bytes") - if distributed: - dist.barrier() - with open("final_model.int6.ptz", "rb") as f: - quant_blob_disk = f.read() - quant_state = torch.load( - io.BytesIO(lzma.decompress(quant_blob_disk)), - map_location="cpu", - ) - deq_unbanked = dequantize_mixed_int6(quant_state["w"], quant_state["m"], unbanked_sd) - # Re-bank the dequantized tensors - deq_state = _rebank_state_dict(deq_unbanked, args.num_layers, sd_cpu) - eval_model = GPT( - vocab_size=args.vocab_size, num_layers=args.num_layers, model_dim=args.model_dim, - num_heads=args.num_heads, num_kv_heads=args.num_kv_heads, mlp_mult=args.mlp_mult, - tie_embeddings=args.tie_embeddings, tied_embed_init_std=args.tied_embed_init_std, - logit_softcap=args.logit_softcap, rope_base=args.rope_base, qk_gain_init=args.qk_gain_init, - bigram_vocab_size=args.bigram_vocab_size, bigram_dim=args.bigram_dim, - xsa_last_n=args.xsa_last_n, - rope_dims=args.rope_dims, ln_scale=args.ln_scale, - ve_enabled=args.ve_enabled, ve_dim=args.ve_dim, ve_layers=args.ve_layers, - neg_slope=args.negative_slope, - ).to(device).bfloat16() - eval_model.qo_bank.data = eval_model.qo_bank.data.float() - eval_model.kv_bank.data = eval_model.kv_bank.data.float() - eval_model.mlp_up_bank.data = eval_model.mlp_up_bank.data.float() - eval_model.mlp_down_bank.data = eval_model.mlp_down_bank.data.float() - for m in eval_model.modules(): - if isinstance(m, CastedLinear): - m.float() - restore_low_dim_params_to_fp32(eval_model) - eval_model.load_state_dict(deq_state, strict=True) - compiled_eval = torch.compile(eval_model, dynamic=False, fullgraph=True) - torch.cuda.synchronize() - t_qeval = time.perf_counter() - q_val_loss, q_val_bpb = eval_val( - args, compiled_eval, rank, world_size, device, grad_accum_steps, - val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, - eval_seq_len=effective_eval_seq_len, - ) - torch.cuda.synchronize() - log0( - f"final_int6_roundtrip val_loss:{q_val_loss:.4f} val_bpb:{q_val_bpb:.4f} " - f"eval_time:{1000.0 * (time.perf_counter() - t_qeval):.0f}ms" - ) - log0(f"final_int6_roundtrip_exact val_loss:{q_val_loss:.8f} val_bpb:{q_val_bpb:.8f}") - sw_seq_len = effective_eval_seq_len - if args.eval_stride > 0 and args.eval_stride < sw_seq_len: - torch.cuda.synchronize() - t_slide = time.perf_counter() - sw_val_loss, sw_val_bpb = eval_val_sliding( - args, eval_model, rank, world_size, device, - val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, - stride=args.eval_stride, - eval_seq_len=sw_seq_len, - ) - torch.cuda.synchronize() - log0( - f"final_int6_sliding_window val_loss:{sw_val_loss:.4f} val_bpb:{sw_val_bpb:.4f} " - f"stride:{args.eval_stride} eval_time:{1000.0 * (time.perf_counter() - t_slide):.0f}ms" - ) - log0(f"final_int6_sliding_window_exact val_loss:{sw_val_loss:.8f} val_bpb:{sw_val_bpb:.8f}") - log0(f"final_int8_zlib_roundtrip_exact val_loss:{sw_val_loss:.8f} val_bpb:{sw_val_bpb:.8f}") - if args.eval_stride != 64 and 64 < sw_seq_len: - torch.cuda.synchronize() - t_slide64 = time.perf_counter() - sw64_val_loss, sw64_val_bpb = eval_val_sliding( - args, eval_model, rank, world_size, device, - val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, - stride=64, - eval_seq_len=sw_seq_len, - ) - torch.cuda.synchronize() - log0( - f"final_int6_sliding_window_s64 val_loss:{sw64_val_loss:.4f} val_bpb:{sw64_val_bpb:.4f} " - f"stride:64 eval_time:{1000.0 * (time.perf_counter() - t_slide64):.0f}ms" - ) - log0(f"final_int6_sliding_window_s64_exact val_loss:{sw64_val_loss:.8f} val_bpb:{sw64_val_bpb:.8f}") - log0(f"final_int8_zlib_roundtrip_exact val_loss:{sw64_val_loss:.8f} val_bpb:{sw64_val_bpb:.8f}") - # Legal score-first TTT (PR #461 recipe) - if args.ttt_enabled: - torch.cuda.synchronize() - t_ttt = time.perf_counter() - ttt_loss, ttt_bpb = eval_val_sliding_ttt( - args, eval_model, rank, world_size, device, - val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, - stride=args.eval_stride, log0=log0, - ) - torch.cuda.synchronize() - log0(f"legal_ttt val_loss:{ttt_loss:.4f} val_bpb:{ttt_bpb:.4f} " - f"eval_time:{1000.0 * (time.perf_counter() - t_ttt):.0f}ms") - log0(f"legal_ttt_exact val_loss:{ttt_loss:.8f} val_bpb:{ttt_bpb:.8f}") - if distributed: - dist.destroy_process_group() -if __name__ == "__main__": - main() - -==================================================================================================== -Running Python 3.12.3 (main, Mar 3 2026, 12:15:18) [GCC 13.3.0] -Running PyTorch 2.9.1+cu128 -Tue Mar 31 14:43:28 2026 -+-----------------------------------------------------------------------------------------+ -| NVIDIA-SMI 580.126.09 Driver Version: 580.126.09 CUDA Version: 13.0 | -+-----------------------------------------+------------------------+----------------------+ -| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC | -| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. | -| | | MIG M. | -|=========================================+========================+======================| -| 0 NVIDIA H100 80GB HBM3 On | 00000000:0A:00.0 Off | 0 | -| N/A 32C P0 117W / 700W | 1521MiB / 81559MiB | 0% Default | -| | | Disabled | -+-----------------------------------------+------------------------+----------------------+ -| 1 NVIDIA H100 80GB HBM3 On | 00000000:18:00.0 Off | 0 | -| N/A 29C P0 124W / 700W | 1521MiB / 81559MiB | 0% Default | -| | | Disabled | -+-----------------------------------------+------------------------+----------------------+ -| 2 NVIDIA H100 80GB HBM3 On | 00000000:3F:00.0 Off | 0 | -| N/A 28C P0 116W / 700W | 1521MiB / 81559MiB | 0% Default | -| | | Disabled | -+-----------------------------------------+------------------------+----------------------+ -| 3 NVIDIA H100 80GB HBM3 On | 00000000:48:00.0 Off | 0 | -| N/A 33C P0 122W / 700W | 1521MiB / 81559MiB | 0% Default | -| | | Disabled | -+-----------------------------------------+------------------------+----------------------+ -| 4 NVIDIA H100 80GB HBM3 On | 00000000:87:00.0 Off | 0 | -| N/A 32C P0 118W / 700W | 1521MiB / 81559MiB | 0% Default | -| | | Disabled | -+-----------------------------------------+------------------------+----------------------+ -| 5 NVIDIA H100 80GB HBM3 On | 00000000:90:00.0 Off | 0 | -| N/A 28C P0 112W / 700W | 1521MiB / 81559MiB | 0% Default | -| | | Disabled | -+-----------------------------------------+------------------------+----------------------+ -| 6 NVIDIA H100 80GB HBM3 On | 00000000:BE:00.0 Off | 0 | -| N/A 28C P0 117W / 700W | 1521MiB / 81559MiB | 0% Default | -| | | Disabled | -+-----------------------------------------+------------------------+----------------------+ -| 7 NVIDIA H100 80GB HBM3 On | 00000000:C7:00.0 Off | 0 | -| N/A 33C P0 120W / 700W | 1521MiB / 81559MiB | 0% Default | -| | | Disabled | -+-----------------------------------------+------------------------+----------------------+ - -+-----------------------------------------------------------------------------------------+ -| Processes: | -| GPU GI CI PID Type Process name GPU Memory | -| ID ID Usage | -|=========================================================================================| -| 0 N/A N/A 1930 C ...ameter-golf/.venv/bin/python3 1512MiB | -| 1 N/A N/A 1931 C ...ameter-golf/.venv/bin/python3 1512MiB | -| 2 N/A N/A 1932 C ...ameter-golf/.venv/bin/python3 1512MiB | -| 3 N/A N/A 1933 C ...ameter-golf/.venv/bin/python3 1512MiB | -| 4 N/A N/A 1934 C ...ameter-golf/.venv/bin/python3 1512MiB | -| 5 N/A N/A 1935 C ...ameter-golf/.venv/bin/python3 1512MiB | -| 6 N/A N/A 1936 C ...ameter-golf/.venv/bin/python3 1512MiB | -| 7 N/A N/A 1937 C ...ameter-golf/.venv/bin/python3 1512MiB | -+-----------------------------------------------------------------------------------------+ - -==================================================================================================== -tokenizer: kind=tokenmonster vocab=998 -train_loader:dataset:fineweb10B_scylla train_shards:194 -val_loader:shards pattern=./data/datasets/fineweb10B_scylla/fineweb_val_*.bin tokens:62363648 -model_params:27022172 -XSA:last_11 active_layers:[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10] -world_size:8 grad_accum_steps:1 -sdp_backends:cudnn=False flash=True mem_efficient=False math=False -attention_mode:gqa num_heads:8 num_kv_heads:4 -tie_embeddings:True embed_lr:0.035 head_lr:0.0 matrix_lr:0.025 scalar_lr:0.025 -train_batch_tokens:786432 train_seq_len:2048 iterations:9000 warmup_steps:20 max_wallclock_seconds:600.000 -seed:1337 -gptq:reserving 9000ms from training budget, effective=591000ms -warmup_step:1/20 -warmup_step:2/20 -warmup_step:3/20 -warmup_step:4/20 -warmup_step:5/20 -warmup_step:6/20 -warmup_step:7/20 -warmup_step:8/20 -warmup_step:9/20 -warmup_step:10/20 -warmup_step:11/20 -warmup_step:12/20 -warmup_step:13/20 -warmup_step:14/20 -warmup_step:15/20 -warmup_step:16/20 -warmup_step:17/20 -warmup_step:18/20 -warmup_step:19/20 -warmup_step:20/20 -step:0/9000 val_loss:6.8893 val_bpb:3.3906 train_time:0ms step_avg:0.01ms -step:1/9000 train_loss:6.8885 train_time:133ms step_avg:133.24ms -step:2/9000 train_loss:9.2540 train_time:168ms step_avg:84.25ms -step:3/9000 train_loss:9.1371 train_time:253ms step_avg:84.35ms -step:4/9000 train_loss:8.5749 train_time:338ms step_avg:84.61ms -step:5/9000 train_loss:8.1062 train_time:424ms step_avg:84.76ms -step:6/9000 train_loss:7.5450 train_time:509ms step_avg:84.79ms -step:7/9000 train_loss:7.0905 train_time:594ms step_avg:84.81ms -step:8/9000 train_loss:6.7904 train_time:679ms step_avg:84.83ms -step:9/9000 train_loss:6.6502 train_time:764ms step_avg:84.94ms -step:10/9000 train_loss:6.4416 train_time:849ms step_avg:84.94ms -step:500/9000 train_loss:2.3187 train_time:43818ms step_avg:87.64ms -step:1000/9000 train_loss:2.2047 train_time:87672ms step_avg:87.67ms -step:1500/9000 train_loss:2.1824 train_time:131516ms step_avg:87.68ms -step:2000/9000 train_loss:2.1525 train_time:175431ms step_avg:87.72ms -step:2500/9000 train_loss:2.2113 train_time:219385ms step_avg:87.75ms -step:3000/9000 train_loss:2.1066 train_time:263337ms step_avg:87.78ms -step:3500/9000 train_loss:2.0862 train_time:307270ms step_avg:87.79ms -step:4000/9000 train_loss:2.0368 train_time:351208ms step_avg:87.80ms -step:4000/9000 val_loss:2.0789 val_bpb:1.0232 train_time:351264ms step_avg:87.82ms -step:4500/9000 train_loss:2.0131 train_time:395168ms step_avg:87.82ms -step:5000/9000 train_loss:1.9703 train_time:439082ms step_avg:87.82ms -step:5500/9000 train_loss:1.9399 train_time:483044ms step_avg:87.83ms -step:6000/9000 train_loss:1.9922 train_time:527029ms step_avg:87.84ms -swa:start step:6050 -step:6500/9000 train_loss:1.9783 train_time:571716ms step_avg:87.96ms -step:6716/9000 val_loss:1.9638 val_bpb:0.9665 train_time:591133ms step_avg:88.02ms -stopping_early: wallclock_cap train_time:591133ms step:6716/9000 -peak memory allocated: 23035 MiB reserved: 24046 MiB -ema:applying EMA weights -DIAGNOSTIC post_ema val_loss:1.9622 val_bpb:0.9657 eval_time:2112ms -Serialized model: 106201974 bytes -Code size: 101929 bytes -gptq:calibrating with 64 batches (training data)... -gptq:calibrated 66 layers in 6.7s -Serialized model int6+lzma: 15859076 bytes -Total submission size int6+lzma: 15961005 bytes -final_int6_roundtrip val_loss:1.9683 val_bpb:0.9687 eval_time:16727ms -final_int6_roundtrip_exact val_loss:1.96829674 val_bpb:0.96871918 -final_int6_sliding_window val_loss:1.9285 val_bpb:0.9491 stride:64 eval_time:92334ms -final_int6_sliding_window_exact val_loss:1.92847248 val_bpb:0.94911064 -final_int8_zlib_roundtrip_exact val_loss:1.92847248 val_bpb:0.94911064 -ttt_sliding:start chunks=1904 chunk_tokens=32768 total_windows=974432 stride=64 ttt_lr=0.002 ttt_epochs=3 freeze_blocks=0 -ttt_sliding:params unfrozen=27022172 frozen=0 - ttt_chunk [1/1904] bpb=0.914901 time=0.5s - ttt_chunk [11/1904] bpb=0.913003 time=2.8s - ttt_chunk [21/1904] bpb=0.937304 time=5.0s - ttt_chunk [31/1904] bpb=0.941354 time=7.3s - ttt_chunk [41/1904] bpb=0.968061 time=9.5s - ttt_chunk [51/1904] bpb=0.967601 time=11.8s - ttt_chunk [61/1904] bpb=0.968924 time=14.1s - ttt_chunk [71/1904] bpb=0.962667 time=16.3s - ttt_chunk [81/1904] bpb=0.974603 time=18.6s - ttt_chunk [91/1904] bpb=0.967818 time=20.9s - ttt_chunk [101/1904] bpb=0.965223 time=23.3s - ttt_chunk [111/1904] bpb=0.960300 time=25.6s - ttt_chunk [121/1904] bpb=0.961390 time=28.0s - ttt_chunk [131/1904] bpb=0.961213 time=30.4s - ttt_chunk [141/1904] bpb=0.957305 time=32.7s - ttt_chunk [151/1904] bpb=0.954450 time=35.1s - ttt_chunk [161/1904] bpb=0.953940 time=37.3s - ttt_chunk [171/1904] bpb=0.951811 time=39.6s - ttt_chunk [181/1904] bpb=0.950999 time=41.9s - ttt_chunk [191/1904] bpb=0.950985 time=44.3s - ttt_chunk [201/1904] bpb=0.952431 time=46.7s - ttt_chunk [211/1904] bpb=0.950764 time=49.0s - ttt_chunk [221/1904] bpb=0.947317 time=51.4s - ttt_chunk [231/1904] bpb=0.945341 time=53.8s - ttt_chunk [241/1904] bpb=0.947721 time=56.1s - ttt_chunk [251/1904] bpb=0.946732 time=58.5s - ttt_chunk [261/1904] bpb=0.944192 time=60.9s - ttt_chunk [271/1904] bpb=0.944057 time=63.2s - ttt_chunk [281/1904] bpb=0.943612 time=65.6s - ttt_chunk [291/1904] bpb=0.942713 time=67.9s - ttt_chunk [301/1904] bpb=0.943475 time=70.3s - ttt_chunk [311/1904] bpb=0.944422 time=72.6s - ttt_chunk [321/1904] bpb=0.944382 time=75.0s - ttt_chunk [331/1904] bpb=0.944306 time=77.4s - ttt_chunk [341/1904] bpb=0.943192 time=79.7s - ttt_chunk [351/1904] bpb=0.942651 time=82.1s - ttt_chunk [361/1904] bpb=0.944115 time=84.5s - ttt_chunk [371/1904] bpb=0.943507 time=86.8s - ttt_chunk [381/1904] bpb=0.943059 time=89.2s - ttt_chunk [391/1904] bpb=0.942307 time=91.6s - ttt_chunk [401/1904] bpb=0.942080 time=93.9s - ttt_chunk [411/1904] bpb=0.941052 time=96.3s - ttt_chunk [421/1904] bpb=0.939784 time=98.7s - ttt_chunk [431/1904] bpb=0.938400 time=101.0s - ttt_chunk [441/1904] bpb=0.937732 time=103.4s - ttt_chunk [451/1904] bpb=0.935945 time=105.8s - ttt_chunk [461/1904] bpb=0.935694 time=108.2s - ttt_chunk [471/1904] bpb=0.934666 time=110.5s - ttt_chunk [481/1904] bpb=0.935068 time=112.9s - ttt_chunk [491/1904] bpb=0.935595 time=115.2s - ttt_chunk [501/1904] bpb=0.936116 time=117.6s - ttt_chunk [511/1904] bpb=0.936868 time=120.0s - ttt_chunk [521/1904] bpb=0.939505 time=122.3s - ttt_chunk [531/1904] bpb=0.939533 time=124.7s - ttt_chunk [541/1904] bpb=0.940639 time=127.1s - ttt_chunk [551/1904] bpb=0.940191 time=129.4s - ttt_chunk [561/1904] bpb=0.940047 time=131.8s - ttt_chunk [571/1904] bpb=0.940319 time=134.2s - ttt_chunk [581/1904] bpb=0.941310 time=136.5s - ttt_chunk [591/1904] bpb=0.941933 time=138.9s - ttt_chunk [601/1904] bpb=0.941911 time=141.3s - ttt_chunk [611/1904] bpb=0.941772 time=143.6s - ttt_chunk [621/1904] bpb=0.943185 time=146.0s - ttt_chunk [631/1904] bpb=0.943644 time=148.4s - ttt_chunk [641/1904] bpb=0.944237 time=150.7s - ttt_chunk [651/1904] bpb=0.943977 time=153.1s - ttt_chunk [661/1904] bpb=0.944482 time=155.4s - ttt_chunk [671/1904] bpb=0.945080 time=157.8s - ttt_chunk [681/1904] bpb=0.945138 time=160.2s - ttt_chunk [691/1904] bpb=0.945189 time=162.5s - ttt_chunk [701/1904] bpb=0.945827 time=164.9s - ttt_chunk [711/1904] bpb=0.946313 time=167.3s - ttt_chunk [721/1904] bpb=0.945798 time=169.5s - ttt_chunk [731/1904] bpb=0.945671 time=171.8s - ttt_chunk [741/1904] bpb=0.945600 time=174.1s - ttt_chunk [751/1904] bpb=0.945010 time=176.3s - ttt_chunk [761/1904] bpb=0.944838 time=178.6s - ttt_chunk [771/1904] bpb=0.944373 time=180.9s - ttt_chunk [781/1904] bpb=0.944208 time=183.3s - ttt_chunk [791/1904] bpb=0.943991 time=185.7s - ttt_chunk [801/1904] bpb=0.944312 time=188.0s - ttt_chunk [811/1904] bpb=0.945292 time=190.4s - ttt_chunk [821/1904] bpb=0.944678 time=192.8s - ttt_chunk [831/1904] bpb=0.945320 time=195.1s - ttt_chunk [841/1904] bpb=0.945218 time=197.5s - ttt_chunk [851/1904] bpb=0.945569 time=199.9s - ttt_chunk [861/1904] bpb=0.946017 time=202.3s - ttt_chunk [871/1904] bpb=0.946303 time=204.6s - ttt_chunk [881/1904] bpb=0.946467 time=207.0s - ttt_chunk [891/1904] bpb=0.947352 time=209.2s - ttt_chunk [901/1904] bpb=0.947989 time=211.6s - ttt_chunk [911/1904] bpb=0.948574 time=213.9s - ttt_chunk [921/1904] bpb=0.949399 time=216.3s - ttt_chunk [931/1904] bpb=0.949907 time=218.7s - ttt_chunk [941/1904] bpb=0.949558 time=221.0s - ttt_chunk [951/1904] bpb=0.949777 time=223.4s - ttt_chunk [961/1904] bpb=0.949834 time=225.7s - ttt_chunk [971/1904] bpb=0.950314 time=228.1s - ttt_chunk [981/1904] bpb=0.950653 time=230.4s - ttt_chunk [991/1904] bpb=0.950442 time=232.7s - ttt_chunk [1001/1904] bpb=0.950593 time=235.1s - ttt_chunk [1011/1904] bpb=0.951170 time=237.5s - ttt_chunk [1021/1904] bpb=0.951753 time=239.8s - ttt_chunk [1031/1904] bpb=0.951453 time=242.2s - ttt_chunk [1041/1904] bpb=0.951271 time=244.5s - ttt_chunk [1051/1904] bpb=0.950964 time=246.9s - ttt_chunk [1061/1904] bpb=0.950293 time=249.3s - ttt_chunk [1071/1904] bpb=0.950168 time=251.6s - ttt_chunk [1081/1904] bpb=0.950004 time=254.0s - ttt_chunk [1091/1904] bpb=0.949922 time=256.4s - ttt_chunk [1101/1904] bpb=0.950135 time=258.7s - ttt_chunk [1111/1904] bpb=0.950082 time=261.1s - ttt_chunk [1121/1904] bpb=0.950364 time=263.4s - ttt_chunk [1131/1904] bpb=0.950346 time=265.8s - ttt_chunk [1141/1904] bpb=0.950656 time=268.2s - ttt_chunk [1151/1904] bpb=0.950340 time=270.5s - ttt_chunk [1161/1904] bpb=0.950405 time=272.9s - ttt_chunk [1171/1904] bpb=0.950157 time=275.2s - ttt_chunk [1181/1904] bpb=0.950237 time=277.6s - ttt_chunk [1191/1904] bpb=0.949822 time=280.0s - ttt_chunk [1201/1904] bpb=0.950124 time=282.3s - ttt_chunk [1211/1904] bpb=0.950281 time=284.7s - ttt_chunk [1221/1904] bpb=0.950330 time=287.0s - ttt_chunk [1231/1904] bpb=0.951030 time=289.4s - ttt_chunk [1241/1904] bpb=0.951141 time=291.8s - ttt_chunk [1251/1904] bpb=0.951689 time=294.1s - ttt_chunk [1261/1904] bpb=0.951526 time=296.5s - ttt_chunk [1271/1904] bpb=0.951625 time=298.9s - ttt_chunk [1281/1904] bpb=0.951833 time=301.2s - ttt_chunk [1291/1904] bpb=0.952138 time=303.6s - ttt_chunk [1301/1904] bpb=0.952691 time=306.0s - ttt_chunk [1311/1904] bpb=0.952943 time=308.3s - ttt_chunk [1321/1904] bpb=0.953451 time=310.7s - ttt_chunk [1331/1904] bpb=0.953436 time=313.0s - ttt_chunk [1341/1904] bpb=0.953740 time=315.4s - ttt_chunk [1351/1904] bpb=0.954124 time=317.8s - ttt_chunk [1361/1904] bpb=0.954480 time=320.1s - ttt_chunk [1371/1904] bpb=0.954304 time=322.5s - ttt_chunk [1381/1904] bpb=0.953771 time=324.9s - ttt_chunk [1391/1904] bpb=0.953956 time=327.2s - ttt_chunk [1401/1904] bpb=0.953548 time=329.6s - ttt_chunk [1411/1904] bpb=0.953520 time=331.9s - ttt_chunk [1421/1904] bpb=0.953437 time=334.3s - ttt_chunk [1431/1904] bpb=0.953470 time=336.7s - ttt_chunk [1441/1904] bpb=0.953492 time=339.0s - ttt_chunk [1451/1904] bpb=0.953543 time=341.4s - ttt_chunk [1461/1904] bpb=0.953783 time=343.8s - ttt_chunk [1471/1904] bpb=0.953578 time=346.1s - ttt_chunk [1481/1904] bpb=0.953328 time=348.5s - ttt_chunk [1491/1904] bpb=0.952615 time=350.8s - ttt_chunk [1501/1904] bpb=0.953147 time=353.2s - ttt_chunk [1511/1904] bpb=0.953082 time=355.6s - ttt_chunk [1521/1904] bpb=0.952924 time=357.9s - ttt_chunk [1531/1904] bpb=0.952638 time=360.3s - ttt_chunk [1541/1904] bpb=0.952957 time=362.6s - ttt_chunk [1551/1904] bpb=0.953036 time=365.0s - ttt_chunk [1561/1904] bpb=0.953133 time=367.4s - ttt_chunk [1571/1904] bpb=0.953289 time=369.7s - ttt_chunk [1581/1904] bpb=0.953147 time=372.1s - ttt_chunk [1591/1904] bpb=0.952888 time=374.4s - ttt_chunk [1601/1904] bpb=0.953175 time=376.8s - ttt_chunk [1611/1904] bpb=0.953120 time=379.2s - ttt_chunk [1621/1904] bpb=0.953266 time=381.5s - ttt_chunk [1631/1904] bpb=0.953248 time=383.9s - ttt_chunk [1641/1904] bpb=0.953153 time=386.2s - ttt_chunk [1651/1904] bpb=0.953304 time=388.6s - ttt_chunk [1661/1904] bpb=0.953450 time=391.0s - ttt_chunk [1671/1904] bpb=0.953031 time=393.3s - ttt_chunk [1681/1904] bpb=0.953203 time=395.7s - ttt_chunk [1691/1904] bpb=0.953012 time=398.1s - ttt_chunk [1701/1904] bpb=0.953064 time=400.4s - ttt_chunk [1711/1904] bpb=0.953522 time=402.8s - ttt_chunk [1721/1904] bpb=0.953598 time=405.1s - ttt_chunk [1731/1904] bpb=0.954146 time=407.5s - ttt_chunk [1741/1904] bpb=0.954059 time=409.9s - ttt_chunk [1751/1904] bpb=0.954169 time=412.2s - ttt_chunk [1761/1904] bpb=0.954057 time=414.6s - ttt_chunk [1771/1904] bpb=0.954031 time=416.9s - ttt_chunk [1781/1904] bpb=0.953853 time=419.3s - ttt_chunk [1791/1904] bpb=0.953654 time=421.7s - ttt_chunk [1801/1904] bpb=0.953564 time=424.0s - ttt_chunk [1811/1904] bpb=0.953581 time=426.4s - ttt_chunk [1821/1904] bpb=0.953341 time=428.8s - ttt_chunk [1831/1904] bpb=0.953460 time=431.1s - ttt_chunk [1841/1904] bpb=0.953381 time=433.5s - ttt_chunk [1851/1904] bpb=0.953179 time=435.8s - ttt_chunk [1861/1904] bpb=0.953061 time=438.2s - ttt_chunk [1871/1904] bpb=0.952992 time=440.6s - ttt_chunk [1881/1904] bpb=0.953127 time=442.9s - ttt_chunk [1891/1904] bpb=0.952652 time=445.3s - ttt_chunk [1901/1904] bpb=0.952918 time=447.6s - ttt_chunk [1904/1904] bpb=0.952924 time=448.1s -ttt_sliding:done val_loss=1.928505 val_bpb=0.949127 elapsed=448.2s -legal_ttt val_loss:1.9285 val_bpb:0.9491 eval_time:448555ms -legal_ttt_exact val_loss:1.92850507 val_bpb:0.94912668 diff --git a/records/track_10min_16mb/2026-03-31_Scylla_FullGPTQ_XSA11_FA3_0.9485/train_seed2025.log b/records/track_10min_16mb/2026-03-31_Scylla_FullGPTQ_XSA11_FA3_0.9485/train_seed2025.log deleted file mode 100644 index 775d27aa2e..0000000000 --- a/records/track_10min_16mb/2026-03-31_Scylla_FullGPTQ_XSA11_FA3_0.9485/train_seed2025.log +++ /dev/null @@ -1,2308 +0,0 @@ -from __future__ import annotations -import copy -import glob -import io -import lzma -import math -import os -import random -import subprocess -import sys -import time -import uuid -from pathlib import Path -import numpy as np -import sentencepiece as spm -import torch -import torch.distributed as dist -import torch.nn.functional as F -from torch import Tensor, nn -from torch.nn.parallel import DistributedDataParallel as DDP -from flash_attn_interface import flash_attn_func as flash_attn_3_func -class Hyperparameters: - data_path = os.environ.get("DATA_PATH", "./data/datasets/fineweb10B_sp1024") - train_files = os.path.join(data_path, "fineweb_train_*.bin") - val_files = os.path.join(data_path, "fineweb_val_*.bin") - tokenizer_path = os.environ.get("TOKENIZER_PATH", "./data/tokenizers/fineweb_1024_bpe.model") - run_id = os.environ.get("RUN_ID", str(uuid.uuid4())) - seed = int(os.environ.get("SEED", 1337)) - val_batch_size = int(os.environ.get("VAL_BATCH_SIZE", 524_288)) - val_loss_every = int(os.environ.get("VAL_LOSS_EVERY", 4000)) - train_log_every = int(os.environ.get("TRAIN_LOG_EVERY", 500)) - iterations = int(os.environ.get("ITERATIONS", 20000)) - warmdown_iters = int(os.environ.get("WARMDOWN_ITERS", 3500)) - warmup_steps = int(os.environ.get("WARMUP_STEPS", 20)) - train_batch_tokens = int(os.environ.get("TRAIN_BATCH_TOKENS", 786_432)) - train_seq_len = int(os.environ.get("TRAIN_SEQ_LEN", 2048)) - eval_seq_len = int(os.environ.get("EVAL_SEQ_LEN", 2048)) - max_wallclock_seconds = float(os.environ.get("MAX_WALLCLOCK_SECONDS", 600.0)) - qk_gain_init = float(os.environ.get("QK_GAIN_INIT", 1.5)) - vocab_size = int(os.environ.get("VOCAB_SIZE", 1024)) - num_layers = int(os.environ.get("NUM_LAYERS", 11)) - num_kv_heads = int(os.environ.get("NUM_KV_HEADS", 4)) - model_dim = int(os.environ.get("MODEL_DIM", 512)) - num_heads = int(os.environ.get("NUM_HEADS", 8)) - mlp_mult = float(os.environ.get("MLP_MULT", 3.0)) - tie_embeddings = bool(int(os.environ.get("TIE_EMBEDDINGS", "1"))) - rope_base = float(os.environ.get("ROPE_BASE", 10000.0)) - logit_softcap = float(os.environ.get("LOGIT_SOFTCAP", 30.0)) - embed_lr = float(os.environ.get("EMBED_LR", 0.6)) - head_lr = float(os.environ.get("HEAD_LR", 0.008)) - tied_embed_lr = float(os.environ.get("TIED_EMBED_LR", 0.035)) - tied_embed_init_std = float(os.environ.get("TIED_EMBED_INIT_STD", 0.005)) - matrix_lr = float(os.environ.get("MATRIX_LR", 0.025)) - scalar_lr = float(os.environ.get("SCALAR_LR", 0.025)) - muon_momentum = float(os.environ.get("MUON_MOMENTUM", 0.99)) - muon_backend_steps = int(os.environ.get("MUON_BACKEND_STEPS", 5)) - muon_momentum_warmup_start = float(os.environ.get("MUON_MOMENTUM_WARMUP_START", 0.92)) - muon_momentum_warmup_steps = int(os.environ.get("MUON_MOMENTUM_WARMUP_STEPS", 1500)) - beta1 = float(os.environ.get("BETA1", 0.9)) - beta2 = float(os.environ.get("BETA2", 0.95)) - adam_eps = float(os.environ.get("ADAM_EPS", 1e-8)) - grad_clip_norm = float(os.environ.get("GRAD_CLIP_NORM", 0.3)) - eval_stride = int(os.environ.get("EVAL_STRIDE", 64)) - muon_beta2 = float(os.environ.get("MUON_BETA2", 0.95)) - swa_enabled = bool(int(os.environ.get("SWA_ENABLED", "1"))) - swa_every = int(os.environ.get("SWA_EVERY", 50)) - muon_wd = float(os.environ.get("MUON_WD", 0.04)) - adam_wd = float(os.environ.get("ADAM_WD", 0.04)) - bigram_vocab_size = int(os.environ.get("BIGRAM_VOCAB_SIZE", 2048)) - bigram_dim = int(os.environ.get("BIGRAM_DIM", 128)) - xsa_last_n = int(os.environ.get("XSA_LAST_N", 4)) - rope_dims = int(os.environ.get("ROPE_DIMS", 16)) - ln_scale = bool(int(os.environ.get("LN_SCALE", "1"))) - ve_enabled = bool(int(os.environ.get("VE_ENABLED", "1"))) - ve_dim = int(os.environ.get("VE_DIM", 128)) - ve_layers = os.environ.get("VE_LAYERS", "9,10") - ttt_enabled = bool(int(os.environ.get("TTT_ENABLED", "0"))) - ttt_lr = float(os.environ.get("TTT_LR", 0.002)) - ttt_epochs = int(os.environ.get("TTT_EPOCHS", 3)) - ttt_chunk_tokens = int(os.environ.get("TTT_CHUNK_TOKENS", 32768)) - ttt_freeze_blocks = int(os.environ.get("TTT_FREEZE_BLOCKS", 2)) - ttt_momentum = float(os.environ.get("TTT_MOMENTUM", 0.9)) - ttt_batch_seqs = int(os.environ.get("TTT_BATCH_SEQS", 32)) - ttt_grad_clip = float(os.environ.get("TTT_GRAD_CLIP", 1.0)) - negative_slope = float(os.environ.get("NEGATIVE_SLOPE", 0.5)) - use_gptq = bool(int(os.environ.get("USE_GPTQ", "0"))) - gptq_calib_samples = int(os.environ.get("GPTQ_CALIB_SAMPLES", "64")) - gptq_reserve_ms = float(os.environ.get("GPTQ_RESERVE_MS", "14000")) - quant_clip_range = int(os.environ.get("QUANT_CLIP_RANGE", 31)) - tokenizer_meta_path = os.environ.get("TOKENIZER_META_PATH", "") - tokenizer_meta_validate = bool(int(os.environ.get("TOKENIZER_META_VALIDATE", "0"))) - -# --- Batched Newton-Schulz orthogonalization --- - -def zeropower_via_newtonschulz5(G: Tensor, steps: int = 5, eps: float = 1e-7) -> Tensor: - """Batched Newton-Schulz orthogonalization. G: (B,M,N) or (M,N).""" - a, b, c = (3.4445, -4.7750, 2.0315) - was_2d = G.ndim == 2 - if was_2d: - G = G.unsqueeze(0) - X = G.bfloat16() - transposed = X.size(-2) > X.size(-1) - if transposed: - X = X.mT - X = X / (X.norm(dim=(-2, -1), keepdim=True) + eps) - for _ in range(steps): - A = X @ X.mT - B = b * A + c * (A @ A) - X = a * X + B @ X - if transposed: - X = X.mT - if was_2d: - X = X.squeeze(0) - return X - -# --- Parallel Muon optimizer --- - -class Muon(torch.optim.Optimizer): - """Parallel Muon: post-backward reduce-scatter -> local NS5 -> all-gather. - - No DDP for bank params. After backward, this optimizer: - 1. Launches async reduce-scatter for all banks (biggest first) - 2. Returns control so Adam can step on small params while RS is in-flight - 3. Waits for each RS, runs local NS5 on the shard, launches async all-gather - 4. Each all-gather overlaps with next bank's NS5 - """ - def __init__(self, params, lr: float, momentum: float, backend_steps: int, - nesterov: bool = True, weight_decay: float = 0.0): - super().__init__( - params, - dict(lr=lr, momentum=momentum, backend_steps=backend_steps, - nesterov=nesterov, weight_decay=weight_decay), - ) - self._built = False - - def _build(self): - self._distributed = dist.is_available() and dist.is_initialized() - self._world_size = dist.get_world_size() if self._distributed else 1 - self._rank = dist.get_rank() if self._distributed else 0 - ws = self._world_size - - self._bank_meta = [] - for group in self.param_groups: - for p in group["params"]: - B = p.shape[0] - padded_B = ((B + ws - 1) // ws) * ws - shard_B = padded_B // ws - tail = p.shape[1:] - dev = p.device - self._bank_meta.append({ - 'p': p, - 'B': B, - 'padded_grad': torch.zeros(padded_B, *tail, device=dev, dtype=torch.bfloat16), - 'shard': torch.zeros(shard_B, *tail, device=dev, dtype=torch.bfloat16), - 'shard_mom': torch.zeros(shard_B, *tail, device=dev, dtype=torch.bfloat16), - 'full_update': torch.zeros(padded_B, *tail, device=dev, dtype=torch.bfloat16), - 'scale': max(1, p.shape[-2] / p.shape[-1]) ** 0.5, - }) - # Sort by size descending -- launch biggest reduce-scatters first - self._bank_meta.sort(key=lambda m: -m['p'].numel()) - self._built = True - - def launch_reduce_scatters(self): - """Phase 1: launch async reduce-scatter for all banks. Call right after backward.""" - if not self._built: - self._build() - if not self._distributed: - return - self._rs_futures = [] - for m in self._bank_meta: - p = m['p'] - if p.grad is None: - self._rs_futures.append(None) - continue - pg = m['padded_grad'] - pg[:m['B']].copy_(p.grad.bfloat16()) - if pg.shape[0] > m['B']: - pg[m['B']:].zero_() - fut = dist.reduce_scatter_tensor(m['shard'], pg, op=dist.ReduceOp.AVG, async_op=True) - self._rs_futures.append(fut) - - @torch.no_grad() - def step(self, closure=None): - """Phase 3: wait for RS, local NS5, all-gather. Call AFTER Adam steps.""" - loss = None - if closure is not None: - with torch.enable_grad(): - loss = closure() - - if not self._built: - self._build() - - for group in self.param_groups: - lr = group["lr"] - momentum = group["momentum"] - backend_steps = group["backend_steps"] - nesterov = group["nesterov"] - wd = group.get("weight_decay", 0.0) - - prev_ag_handle = None - prev_m = None - - sharded = self._distributed and hasattr(self, '_rs_futures') - - for i, m in enumerate(self._bank_meta): - p = m['p'] - if p.grad is None: - continue - - if prev_ag_handle is not None: - prev_ag_handle.wait() - pp = prev_m['p'] - upd = prev_m['full_update'][:prev_m['B']] - if wd > 0.0: - pp.data.mul_(1.0 - lr * wd) - pp.add_(upd.to(dtype=pp.dtype), alpha=-lr * prev_m['scale']) - - if sharded and self._rs_futures[i] is not None: - self._rs_futures[i].wait() - g = m['shard'] - buf = m['shard_mom'] - else: - g = p.grad.bfloat16() - state = self.state[p] - if "momentum_buffer" not in state: - state["momentum_buffer"] = torch.zeros_like(g) - buf = state["momentum_buffer"] - - buf.mul_(momentum).add_(g) - if nesterov: - update = g.add(buf, alpha=momentum) - else: - update = buf - - update = zeropower_via_newtonschulz5(update, steps=backend_steps) - - if sharded: - prev_ag_handle = dist.all_gather_into_tensor( - m['full_update'], update, async_op=True) - prev_m = m - else: - if wd > 0.0: - p.data.mul_(1.0 - lr * wd) - p.add_(update.to(dtype=p.dtype), alpha=-lr * m['scale']) - - if prev_ag_handle is not None: - prev_ag_handle.wait() - pp = prev_m['p'] - upd = prev_m['full_update'][:prev_m['B']] - if wd > 0.0: - pp.data.mul_(1.0 - lr * wd) - pp.add_(upd.to(dtype=pp.dtype), alpha=-lr * prev_m['scale']) - - if hasattr(self, '_rs_futures'): - del self._rs_futures - - return loss - -# --- Tokenizer evaluation helpers --- - -def build_sentencepiece_luts( - sp: spm.SentencePieceProcessor, vocab_size: int, device: torch.device -) -> tuple[Tensor, Tensor, Tensor]: - sp_vocab_size = int(sp.vocab_size()) - table_size = max(sp_vocab_size, vocab_size) - base_bytes_np = np.zeros((table_size,), dtype=np.int16) - has_leading_space_np = np.zeros((table_size,), dtype=np.bool_) - is_boundary_token_np = np.ones((table_size,), dtype=np.bool_) - for token_id in range(sp_vocab_size): - if sp.is_control(token_id) or sp.is_unknown(token_id) or sp.is_unused(token_id): - continue - is_boundary_token_np[token_id] = False - if sp.is_byte(token_id): - base_bytes_np[token_id] = 1 - continue - piece = sp.id_to_piece(token_id) - if piece.startswith("\u2581"): - has_leading_space_np[token_id] = True - piece = piece[1:] - base_bytes_np[token_id] = len(piece.encode("utf-8")) - return ( - torch.tensor(base_bytes_np, dtype=torch.int16, device=device), - torch.tensor(has_leading_space_np, dtype=torch.bool, device=device), - torch.tensor(is_boundary_token_np, dtype=torch.bool, device=device), - ) - -TOKENIZER_META_FORMAT_VERSION = 1 -TOKENIZER_META_SUFFIX = ".meta.npz" - - -def _derive_tokenizer_meta_path(tokenizer_path: str) -> Path: - tokenizer = Path(tokenizer_path) - if tokenizer.suffix == ".model": - return tokenizer.with_suffix(TOKENIZER_META_SUFFIX) - return tokenizer.with_name(tokenizer.name + TOKENIZER_META_SUFFIX) - - -def build_sentencepiece_luts_np( - sp: spm.SentencePieceProcessor, vocab_size: int -) -> tuple[np.ndarray, np.ndarray, np.ndarray]: - sp_vocab_size = int(sp.vocab_size()) - table_size = max(sp_vocab_size, vocab_size) - base_bytes_np = np.zeros((table_size,), dtype=np.int16) - has_leading_space_np = np.zeros((table_size,), dtype=np.bool_) - is_boundary_token_np = np.ones((table_size,), dtype=np.bool_) - for token_id in range(sp_vocab_size): - if sp.is_control(token_id) or sp.is_unknown(token_id) or sp.is_unused(token_id): - continue - is_boundary_token_np[token_id] = False - if sp.is_byte(token_id): - base_bytes_np[token_id] = 1 - continue - piece = sp.id_to_piece(token_id) - if piece.startswith("\u2581"): - has_leading_space_np[token_id] = True - piece = piece[1:] - base_bytes_np[token_id] = len(piece.encode("utf-8")) - return base_bytes_np, has_leading_space_np, is_boundary_token_np - - -def load_tokenizer_meta_luts_np( - meta_path: Path, vocab_size: int -) -> tuple[np.ndarray, np.ndarray, np.ndarray, dict[str, object]]: - def _scalar(value): - arr = np.asarray(value) - if arr.ndim == 0: - return arr.item() - first = arr.reshape(-1)[0] - return first.item() if hasattr(first, "item") else first - - with np.load(meta_path, allow_pickle=False) as data: - format_version = int(_scalar(data["format_version"])) - if format_version != TOKENIZER_META_FORMAT_VERSION: - raise ValueError( - f"Unsupported tokenizer meta format_version={format_version} " - f"expected={TOKENIZER_META_FORMAT_VERSION}" - ) - meta_vocab_size = int(_scalar(data["vocab_size"])) - tokenizer_kind = str(_scalar(data["tokenizer_kind"])) - source_model_name = str(_scalar(data["source_model_name"])) - base_bytes_np = np.asarray(data["base_bytes"], dtype=np.int16) - has_leading_space_np = np.asarray(data["has_leading_space"], dtype=np.bool_) - is_boundary_token_np = np.asarray(data["is_boundary_token"], dtype=np.bool_) - table_size = max(meta_vocab_size, vocab_size) - if base_bytes_np.shape[0] < table_size: - padded_base_bytes = np.zeros((table_size,), dtype=np.int16) - padded_has_leading_space = np.zeros((table_size,), dtype=np.bool_) - padded_is_boundary = np.ones((table_size,), dtype=np.bool_) - padded_base_bytes[: base_bytes_np.shape[0]] = base_bytes_np - padded_has_leading_space[: has_leading_space_np.shape[0]] = has_leading_space_np - padded_is_boundary[: is_boundary_token_np.shape[0]] = is_boundary_token_np - base_bytes_np = padded_base_bytes - has_leading_space_np = padded_has_leading_space - is_boundary_token_np = padded_is_boundary - metadata = { - "format_version": format_version, - "tokenizer_kind": tokenizer_kind, - "source_model_name": source_model_name, - "vocab_size": meta_vocab_size, - "meta_path": str(meta_path), - } - return base_bytes_np, has_leading_space_np, is_boundary_token_np, metadata - - -def load_tokenizer_luts( - tokenizer_path: str, - tokenizer_meta_path: str, - vocab_size: int, - device: torch.device, - *, - validate_meta: bool = False, -) -> tuple[tuple[Tensor, Tensor, Tensor], dict[str, object]]: - meta_path = ( - Path(tokenizer_meta_path) if tokenizer_meta_path - else _derive_tokenizer_meta_path(tokenizer_path) - ) - if meta_path.exists(): - base_bytes_np, has_leading_space_np, is_boundary_token_np, metadata = ( - load_tokenizer_meta_luts_np(meta_path, vocab_size) - ) - if validate_meta and str(tokenizer_path).endswith(".model"): - sp = spm.SentencePieceProcessor(model_file=tokenizer_path) - sp_luts = build_sentencepiece_luts_np(sp, vocab_size) - if not ( - np.array_equal(base_bytes_np, sp_luts[0]) - and np.array_equal(has_leading_space_np, sp_luts[1]) - and np.array_equal(is_boundary_token_np, sp_luts[2]) - ): - raise ValueError(f"Tokenizer metadata mismatch for {meta_path}") - return ( - torch.tensor(base_bytes_np, dtype=torch.int16, device=device), - torch.tensor(has_leading_space_np, dtype=torch.bool, device=device), - torch.tensor(is_boundary_token_np, dtype=torch.bool, device=device), - ), metadata - if not str(tokenizer_path).endswith(".model"): - raise FileNotFoundError(f"TOKENIZER_META_PATH does not exist: {meta_path}") - sp = spm.SentencePieceProcessor(model_file=tokenizer_path) - return build_sentencepiece_luts(sp, vocab_size, device), { - "tokenizer_kind": "sentencepiece", - "source_model_name": str(tokenizer_path), - "vocab_size": int(sp.vocab_size()), - "meta_path": None, - "fallback": True, - } - -def load_validation_tokens(pattern: str, seq_len: int) -> Tensor: - files = [Path(p) for p in sorted(glob.glob(pattern))] - if not files: - raise FileNotFoundError(f"No files found for pattern: {pattern}") - tokens = torch.cat([load_data_shard(file) for file in files]).contiguous() - usable = ((tokens.numel() - 1) // seq_len) * seq_len - if usable <= 0: - raise ValueError(f"Validation split is too short for TRAIN_SEQ_LEN={seq_len}") - return tokens[: usable + 1] -def eval_val( - args: Hyperparameters, - model: nn.Module, - rank: int, - world_size: int, - device: torch.device, - grad_accum_steps: int, - val_tokens: Tensor, - base_bytes_lut: Tensor, - has_leading_space_lut: Tensor, - is_boundary_token_lut: Tensor, - eval_seq_len: int | None = None, -) -> tuple[float, float]: - seq_len = eval_seq_len or args.train_seq_len - local_batch_tokens = args.val_batch_size // (world_size * grad_accum_steps) - if local_batch_tokens < seq_len: - raise ValueError( - "VAL_BATCH_SIZE must provide at least one sequence per rank; " - f"got VAL_BATCH_SIZE={args.val_batch_size}, WORLD_SIZE={world_size}, " - f"GRAD_ACCUM_STEPS={grad_accum_steps}, seq_len={seq_len}" - ) - local_batch_seqs = local_batch_tokens // seq_len - total_seqs = (val_tokens.numel() - 1) // seq_len - seq_start = (total_seqs * rank) // world_size - seq_end = (total_seqs * (rank + 1)) // world_size - val_loss_sum = torch.zeros((), device=device, dtype=torch.float64) - val_token_count = torch.zeros((), device=device, dtype=torch.float64) - val_byte_count = torch.zeros((), device=device, dtype=torch.float64) - model.eval() - with torch.inference_mode(): - for batch_seq_start in range(seq_start, seq_end, local_batch_seqs): - batch_seq_end = min(batch_seq_start + local_batch_seqs, seq_end) - raw_start = batch_seq_start * seq_len - raw_end = batch_seq_end * seq_len + 1 - local = val_tokens[raw_start:raw_end].to(device=device, dtype=torch.int64, non_blocking=True) - x = local[:-1].reshape(-1, seq_len) - y = local[1:].reshape(-1, seq_len) - with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): - batch_loss = model(x, y).detach() - batch_token_count = float(y.numel()) - val_loss_sum += batch_loss.to(torch.float64) * batch_token_count - val_token_count += batch_token_count - prev_ids = x.reshape(-1) - tgt_ids = y.reshape(-1) - token_bytes = base_bytes_lut[tgt_ids].to(dtype=torch.int16) - token_bytes += (has_leading_space_lut[tgt_ids] & ~is_boundary_token_lut[prev_ids]).to(dtype=torch.int16) - val_byte_count += token_bytes.to(torch.float64).sum() - if dist.is_available() and dist.is_initialized(): - dist.all_reduce(val_loss_sum, op=dist.ReduceOp.SUM) - dist.all_reduce(val_token_count, op=dist.ReduceOp.SUM) - dist.all_reduce(val_byte_count, op=dist.ReduceOp.SUM) - val_loss = val_loss_sum / val_token_count - bits_per_token = val_loss.item() / math.log(2.0) - tokens_per_byte = val_token_count.item() / val_byte_count.item() - model.train() - return float(val_loss.item()), float(bits_per_token * tokens_per_byte) - -# --- Quantization helpers --- - -CONTROL_TENSOR_NAME_PATTERNS = tuple( - pattern - for pattern in os.environ.get( - "CONTROL_TENSOR_NAME_PATTERNS", - "attn_scale,attn_scales,mlp_scale,mlp_scales,resid_mix,resid_mixes,q_gain,skip_weight,skip_weights,smear,dtg_gate,ve_layer_scales,ve_shared.scale,attn_gate,vr_lambda", - ).split(",") - if pattern -) -INT8_KEEP_FLOAT_FP32_NAME_PATTERNS = tuple( - pattern - for pattern in os.environ.get( - "INT8_KEEP_FLOAT_FP32_NAME_PATTERNS", - ",".join(CONTROL_TENSOR_NAME_PATTERNS), - ).split(",") - if pattern -) -INT8_KEEP_FLOAT_MAX_NUMEL = 65_536 -INT8_KEEP_FLOAT_STORE_DTYPE = torch.float16 -INT8_PER_ROW_SCALE_DTYPE = torch.float16 -INT8_CLIP_PERCENTILE = 99.99984 -INT8_CLIP_Q = INT8_CLIP_PERCENTILE / 100.0 -def tensor_nbytes(t: Tensor) -> int: - return int(t.numel()) * int(t.element_size()) -def keep_float_tensor(name: str, t: Tensor, passthrough_orig_dtypes: dict[str, str]) -> Tensor: - if any(pattern in name for pattern in INT8_KEEP_FLOAT_FP32_NAME_PATTERNS): - return t.float().contiguous() - if t.dtype in {torch.float32, torch.bfloat16}: - passthrough_orig_dtypes[name] = str(t.dtype).removeprefix("torch.") - return t.to(dtype=INT8_KEEP_FLOAT_STORE_DTYPE).contiguous() - return t -def quantize_float_tensor(t: Tensor) -> tuple[Tensor, Tensor]: - t32 = t.float() - if t32.ndim == 2: - clip_abs = ( - torch.quantile(t32.abs(), INT8_CLIP_Q, dim=1) - if t32.numel() - else torch.empty((t32.shape[0],), dtype=torch.float32) - ) - clipped = torch.maximum(torch.minimum(t32, clip_abs[:, None]), -clip_abs[:, None]) - scale = (clip_abs / 127.0).clamp_min(1.0 / 127.0) - q = torch.clamp(torch.round(clipped / scale[:, None]), -127, 127).to(torch.int8).contiguous() - return q, scale.to(dtype=INT8_PER_ROW_SCALE_DTYPE).contiguous() - clip_abs = float(torch.quantile(t32.abs().flatten(), INT8_CLIP_Q).item()) if t32.numel() else 0.0 - scale = torch.tensor(clip_abs / 127.0 if clip_abs > 0 else 1.0, dtype=torch.float32) - q = torch.clamp(torch.round(torch.clamp(t32, -clip_abs, clip_abs) / scale), -127, 127).to(torch.int8).contiguous() - return q, scale -def quantize_state_dict_int8(state_dict: dict[str, Tensor]): - quantized: dict[str, Tensor] = {} - scales: dict[str, Tensor] = {} - dtypes: dict[str, str] = {} - passthrough: dict[str, Tensor] = {} - passthrough_orig_dtypes: dict[str, str] = {} - qmeta: dict[str, dict[str, object]] = {} - stats = dict.fromkeys( - ("param_count", "num_tensors", "num_float_tensors", "num_nonfloat_tensors", "baseline_tensor_bytes", "int8_payload_bytes"), - 0, - ) - for name, tensor in state_dict.items(): - t = tensor.detach().to("cpu").contiguous() - stats["param_count"] += int(t.numel()) - stats["num_tensors"] += 1 - stats["baseline_tensor_bytes"] += tensor_nbytes(t) - if not t.is_floating_point(): - stats["num_nonfloat_tensors"] += 1 - passthrough[name] = t - stats["int8_payload_bytes"] += tensor_nbytes(t) - continue - if t.numel() <= INT8_KEEP_FLOAT_MAX_NUMEL: - kept = keep_float_tensor(name, t, passthrough_orig_dtypes) - passthrough[name] = kept - stats["int8_payload_bytes"] += tensor_nbytes(kept) - continue - stats["num_float_tensors"] += 1 - q, s = quantize_float_tensor(t) - if s.ndim > 0: - qmeta[name] = {"scheme": "per_row", "axis": 0} - quantized[name] = q - scales[name] = s - dtypes[name] = str(t.dtype).removeprefix("torch.") - stats["int8_payload_bytes"] += tensor_nbytes(q) + tensor_nbytes(s) - obj: dict[str, object] = { - "__quant_format__": "int8_clean_per_row_v1", - "quantized": quantized, - "scales": scales, - "dtypes": dtypes, - "passthrough": passthrough, - } - if qmeta: - obj["qmeta"] = qmeta - if passthrough_orig_dtypes: - obj["passthrough_orig_dtypes"] = passthrough_orig_dtypes - return obj, stats -def dequantize_state_dict_int8(obj: dict[str, object]) -> dict[str, Tensor]: - out: dict[str, Tensor] = {} - qmeta = obj.get("qmeta", {}) - passthrough_orig_dtypes = obj.get("passthrough_orig_dtypes", {}) - for name, q in obj["quantized"].items(): - dtype = getattr(torch, obj["dtypes"][name]) - s = obj["scales"][name] - if qmeta.get(name, {}).get("scheme") == "per_row" or s.ndim > 0: - s = s.to(dtype=torch.float32) - out[name] = (q.float() * s.view(q.shape[0], *([1] * (q.ndim - 1)))).to(dtype=dtype).contiguous() - else: - scale = float(s.item()) - out[name] = (q.float() * scale).to(dtype=dtype).contiguous() - for name, t in obj["passthrough"].items(): - out_t = t.detach().to("cpu").contiguous() - orig_dtype = passthrough_orig_dtypes.get(name) - if isinstance(orig_dtype, str): - out_t = out_t.to(dtype=getattr(torch, orig_dtype)).contiguous() - out[name] = out_t - return out - -# --- Data loading --- - -def load_data_shard(file: Path) -> Tensor: - header_bytes = 256 * np.dtype(" int: - key = str(file) - cached = _SHARD_NTOKENS_CACHE.get(key) - if cached is not None: - return cached - header = np.fromfile(file, dtype=" np.memmap: - key = str(file) - mm = _MMAP_CACHE.get(key) - if mm is not None: - return mm - n = _read_num_tokens(file) - mm = np.memmap(file, mode="r", dtype=" int: - if n <= 1: - return 1 - while True: - s = int(self._rng.integers(1, n)) - if math.gcd(s, n) == 1: - return s - def _reset_cursor(self, si: int, seq_len: int) -> None: - nt = int(self._num_tokens[si]) - max_phase = min(seq_len - 1, max(0, nt - seq_len - 1)) - phase = int(self._rng.integers(max_phase + 1)) if max_phase > 0 else 0 - bc = (nt - 1 - phase) // seq_len - self._cursor_phase[si] = phase - self._cursor_block_count[si] = bc - self._cursor_next[si] = 0 - self._cursor_start[si] = int(self._rng.integers(bc)) if bc > 1 else 0 - self._cursor_stride[si] = self._pick_coprime_stride(bc) - self._cursor_init[si] = True - def _ensure_cursor(self, si: int, seq_len: int) -> None: - if not self._cursor_init[si] or self._cursor_next[si] >= self._cursor_block_count[si]: - self._reset_cursor(si, seq_len) - def _take_from_shard(self, si: int, seq_len: int, count: int, out: list[tuple[int, int]]) -> None: - rem = count - while rem > 0: - self._ensure_cursor(si, seq_len) - bc = int(self._cursor_block_count[si]) - ni = int(self._cursor_next[si]) - take = min(rem, bc - ni) - phase = int(self._cursor_phase[si]) - start = int(self._cursor_start[si]) - stride = int(self._cursor_stride[si]) - for j in range(take): - bi = (start + (ni + j) * stride) % bc - out.append((si, phase + bi * seq_len)) - self._cursor_next[si] = ni + take - rem -= take - def _init_pipeline(self, global_tokens: int, seq_len: int, grad_accum_steps: int) -> None: - local_tokens = global_tokens // (self.world_size * grad_accum_steps) - num_seqs = local_tokens // seq_len - global_num_seqs = num_seqs * self.world_size - self._cfg = (local_tokens, seq_len, num_seqs, global_num_seqs) - bbc = (self._num_tokens - 1) // seq_len - eligible = bbc > 0 - self._eligible_shards = np.nonzero(eligible)[0].astype(np.int64) - self._base_block_counts = bbc[self._eligible_shards].astype(np.int64) - def _sample_global_windows(self) -> list[tuple[int, int]]: - assert self._cfg is not None and self._eligible_shards is not None - _, seq_len, _, gns = self._cfg - ec = int(self._eligible_shards.size) - progress = min(self._batches_built / 1800.0, 1.0) - remaining = np.empty(ec, dtype=np.float64) - for i, si in enumerate(self._eligible_shards.tolist()): - if self._cursor_init[si]: - r = int(self._cursor_block_count[si]) - int(self._cursor_next[si]) - remaining[i] = float(max(r, 1)) - else: - remaining[i] = float(self._base_block_counts[i]) - alpha = 0.90 - 0.40 * progress - weights = np.power(remaining, alpha) - ws = float(weights.sum()) - if not np.isfinite(ws) or ws <= 0.0: - weights = np.ones(ec, dtype=np.float64) - ws = float(weights.sum()) - probs = weights / ws - low = min(max(8, self.world_size), ec, gns) - high = min(max(32, self.world_size * 8), ec, gns) - mix = max(1, min(int(round(low + progress * (high - low))), ec, gns)) - cp = self._rng.choice(ec, size=mix, replace=False, p=probs) - cs = self._eligible_shards[cp] - cpr = probs[cp].copy() - cpr /= cpr.sum() - counts = np.ones(mix, dtype=np.int64) - extra = gns - mix - if extra > 0: - counts += self._rng.multinomial(extra, cpr).astype(np.int64) - perm = self._rng.permutation(mix) - cs, counts = cs[perm], counts[perm] - buckets: list[list[tuple[int, int]]] = [] - for si, cnt in zip(cs.tolist(), counts.tolist()): - b: list[tuple[int, int]] = [] - self._take_from_shard(int(si), seq_len, int(cnt), b) - if b: - if len(b) > 1: - bp = self._rng.permutation(len(b)) - b = [b[int(k)] for k in bp.tolist()] - buckets.append(b) - windows: list[tuple[int, int]] = [] - active = [i for i, bk in enumerate(buckets) if bk] - while active: - order = self._rng.permutation(len(active)) - new_active: list[int] = [] - for oi in order.tolist(): - bi = active[oi] - if buckets[bi]: - windows.append(buckets[bi].pop()) - if buckets[bi]: - new_active.append(bi) - active = new_active - return windows - def next_batch(self, global_tokens: int, seq_len: int, grad_accum_steps: int) -> tuple[Tensor, Tensor]: - if self._cfg is None: - self._init_pipeline(global_tokens, seq_len, grad_accum_steps) - _, _, num_seqs, gns = self._cfg - gw = self._sample_global_windows() - local_w = gw[self.rank::self.world_size] - x = torch.empty((num_seqs, seq_len), dtype=torch.int64) - y = torch.empty((num_seqs, seq_len), dtype=torch.int64) - for slot, (si, pos) in enumerate(local_w): - mm = _get_shard_memmap(self.files[si]) - window = torch.as_tensor(np.array(mm[pos:pos + seq_len + 1], dtype=np.int64)) - x[slot] = window[:-1] - y[slot] = window[1:] - self._batches_built += 1 - return x.to(self.device, non_blocking=True), y.to(self.device, non_blocking=True) - -# --- Transformer modules --- - -class RMSNorm(nn.Module): - def __init__(self, eps: float | None = None): - super().__init__() - self.eps = eps - def forward(self, x: Tensor) -> Tensor: - return F.rms_norm(x, (x.size(-1),), eps=self.eps) -class CastedLinear(nn.Linear): - def forward(self, x: Tensor) -> Tensor: - w = self.weight.to(x.dtype) - bias = self.bias.to(x.dtype) if self.bias is not None else None - return F.linear(x, w, bias) -def restore_low_dim_params_to_fp32(module: nn.Module) -> None: - with torch.no_grad(): - for name, param in module.named_parameters(): - if (param.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS)) and param.dtype != torch.float32: - param.data = param.data.float() -class Rotary(nn.Module): - def __init__(self, dim: int, base: float = 10000.0, train_seq_len: int = 1024, rope_dims: int = 0): - super().__init__() - self.dim = dim - self.base = base - self.train_seq_len = train_seq_len - self.rope_dims = rope_dims if rope_dims > 0 else dim - inv_freq = 1.0 / (base ** (torch.arange(0, self.rope_dims, 2, dtype=torch.float32) / self.rope_dims)) - self.register_buffer("inv_freq", inv_freq, persistent=False) - self._seq_len_cached = 0 - self._cos_cached: Tensor | None = None - self._sin_cached: Tensor | None = None - def forward(self, seq_len: int, device: torch.device, dtype: torch.dtype) -> tuple[Tensor, Tensor]: - if ( - self._cos_cached is None - or self._sin_cached is None - or self._seq_len_cached != seq_len - or self._cos_cached.device != device - ): - rd = self.rope_dims - if seq_len > self.train_seq_len: - scale = seq_len / self.train_seq_len - new_base = self.base * (scale ** (rd / (rd - 2))) - inv_freq = 1.0 / (new_base ** (torch.arange(0, rd, 2, dtype=torch.float32, device=device) / rd)) - else: - inv_freq = self.inv_freq.to(device) - t = torch.arange(seq_len, device=device, dtype=inv_freq.dtype) - freqs = torch.outer(t, inv_freq) - self._cos_cached = freqs.cos()[None, :, None, :] - self._sin_cached = freqs.sin()[None, :, None, :] - self._seq_len_cached = seq_len - return self._cos_cached.to(dtype=dtype), self._sin_cached.to(dtype=dtype) -def apply_rotary_emb(x: Tensor, cos: Tensor, sin: Tensor, rope_dims: int = 0) -> Tensor: - if rope_dims > 0 and rope_dims < x.size(-1): - x_rope, x_pass = x[..., :rope_dims], x[..., rope_dims:] - half = rope_dims // 2 - x1, x2 = x_rope[..., :half], x_rope[..., half:] - x_rope = torch.cat((x1 * cos + x2 * sin, x1 * (-sin) + x2 * cos), dim=-1) - return torch.cat((x_rope, x_pass), dim=-1) - half = x.size(-1) // 2 - x1, x2 = x[..., :half], x[..., half:] - return torch.cat((x1 * cos + x2 * sin, x1 * (-sin) + x2 * cos), dim=-1) - -class CausalSelfAttention(nn.Module): - def __init__( - self, - dim: int, - num_heads: int, - num_kv_heads: int, - rope_base: float, - qk_gain_init: float, - ): - super().__init__() - if dim % num_heads != 0: - raise ValueError("model_dim must be divisible by num_heads") - if num_heads % num_kv_heads != 0: - raise ValueError("num_heads must be divisible by num_kv_heads") - self.num_heads = num_heads - self.num_kv_heads = num_kv_heads - self.head_dim = dim // num_heads - if self.head_dim % 2 != 0: - raise ValueError("head_dim must be even for RoPE") - # No CastedLinear -- weights come from banks - self.q_gain = nn.Parameter(torch.full((num_heads,), qk_gain_init, dtype=torch.float32)) - self.rope_dims = 0 # set by GPT.__init__ for partial RoPE - self.rotary = Rotary(self.head_dim, base=rope_base, train_seq_len=1024) - self.use_xsa = False # set by GPT.__init__ for deep layers only - def _xsa_efficient(self, y: Tensor, v: Tensor) -> Tensor: - """Efficient XSA: subtract self-value projection via GQA-aware reshape (no repeat_interleave). - y: [B, T, H, D], v: [B, T, Hkv, D]. H must be divisible by Hkv.""" - B, T, H, D = y.shape - Hkv = v.size(-2) - group = H // Hkv - y_g = y.reshape(B, T, Hkv, group, D) # [B, T, Hkv, group, D] - vn = F.normalize(v, dim=-1).unsqueeze(-2) # [B, T, Hkv, 1, D] -- broadcast ready - proj = (y_g * vn).sum(dim=-1, keepdim=True) * vn - return (y_g - proj).reshape(B, T, H, D) - def forward(self, x: Tensor, q_w: Tensor, k_w: Tensor, v_w: Tensor, out_w: Tensor, v_embed: Tensor | None = None) -> Tensor: - if getattr(self, '_save_gptq', False): - self._gptq_qkv_in = x.detach() - bsz, seqlen, dim = x.shape - q = F.linear(x, q_w.to(x.dtype)).reshape(bsz, seqlen, self.num_heads, self.head_dim) - k = F.linear(x, k_w.to(x.dtype)).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim) - v = F.linear(x, v_w.to(x.dtype)) - if v_embed is not None: - v = v + v_embed - v = v.reshape(bsz, seqlen, self.num_kv_heads, self.head_dim) - q = F.rms_norm(q, (q.size(-1),)) - k = F.rms_norm(k, (k.size(-1),)) - cos, sin = self.rotary(seqlen, x.device, q.dtype) - q = apply_rotary_emb(q, cos, sin, self.rope_dims) - k = apply_rotary_emb(k, cos, sin, self.rope_dims) - q = q * self.q_gain.to(dtype=q.dtype)[None, None, :, None] - y = flash_attn_3_func(q, k, v, causal=True) - if self.use_xsa: - y = self._xsa_efficient(y, v) - y = y.reshape(bsz, seqlen, dim) - if getattr(self, '_save_gptq', False): - self._gptq_o_in = y.detach() - return F.linear(y, out_w.to(x.dtype)) - -class SmearGate(nn.Module): - def __init__(self, dim: int): - super().__init__() - self.gate = nn.Parameter(torch.zeros(dim, dtype=torch.float32)) - def forward(self, x: Tensor) -> Tensor: - g = torch.sigmoid(self.gate.to(dtype=x.dtype))[None, None, :] - x_prev = torch.cat([torch.zeros_like(x[:, :1]), x[:, :-1]], dim=1) - return (1 - g) * x + g * x_prev - -class BigramHashEmbedding(nn.Module): - def __init__(self, bigram_vocab_size: int, bigram_dim: int, model_dim: int): - super().__init__() - self.bigram_vocab_size = bigram_vocab_size - self.embed = nn.Embedding(bigram_vocab_size, bigram_dim) - nn.init.zeros_(self.embed.weight) - self.proj = CastedLinear(bigram_dim, model_dim, bias=False) if bigram_dim != model_dim else None - if self.proj is not None: - nn.init.zeros_(self.proj.weight) - self.scale = nn.Parameter(torch.tensor(0.05, dtype=torch.float32)) - def bigram_hash(self, tokens: Tensor) -> Tensor: - t = tokens.to(torch.int32) - mod = self.bigram_vocab_size - 1 - out = torch.empty_like(t) - out[..., 0] = mod - out[..., 1:] = torch.bitwise_xor(36313 * t[..., 1:], 27191 * t[..., :-1]) % mod - return out.long() - def forward(self, token_ids: Tensor) -> Tensor: - h = self.embed(self.bigram_hash(token_ids)) - if self.proj is not None: - h = self.proj(h) - return h * self.scale.to(dtype=h.dtype) - -class ValueEmbedding(nn.Module): - """Reinject token identity into attention values at specific layers. - Each table maps vocab tokens to a low-dim embedding, projected to model_dim.""" - def __init__(self, vocab_size: int, ve_dim: int, model_dim: int): - super().__init__() - self.embed = nn.Embedding(vocab_size, ve_dim) - nn.init.normal_(self.embed.weight, std=0.01) - self.proj = CastedLinear(ve_dim, model_dim, bias=False) if ve_dim != model_dim else None - if self.proj is not None: - nn.init.zeros_(self.proj.weight) - self.scale = nn.Parameter(torch.tensor(0.1, dtype=torch.float32)) - def forward(self, token_ids: Tensor) -> Tensor: - h = self.embed(token_ids) - if self.proj is not None: - h = self.proj(h) - return h * self.scale.to(dtype=h.dtype) - -class MLP(nn.Module): - def __init__(self, dim: int, mlp_mult: int, neg_slope: float = 0.5): - super().__init__() - self.neg_slope = neg_slope - # No CastedLinear -- weights come from banks - def forward(self, x: Tensor, up_w: Tensor, down_w: Tensor) -> Tensor: - if getattr(self, '_save_gptq', False): - self._gptq_up_in = x.detach() - x = F.leaky_relu(F.linear(x, up_w.to(x.dtype)), negative_slope=self.neg_slope) - x2 = x.square() - if getattr(self, '_save_gptq', False): - self._gptq_down_in = x2.detach() - return F.linear(x2, down_w.to(x.dtype)) - -class Block(nn.Module): - def __init__( - self, - dim: int, - num_heads: int, - num_kv_heads: int, - mlp_mult: int, - rope_base: float, - qk_gain_init: float, - layer_idx: int = 0, - ln_scale: bool = False, - neg_slope: float = 0.5, - ): - super().__init__() - self.layer_idx = layer_idx - self.attn_norm = RMSNorm() - self.mlp_norm = RMSNorm() - self.attn = CausalSelfAttention(dim, num_heads, num_kv_heads, rope_base, qk_gain_init) - self.mlp = MLP(dim, mlp_mult, neg_slope=neg_slope) - self.attn_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) - self.mlp_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) - self.resid_mix = nn.Parameter(torch.stack((torch.ones(dim), torch.zeros(dim))).float()) - self.ln_scale_factor = 1.0 / math.sqrt(layer_idx + 1) if ln_scale else 1.0 - def forward(self, x: Tensor, x0: Tensor, q_w: Tensor, k_w: Tensor, v_w: Tensor, out_w: Tensor, up_w: Tensor, down_w: Tensor, v_embed: Tensor | None = None) -> Tensor: - mix = self.resid_mix.to(dtype=x.dtype) - x_in = mix[0][None, None, :] * x + mix[1][None, None, :] * x0 - attn_out = self.attn(self.attn_norm(x_in) * self.ln_scale_factor, q_w, k_w, v_w, out_w, v_embed=v_embed) - x_out = x_in + self.attn_scale.to(dtype=x_in.dtype)[None, None, :] * attn_out - mlp_out = self.mlp_scale.to(dtype=x_out.dtype)[None, None, :] * self.mlp(self.mlp_norm(x_out) * self.ln_scale_factor, up_w, down_w) - return x_out + mlp_out - -class GPT(nn.Module): - def __init__( - self, - vocab_size: int, - num_layers: int, - model_dim: int, - num_heads: int, - num_kv_heads: int, - mlp_mult: int, - tie_embeddings: bool, - tied_embed_init_std: float, - logit_softcap: float, - rope_base: float, - qk_gain_init: float, - bigram_vocab_size: int = 0, - bigram_dim: int = 128, - xsa_last_n: int = 0, - rope_dims: int = 0, - ln_scale: bool = False, - ve_enabled: bool = False, - ve_dim: int = 128, - ve_layers: str = "9,10", - neg_slope: float = 0.5, - ): - super().__init__() - self._ve_target_dim = num_kv_heads * (model_dim // num_heads) # kv_dim for value projection - if logit_softcap <= 0.0: - raise ValueError(f"logit_softcap must be positive, got {logit_softcap}") - self.tie_embeddings = tie_embeddings - self.tied_embed_init_std = tied_embed_init_std - self.logit_softcap = logit_softcap - self.tok_emb = nn.Embedding(vocab_size, model_dim) - self.bigram = BigramHashEmbedding(bigram_vocab_size, bigram_dim, model_dim) if bigram_vocab_size > 0 else None - self.smear = SmearGate(model_dim) - self.num_encoder_layers = num_layers // 2 - self.num_decoder_layers = num_layers - self.num_encoder_layers - self.num_skip_weights = min(self.num_encoder_layers, self.num_decoder_layers) - self.skip_weights = nn.Parameter(torch.ones(self.num_skip_weights, model_dim, dtype=torch.float32)) - # Parameter banks: contiguous 3D tensors for batched optimizer - head_dim = model_dim // num_heads - kv_dim = num_kv_heads * head_dim - mlp_dim = int(mlp_mult * model_dim) - self.num_layers = num_layers - self.qo_bank = nn.Parameter(torch.empty(2 * num_layers, model_dim, model_dim)) - self.kv_bank = nn.Parameter(torch.empty(2 * num_layers, kv_dim, model_dim)) - self.mlp_up_bank = nn.Parameter(torch.empty(num_layers, mlp_dim, model_dim)) - self.mlp_down_bank = nn.Parameter(torch.empty(num_layers, model_dim, mlp_dim)) - self.blocks = nn.ModuleList( - [ - Block( - model_dim, - num_heads, - num_kv_heads, - mlp_mult, - rope_base, - qk_gain_init, - layer_idx=i, - ln_scale=ln_scale, - neg_slope=neg_slope, - ) - for i in range(num_layers) - ] - ) - if rope_dims > 0: - head_dim = model_dim // num_heads - for block in self.blocks: - block.attn.rope_dims = rope_dims - block.attn.rotary = Rotary(head_dim, base=rope_base, train_seq_len=1024, rope_dims=rope_dims) - self.ve_layer_indices = [int(x) for x in ve_layers.split(",") if x.strip()] if ve_enabled else [] - kv_dim_ve = self._ve_target_dim - if self.ve_layer_indices: - self.ve_shared = ValueEmbedding(vocab_size, ve_dim, kv_dim_ve) - self.ve_layer_scales = nn.ParameterList( - [nn.Parameter(torch.ones(1, dtype=torch.float32)) for _ in self.ve_layer_indices] - ) - else: - self.ve_shared = None - self.ve_layer_scales = nn.ParameterList() - self.value_embeds = nn.ModuleList() # keep empty for compat - self.final_norm = RMSNorm() - self.lm_head = None if tie_embeddings else CastedLinear(model_dim, vocab_size, bias=False) - if self.lm_head is not None: - self.lm_head._zero_init = True - if xsa_last_n > 0: - for i in range(max(0, num_layers - xsa_last_n), num_layers): - self.blocks[i].attn.use_xsa = True - self._init_weights() - def _init_weights(self) -> None: - if self.tie_embeddings: - nn.init.normal_(self.tok_emb.weight, mean=0.0, std=self.tied_embed_init_std) - n = self.num_layers - proj_scale = 1.0 / math.sqrt(2 * n) - # Init banks: orthogonal, with proj layers scaled down and out/down zero-init - for i in range(n): - nn.init.orthogonal_(self.qo_bank.data[i], gain=1.0) # Q - nn.init.zeros_(self.qo_bank.data[n + i]) # Out (zero init) - nn.init.orthogonal_(self.kv_bank.data[i], gain=1.0) # K - nn.init.orthogonal_(self.kv_bank.data[n + i], gain=1.0) # V - nn.init.orthogonal_(self.mlp_up_bank.data[i], gain=1.0) # MLP up - nn.init.zeros_(self.mlp_down_bank.data[i]) # MLP down (zero init) - # Scale proj layers (out_proj and mlp_down are "proj" layers) - self.qo_bank.data[n + i].mul_(proj_scale) - self.mlp_down_bank.data[i].mul_(proj_scale) - # Init remaining nn.Linear modules (bigram proj, lm_head) - for name, module in self.named_modules(): - if isinstance(module, nn.Linear): - if getattr(module, "_zero_init", False): - nn.init.zeros_(module.weight) - elif module.weight.ndim == 2 and module.weight.shape[0] >= 64 and module.weight.shape[1] >= 64: - nn.init.orthogonal_(module.weight, gain=1.0) - def _get_ve(self, layer_idx: int, input_ids: Tensor, ve_cache: dict | None = None) -> Tensor | None: - """Get value embedding for a specific layer using shared table + per-layer scale.""" - if self.ve_shared is None or layer_idx not in self.ve_layer_indices: - return None - if ve_cache is not None and 've' not in ve_cache: - ve_cache['ve'] = self.ve_shared(input_ids) - ve_base = ve_cache['ve'] if ve_cache is not None else self.ve_shared(input_ids) - ve_idx = self.ve_layer_indices.index(layer_idx) - return ve_base * self.ve_layer_scales[ve_idx].to(dtype=ve_base.dtype) - def forward(self, input_ids: Tensor, target_ids: Tensor) -> Tensor: - n = self.num_layers - x = self.tok_emb(input_ids) - if self.bigram is not None: - x = x + self.bigram(input_ids) - x = F.rms_norm(x, (x.size(-1),)) - x = self.smear(x) - x0 = x - skips: list[Tensor] = [] - ve_cache: dict = {} - for i in range(self.num_encoder_layers): - ve = self._get_ve(i, input_ids, ve_cache) - x = self.blocks[i](x, x0, - self.qo_bank[i], self.kv_bank[i], self.kv_bank[n + i], - self.qo_bank[n + i], self.mlp_up_bank[i], self.mlp_down_bank[i], - v_embed=ve) - skips.append(x) - for i in range(self.num_decoder_layers): - bi = self.num_encoder_layers + i - if skips: - x = x + self.skip_weights[i].to(dtype=x.dtype)[None, None, :] * skips.pop() - ve = self._get_ve(bi, input_ids, ve_cache) - x = self.blocks[bi](x, x0, - self.qo_bank[bi], self.kv_bank[bi], self.kv_bank[n + bi], - self.qo_bank[n + bi], self.mlp_up_bank[bi], self.mlp_down_bank[bi], - v_embed=ve) - x = self.final_norm(x) - x_flat = x.reshape(-1, x.size(-1)) - targets = target_ids.reshape(-1) - if self.tie_embeddings: - logits_proj = F.linear(x_flat, self.tok_emb.weight) - else: - if self.lm_head is None: - raise RuntimeError("lm_head is required when tie_embeddings=False") - logits_proj = self.lm_head(x_flat) - logits = self.logit_softcap * torch.tanh(logits_proj / self.logit_softcap) - return F.cross_entropy(logits.float(), targets, reduction="mean") - def forward_logits(self, input_ids: Tensor) -> Tensor: - """Return logits (bsz, seq_len, vocab) without computing loss.""" - n = self.num_layers - x = self.tok_emb(input_ids) - if self.bigram is not None: - x = x + self.bigram(input_ids) - x = F.rms_norm(x, (x.size(-1),)) - x = self.smear(x) - x0 = x - skips: list[Tensor] = [] - ve_cache: dict = {} - for i in range(self.num_encoder_layers): - ve = self._get_ve(i, input_ids, ve_cache) - x = self.blocks[i](x, x0, - self.qo_bank[i], self.kv_bank[i], self.kv_bank[n + i], - self.qo_bank[n + i], self.mlp_up_bank[i], self.mlp_down_bank[i], - v_embed=ve) - skips.append(x) - for i in range(self.num_decoder_layers): - bi = self.num_encoder_layers + i - if skips: - x = x + self.skip_weights[i].to(dtype=x.dtype)[None, None, :] * skips.pop() - ve = self._get_ve(bi, input_ids, ve_cache) - x = self.blocks[bi](x, x0, - self.qo_bank[bi], self.kv_bank[bi], self.kv_bank[n + bi], - self.qo_bank[n + bi], self.mlp_up_bank[bi], self.mlp_down_bank[bi], - v_embed=ve) - x = self.final_norm(x) - if self.tie_embeddings: - logits_proj = F.linear(x, self.tok_emb.weight) - else: - logits_proj = self.lm_head(x) - return self.logit_softcap * torch.tanh(logits_proj / self.logit_softcap) - -# --- Sliding window evaluation --- - -def eval_val_sliding( - args: Hyperparameters, - base_model: nn.Module, - rank: int, - world_size: int, - device: torch.device, - val_tokens: Tensor, - base_bytes_lut: Tensor, - has_leading_space_lut: Tensor, - is_boundary_token_lut: Tensor, - stride: int, - batch_seqs: int = 32, - eval_seq_len: int | None = None, -) -> tuple[float, float]: - """Sliding window evaluation: each token scored with maximum context.""" - seq_len = eval_seq_len or args.train_seq_len - total_tokens = val_tokens.numel() - 1 - window_starts = [ws for ws in range(0, total_tokens, stride) - if min(ws + seq_len, total_tokens) - ws >= 1] - total_windows = len(window_starts) - my_s = (total_windows * rank) // world_size - my_e = (total_windows * (rank + 1)) // world_size - my_windows = window_starts[my_s:my_e] - loss_sum = torch.zeros((), device=device, dtype=torch.float64) - token_count = torch.zeros((), device=device, dtype=torch.float64) - byte_count = torch.zeros((), device=device, dtype=torch.float64) - base_model.eval() - compiled_logits = torch.compile(base_model.forward_logits, dynamic=False, fullgraph=True) - with torch.inference_mode(): - for bi in range(0, len(my_windows), batch_seqs): - batch_ws = my_windows[bi:bi + batch_seqs] - bsz = len(batch_ws) - x_batch = torch.zeros(bsz, seq_len, dtype=torch.int64, device=device) - y_batch = torch.zeros(bsz, seq_len, dtype=torch.int64, device=device) - wlens: list[int] = [] - for i, ws in enumerate(batch_ws): - end = min(ws + seq_len, total_tokens) - wlen = end - ws - wlens.append(wlen) - chunk = val_tokens[ws:end + 1].to(dtype=torch.int64, device=device) - x_batch[i, :wlen] = chunk[:-1] - y_batch[i, :wlen] = chunk[1:] - with torch.autocast(device_type="cuda", dtype=torch.bfloat16): - logits = compiled_logits(x_batch) - nll = F.cross_entropy( - logits.reshape(-1, logits.size(-1)).float(), - y_batch.reshape(-1), - reduction="none", - ).reshape(bsz, seq_len) - for i, ws in enumerate(batch_ws): - wlen = wlens[i] - s = 0 if ws == 0 else max(wlen - stride, 0) - scored_nll = nll[i, s:wlen].to(torch.float64) - loss_sum += scored_nll.sum() - token_count += float(wlen - s) - tgt = y_batch[i, s:wlen] - prev = x_batch[i, s:wlen] - tb = base_bytes_lut[tgt].to(torch.float64) - tb += (has_leading_space_lut[tgt] & ~is_boundary_token_lut[prev]).to(torch.float64) - byte_count += tb.sum() - if dist.is_available() and dist.is_initialized(): - dist.all_reduce(loss_sum, op=dist.ReduceOp.SUM) - dist.all_reduce(token_count, op=dist.ReduceOp.SUM) - dist.all_reduce(byte_count, op=dist.ReduceOp.SUM) - val_loss = (loss_sum / token_count).item() - bits_per_token = val_loss / math.log(2.0) - tokens_per_byte = token_count.item() / byte_count.item() - base_model.train() - return val_loss, bits_per_token * tokens_per_byte - - -def eval_val_sliding_ttt( - args: Hyperparameters, base_model: nn.Module, rank: int, world_size: int, - device: torch.device, val_tokens: Tensor, base_bytes_lut: Tensor, - has_leading_space_lut: Tensor, is_boundary_token_lut: Tensor, - stride: int, batch_seqs: int = 32, log0=print, -) -> tuple[float, float]: - """Legal score-first TTT (PR #461 recipe): score each chunk with sliding windows, - then train on it. Every token scored BEFORE any update that could use it.""" - seq_len = args.train_seq_len - total_tokens = val_tokens.numel() - 1 - ttt_chunk = args.ttt_chunk_tokens - - # Pre-compute all window starts - window_starts = [ws for ws in range(0, total_tokens, stride) - if min(ws + seq_len, total_tokens) - ws >= stride or ws == 0] - - # Assign each window to a chunk based on the first token it scores - num_chunks = (total_tokens + ttt_chunk - 1) // ttt_chunk - chunk_windows: list[list[int]] = [[] for _ in range(num_chunks)] - for ws in window_starts: - end = min(ws + seq_len, total_tokens) - wlen = end - ws - s = 0 if ws == 0 else max(wlen - stride, 0) - scored_start = ws + s - ci = min(scored_start // ttt_chunk, num_chunks - 1) - chunk_windows[ci].append(ws) - - log0(f"ttt_sliding:start chunks={num_chunks} chunk_tokens={ttt_chunk} " - f"total_windows={len(window_starts)} stride={stride} " - f"ttt_lr={args.ttt_lr} ttt_epochs={args.ttt_epochs} " - f"freeze_blocks={args.ttt_freeze_blocks}") - - loss_sum = torch.zeros((), device=device, dtype=torch.float64) - token_count = torch.zeros((), device=device, dtype=torch.float64) - byte_count = torch.zeros((), device=device, dtype=torch.float64) - - # Freeze first N blocks - frozen_block_ids = set(range(min(args.ttt_freeze_blocks, len(base_model.blocks)))) - ttt_params = [] - for name, p in base_model.named_parameters(): - freeze = False - for bi in frozen_block_ids: - if f"blocks.{bi}." in name: - freeze = True - break - if freeze: - p.requires_grad_(False) - else: - p.requires_grad_(True) - ttt_params.append(p) - - log0(f"ttt_sliding:params unfrozen={sum(p.numel() for p in ttt_params)} " - f"frozen={sum(p.numel() for p in base_model.parameters() if not p.requires_grad)}") - - optimizer = torch.optim.SGD(ttt_params, lr=args.ttt_lr, momentum=args.ttt_momentum) - t0 = time.perf_counter() - - for ci in range(num_chunks): - windows = chunk_windows[ci] - if not windows: - continue - chunk_start = ci * ttt_chunk - chunk_end = min((ci + 1) * ttt_chunk, total_tokens) - - # --- Phase 1: SCORE this chunk's windows (inference_mode) --- - my_s = (len(windows) * rank) // world_size - my_e = (len(windows) * (rank + 1)) // world_size - my_windows = windows[my_s:my_e] - - base_model.eval() - with torch.inference_mode(): - for bi in range(0, len(my_windows), batch_seqs): - batch_ws = my_windows[bi:bi + batch_seqs] - bsz = len(batch_ws) - x_batch = torch.zeros(bsz, seq_len, dtype=torch.int64, device=device) - y_batch = torch.zeros(bsz, seq_len, dtype=torch.int64, device=device) - wlens: list[int] = [] - for i, ws in enumerate(batch_ws): - end = min(ws + seq_len, total_tokens) - wlen = end - ws - wlens.append(wlen) - chunk_tok = val_tokens[ws:end + 1].to(dtype=torch.int64, device=device) - x_batch[i, :wlen] = chunk_tok[:-1] - y_batch[i, :wlen] = chunk_tok[1:] - with torch.autocast(device_type="cuda", dtype=torch.bfloat16): - logits = base_model.forward_logits(x_batch) - nll = F.cross_entropy( - logits.reshape(-1, logits.size(-1)).float(), - y_batch.reshape(-1), reduction="none", - ).reshape(bsz, seq_len) - for i, ws in enumerate(batch_ws): - wlen = wlens[i] - s = 0 if ws == 0 else max(wlen - stride, 0) - scored_nll = nll[i, s:wlen].to(torch.float64) - loss_sum += scored_nll.sum() - token_count += float(wlen - s) - tgt, prev = y_batch[i, s:wlen], x_batch[i, s:wlen] - tb = base_bytes_lut[tgt].to(torch.float64) - tb += (has_leading_space_lut[tgt] & ~is_boundary_token_lut[prev]).to(torch.float64) - byte_count += tb.sum() - - # --- Phase 2: TRAIN on this chunk (already scored = legal) --- - is_last_chunk = (ci == num_chunks - 1) - if not is_last_chunk and args.ttt_epochs > 0: - base_model.train() - chunk_seqs = (chunk_end - chunk_start) // seq_len - if chunk_seqs > 0: - cos_lr = args.ttt_lr * 0.5 * (1.0 + math.cos(math.pi * ci / max(num_chunks - 1, 1))) - for pg in optimizer.param_groups: - pg['lr'] = cos_lr - my_seq_s = (chunk_seqs * rank) // world_size - my_seq_e = (chunk_seqs * (rank + 1)) // world_size - my_chunk_seqs = my_seq_e - my_seq_s - for _ep in range(args.ttt_epochs): - for bs in range(0, my_chunk_seqs, args.ttt_batch_seqs): - be = min(bs + args.ttt_batch_seqs, my_chunk_seqs) - actual_bs = my_seq_s + bs - start_tok = chunk_start + actual_bs * seq_len - end_tok = chunk_start + (my_seq_s + be) * seq_len + 1 - if end_tok > val_tokens.numel(): - continue - local = val_tokens[start_tok:end_tok].to(device=device, dtype=torch.int64) - x = local[:-1].reshape(-1, seq_len) - y = local[1:].reshape(-1, seq_len) - optimizer.zero_grad(set_to_none=True) - with torch.autocast(device_type="cuda", dtype=torch.bfloat16): - loss = base_model(x, y) - loss.backward() - if world_size > 1: - for p in ttt_params: - if p.grad is not None: - dist.all_reduce(p.grad, op=dist.ReduceOp.AVG) - torch.nn.utils.clip_grad_norm_(ttt_params, args.ttt_grad_clip) - optimizer.step() - - if rank == 0 and (ci % 10 == 0 or ci == num_chunks - 1): - elapsed = time.perf_counter() - t0 - rl = loss_sum.item() / max(token_count.item(), 1) - rbpb = rl / math.log(2.0) * (token_count.item() / max(byte_count.item(), 1)) if token_count.item() > 0 else 0.0 - log0(f" ttt_chunk [{ci+1}/{num_chunks}] bpb={rbpb:.6f} time={elapsed:.1f}s") - - if dist.is_available() and dist.is_initialized(): - dist.all_reduce(loss_sum, op=dist.ReduceOp.SUM) - dist.all_reduce(token_count, op=dist.ReduceOp.SUM) - dist.all_reduce(byte_count, op=dist.ReduceOp.SUM) - - val_loss = (loss_sum / token_count).item() - val_bpb = val_loss / math.log(2.0) * (token_count.item() / byte_count.item()) - - for p in base_model.parameters(): - p.requires_grad_(True) - base_model.eval() - - log0(f"ttt_sliding:done val_loss={val_loss:.6f} val_bpb={val_bpb:.6f} " - f"elapsed={time.perf_counter() - t0:.1f}s") - return val_loss, val_bpb - - -# --- GPTQ-lite int6 quantization --- - -def _classify_param(name: str) -> str: - if "tok_emb" in name or "lm_head" in name: - return "embed" - if ".mlp." in name: - return "mlp" - if ".attn." in name or (".proj." in name and ".mlp." not in name): - return "attn" - return "other" -def quantize_int6_per_row(t: Tensor, clip_range: int = 31) -> tuple[Tensor, Tensor]: - t32 = t.float() - if t32.ndim == 2: - best_q, best_s, best_err = None, None, float('inf') - for pct in [0.9990, 0.9995, 0.9999, 0.99999, 1.0]: - if pct < 1.0: - row_clip = torch.quantile(t32.abs(), pct, dim=1) - else: - row_clip = t32.abs().amax(dim=1) - s = (row_clip / clip_range).clamp_min(1.0 / clip_range).to(torch.float16) - q = torch.clamp(torch.round(t32 / s.float()[:, None]), -clip_range, clip_range).to(torch.int8) - recon = q.float() * s.float()[:, None] - err = (t32 - recon).pow(2).mean().item() - if err < best_err: - best_q, best_s, best_err = q, s, err - return best_q, best_s - amax = t32.abs().max().item() - scale = torch.tensor(amax / clip_range if amax > 0 else 1.0, dtype=torch.float16) - q = torch.clamp(torch.round(t32 / scale.float()), -clip_range, clip_range).to(torch.int8) - return q, scale - -def _unbank_state_dict(sd: dict[str, Tensor], num_layers: int) -> dict[str, Tensor]: - """Convert 3D bank tensors into individual 2D tensors with standard names.""" - out: dict[str, Tensor] = {} - n = num_layers - for name, tensor in sd.items(): - if name == "qo_bank": - for i in range(n): - out[f"blocks.{i}.attn.c_q.weight"] = tensor[i] - out[f"blocks.{i}.attn.proj.weight"] = tensor[n + i] - elif name == "kv_bank": - for i in range(n): - out[f"blocks.{i}.attn.c_k.weight"] = tensor[i] - out[f"blocks.{i}.attn.c_v.weight"] = tensor[n + i] - elif name == "mlp_up_bank": - for i in range(n): - out[f"blocks.{i}.mlp.fc.weight"] = tensor[i] - elif name == "mlp_down_bank": - for i in range(n): - out[f"blocks.{i}.mlp.proj.weight"] = tensor[i] - else: - out[name] = tensor - return out - -def _rebank_state_dict(sd: dict[str, Tensor], num_layers: int, template_sd: dict[str, Tensor]) -> dict[str, Tensor]: - """Convert individual 2D tensors back into 3D bank tensors.""" - out: dict[str, Tensor] = {} - n = num_layers - # Reconstruct banks from individual weight keys - qo_slices = [None] * (2 * n) - kv_slices = [None] * (2 * n) - up_slices = [None] * n - down_slices = [None] * n - consumed = set() - for i in range(n): - qk = f"blocks.{i}.attn.c_q.weight" - if qk in sd: - qo_slices[i] = sd[qk] - consumed.add(qk) - ok = f"blocks.{i}.attn.proj.weight" - if ok in sd: - qo_slices[n + i] = sd[ok] - consumed.add(ok) - kk = f"blocks.{i}.attn.c_k.weight" - if kk in sd: - kv_slices[i] = sd[kk] - consumed.add(kk) - vk = f"blocks.{i}.attn.c_v.weight" - if vk in sd: - kv_slices[n + i] = sd[vk] - consumed.add(vk) - fk = f"blocks.{i}.mlp.fc.weight" - if fk in sd: - up_slices[i] = sd[fk] - consumed.add(fk) - dk = f"blocks.{i}.mlp.proj.weight" - if dk in sd: - down_slices[i] = sd[dk] - consumed.add(dk) - out["qo_bank"] = torch.stack(qo_slices).to(dtype=template_sd["qo_bank"].dtype) - out["kv_bank"] = torch.stack(kv_slices).to(dtype=template_sd["kv_bank"].dtype) - out["mlp_up_bank"] = torch.stack(up_slices).to(dtype=template_sd["mlp_up_bank"].dtype) - out["mlp_down_bank"] = torch.stack(down_slices).to(dtype=template_sd["mlp_down_bank"].dtype) - for name, tensor in sd.items(): - if name not in consumed: - out[name] = tensor - return out - -def mixed_quantize_int6(state_dict: dict[str, Tensor], int6_cats: set[str], clip_range: int = 31, - hessians: dict[str, Tensor] | None = None): - num_layers_total = max( - (int(k.split(".")[1]) for k in state_dict if k.startswith("blocks.")), - default=0, - ) + 1 - late_k_layers = set(range(num_layers_total - 2, num_layers_total)) - result: dict[str, Tensor] = {} - meta: dict[str, object] = {} - gptq_count, naive_count = 0, 0 - for name, tensor in state_dict.items(): - t = tensor.detach().cpu().contiguous() - cat = _classify_param(name) - if not t.is_floating_point() or t.numel() <= 65536: - result[name] = t.to(torch.float16) if t.is_floating_point() else t - meta[name] = "passthrough" - continue - if any(p in name for p in CONTROL_TENSOR_NAME_PATTERNS): - result[name] = t.float() - meta[name] = "passthrough_ctrl" - continue - if cat in int6_cats and t.ndim >= 1: - H = hessians.get(name) if hessians else None - if H is not None and t.ndim == 2: - q, s = gptq_quantize_weight(t, H.cpu(), clip_range=clip_range) - gptq_count += 1 - else: - q, s = quantize_int6_per_row(t, clip_range=clip_range) - naive_count += 1 - result[name + ".q"] = q - result[name + ".scale"] = s - meta[name] = {"type": "int6"} - else: - q, s = quantize_float_tensor(t) - result[name + ".q"] = q - result[name + ".scale"] = s - meta[name] = {"type": "int8"} - if hessians: - print(f"gptq_quantize: {gptq_count} GPTQ layers, {naive_count} naive layers", flush=True) - return result, meta -def dequantize_mixed_int6(result: dict[str, Tensor], meta: dict[str, object], - template_sd: dict[str, Tensor]) -> dict[str, Tensor]: - out: dict[str, Tensor] = {} - for name, orig in template_sd.items(): - info = meta.get(name) - if info is None: - continue - orig_dtype = orig.dtype - if info in ("passthrough", "passthrough_ctrl", "passthrough_fp16"): - t = result[name] - if t.dtype == torch.float16 and orig_dtype in (torch.float32, torch.bfloat16): - t = t.to(orig_dtype) - out[name] = t - continue - q, s = result[name + ".q"], result[name + ".scale"] - if s.ndim > 0: - out[name] = (q.float() * s.float().view(q.shape[0], *([1] * (q.ndim - 1)))).to(orig_dtype) - else: - out[name] = (q.float() * float(s.item())).to(orig_dtype) - return out - -# --- Full Hessian GPTQ --- - -def gptq_quantize_weight(W: Tensor, H: Tensor, clip_range: int = 31, - block_size: int = 128, percdamp: float = 0.01) -> tuple[Tensor, Tensor]: - """GPTQ with Cholesky error compensation and actorder (Frantar et al., ICLR 2023).""" - W_orig = W.float().clone() - rows, cols = W_orig.shape - H = H.float().clone() - dead = torch.diag(H) == 0 - H[dead, dead] = 1 - damp = percdamp * H.diag().mean() - H.diagonal().add_(damp) - perm = torch.argsort(H.diag(), descending=True) - invperm = torch.argsort(perm) - W_perm = W_orig[:, perm].clone() - W_perm[:, dead[perm]] = 0 - H = H[perm][:, perm] - try: - Hinv = torch.cholesky_inverse(torch.linalg.cholesky(H)) - Hinv = torch.linalg.cholesky(Hinv, upper=True) - except torch.linalg.LinAlgError: - return quantize_int6_per_row(W_orig, clip_range) - best_q, best_scale, best_err = None, None, float('inf') - for pct in [0.9990, 0.9995, 0.9999, 0.99999, 1.0]: - if pct < 1.0: - row_clip = torch.quantile(W_orig.abs(), pct, dim=1) - else: - row_clip = W_orig.abs().amax(dim=1) - s = (row_clip / clip_range).clamp_min(1.0 / clip_range).to(torch.float16) - sf = s.float() - Q = torch.zeros(rows, cols, dtype=torch.int8) - W_work = W_perm.clone() - for i1 in range(0, cols, block_size): - i2 = min(i1 + block_size, cols) - W_block = W_work[:, i1:i2].clone() - Hinv_block = Hinv[i1:i2, i1:i2] - Err = torch.zeros(rows, i2 - i1) - for j in range(i2 - i1): - w_col = W_block[:, j] - d = Hinv_block[j, j] - q_col = torch.clamp(torch.round(w_col / sf), -clip_range, clip_range) - Q[:, i1 + j] = q_col.to(torch.int8) - err = (w_col - q_col.float() * sf) / d - Err[:, j] = err - W_block[:, j:] -= err.unsqueeze(1) * Hinv_block[j, j:].unsqueeze(0) - if i2 < cols: - W_work[:, i2:] -= Err @ Hinv[i1:i2, i2:] - recon = Q.float() * sf[:, None] - mse = (W_perm - recon).pow(2).mean().item() - if mse < best_err: - best_q, best_scale, best_err = Q, s, mse - best_q = best_q[:, invperm] - return best_q, best_scale - -def _init_hessians(nl: int, dim: int, mlp_dim: int, device: torch.device) -> dict[str, Tensor]: - h: dict[str, Tensor] = {} - for i in range(nl): - for k in ['c_q', 'c_k', 'c_v']: - h[f'blocks.{i}.attn.{k}.weight'] = torch.zeros(dim, dim, dtype=torch.float32, device=device) - h[f'blocks.{i}.attn.proj.weight'] = torch.zeros(dim, dim, dtype=torch.float32, device=device) - h[f'blocks.{i}.mlp.fc.weight'] = torch.zeros(dim, dim, dtype=torch.float32, device=device) - h[f'blocks.{i}.mlp.proj.weight'] = torch.zeros(mlp_dim, mlp_dim, dtype=torch.float32, device=device) - return h - -def _accum_hessians(hessians: dict[str, Tensor], blocks: nn.ModuleList, dim: int, mlp_dim: int) -> None: - for i, block in enumerate(blocks): - qkv_in = block.attn._gptq_qkv_in.float().reshape(-1, dim) - h_qkv = qkv_in.t() @ qkv_in - hessians[f'blocks.{i}.attn.c_q.weight'] += h_qkv - hessians[f'blocks.{i}.attn.c_k.weight'] += h_qkv - hessians[f'blocks.{i}.attn.c_v.weight'] += h_qkv - o_in = block.attn._gptq_o_in.float().reshape(-1, dim) - hessians[f'blocks.{i}.attn.proj.weight'] += o_in.t() @ o_in - up_in = block.mlp._gptq_up_in.float().reshape(-1, dim) - hessians[f'blocks.{i}.mlp.fc.weight'] += up_in.t() @ up_in - down_in = block.mlp._gptq_down_in.float().reshape(-1, mlp_dim) - hessians[f'blocks.{i}.mlp.proj.weight'] += down_in.t() @ down_in - -def _finalize_hessians(hessians: dict[str, Tensor], num_batches: int) -> None: - for name in hessians: - hessians[name] = hessians[name].cpu() / num_batches - damp = 0.01 * torch.diag(hessians[name]).mean().clamp_min(1e-6) - hessians[name] += damp * torch.eye(hessians[name].shape[0]) - -def gptq_collect_hessians(base_model: nn.Module, train_loader, device: torch.device, - num_batches: int, batch_tokens: int, seq_len: int, - grad_accum_steps: int) -> dict[str, Tensor]: - """Collect Hessians H = X^T X from training data.""" - nl = base_model.num_layers - dim = base_model.tok_emb.weight.shape[1] - mlp_dim = base_model.mlp_up_bank.shape[1] - hessians = _init_hessians(nl, dim, mlp_dim, device) - for block in base_model.blocks: - block.attn._save_gptq = True - block.mlp._save_gptq = True - base_model.eval() - with torch.inference_mode(), torch.autocast(device_type='cuda', dtype=torch.bfloat16): - for _ in range(num_batches): - x, y = train_loader.next_batch(batch_tokens, seq_len, grad_accum_steps) - base_model(x, y) - _accum_hessians(hessians, base_model.blocks, dim, mlp_dim) - for block in base_model.blocks: - block.attn._save_gptq = False - block.mlp._save_gptq = False - _finalize_hessians(hessians, num_batches) - base_model.train() - return hessians - -# --- Training --- - -def main() -> None: - code = Path(__file__).read_text(encoding="utf-8") - args = Hyperparameters() - # zeropower_via_newtonschulz5 runs eagerly with bmm -- do NOT compile - distributed = "RANK" in os.environ and "WORLD_SIZE" in os.environ - rank = int(os.environ.get("RANK", "0")) - world_size = int(os.environ.get("WORLD_SIZE", "1")) - local_rank = int(os.environ.get("LOCAL_RANK", "0")) - if world_size <= 0: - raise ValueError(f"WORLD_SIZE must be positive, got {world_size}") - if 8 % world_size != 0: - raise ValueError(f"WORLD_SIZE={world_size} must divide 8 so grad_accum_steps stays integral") - grad_accum_steps = 8 // world_size - grad_scale = 1.0 / grad_accum_steps - if not torch.cuda.is_available(): - raise RuntimeError("CUDA is required") - device = torch.device("cuda", local_rank) - torch.cuda.set_device(device) - if distributed: - dist.init_process_group(backend="nccl", device_id=device) - dist.barrier() - master_process = rank == 0 - torch.backends.cuda.matmul.allow_tf32 = True - torch.backends.cudnn.allow_tf32 = True - from torch.backends.cuda import enable_cudnn_sdp, enable_flash_sdp, enable_math_sdp, enable_mem_efficient_sdp - enable_cudnn_sdp(False) - enable_flash_sdp(True) - enable_mem_efficient_sdp(False) - enable_math_sdp(False) - logfile = None - if master_process: - os.makedirs("logs", exist_ok=True) - logfile = f"logs/{args.run_id}.txt" - print(logfile) - def log0(msg: str, console: bool = True) -> None: - if not master_process: - return - if console: - print(msg) - if logfile is not None: - with open(logfile, "a", encoding="utf-8") as f: - print(msg, file=f) - log0(code, console=False) - log0("=" * 100, console=False) - log0(f"Running Python {sys.version}", console=False) - log0(f"Running PyTorch {torch.__version__}", console=False) - log0( - subprocess.run(["nvidia-smi"], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True, check=False).stdout, - console=False, - ) - log0("=" * 100, console=False) - random.seed(args.seed) - np.random.seed(args.seed) - torch.manual_seed(args.seed) - torch.cuda.manual_seed_all(args.seed) - (base_bytes_lut, has_leading_space_lut, is_boundary_token_lut), tokenizer_metadata = load_tokenizer_luts( - args.tokenizer_path, args.tokenizer_meta_path, args.vocab_size, device, - validate_meta=args.tokenizer_meta_validate, - ) - log0(f"tokenizer: kind={tokenizer_metadata.get('tokenizer_kind', 'unknown')} vocab={tokenizer_metadata.get('vocab_size', '?')}") - if tokenizer_metadata.get('tokenizer_kind') == 'sentencepiece': - sp = spm.SentencePieceProcessor(model_file=args.tokenizer_path) - if int(sp.vocab_size()) != args.vocab_size: - raise ValueError( - f"VOCAB_SIZE={args.vocab_size} does not match tokenizer vocab_size={int(sp.vocab_size())}" - ) - dataset_dir = Path(args.data_path).resolve() - actual_train_files = len(list(dataset_dir.glob("fineweb_train_*.bin"))) - effective_eval_seq_len = args.eval_seq_len if args.eval_seq_len > 0 else args.train_seq_len - val_seq_len = max(args.train_seq_len, effective_eval_seq_len) - val_tokens = load_validation_tokens(args.val_files, val_seq_len) - log0(f"train_loader:dataset:{dataset_dir.name} train_shards:{actual_train_files}") - log0(f"val_loader:shards pattern={args.val_files} tokens:{val_tokens.numel() - 1}") - base_model = GPT( - vocab_size=args.vocab_size, - num_layers=args.num_layers, - model_dim=args.model_dim, - num_heads=args.num_heads, - num_kv_heads=args.num_kv_heads, - mlp_mult=args.mlp_mult, - tie_embeddings=args.tie_embeddings, - tied_embed_init_std=args.tied_embed_init_std, - logit_softcap=args.logit_softcap, - rope_base=args.rope_base, - qk_gain_init=args.qk_gain_init, - bigram_vocab_size=args.bigram_vocab_size, - bigram_dim=args.bigram_dim, - xsa_last_n=args.xsa_last_n, - rope_dims=args.rope_dims, - ln_scale=args.ln_scale, - ve_enabled=args.ve_enabled, - ve_dim=args.ve_dim, - ve_layers=args.ve_layers, - neg_slope=args.negative_slope, - ).to(device).bfloat16() - # Banks stay FP32 (like CastedLinear weights), cast to BF16 in forward - base_model.qo_bank.data = base_model.qo_bank.data.float() - base_model.kv_bank.data = base_model.kv_bank.data.float() - base_model.mlp_up_bank.data = base_model.mlp_up_bank.data.float() - base_model.mlp_down_bank.data = base_model.mlp_down_bank.data.float() - for module in base_model.modules(): - if isinstance(module, CastedLinear): - module.float() - restore_low_dim_params_to_fp32(base_model) - # No DDP -- Parallel Muon handles bank grad communication via reduce-scatter, - # and non-bank grads are manually all-reduced before Adam steps. - compiled_model = torch.compile(base_model, dynamic=False, fullgraph=True) - model = compiled_model - - # Optimizer split: - # - 4 parameter banks -> Muon (batched Newton-Schulz) - # - token embedding -> Adam - # - scalars/control tensors -> Adam - # - bigram proj, VE proj -> Adam (small matrix params not worth banking) - matrix_params = [ - base_model.qo_bank, base_model.kv_bank, - base_model.mlp_up_bank, base_model.mlp_down_bank, - ] - block_named_params = list(base_model.blocks.named_parameters()) - scalar_params = [ - p - for name, p in block_named_params - if p.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) - ] - if base_model.skip_weights.numel() > 0: - scalar_params.append(base_model.skip_weights) - scalar_params.append(base_model.smear.gate) - if base_model.bigram is not None: - scalar_params.append(base_model.bigram.scale) - token_lr = args.tied_embed_lr if args.tie_embeddings else args.embed_lr - tok_params = [{"params": [base_model.tok_emb.weight], "lr": token_lr, "base_lr": token_lr}] - if base_model.bigram is not None: - tok_params.append({"params": [base_model.bigram.embed.weight], "lr": token_lr, "base_lr": token_lr}) - if base_model.bigram.proj is not None: - scalar_params.append(base_model.bigram.proj.weight) - if base_model.ve_shared is not None: - tok_params.append({"params": [base_model.ve_shared.embed.weight], "lr": token_lr, "base_lr": token_lr}) - if base_model.ve_shared.proj is not None: - scalar_params.append(base_model.ve_shared.proj.weight) - scalar_params.append(base_model.ve_shared.scale) - for s in base_model.ve_layer_scales: - scalar_params.append(s) - optimizer_tok = torch.optim.AdamW( - tok_params, - betas=(args.beta1, args.beta2), - eps=args.adam_eps, - weight_decay=args.adam_wd, - fused=True, - ) - optimizer_muon = Muon( - matrix_params, - lr=args.matrix_lr, - momentum=args.muon_momentum, - backend_steps=args.muon_backend_steps, - weight_decay=args.muon_wd, - ) - for group in optimizer_muon.param_groups: - group["base_lr"] = args.matrix_lr - optimizer_scalar = torch.optim.AdamW( - [{"params": scalar_params, "lr": args.scalar_lr, "base_lr": args.scalar_lr}], - betas=(args.beta1, args.beta2), - eps=args.adam_eps, - weight_decay=args.adam_wd, - fused=True, - ) - # Non-bank params that need manual all-reduce (replicated across GPUs) - replicated_params = list(optimizer_tok.param_groups[0]["params"]) - for pg in optimizer_tok.param_groups[1:]: - replicated_params.extend(pg["params"]) - replicated_params.extend(scalar_params) - - optimizer_head = None - if base_model.lm_head is not None: - optimizer_head = torch.optim.Adam( - [{"params": [base_model.lm_head.weight], "lr": args.head_lr, "base_lr": args.head_lr}], - betas=(args.beta1, args.beta2), - eps=args.adam_eps, - fused=True, - ) - replicated_params.append(base_model.lm_head.weight) - optimizers: list[torch.optim.Optimizer] = [optimizer_tok, optimizer_muon, optimizer_scalar] - if optimizer_head is not None: - optimizers.append(optimizer_head) - log0(f"model_params:{sum(p.numel() for p in base_model.parameters())}") - xsa_layers = [i for i, b in enumerate(base_model.blocks) if b.attn.use_xsa] - log0(f"XSA:last_{args.xsa_last_n} active_layers:{xsa_layers}") - log0(f"world_size:{world_size} grad_accum_steps:{grad_accum_steps}") - log0("sdp_backends:cudnn=False flash=True mem_efficient=False math=False") - log0(f"attention_mode:gqa num_heads:{args.num_heads} num_kv_heads:{args.num_kv_heads}") - log0( - f"tie_embeddings:{args.tie_embeddings} embed_lr:{token_lr} " - f"head_lr:{args.head_lr if base_model.lm_head is not None else 0.0} " - f"matrix_lr:{args.matrix_lr} scalar_lr:{args.scalar_lr}" - ) - log0( - f"train_batch_tokens:{args.train_batch_tokens} train_seq_len:{args.train_seq_len} " - f"iterations:{args.iterations} warmup_steps:{args.warmup_steps} " - f"max_wallclock_seconds:{args.max_wallclock_seconds:.3f}" - ) - log0(f"seed:{args.seed}") - train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) - def zero_grad_all() -> None: - for opt in optimizers: - opt.zero_grad(set_to_none=True) - max_wallclock_ms = 1000.0 * args.max_wallclock_seconds if args.max_wallclock_seconds > 0 else None - if args.use_gptq and max_wallclock_ms is not None: - max_wallclock_ms -= args.gptq_reserve_ms - log0(f"gptq:reserving {args.gptq_reserve_ms:.0f}ms from training budget, effective={max_wallclock_ms:.0f}ms") - def lr_mul(step: int, elapsed_ms: float) -> float: - if args.warmdown_iters <= 0: - return 1.0 - if max_wallclock_ms is None: - warmdown_start = max(args.iterations - args.warmdown_iters, 0) - return max((args.iterations - step) / max(args.warmdown_iters, 1), 0.0) if warmdown_start <= step < args.iterations else 1.0 - step_ms = elapsed_ms / max(step, 1) - warmdown_ms = args.warmdown_iters * step_ms - remaining_ms = max(max_wallclock_ms - elapsed_ms, 0.0) - return remaining_ms / max(warmdown_ms, 1e-9) if remaining_ms <= warmdown_ms else 1.0 - if args.warmup_steps > 0: - initial_model_state = {name: tensor.detach().cpu().clone() for name, tensor in base_model.state_dict().items()} - initial_optimizer_states = [copy.deepcopy(opt.state_dict()) for opt in optimizers] - model.train() - for warmup_step in range(args.warmup_steps): - zero_grad_all() - for micro_step in range(grad_accum_steps): - x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) - with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): - warmup_loss = model(x, y) - (warmup_loss * grad_scale).backward() - # All-reduce all grads for warmup (simple, not optimized) - if distributed: - for p in base_model.parameters(): - if p.grad is not None: - dist.all_reduce(p.grad, op=dist.ReduceOp.AVG) - for opt in optimizers: - opt.step() - zero_grad_all() - if args.warmup_steps <= 20 or (warmup_step + 1) % 10 == 0 or warmup_step + 1 == args.warmup_steps: - log0(f"warmup_step:{warmup_step + 1}/{args.warmup_steps}") - base_model.load_state_dict(initial_model_state, strict=True) - for opt, state in zip(optimizers, initial_optimizer_states, strict=True): - opt.load_state_dict(state) - zero_grad_all() - train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) - swa_state: dict[str, Tensor] | None = None - swa_count = 0 - ema_state = {name: t.detach().float().clone() for name, t in base_model.state_dict().items()} - ema_decay = 0.997 - training_time_ms = 0.0 - stop_after_step: int | None = None - torch.cuda.synchronize() - t0 = time.perf_counter() - step = 0 - while True: - last_step = step == args.iterations or (stop_after_step is not None and step >= stop_after_step) - should_validate = last_step or (args.val_loss_every > 0 and step % args.val_loss_every == 0) - if should_validate: - torch.cuda.synchronize() - training_time_ms += 1000.0 * (time.perf_counter() - t0) - val_loss, val_bpb = eval_val( - args, - model, - rank, - world_size, - device, - grad_accum_steps, - val_tokens, - base_bytes_lut, - has_leading_space_lut, - is_boundary_token_lut, - ) - log0( - f"step:{step}/{args.iterations} val_loss:{val_loss:.4f} val_bpb:{val_bpb:.4f} " - f"train_time:{training_time_ms:.0f}ms step_avg:{training_time_ms / max(step, 1):.2f}ms" - ) - torch.cuda.synchronize() - t0 = time.perf_counter() - if last_step: - if stop_after_step is not None and step < args.iterations: - log0( - f"stopping_early: wallclock_cap train_time:{training_time_ms:.0f}ms " - f"step:{step}/{args.iterations}" - ) - break - elapsed_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) - scale = lr_mul(step, elapsed_ms) - zero_grad_all() - train_loss = torch.zeros((), device=device) - for micro_step in range(grad_accum_steps): - x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) - with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): - loss = model(x, y) - train_loss += loss.detach() - (loss * grad_scale).backward() - train_loss /= grad_accum_steps - frac = min(step / args.muon_momentum_warmup_steps, 1.0) if args.muon_momentum_warmup_steps > 0 else 1.0 - muon_momentum = (1 - frac) * args.muon_momentum_warmup_start + frac * args.muon_momentum - for group in optimizer_muon.param_groups: - group["momentum"] = muon_momentum - for opt in optimizers: - for group in opt.param_groups: - group["lr"] = group["base_lr"] * scale - if args.grad_clip_norm > 0: - torch.nn.utils.clip_grad_norm_(base_model.parameters(), args.grad_clip_norm) - # === 3-phase overlapped optimizer step === - # Phase 1: Launch async reduce-scatter for banks (biggest first) - optimizer_muon.launch_reduce_scatters() - # Phase 2: All-reduce non-bank grads + step Adam (while bank RS is in-flight) - if distributed: - for p in replicated_params: - if p.grad is not None: - dist.all_reduce(p.grad, op=dist.ReduceOp.AVG) - optimizer_tok.step() - optimizer_scalar.step() - if optimizer_head is not None: - optimizer_head.step() - # Phase 3: Wait for RS, local NS5, all-gather (banks processed last) - optimizer_muon.step() - zero_grad_all() - # EMA update - with torch.no_grad(): - for name, t in base_model.state_dict().items(): - ema_state[name].mul_(ema_decay).add_(t.detach().float(), alpha=1.0 - ema_decay) - step += 1 - approx_training_time_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) - if args.swa_enabled and scale < 0.2 and step % args.swa_every == 0: - if swa_state is None: - swa_state = {name: t.detach().cpu().clone() for name, t in base_model.state_dict().items()} - swa_count = 1 - log0(f"swa:start step:{step}") - else: - for name, t in base_model.state_dict().items(): - swa_state[name] += t.detach().cpu() - swa_count += 1 - should_log_train = ( - args.train_log_every > 0 - and (step <= 10 or step % args.train_log_every == 0 or stop_after_step is not None) - ) - if should_log_train: - log0( - f"step:{step}/{args.iterations} train_loss:{train_loss.item():.4f} " - f"train_time:{approx_training_time_ms:.0f}ms step_avg:{approx_training_time_ms / step:.2f}ms" - ) - reached_cap = max_wallclock_ms is not None and approx_training_time_ms >= max_wallclock_ms - if distributed and max_wallclock_ms is not None: - reached_cap_tensor = torch.tensor(int(reached_cap), device=device) - dist.all_reduce(reached_cap_tensor, op=dist.ReduceOp.MAX) - reached_cap = bool(reached_cap_tensor.item()) - if stop_after_step is None and reached_cap: - stop_after_step = step - log0( - f"peak memory allocated: {torch.cuda.max_memory_allocated() // 1024 // 1024} MiB " - f"reserved: {torch.cuda.max_memory_reserved() // 1024 // 1024} MiB" - ) - # Apply EMA weights - log0("ema:applying EMA weights") - current_state = base_model.state_dict() - avg_state = {name: t.to(dtype=current_state[name].dtype) for name, t in ema_state.items()} - base_model.load_state_dict(avg_state, strict=True) - torch.cuda.synchronize() - t_diag = time.perf_counter() - diag_val_loss, diag_val_bpb = eval_val( - args, compiled_model, rank, world_size, device, grad_accum_steps, - val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, - ) - torch.cuda.synchronize() - log0( - f"DIAGNOSTIC post_ema val_loss:{diag_val_loss:.4f} val_bpb:{diag_val_bpb:.4f} " - f"eval_time:{1000.0 * (time.perf_counter() - t_diag):.0f}ms" - ) - export_sd = base_model.state_dict() - if master_process: - torch.save(export_sd, "final_model.pt") - model_bytes = os.path.getsize("final_model.pt") - code_bytes = len(code.encode("utf-8")) - log0(f"Serialized model: {model_bytes} bytes") - log0(f"Code size: {code_bytes} bytes") - # Unbank 3D tensors into individual 2D tensors for quantization - sd_cpu = {k: v.detach().cpu() for k, v in export_sd.items()} - unbanked_sd = _unbank_state_dict(sd_cpu, args.num_layers) - # GPTQ calibration: collect Hessians from training data - gptq_hessians = None - if args.use_gptq: - t_gptq = time.perf_counter() - log0(f"gptq:calibrating with {args.gptq_calib_samples} batches (training data)...") - calib_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) - gptq_hessians = gptq_collect_hessians( - base_model, calib_loader, device, num_batches=args.gptq_calib_samples, - batch_tokens=args.train_batch_tokens, seq_len=args.train_seq_len, - grad_accum_steps=grad_accum_steps) - del calib_loader - gptq_elapsed = time.perf_counter() - t_gptq - log0(f"gptq:calibrated {len(gptq_hessians)} layers in {gptq_elapsed:.1f}s") - torch.cuda.empty_cache() - quant_result, quant_meta = mixed_quantize_int6(unbanked_sd, {"mlp", "attn"}, clip_range=args.quant_clip_range, hessians=gptq_hessians) - quant_buf = io.BytesIO() - torch.save({"w": quant_result, "m": quant_meta}, quant_buf) - quant_raw = quant_buf.getvalue() - quant_blob = lzma.compress(quant_raw, preset=6) - if master_process: - with open("final_model.int6.ptz", "wb") as f: - f.write(quant_blob) - quant_file_bytes = len(quant_blob) - code_bytes = len(code.encode("utf-8")) - log0(f"Serialized model int6+lzma: {quant_file_bytes} bytes") - log0(f"Total submission size int6+lzma: {quant_file_bytes + code_bytes} bytes") - if distributed: - dist.barrier() - with open("final_model.int6.ptz", "rb") as f: - quant_blob_disk = f.read() - quant_state = torch.load( - io.BytesIO(lzma.decompress(quant_blob_disk)), - map_location="cpu", - ) - deq_unbanked = dequantize_mixed_int6(quant_state["w"], quant_state["m"], unbanked_sd) - # Re-bank the dequantized tensors - deq_state = _rebank_state_dict(deq_unbanked, args.num_layers, sd_cpu) - eval_model = GPT( - vocab_size=args.vocab_size, num_layers=args.num_layers, model_dim=args.model_dim, - num_heads=args.num_heads, num_kv_heads=args.num_kv_heads, mlp_mult=args.mlp_mult, - tie_embeddings=args.tie_embeddings, tied_embed_init_std=args.tied_embed_init_std, - logit_softcap=args.logit_softcap, rope_base=args.rope_base, qk_gain_init=args.qk_gain_init, - bigram_vocab_size=args.bigram_vocab_size, bigram_dim=args.bigram_dim, - xsa_last_n=args.xsa_last_n, - rope_dims=args.rope_dims, ln_scale=args.ln_scale, - ve_enabled=args.ve_enabled, ve_dim=args.ve_dim, ve_layers=args.ve_layers, - neg_slope=args.negative_slope, - ).to(device).bfloat16() - eval_model.qo_bank.data = eval_model.qo_bank.data.float() - eval_model.kv_bank.data = eval_model.kv_bank.data.float() - eval_model.mlp_up_bank.data = eval_model.mlp_up_bank.data.float() - eval_model.mlp_down_bank.data = eval_model.mlp_down_bank.data.float() - for m in eval_model.modules(): - if isinstance(m, CastedLinear): - m.float() - restore_low_dim_params_to_fp32(eval_model) - eval_model.load_state_dict(deq_state, strict=True) - compiled_eval = torch.compile(eval_model, dynamic=False, fullgraph=True) - torch.cuda.synchronize() - t_qeval = time.perf_counter() - q_val_loss, q_val_bpb = eval_val( - args, compiled_eval, rank, world_size, device, grad_accum_steps, - val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, - eval_seq_len=effective_eval_seq_len, - ) - torch.cuda.synchronize() - log0( - f"final_int6_roundtrip val_loss:{q_val_loss:.4f} val_bpb:{q_val_bpb:.4f} " - f"eval_time:{1000.0 * (time.perf_counter() - t_qeval):.0f}ms" - ) - log0(f"final_int6_roundtrip_exact val_loss:{q_val_loss:.8f} val_bpb:{q_val_bpb:.8f}") - sw_seq_len = effective_eval_seq_len - if args.eval_stride > 0 and args.eval_stride < sw_seq_len: - torch.cuda.synchronize() - t_slide = time.perf_counter() - sw_val_loss, sw_val_bpb = eval_val_sliding( - args, eval_model, rank, world_size, device, - val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, - stride=args.eval_stride, - eval_seq_len=sw_seq_len, - ) - torch.cuda.synchronize() - log0( - f"final_int6_sliding_window val_loss:{sw_val_loss:.4f} val_bpb:{sw_val_bpb:.4f} " - f"stride:{args.eval_stride} eval_time:{1000.0 * (time.perf_counter() - t_slide):.0f}ms" - ) - log0(f"final_int6_sliding_window_exact val_loss:{sw_val_loss:.8f} val_bpb:{sw_val_bpb:.8f}") - log0(f"final_int8_zlib_roundtrip_exact val_loss:{sw_val_loss:.8f} val_bpb:{sw_val_bpb:.8f}") - if args.eval_stride != 64 and 64 < sw_seq_len: - torch.cuda.synchronize() - t_slide64 = time.perf_counter() - sw64_val_loss, sw64_val_bpb = eval_val_sliding( - args, eval_model, rank, world_size, device, - val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, - stride=64, - eval_seq_len=sw_seq_len, - ) - torch.cuda.synchronize() - log0( - f"final_int6_sliding_window_s64 val_loss:{sw64_val_loss:.4f} val_bpb:{sw64_val_bpb:.4f} " - f"stride:64 eval_time:{1000.0 * (time.perf_counter() - t_slide64):.0f}ms" - ) - log0(f"final_int6_sliding_window_s64_exact val_loss:{sw64_val_loss:.8f} val_bpb:{sw64_val_bpb:.8f}") - log0(f"final_int8_zlib_roundtrip_exact val_loss:{sw64_val_loss:.8f} val_bpb:{sw64_val_bpb:.8f}") - # Legal score-first TTT (PR #461 recipe) - if args.ttt_enabled: - torch.cuda.synchronize() - t_ttt = time.perf_counter() - ttt_loss, ttt_bpb = eval_val_sliding_ttt( - args, eval_model, rank, world_size, device, - val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, - stride=args.eval_stride, log0=log0, - ) - torch.cuda.synchronize() - log0(f"legal_ttt val_loss:{ttt_loss:.4f} val_bpb:{ttt_bpb:.4f} " - f"eval_time:{1000.0 * (time.perf_counter() - t_ttt):.0f}ms") - log0(f"legal_ttt_exact val_loss:{ttt_loss:.8f} val_bpb:{ttt_bpb:.8f}") - if distributed: - dist.destroy_process_group() -if __name__ == "__main__": - main() - -==================================================================================================== -Running Python 3.12.3 (main, Mar 3 2026, 12:15:18) [GCC 13.3.0] -Running PyTorch 2.9.1+cu128 -Tue Mar 31 15:24:47 2026 -+-----------------------------------------------------------------------------------------+ -| NVIDIA-SMI 580.126.09 Driver Version: 580.126.09 CUDA Version: 13.0 | -+-----------------------------------------+------------------------+----------------------+ -| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC | -| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. | -| | | MIG M. | -|=========================================+========================+======================| -| 0 NVIDIA H100 80GB HBM3 On | 00000000:0A:00.0 Off | 0 | -| N/A 38C P0 119W / 700W | 1521MiB / 81559MiB | 0% Default | -| | | Disabled | -+-----------------------------------------+------------------------+----------------------+ -| 1 NVIDIA H100 80GB HBM3 On | 00000000:18:00.0 Off | 0 | -| N/A 32C P0 125W / 700W | 1521MiB / 81559MiB | 0% Default | -| | | Disabled | -+-----------------------------------------+------------------------+----------------------+ -| 2 NVIDIA H100 80GB HBM3 On | 00000000:3F:00.0 Off | 0 | -| N/A 31C P0 117W / 700W | 1521MiB / 81559MiB | 0% Default | -| | | Disabled | -+-----------------------------------------+------------------------+----------------------+ -| 3 NVIDIA H100 80GB HBM3 On | 00000000:48:00.0 Off | 0 | -| N/A 40C P0 127W / 700W | 1521MiB / 81559MiB | 0% Default | -| | | Disabled | -+-----------------------------------------+------------------------+----------------------+ -| 4 NVIDIA H100 80GB HBM3 On | 00000000:87:00.0 Off | 0 | -| N/A 37C P0 122W / 700W | 1521MiB / 81559MiB | 0% Default | -| | | Disabled | -+-----------------------------------------+------------------------+----------------------+ -| 5 NVIDIA H100 80GB HBM3 On | 00000000:90:00.0 Off | 0 | -| N/A 30C P0 114W / 700W | 1521MiB / 81559MiB | 0% Default | -| | | Disabled | -+-----------------------------------------+------------------------+----------------------+ -| 6 NVIDIA H100 80GB HBM3 On | 00000000:BE:00.0 Off | 0 | -| N/A 30C P0 116W / 700W | 1521MiB / 81559MiB | 0% Default | -| | | Disabled | -+-----------------------------------------+------------------------+----------------------+ -| 7 NVIDIA H100 80GB HBM3 On | 00000000:C7:00.0 Off | 0 | -| N/A 38C P0 123W / 700W | 1521MiB / 81559MiB | 0% Default | -| | | Disabled | -+-----------------------------------------+------------------------+----------------------+ - -+-----------------------------------------------------------------------------------------+ -| Processes: | -| GPU GI CI PID Type Process name GPU Memory | -| ID ID Usage | -|=========================================================================================| -| 0 N/A N/A 63786 C ...ameter-golf/.venv/bin/python3 1512MiB | -| 1 N/A N/A 63787 C ...ameter-golf/.venv/bin/python3 1512MiB | -| 2 N/A N/A 63788 C ...ameter-golf/.venv/bin/python3 1512MiB | -| 3 N/A N/A 63789 C ...ameter-golf/.venv/bin/python3 1512MiB | -| 4 N/A N/A 63790 C ...ameter-golf/.venv/bin/python3 1512MiB | -| 5 N/A N/A 63791 C ...ameter-golf/.venv/bin/python3 1512MiB | -| 6 N/A N/A 63792 C ...ameter-golf/.venv/bin/python3 1512MiB | -| 7 N/A N/A 63793 C ...ameter-golf/.venv/bin/python3 1512MiB | -+-----------------------------------------------------------------------------------------+ - -==================================================================================================== -tokenizer: kind=tokenmonster vocab=998 -train_loader:dataset:fineweb10B_scylla train_shards:194 -val_loader:shards pattern=./data/datasets/fineweb10B_scylla/fineweb_val_*.bin tokens:62363648 -model_params:27022172 -XSA:last_11 active_layers:[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10] -world_size:8 grad_accum_steps:1 -sdp_backends:cudnn=False flash=True mem_efficient=False math=False -attention_mode:gqa num_heads:8 num_kv_heads:4 -tie_embeddings:True embed_lr:0.035 head_lr:0.0 matrix_lr:0.025 scalar_lr:0.025 -train_batch_tokens:786432 train_seq_len:2048 iterations:9000 warmup_steps:20 max_wallclock_seconds:600.000 -seed:2025 -gptq:reserving 9000ms from training budget, effective=591000ms -warmup_step:1/20 -warmup_step:2/20 -warmup_step:3/20 -warmup_step:4/20 -warmup_step:5/20 -warmup_step:6/20 -warmup_step:7/20 -warmup_step:8/20 -warmup_step:9/20 -warmup_step:10/20 -warmup_step:11/20 -warmup_step:12/20 -warmup_step:13/20 -warmup_step:14/20 -warmup_step:15/20 -warmup_step:16/20 -warmup_step:17/20 -warmup_step:18/20 -warmup_step:19/20 -warmup_step:20/20 -step:0/9000 val_loss:6.8931 val_bpb:3.3925 train_time:0ms step_avg:0.01ms -step:1/9000 train_loss:6.8925 train_time:122ms step_avg:121.74ms -step:2/9000 train_loss:9.2176 train_time:155ms step_avg:77.61ms -step:3/9000 train_loss:8.9342 train_time:242ms step_avg:80.53ms -step:4/9000 train_loss:8.4847 train_time:327ms step_avg:81.64ms -step:5/9000 train_loss:8.0678 train_time:411ms step_avg:82.16ms -step:6/9000 train_loss:7.5826 train_time:496ms step_avg:82.60ms -step:7/9000 train_loss:7.0192 train_time:581ms step_avg:83.07ms -step:8/9000 train_loss:6.7110 train_time:667ms step_avg:83.40ms -step:9/9000 train_loss:6.6151 train_time:752ms step_avg:83.58ms -step:10/9000 train_loss:6.4055 train_time:837ms step_avg:83.68ms -step:500/9000 train_loss:2.3213 train_time:43601ms step_avg:87.20ms -step:1000/9000 train_loss:2.2074 train_time:87405ms step_avg:87.41ms -step:1500/9000 train_loss:2.1863 train_time:131253ms step_avg:87.50ms -step:2000/9000 train_loss:2.1492 train_time:175187ms step_avg:87.59ms -step:2500/9000 train_loss:2.2069 train_time:219134ms step_avg:87.65ms -step:3000/9000 train_loss:2.1050 train_time:263091ms step_avg:87.70ms -step:3500/9000 train_loss:2.0852 train_time:307079ms step_avg:87.74ms -step:4000/9000 train_loss:2.0364 train_time:351032ms step_avg:87.76ms -step:4000/9000 val_loss:2.0777 val_bpb:1.0226 train_time:351087ms step_avg:87.77ms -step:4500/9000 train_loss:2.0086 train_time:395014ms step_avg:87.78ms -step:5000/9000 train_loss:1.9688 train_time:438965ms step_avg:87.79ms -step:5500/9000 train_loss:1.9353 train_time:482910ms step_avg:87.80ms -step:6000/9000 train_loss:1.9898 train_time:526848ms step_avg:87.81ms -swa:start step:6050 -step:6500/9000 train_loss:1.9760 train_time:571439ms step_avg:87.91ms -step:6719/9000 val_loss:1.9633 val_bpb:0.9663 train_time:591097ms step_avg:87.97ms -stopping_early: wallclock_cap train_time:591097ms step:6719/9000 -peak memory allocated: 23031 MiB reserved: 23324 MiB -ema:applying EMA weights -DIAGNOSTIC post_ema val_loss:1.9617 val_bpb:0.9655 eval_time:2111ms -Serialized model: 106201974 bytes -Code size: 101929 bytes -gptq:calibrating with 64 batches (training data)... -gptq:calibrated 66 layers in 6.7s -Serialized model int6+lzma: 15874004 bytes -Total submission size int6+lzma: 15975933 bytes -final_int6_roundtrip val_loss:1.9677 val_bpb:0.9684 eval_time:5337ms -final_int6_roundtrip_exact val_loss:1.96774016 val_bpb:0.96844526 -final_int6_sliding_window val_loss:1.9280 val_bpb:0.9489 stride:64 eval_time:76754ms -final_int6_sliding_window_exact val_loss:1.92801508 val_bpb:0.94888552 -final_int8_zlib_roundtrip_exact val_loss:1.92801508 val_bpb:0.94888552 diff --git a/records/track_10min_16mb/2026-03-31_Scylla_FullGPTQ_XSA11_FA3_0.9485/train_seed42.log b/records/track_10min_16mb/2026-03-31_Scylla_FullGPTQ_XSA11_FA3_0.9485/train_seed42.log deleted file mode 100644 index a3609416eb..0000000000 --- a/records/track_10min_16mb/2026-03-31_Scylla_FullGPTQ_XSA11_FA3_0.9485/train_seed42.log +++ /dev/null @@ -1,2308 +0,0 @@ -from __future__ import annotations -import copy -import glob -import io -import lzma -import math -import os -import random -import subprocess -import sys -import time -import uuid -from pathlib import Path -import numpy as np -import sentencepiece as spm -import torch -import torch.distributed as dist -import torch.nn.functional as F -from torch import Tensor, nn -from torch.nn.parallel import DistributedDataParallel as DDP -from flash_attn_interface import flash_attn_func as flash_attn_3_func -class Hyperparameters: - data_path = os.environ.get("DATA_PATH", "./data/datasets/fineweb10B_sp1024") - train_files = os.path.join(data_path, "fineweb_train_*.bin") - val_files = os.path.join(data_path, "fineweb_val_*.bin") - tokenizer_path = os.environ.get("TOKENIZER_PATH", "./data/tokenizers/fineweb_1024_bpe.model") - run_id = os.environ.get("RUN_ID", str(uuid.uuid4())) - seed = int(os.environ.get("SEED", 1337)) - val_batch_size = int(os.environ.get("VAL_BATCH_SIZE", 524_288)) - val_loss_every = int(os.environ.get("VAL_LOSS_EVERY", 4000)) - train_log_every = int(os.environ.get("TRAIN_LOG_EVERY", 500)) - iterations = int(os.environ.get("ITERATIONS", 20000)) - warmdown_iters = int(os.environ.get("WARMDOWN_ITERS", 3500)) - warmup_steps = int(os.environ.get("WARMUP_STEPS", 20)) - train_batch_tokens = int(os.environ.get("TRAIN_BATCH_TOKENS", 786_432)) - train_seq_len = int(os.environ.get("TRAIN_SEQ_LEN", 2048)) - eval_seq_len = int(os.environ.get("EVAL_SEQ_LEN", 2048)) - max_wallclock_seconds = float(os.environ.get("MAX_WALLCLOCK_SECONDS", 600.0)) - qk_gain_init = float(os.environ.get("QK_GAIN_INIT", 1.5)) - vocab_size = int(os.environ.get("VOCAB_SIZE", 1024)) - num_layers = int(os.environ.get("NUM_LAYERS", 11)) - num_kv_heads = int(os.environ.get("NUM_KV_HEADS", 4)) - model_dim = int(os.environ.get("MODEL_DIM", 512)) - num_heads = int(os.environ.get("NUM_HEADS", 8)) - mlp_mult = float(os.environ.get("MLP_MULT", 3.0)) - tie_embeddings = bool(int(os.environ.get("TIE_EMBEDDINGS", "1"))) - rope_base = float(os.environ.get("ROPE_BASE", 10000.0)) - logit_softcap = float(os.environ.get("LOGIT_SOFTCAP", 30.0)) - embed_lr = float(os.environ.get("EMBED_LR", 0.6)) - head_lr = float(os.environ.get("HEAD_LR", 0.008)) - tied_embed_lr = float(os.environ.get("TIED_EMBED_LR", 0.035)) - tied_embed_init_std = float(os.environ.get("TIED_EMBED_INIT_STD", 0.005)) - matrix_lr = float(os.environ.get("MATRIX_LR", 0.025)) - scalar_lr = float(os.environ.get("SCALAR_LR", 0.025)) - muon_momentum = float(os.environ.get("MUON_MOMENTUM", 0.99)) - muon_backend_steps = int(os.environ.get("MUON_BACKEND_STEPS", 5)) - muon_momentum_warmup_start = float(os.environ.get("MUON_MOMENTUM_WARMUP_START", 0.92)) - muon_momentum_warmup_steps = int(os.environ.get("MUON_MOMENTUM_WARMUP_STEPS", 1500)) - beta1 = float(os.environ.get("BETA1", 0.9)) - beta2 = float(os.environ.get("BETA2", 0.95)) - adam_eps = float(os.environ.get("ADAM_EPS", 1e-8)) - grad_clip_norm = float(os.environ.get("GRAD_CLIP_NORM", 0.3)) - eval_stride = int(os.environ.get("EVAL_STRIDE", 64)) - muon_beta2 = float(os.environ.get("MUON_BETA2", 0.95)) - swa_enabled = bool(int(os.environ.get("SWA_ENABLED", "1"))) - swa_every = int(os.environ.get("SWA_EVERY", 50)) - muon_wd = float(os.environ.get("MUON_WD", 0.04)) - adam_wd = float(os.environ.get("ADAM_WD", 0.04)) - bigram_vocab_size = int(os.environ.get("BIGRAM_VOCAB_SIZE", 2048)) - bigram_dim = int(os.environ.get("BIGRAM_DIM", 128)) - xsa_last_n = int(os.environ.get("XSA_LAST_N", 4)) - rope_dims = int(os.environ.get("ROPE_DIMS", 16)) - ln_scale = bool(int(os.environ.get("LN_SCALE", "1"))) - ve_enabled = bool(int(os.environ.get("VE_ENABLED", "1"))) - ve_dim = int(os.environ.get("VE_DIM", 128)) - ve_layers = os.environ.get("VE_LAYERS", "9,10") - ttt_enabled = bool(int(os.environ.get("TTT_ENABLED", "0"))) - ttt_lr = float(os.environ.get("TTT_LR", 0.002)) - ttt_epochs = int(os.environ.get("TTT_EPOCHS", 3)) - ttt_chunk_tokens = int(os.environ.get("TTT_CHUNK_TOKENS", 32768)) - ttt_freeze_blocks = int(os.environ.get("TTT_FREEZE_BLOCKS", 2)) - ttt_momentum = float(os.environ.get("TTT_MOMENTUM", 0.9)) - ttt_batch_seqs = int(os.environ.get("TTT_BATCH_SEQS", 32)) - ttt_grad_clip = float(os.environ.get("TTT_GRAD_CLIP", 1.0)) - negative_slope = float(os.environ.get("NEGATIVE_SLOPE", 0.5)) - use_gptq = bool(int(os.environ.get("USE_GPTQ", "0"))) - gptq_calib_samples = int(os.environ.get("GPTQ_CALIB_SAMPLES", "64")) - gptq_reserve_ms = float(os.environ.get("GPTQ_RESERVE_MS", "14000")) - quant_clip_range = int(os.environ.get("QUANT_CLIP_RANGE", 31)) - tokenizer_meta_path = os.environ.get("TOKENIZER_META_PATH", "") - tokenizer_meta_validate = bool(int(os.environ.get("TOKENIZER_META_VALIDATE", "0"))) - -# --- Batched Newton-Schulz orthogonalization --- - -def zeropower_via_newtonschulz5(G: Tensor, steps: int = 5, eps: float = 1e-7) -> Tensor: - """Batched Newton-Schulz orthogonalization. G: (B,M,N) or (M,N).""" - a, b, c = (3.4445, -4.7750, 2.0315) - was_2d = G.ndim == 2 - if was_2d: - G = G.unsqueeze(0) - X = G.bfloat16() - transposed = X.size(-2) > X.size(-1) - if transposed: - X = X.mT - X = X / (X.norm(dim=(-2, -1), keepdim=True) + eps) - for _ in range(steps): - A = X @ X.mT - B = b * A + c * (A @ A) - X = a * X + B @ X - if transposed: - X = X.mT - if was_2d: - X = X.squeeze(0) - return X - -# --- Parallel Muon optimizer --- - -class Muon(torch.optim.Optimizer): - """Parallel Muon: post-backward reduce-scatter -> local NS5 -> all-gather. - - No DDP for bank params. After backward, this optimizer: - 1. Launches async reduce-scatter for all banks (biggest first) - 2. Returns control so Adam can step on small params while RS is in-flight - 3. Waits for each RS, runs local NS5 on the shard, launches async all-gather - 4. Each all-gather overlaps with next bank's NS5 - """ - def __init__(self, params, lr: float, momentum: float, backend_steps: int, - nesterov: bool = True, weight_decay: float = 0.0): - super().__init__( - params, - dict(lr=lr, momentum=momentum, backend_steps=backend_steps, - nesterov=nesterov, weight_decay=weight_decay), - ) - self._built = False - - def _build(self): - self._distributed = dist.is_available() and dist.is_initialized() - self._world_size = dist.get_world_size() if self._distributed else 1 - self._rank = dist.get_rank() if self._distributed else 0 - ws = self._world_size - - self._bank_meta = [] - for group in self.param_groups: - for p in group["params"]: - B = p.shape[0] - padded_B = ((B + ws - 1) // ws) * ws - shard_B = padded_B // ws - tail = p.shape[1:] - dev = p.device - self._bank_meta.append({ - 'p': p, - 'B': B, - 'padded_grad': torch.zeros(padded_B, *tail, device=dev, dtype=torch.bfloat16), - 'shard': torch.zeros(shard_B, *tail, device=dev, dtype=torch.bfloat16), - 'shard_mom': torch.zeros(shard_B, *tail, device=dev, dtype=torch.bfloat16), - 'full_update': torch.zeros(padded_B, *tail, device=dev, dtype=torch.bfloat16), - 'scale': max(1, p.shape[-2] / p.shape[-1]) ** 0.5, - }) - # Sort by size descending -- launch biggest reduce-scatters first - self._bank_meta.sort(key=lambda m: -m['p'].numel()) - self._built = True - - def launch_reduce_scatters(self): - """Phase 1: launch async reduce-scatter for all banks. Call right after backward.""" - if not self._built: - self._build() - if not self._distributed: - return - self._rs_futures = [] - for m in self._bank_meta: - p = m['p'] - if p.grad is None: - self._rs_futures.append(None) - continue - pg = m['padded_grad'] - pg[:m['B']].copy_(p.grad.bfloat16()) - if pg.shape[0] > m['B']: - pg[m['B']:].zero_() - fut = dist.reduce_scatter_tensor(m['shard'], pg, op=dist.ReduceOp.AVG, async_op=True) - self._rs_futures.append(fut) - - @torch.no_grad() - def step(self, closure=None): - """Phase 3: wait for RS, local NS5, all-gather. Call AFTER Adam steps.""" - loss = None - if closure is not None: - with torch.enable_grad(): - loss = closure() - - if not self._built: - self._build() - - for group in self.param_groups: - lr = group["lr"] - momentum = group["momentum"] - backend_steps = group["backend_steps"] - nesterov = group["nesterov"] - wd = group.get("weight_decay", 0.0) - - prev_ag_handle = None - prev_m = None - - sharded = self._distributed and hasattr(self, '_rs_futures') - - for i, m in enumerate(self._bank_meta): - p = m['p'] - if p.grad is None: - continue - - if prev_ag_handle is not None: - prev_ag_handle.wait() - pp = prev_m['p'] - upd = prev_m['full_update'][:prev_m['B']] - if wd > 0.0: - pp.data.mul_(1.0 - lr * wd) - pp.add_(upd.to(dtype=pp.dtype), alpha=-lr * prev_m['scale']) - - if sharded and self._rs_futures[i] is not None: - self._rs_futures[i].wait() - g = m['shard'] - buf = m['shard_mom'] - else: - g = p.grad.bfloat16() - state = self.state[p] - if "momentum_buffer" not in state: - state["momentum_buffer"] = torch.zeros_like(g) - buf = state["momentum_buffer"] - - buf.mul_(momentum).add_(g) - if nesterov: - update = g.add(buf, alpha=momentum) - else: - update = buf - - update = zeropower_via_newtonschulz5(update, steps=backend_steps) - - if sharded: - prev_ag_handle = dist.all_gather_into_tensor( - m['full_update'], update, async_op=True) - prev_m = m - else: - if wd > 0.0: - p.data.mul_(1.0 - lr * wd) - p.add_(update.to(dtype=p.dtype), alpha=-lr * m['scale']) - - if prev_ag_handle is not None: - prev_ag_handle.wait() - pp = prev_m['p'] - upd = prev_m['full_update'][:prev_m['B']] - if wd > 0.0: - pp.data.mul_(1.0 - lr * wd) - pp.add_(upd.to(dtype=pp.dtype), alpha=-lr * prev_m['scale']) - - if hasattr(self, '_rs_futures'): - del self._rs_futures - - return loss - -# --- Tokenizer evaluation helpers --- - -def build_sentencepiece_luts( - sp: spm.SentencePieceProcessor, vocab_size: int, device: torch.device -) -> tuple[Tensor, Tensor, Tensor]: - sp_vocab_size = int(sp.vocab_size()) - table_size = max(sp_vocab_size, vocab_size) - base_bytes_np = np.zeros((table_size,), dtype=np.int16) - has_leading_space_np = np.zeros((table_size,), dtype=np.bool_) - is_boundary_token_np = np.ones((table_size,), dtype=np.bool_) - for token_id in range(sp_vocab_size): - if sp.is_control(token_id) or sp.is_unknown(token_id) or sp.is_unused(token_id): - continue - is_boundary_token_np[token_id] = False - if sp.is_byte(token_id): - base_bytes_np[token_id] = 1 - continue - piece = sp.id_to_piece(token_id) - if piece.startswith("\u2581"): - has_leading_space_np[token_id] = True - piece = piece[1:] - base_bytes_np[token_id] = len(piece.encode("utf-8")) - return ( - torch.tensor(base_bytes_np, dtype=torch.int16, device=device), - torch.tensor(has_leading_space_np, dtype=torch.bool, device=device), - torch.tensor(is_boundary_token_np, dtype=torch.bool, device=device), - ) - -TOKENIZER_META_FORMAT_VERSION = 1 -TOKENIZER_META_SUFFIX = ".meta.npz" - - -def _derive_tokenizer_meta_path(tokenizer_path: str) -> Path: - tokenizer = Path(tokenizer_path) - if tokenizer.suffix == ".model": - return tokenizer.with_suffix(TOKENIZER_META_SUFFIX) - return tokenizer.with_name(tokenizer.name + TOKENIZER_META_SUFFIX) - - -def build_sentencepiece_luts_np( - sp: spm.SentencePieceProcessor, vocab_size: int -) -> tuple[np.ndarray, np.ndarray, np.ndarray]: - sp_vocab_size = int(sp.vocab_size()) - table_size = max(sp_vocab_size, vocab_size) - base_bytes_np = np.zeros((table_size,), dtype=np.int16) - has_leading_space_np = np.zeros((table_size,), dtype=np.bool_) - is_boundary_token_np = np.ones((table_size,), dtype=np.bool_) - for token_id in range(sp_vocab_size): - if sp.is_control(token_id) or sp.is_unknown(token_id) or sp.is_unused(token_id): - continue - is_boundary_token_np[token_id] = False - if sp.is_byte(token_id): - base_bytes_np[token_id] = 1 - continue - piece = sp.id_to_piece(token_id) - if piece.startswith("\u2581"): - has_leading_space_np[token_id] = True - piece = piece[1:] - base_bytes_np[token_id] = len(piece.encode("utf-8")) - return base_bytes_np, has_leading_space_np, is_boundary_token_np - - -def load_tokenizer_meta_luts_np( - meta_path: Path, vocab_size: int -) -> tuple[np.ndarray, np.ndarray, np.ndarray, dict[str, object]]: - def _scalar(value): - arr = np.asarray(value) - if arr.ndim == 0: - return arr.item() - first = arr.reshape(-1)[0] - return first.item() if hasattr(first, "item") else first - - with np.load(meta_path, allow_pickle=False) as data: - format_version = int(_scalar(data["format_version"])) - if format_version != TOKENIZER_META_FORMAT_VERSION: - raise ValueError( - f"Unsupported tokenizer meta format_version={format_version} " - f"expected={TOKENIZER_META_FORMAT_VERSION}" - ) - meta_vocab_size = int(_scalar(data["vocab_size"])) - tokenizer_kind = str(_scalar(data["tokenizer_kind"])) - source_model_name = str(_scalar(data["source_model_name"])) - base_bytes_np = np.asarray(data["base_bytes"], dtype=np.int16) - has_leading_space_np = np.asarray(data["has_leading_space"], dtype=np.bool_) - is_boundary_token_np = np.asarray(data["is_boundary_token"], dtype=np.bool_) - table_size = max(meta_vocab_size, vocab_size) - if base_bytes_np.shape[0] < table_size: - padded_base_bytes = np.zeros((table_size,), dtype=np.int16) - padded_has_leading_space = np.zeros((table_size,), dtype=np.bool_) - padded_is_boundary = np.ones((table_size,), dtype=np.bool_) - padded_base_bytes[: base_bytes_np.shape[0]] = base_bytes_np - padded_has_leading_space[: has_leading_space_np.shape[0]] = has_leading_space_np - padded_is_boundary[: is_boundary_token_np.shape[0]] = is_boundary_token_np - base_bytes_np = padded_base_bytes - has_leading_space_np = padded_has_leading_space - is_boundary_token_np = padded_is_boundary - metadata = { - "format_version": format_version, - "tokenizer_kind": tokenizer_kind, - "source_model_name": source_model_name, - "vocab_size": meta_vocab_size, - "meta_path": str(meta_path), - } - return base_bytes_np, has_leading_space_np, is_boundary_token_np, metadata - - -def load_tokenizer_luts( - tokenizer_path: str, - tokenizer_meta_path: str, - vocab_size: int, - device: torch.device, - *, - validate_meta: bool = False, -) -> tuple[tuple[Tensor, Tensor, Tensor], dict[str, object]]: - meta_path = ( - Path(tokenizer_meta_path) if tokenizer_meta_path - else _derive_tokenizer_meta_path(tokenizer_path) - ) - if meta_path.exists(): - base_bytes_np, has_leading_space_np, is_boundary_token_np, metadata = ( - load_tokenizer_meta_luts_np(meta_path, vocab_size) - ) - if validate_meta and str(tokenizer_path).endswith(".model"): - sp = spm.SentencePieceProcessor(model_file=tokenizer_path) - sp_luts = build_sentencepiece_luts_np(sp, vocab_size) - if not ( - np.array_equal(base_bytes_np, sp_luts[0]) - and np.array_equal(has_leading_space_np, sp_luts[1]) - and np.array_equal(is_boundary_token_np, sp_luts[2]) - ): - raise ValueError(f"Tokenizer metadata mismatch for {meta_path}") - return ( - torch.tensor(base_bytes_np, dtype=torch.int16, device=device), - torch.tensor(has_leading_space_np, dtype=torch.bool, device=device), - torch.tensor(is_boundary_token_np, dtype=torch.bool, device=device), - ), metadata - if not str(tokenizer_path).endswith(".model"): - raise FileNotFoundError(f"TOKENIZER_META_PATH does not exist: {meta_path}") - sp = spm.SentencePieceProcessor(model_file=tokenizer_path) - return build_sentencepiece_luts(sp, vocab_size, device), { - "tokenizer_kind": "sentencepiece", - "source_model_name": str(tokenizer_path), - "vocab_size": int(sp.vocab_size()), - "meta_path": None, - "fallback": True, - } - -def load_validation_tokens(pattern: str, seq_len: int) -> Tensor: - files = [Path(p) for p in sorted(glob.glob(pattern))] - if not files: - raise FileNotFoundError(f"No files found for pattern: {pattern}") - tokens = torch.cat([load_data_shard(file) for file in files]).contiguous() - usable = ((tokens.numel() - 1) // seq_len) * seq_len - if usable <= 0: - raise ValueError(f"Validation split is too short for TRAIN_SEQ_LEN={seq_len}") - return tokens[: usable + 1] -def eval_val( - args: Hyperparameters, - model: nn.Module, - rank: int, - world_size: int, - device: torch.device, - grad_accum_steps: int, - val_tokens: Tensor, - base_bytes_lut: Tensor, - has_leading_space_lut: Tensor, - is_boundary_token_lut: Tensor, - eval_seq_len: int | None = None, -) -> tuple[float, float]: - seq_len = eval_seq_len or args.train_seq_len - local_batch_tokens = args.val_batch_size // (world_size * grad_accum_steps) - if local_batch_tokens < seq_len: - raise ValueError( - "VAL_BATCH_SIZE must provide at least one sequence per rank; " - f"got VAL_BATCH_SIZE={args.val_batch_size}, WORLD_SIZE={world_size}, " - f"GRAD_ACCUM_STEPS={grad_accum_steps}, seq_len={seq_len}" - ) - local_batch_seqs = local_batch_tokens // seq_len - total_seqs = (val_tokens.numel() - 1) // seq_len - seq_start = (total_seqs * rank) // world_size - seq_end = (total_seqs * (rank + 1)) // world_size - val_loss_sum = torch.zeros((), device=device, dtype=torch.float64) - val_token_count = torch.zeros((), device=device, dtype=torch.float64) - val_byte_count = torch.zeros((), device=device, dtype=torch.float64) - model.eval() - with torch.inference_mode(): - for batch_seq_start in range(seq_start, seq_end, local_batch_seqs): - batch_seq_end = min(batch_seq_start + local_batch_seqs, seq_end) - raw_start = batch_seq_start * seq_len - raw_end = batch_seq_end * seq_len + 1 - local = val_tokens[raw_start:raw_end].to(device=device, dtype=torch.int64, non_blocking=True) - x = local[:-1].reshape(-1, seq_len) - y = local[1:].reshape(-1, seq_len) - with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): - batch_loss = model(x, y).detach() - batch_token_count = float(y.numel()) - val_loss_sum += batch_loss.to(torch.float64) * batch_token_count - val_token_count += batch_token_count - prev_ids = x.reshape(-1) - tgt_ids = y.reshape(-1) - token_bytes = base_bytes_lut[tgt_ids].to(dtype=torch.int16) - token_bytes += (has_leading_space_lut[tgt_ids] & ~is_boundary_token_lut[prev_ids]).to(dtype=torch.int16) - val_byte_count += token_bytes.to(torch.float64).sum() - if dist.is_available() and dist.is_initialized(): - dist.all_reduce(val_loss_sum, op=dist.ReduceOp.SUM) - dist.all_reduce(val_token_count, op=dist.ReduceOp.SUM) - dist.all_reduce(val_byte_count, op=dist.ReduceOp.SUM) - val_loss = val_loss_sum / val_token_count - bits_per_token = val_loss.item() / math.log(2.0) - tokens_per_byte = val_token_count.item() / val_byte_count.item() - model.train() - return float(val_loss.item()), float(bits_per_token * tokens_per_byte) - -# --- Quantization helpers --- - -CONTROL_TENSOR_NAME_PATTERNS = tuple( - pattern - for pattern in os.environ.get( - "CONTROL_TENSOR_NAME_PATTERNS", - "attn_scale,attn_scales,mlp_scale,mlp_scales,resid_mix,resid_mixes,q_gain,skip_weight,skip_weights,smear,dtg_gate,ve_layer_scales,ve_shared.scale,attn_gate,vr_lambda", - ).split(",") - if pattern -) -INT8_KEEP_FLOAT_FP32_NAME_PATTERNS = tuple( - pattern - for pattern in os.environ.get( - "INT8_KEEP_FLOAT_FP32_NAME_PATTERNS", - ",".join(CONTROL_TENSOR_NAME_PATTERNS), - ).split(",") - if pattern -) -INT8_KEEP_FLOAT_MAX_NUMEL = 65_536 -INT8_KEEP_FLOAT_STORE_DTYPE = torch.float16 -INT8_PER_ROW_SCALE_DTYPE = torch.float16 -INT8_CLIP_PERCENTILE = 99.99984 -INT8_CLIP_Q = INT8_CLIP_PERCENTILE / 100.0 -def tensor_nbytes(t: Tensor) -> int: - return int(t.numel()) * int(t.element_size()) -def keep_float_tensor(name: str, t: Tensor, passthrough_orig_dtypes: dict[str, str]) -> Tensor: - if any(pattern in name for pattern in INT8_KEEP_FLOAT_FP32_NAME_PATTERNS): - return t.float().contiguous() - if t.dtype in {torch.float32, torch.bfloat16}: - passthrough_orig_dtypes[name] = str(t.dtype).removeprefix("torch.") - return t.to(dtype=INT8_KEEP_FLOAT_STORE_DTYPE).contiguous() - return t -def quantize_float_tensor(t: Tensor) -> tuple[Tensor, Tensor]: - t32 = t.float() - if t32.ndim == 2: - clip_abs = ( - torch.quantile(t32.abs(), INT8_CLIP_Q, dim=1) - if t32.numel() - else torch.empty((t32.shape[0],), dtype=torch.float32) - ) - clipped = torch.maximum(torch.minimum(t32, clip_abs[:, None]), -clip_abs[:, None]) - scale = (clip_abs / 127.0).clamp_min(1.0 / 127.0) - q = torch.clamp(torch.round(clipped / scale[:, None]), -127, 127).to(torch.int8).contiguous() - return q, scale.to(dtype=INT8_PER_ROW_SCALE_DTYPE).contiguous() - clip_abs = float(torch.quantile(t32.abs().flatten(), INT8_CLIP_Q).item()) if t32.numel() else 0.0 - scale = torch.tensor(clip_abs / 127.0 if clip_abs > 0 else 1.0, dtype=torch.float32) - q = torch.clamp(torch.round(torch.clamp(t32, -clip_abs, clip_abs) / scale), -127, 127).to(torch.int8).contiguous() - return q, scale -def quantize_state_dict_int8(state_dict: dict[str, Tensor]): - quantized: dict[str, Tensor] = {} - scales: dict[str, Tensor] = {} - dtypes: dict[str, str] = {} - passthrough: dict[str, Tensor] = {} - passthrough_orig_dtypes: dict[str, str] = {} - qmeta: dict[str, dict[str, object]] = {} - stats = dict.fromkeys( - ("param_count", "num_tensors", "num_float_tensors", "num_nonfloat_tensors", "baseline_tensor_bytes", "int8_payload_bytes"), - 0, - ) - for name, tensor in state_dict.items(): - t = tensor.detach().to("cpu").contiguous() - stats["param_count"] += int(t.numel()) - stats["num_tensors"] += 1 - stats["baseline_tensor_bytes"] += tensor_nbytes(t) - if not t.is_floating_point(): - stats["num_nonfloat_tensors"] += 1 - passthrough[name] = t - stats["int8_payload_bytes"] += tensor_nbytes(t) - continue - if t.numel() <= INT8_KEEP_FLOAT_MAX_NUMEL: - kept = keep_float_tensor(name, t, passthrough_orig_dtypes) - passthrough[name] = kept - stats["int8_payload_bytes"] += tensor_nbytes(kept) - continue - stats["num_float_tensors"] += 1 - q, s = quantize_float_tensor(t) - if s.ndim > 0: - qmeta[name] = {"scheme": "per_row", "axis": 0} - quantized[name] = q - scales[name] = s - dtypes[name] = str(t.dtype).removeprefix("torch.") - stats["int8_payload_bytes"] += tensor_nbytes(q) + tensor_nbytes(s) - obj: dict[str, object] = { - "__quant_format__": "int8_clean_per_row_v1", - "quantized": quantized, - "scales": scales, - "dtypes": dtypes, - "passthrough": passthrough, - } - if qmeta: - obj["qmeta"] = qmeta - if passthrough_orig_dtypes: - obj["passthrough_orig_dtypes"] = passthrough_orig_dtypes - return obj, stats -def dequantize_state_dict_int8(obj: dict[str, object]) -> dict[str, Tensor]: - out: dict[str, Tensor] = {} - qmeta = obj.get("qmeta", {}) - passthrough_orig_dtypes = obj.get("passthrough_orig_dtypes", {}) - for name, q in obj["quantized"].items(): - dtype = getattr(torch, obj["dtypes"][name]) - s = obj["scales"][name] - if qmeta.get(name, {}).get("scheme") == "per_row" or s.ndim > 0: - s = s.to(dtype=torch.float32) - out[name] = (q.float() * s.view(q.shape[0], *([1] * (q.ndim - 1)))).to(dtype=dtype).contiguous() - else: - scale = float(s.item()) - out[name] = (q.float() * scale).to(dtype=dtype).contiguous() - for name, t in obj["passthrough"].items(): - out_t = t.detach().to("cpu").contiguous() - orig_dtype = passthrough_orig_dtypes.get(name) - if isinstance(orig_dtype, str): - out_t = out_t.to(dtype=getattr(torch, orig_dtype)).contiguous() - out[name] = out_t - return out - -# --- Data loading --- - -def load_data_shard(file: Path) -> Tensor: - header_bytes = 256 * np.dtype(" int: - key = str(file) - cached = _SHARD_NTOKENS_CACHE.get(key) - if cached is not None: - return cached - header = np.fromfile(file, dtype=" np.memmap: - key = str(file) - mm = _MMAP_CACHE.get(key) - if mm is not None: - return mm - n = _read_num_tokens(file) - mm = np.memmap(file, mode="r", dtype=" int: - if n <= 1: - return 1 - while True: - s = int(self._rng.integers(1, n)) - if math.gcd(s, n) == 1: - return s - def _reset_cursor(self, si: int, seq_len: int) -> None: - nt = int(self._num_tokens[si]) - max_phase = min(seq_len - 1, max(0, nt - seq_len - 1)) - phase = int(self._rng.integers(max_phase + 1)) if max_phase > 0 else 0 - bc = (nt - 1 - phase) // seq_len - self._cursor_phase[si] = phase - self._cursor_block_count[si] = bc - self._cursor_next[si] = 0 - self._cursor_start[si] = int(self._rng.integers(bc)) if bc > 1 else 0 - self._cursor_stride[si] = self._pick_coprime_stride(bc) - self._cursor_init[si] = True - def _ensure_cursor(self, si: int, seq_len: int) -> None: - if not self._cursor_init[si] or self._cursor_next[si] >= self._cursor_block_count[si]: - self._reset_cursor(si, seq_len) - def _take_from_shard(self, si: int, seq_len: int, count: int, out: list[tuple[int, int]]) -> None: - rem = count - while rem > 0: - self._ensure_cursor(si, seq_len) - bc = int(self._cursor_block_count[si]) - ni = int(self._cursor_next[si]) - take = min(rem, bc - ni) - phase = int(self._cursor_phase[si]) - start = int(self._cursor_start[si]) - stride = int(self._cursor_stride[si]) - for j in range(take): - bi = (start + (ni + j) * stride) % bc - out.append((si, phase + bi * seq_len)) - self._cursor_next[si] = ni + take - rem -= take - def _init_pipeline(self, global_tokens: int, seq_len: int, grad_accum_steps: int) -> None: - local_tokens = global_tokens // (self.world_size * grad_accum_steps) - num_seqs = local_tokens // seq_len - global_num_seqs = num_seqs * self.world_size - self._cfg = (local_tokens, seq_len, num_seqs, global_num_seqs) - bbc = (self._num_tokens - 1) // seq_len - eligible = bbc > 0 - self._eligible_shards = np.nonzero(eligible)[0].astype(np.int64) - self._base_block_counts = bbc[self._eligible_shards].astype(np.int64) - def _sample_global_windows(self) -> list[tuple[int, int]]: - assert self._cfg is not None and self._eligible_shards is not None - _, seq_len, _, gns = self._cfg - ec = int(self._eligible_shards.size) - progress = min(self._batches_built / 1800.0, 1.0) - remaining = np.empty(ec, dtype=np.float64) - for i, si in enumerate(self._eligible_shards.tolist()): - if self._cursor_init[si]: - r = int(self._cursor_block_count[si]) - int(self._cursor_next[si]) - remaining[i] = float(max(r, 1)) - else: - remaining[i] = float(self._base_block_counts[i]) - alpha = 0.90 - 0.40 * progress - weights = np.power(remaining, alpha) - ws = float(weights.sum()) - if not np.isfinite(ws) or ws <= 0.0: - weights = np.ones(ec, dtype=np.float64) - ws = float(weights.sum()) - probs = weights / ws - low = min(max(8, self.world_size), ec, gns) - high = min(max(32, self.world_size * 8), ec, gns) - mix = max(1, min(int(round(low + progress * (high - low))), ec, gns)) - cp = self._rng.choice(ec, size=mix, replace=False, p=probs) - cs = self._eligible_shards[cp] - cpr = probs[cp].copy() - cpr /= cpr.sum() - counts = np.ones(mix, dtype=np.int64) - extra = gns - mix - if extra > 0: - counts += self._rng.multinomial(extra, cpr).astype(np.int64) - perm = self._rng.permutation(mix) - cs, counts = cs[perm], counts[perm] - buckets: list[list[tuple[int, int]]] = [] - for si, cnt in zip(cs.tolist(), counts.tolist()): - b: list[tuple[int, int]] = [] - self._take_from_shard(int(si), seq_len, int(cnt), b) - if b: - if len(b) > 1: - bp = self._rng.permutation(len(b)) - b = [b[int(k)] for k in bp.tolist()] - buckets.append(b) - windows: list[tuple[int, int]] = [] - active = [i for i, bk in enumerate(buckets) if bk] - while active: - order = self._rng.permutation(len(active)) - new_active: list[int] = [] - for oi in order.tolist(): - bi = active[oi] - if buckets[bi]: - windows.append(buckets[bi].pop()) - if buckets[bi]: - new_active.append(bi) - active = new_active - return windows - def next_batch(self, global_tokens: int, seq_len: int, grad_accum_steps: int) -> tuple[Tensor, Tensor]: - if self._cfg is None: - self._init_pipeline(global_tokens, seq_len, grad_accum_steps) - _, _, num_seqs, gns = self._cfg - gw = self._sample_global_windows() - local_w = gw[self.rank::self.world_size] - x = torch.empty((num_seqs, seq_len), dtype=torch.int64) - y = torch.empty((num_seqs, seq_len), dtype=torch.int64) - for slot, (si, pos) in enumerate(local_w): - mm = _get_shard_memmap(self.files[si]) - window = torch.as_tensor(np.array(mm[pos:pos + seq_len + 1], dtype=np.int64)) - x[slot] = window[:-1] - y[slot] = window[1:] - self._batches_built += 1 - return x.to(self.device, non_blocking=True), y.to(self.device, non_blocking=True) - -# --- Transformer modules --- - -class RMSNorm(nn.Module): - def __init__(self, eps: float | None = None): - super().__init__() - self.eps = eps - def forward(self, x: Tensor) -> Tensor: - return F.rms_norm(x, (x.size(-1),), eps=self.eps) -class CastedLinear(nn.Linear): - def forward(self, x: Tensor) -> Tensor: - w = self.weight.to(x.dtype) - bias = self.bias.to(x.dtype) if self.bias is not None else None - return F.linear(x, w, bias) -def restore_low_dim_params_to_fp32(module: nn.Module) -> None: - with torch.no_grad(): - for name, param in module.named_parameters(): - if (param.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS)) and param.dtype != torch.float32: - param.data = param.data.float() -class Rotary(nn.Module): - def __init__(self, dim: int, base: float = 10000.0, train_seq_len: int = 1024, rope_dims: int = 0): - super().__init__() - self.dim = dim - self.base = base - self.train_seq_len = train_seq_len - self.rope_dims = rope_dims if rope_dims > 0 else dim - inv_freq = 1.0 / (base ** (torch.arange(0, self.rope_dims, 2, dtype=torch.float32) / self.rope_dims)) - self.register_buffer("inv_freq", inv_freq, persistent=False) - self._seq_len_cached = 0 - self._cos_cached: Tensor | None = None - self._sin_cached: Tensor | None = None - def forward(self, seq_len: int, device: torch.device, dtype: torch.dtype) -> tuple[Tensor, Tensor]: - if ( - self._cos_cached is None - or self._sin_cached is None - or self._seq_len_cached != seq_len - or self._cos_cached.device != device - ): - rd = self.rope_dims - if seq_len > self.train_seq_len: - scale = seq_len / self.train_seq_len - new_base = self.base * (scale ** (rd / (rd - 2))) - inv_freq = 1.0 / (new_base ** (torch.arange(0, rd, 2, dtype=torch.float32, device=device) / rd)) - else: - inv_freq = self.inv_freq.to(device) - t = torch.arange(seq_len, device=device, dtype=inv_freq.dtype) - freqs = torch.outer(t, inv_freq) - self._cos_cached = freqs.cos()[None, :, None, :] - self._sin_cached = freqs.sin()[None, :, None, :] - self._seq_len_cached = seq_len - return self._cos_cached.to(dtype=dtype), self._sin_cached.to(dtype=dtype) -def apply_rotary_emb(x: Tensor, cos: Tensor, sin: Tensor, rope_dims: int = 0) -> Tensor: - if rope_dims > 0 and rope_dims < x.size(-1): - x_rope, x_pass = x[..., :rope_dims], x[..., rope_dims:] - half = rope_dims // 2 - x1, x2 = x_rope[..., :half], x_rope[..., half:] - x_rope = torch.cat((x1 * cos + x2 * sin, x1 * (-sin) + x2 * cos), dim=-1) - return torch.cat((x_rope, x_pass), dim=-1) - half = x.size(-1) // 2 - x1, x2 = x[..., :half], x[..., half:] - return torch.cat((x1 * cos + x2 * sin, x1 * (-sin) + x2 * cos), dim=-1) - -class CausalSelfAttention(nn.Module): - def __init__( - self, - dim: int, - num_heads: int, - num_kv_heads: int, - rope_base: float, - qk_gain_init: float, - ): - super().__init__() - if dim % num_heads != 0: - raise ValueError("model_dim must be divisible by num_heads") - if num_heads % num_kv_heads != 0: - raise ValueError("num_heads must be divisible by num_kv_heads") - self.num_heads = num_heads - self.num_kv_heads = num_kv_heads - self.head_dim = dim // num_heads - if self.head_dim % 2 != 0: - raise ValueError("head_dim must be even for RoPE") - # No CastedLinear -- weights come from banks - self.q_gain = nn.Parameter(torch.full((num_heads,), qk_gain_init, dtype=torch.float32)) - self.rope_dims = 0 # set by GPT.__init__ for partial RoPE - self.rotary = Rotary(self.head_dim, base=rope_base, train_seq_len=1024) - self.use_xsa = False # set by GPT.__init__ for deep layers only - def _xsa_efficient(self, y: Tensor, v: Tensor) -> Tensor: - """Efficient XSA: subtract self-value projection via GQA-aware reshape (no repeat_interleave). - y: [B, T, H, D], v: [B, T, Hkv, D]. H must be divisible by Hkv.""" - B, T, H, D = y.shape - Hkv = v.size(-2) - group = H // Hkv - y_g = y.reshape(B, T, Hkv, group, D) # [B, T, Hkv, group, D] - vn = F.normalize(v, dim=-1).unsqueeze(-2) # [B, T, Hkv, 1, D] -- broadcast ready - proj = (y_g * vn).sum(dim=-1, keepdim=True) * vn - return (y_g - proj).reshape(B, T, H, D) - def forward(self, x: Tensor, q_w: Tensor, k_w: Tensor, v_w: Tensor, out_w: Tensor, v_embed: Tensor | None = None) -> Tensor: - if getattr(self, '_save_gptq', False): - self._gptq_qkv_in = x.detach() - bsz, seqlen, dim = x.shape - q = F.linear(x, q_w.to(x.dtype)).reshape(bsz, seqlen, self.num_heads, self.head_dim) - k = F.linear(x, k_w.to(x.dtype)).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim) - v = F.linear(x, v_w.to(x.dtype)) - if v_embed is not None: - v = v + v_embed - v = v.reshape(bsz, seqlen, self.num_kv_heads, self.head_dim) - q = F.rms_norm(q, (q.size(-1),)) - k = F.rms_norm(k, (k.size(-1),)) - cos, sin = self.rotary(seqlen, x.device, q.dtype) - q = apply_rotary_emb(q, cos, sin, self.rope_dims) - k = apply_rotary_emb(k, cos, sin, self.rope_dims) - q = q * self.q_gain.to(dtype=q.dtype)[None, None, :, None] - y = flash_attn_3_func(q, k, v, causal=True) - if self.use_xsa: - y = self._xsa_efficient(y, v) - y = y.reshape(bsz, seqlen, dim) - if getattr(self, '_save_gptq', False): - self._gptq_o_in = y.detach() - return F.linear(y, out_w.to(x.dtype)) - -class SmearGate(nn.Module): - def __init__(self, dim: int): - super().__init__() - self.gate = nn.Parameter(torch.zeros(dim, dtype=torch.float32)) - def forward(self, x: Tensor) -> Tensor: - g = torch.sigmoid(self.gate.to(dtype=x.dtype))[None, None, :] - x_prev = torch.cat([torch.zeros_like(x[:, :1]), x[:, :-1]], dim=1) - return (1 - g) * x + g * x_prev - -class BigramHashEmbedding(nn.Module): - def __init__(self, bigram_vocab_size: int, bigram_dim: int, model_dim: int): - super().__init__() - self.bigram_vocab_size = bigram_vocab_size - self.embed = nn.Embedding(bigram_vocab_size, bigram_dim) - nn.init.zeros_(self.embed.weight) - self.proj = CastedLinear(bigram_dim, model_dim, bias=False) if bigram_dim != model_dim else None - if self.proj is not None: - nn.init.zeros_(self.proj.weight) - self.scale = nn.Parameter(torch.tensor(0.05, dtype=torch.float32)) - def bigram_hash(self, tokens: Tensor) -> Tensor: - t = tokens.to(torch.int32) - mod = self.bigram_vocab_size - 1 - out = torch.empty_like(t) - out[..., 0] = mod - out[..., 1:] = torch.bitwise_xor(36313 * t[..., 1:], 27191 * t[..., :-1]) % mod - return out.long() - def forward(self, token_ids: Tensor) -> Tensor: - h = self.embed(self.bigram_hash(token_ids)) - if self.proj is not None: - h = self.proj(h) - return h * self.scale.to(dtype=h.dtype) - -class ValueEmbedding(nn.Module): - """Reinject token identity into attention values at specific layers. - Each table maps vocab tokens to a low-dim embedding, projected to model_dim.""" - def __init__(self, vocab_size: int, ve_dim: int, model_dim: int): - super().__init__() - self.embed = nn.Embedding(vocab_size, ve_dim) - nn.init.normal_(self.embed.weight, std=0.01) - self.proj = CastedLinear(ve_dim, model_dim, bias=False) if ve_dim != model_dim else None - if self.proj is not None: - nn.init.zeros_(self.proj.weight) - self.scale = nn.Parameter(torch.tensor(0.1, dtype=torch.float32)) - def forward(self, token_ids: Tensor) -> Tensor: - h = self.embed(token_ids) - if self.proj is not None: - h = self.proj(h) - return h * self.scale.to(dtype=h.dtype) - -class MLP(nn.Module): - def __init__(self, dim: int, mlp_mult: int, neg_slope: float = 0.5): - super().__init__() - self.neg_slope = neg_slope - # No CastedLinear -- weights come from banks - def forward(self, x: Tensor, up_w: Tensor, down_w: Tensor) -> Tensor: - if getattr(self, '_save_gptq', False): - self._gptq_up_in = x.detach() - x = F.leaky_relu(F.linear(x, up_w.to(x.dtype)), negative_slope=self.neg_slope) - x2 = x.square() - if getattr(self, '_save_gptq', False): - self._gptq_down_in = x2.detach() - return F.linear(x2, down_w.to(x.dtype)) - -class Block(nn.Module): - def __init__( - self, - dim: int, - num_heads: int, - num_kv_heads: int, - mlp_mult: int, - rope_base: float, - qk_gain_init: float, - layer_idx: int = 0, - ln_scale: bool = False, - neg_slope: float = 0.5, - ): - super().__init__() - self.layer_idx = layer_idx - self.attn_norm = RMSNorm() - self.mlp_norm = RMSNorm() - self.attn = CausalSelfAttention(dim, num_heads, num_kv_heads, rope_base, qk_gain_init) - self.mlp = MLP(dim, mlp_mult, neg_slope=neg_slope) - self.attn_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) - self.mlp_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) - self.resid_mix = nn.Parameter(torch.stack((torch.ones(dim), torch.zeros(dim))).float()) - self.ln_scale_factor = 1.0 / math.sqrt(layer_idx + 1) if ln_scale else 1.0 - def forward(self, x: Tensor, x0: Tensor, q_w: Tensor, k_w: Tensor, v_w: Tensor, out_w: Tensor, up_w: Tensor, down_w: Tensor, v_embed: Tensor | None = None) -> Tensor: - mix = self.resid_mix.to(dtype=x.dtype) - x_in = mix[0][None, None, :] * x + mix[1][None, None, :] * x0 - attn_out = self.attn(self.attn_norm(x_in) * self.ln_scale_factor, q_w, k_w, v_w, out_w, v_embed=v_embed) - x_out = x_in + self.attn_scale.to(dtype=x_in.dtype)[None, None, :] * attn_out - mlp_out = self.mlp_scale.to(dtype=x_out.dtype)[None, None, :] * self.mlp(self.mlp_norm(x_out) * self.ln_scale_factor, up_w, down_w) - return x_out + mlp_out - -class GPT(nn.Module): - def __init__( - self, - vocab_size: int, - num_layers: int, - model_dim: int, - num_heads: int, - num_kv_heads: int, - mlp_mult: int, - tie_embeddings: bool, - tied_embed_init_std: float, - logit_softcap: float, - rope_base: float, - qk_gain_init: float, - bigram_vocab_size: int = 0, - bigram_dim: int = 128, - xsa_last_n: int = 0, - rope_dims: int = 0, - ln_scale: bool = False, - ve_enabled: bool = False, - ve_dim: int = 128, - ve_layers: str = "9,10", - neg_slope: float = 0.5, - ): - super().__init__() - self._ve_target_dim = num_kv_heads * (model_dim // num_heads) # kv_dim for value projection - if logit_softcap <= 0.0: - raise ValueError(f"logit_softcap must be positive, got {logit_softcap}") - self.tie_embeddings = tie_embeddings - self.tied_embed_init_std = tied_embed_init_std - self.logit_softcap = logit_softcap - self.tok_emb = nn.Embedding(vocab_size, model_dim) - self.bigram = BigramHashEmbedding(bigram_vocab_size, bigram_dim, model_dim) if bigram_vocab_size > 0 else None - self.smear = SmearGate(model_dim) - self.num_encoder_layers = num_layers // 2 - self.num_decoder_layers = num_layers - self.num_encoder_layers - self.num_skip_weights = min(self.num_encoder_layers, self.num_decoder_layers) - self.skip_weights = nn.Parameter(torch.ones(self.num_skip_weights, model_dim, dtype=torch.float32)) - # Parameter banks: contiguous 3D tensors for batched optimizer - head_dim = model_dim // num_heads - kv_dim = num_kv_heads * head_dim - mlp_dim = int(mlp_mult * model_dim) - self.num_layers = num_layers - self.qo_bank = nn.Parameter(torch.empty(2 * num_layers, model_dim, model_dim)) - self.kv_bank = nn.Parameter(torch.empty(2 * num_layers, kv_dim, model_dim)) - self.mlp_up_bank = nn.Parameter(torch.empty(num_layers, mlp_dim, model_dim)) - self.mlp_down_bank = nn.Parameter(torch.empty(num_layers, model_dim, mlp_dim)) - self.blocks = nn.ModuleList( - [ - Block( - model_dim, - num_heads, - num_kv_heads, - mlp_mult, - rope_base, - qk_gain_init, - layer_idx=i, - ln_scale=ln_scale, - neg_slope=neg_slope, - ) - for i in range(num_layers) - ] - ) - if rope_dims > 0: - head_dim = model_dim // num_heads - for block in self.blocks: - block.attn.rope_dims = rope_dims - block.attn.rotary = Rotary(head_dim, base=rope_base, train_seq_len=1024, rope_dims=rope_dims) - self.ve_layer_indices = [int(x) for x in ve_layers.split(",") if x.strip()] if ve_enabled else [] - kv_dim_ve = self._ve_target_dim - if self.ve_layer_indices: - self.ve_shared = ValueEmbedding(vocab_size, ve_dim, kv_dim_ve) - self.ve_layer_scales = nn.ParameterList( - [nn.Parameter(torch.ones(1, dtype=torch.float32)) for _ in self.ve_layer_indices] - ) - else: - self.ve_shared = None - self.ve_layer_scales = nn.ParameterList() - self.value_embeds = nn.ModuleList() # keep empty for compat - self.final_norm = RMSNorm() - self.lm_head = None if tie_embeddings else CastedLinear(model_dim, vocab_size, bias=False) - if self.lm_head is not None: - self.lm_head._zero_init = True - if xsa_last_n > 0: - for i in range(max(0, num_layers - xsa_last_n), num_layers): - self.blocks[i].attn.use_xsa = True - self._init_weights() - def _init_weights(self) -> None: - if self.tie_embeddings: - nn.init.normal_(self.tok_emb.weight, mean=0.0, std=self.tied_embed_init_std) - n = self.num_layers - proj_scale = 1.0 / math.sqrt(2 * n) - # Init banks: orthogonal, with proj layers scaled down and out/down zero-init - for i in range(n): - nn.init.orthogonal_(self.qo_bank.data[i], gain=1.0) # Q - nn.init.zeros_(self.qo_bank.data[n + i]) # Out (zero init) - nn.init.orthogonal_(self.kv_bank.data[i], gain=1.0) # K - nn.init.orthogonal_(self.kv_bank.data[n + i], gain=1.0) # V - nn.init.orthogonal_(self.mlp_up_bank.data[i], gain=1.0) # MLP up - nn.init.zeros_(self.mlp_down_bank.data[i]) # MLP down (zero init) - # Scale proj layers (out_proj and mlp_down are "proj" layers) - self.qo_bank.data[n + i].mul_(proj_scale) - self.mlp_down_bank.data[i].mul_(proj_scale) - # Init remaining nn.Linear modules (bigram proj, lm_head) - for name, module in self.named_modules(): - if isinstance(module, nn.Linear): - if getattr(module, "_zero_init", False): - nn.init.zeros_(module.weight) - elif module.weight.ndim == 2 and module.weight.shape[0] >= 64 and module.weight.shape[1] >= 64: - nn.init.orthogonal_(module.weight, gain=1.0) - def _get_ve(self, layer_idx: int, input_ids: Tensor, ve_cache: dict | None = None) -> Tensor | None: - """Get value embedding for a specific layer using shared table + per-layer scale.""" - if self.ve_shared is None or layer_idx not in self.ve_layer_indices: - return None - if ve_cache is not None and 've' not in ve_cache: - ve_cache['ve'] = self.ve_shared(input_ids) - ve_base = ve_cache['ve'] if ve_cache is not None else self.ve_shared(input_ids) - ve_idx = self.ve_layer_indices.index(layer_idx) - return ve_base * self.ve_layer_scales[ve_idx].to(dtype=ve_base.dtype) - def forward(self, input_ids: Tensor, target_ids: Tensor) -> Tensor: - n = self.num_layers - x = self.tok_emb(input_ids) - if self.bigram is not None: - x = x + self.bigram(input_ids) - x = F.rms_norm(x, (x.size(-1),)) - x = self.smear(x) - x0 = x - skips: list[Tensor] = [] - ve_cache: dict = {} - for i in range(self.num_encoder_layers): - ve = self._get_ve(i, input_ids, ve_cache) - x = self.blocks[i](x, x0, - self.qo_bank[i], self.kv_bank[i], self.kv_bank[n + i], - self.qo_bank[n + i], self.mlp_up_bank[i], self.mlp_down_bank[i], - v_embed=ve) - skips.append(x) - for i in range(self.num_decoder_layers): - bi = self.num_encoder_layers + i - if skips: - x = x + self.skip_weights[i].to(dtype=x.dtype)[None, None, :] * skips.pop() - ve = self._get_ve(bi, input_ids, ve_cache) - x = self.blocks[bi](x, x0, - self.qo_bank[bi], self.kv_bank[bi], self.kv_bank[n + bi], - self.qo_bank[n + bi], self.mlp_up_bank[bi], self.mlp_down_bank[bi], - v_embed=ve) - x = self.final_norm(x) - x_flat = x.reshape(-1, x.size(-1)) - targets = target_ids.reshape(-1) - if self.tie_embeddings: - logits_proj = F.linear(x_flat, self.tok_emb.weight) - else: - if self.lm_head is None: - raise RuntimeError("lm_head is required when tie_embeddings=False") - logits_proj = self.lm_head(x_flat) - logits = self.logit_softcap * torch.tanh(logits_proj / self.logit_softcap) - return F.cross_entropy(logits.float(), targets, reduction="mean") - def forward_logits(self, input_ids: Tensor) -> Tensor: - """Return logits (bsz, seq_len, vocab) without computing loss.""" - n = self.num_layers - x = self.tok_emb(input_ids) - if self.bigram is not None: - x = x + self.bigram(input_ids) - x = F.rms_norm(x, (x.size(-1),)) - x = self.smear(x) - x0 = x - skips: list[Tensor] = [] - ve_cache: dict = {} - for i in range(self.num_encoder_layers): - ve = self._get_ve(i, input_ids, ve_cache) - x = self.blocks[i](x, x0, - self.qo_bank[i], self.kv_bank[i], self.kv_bank[n + i], - self.qo_bank[n + i], self.mlp_up_bank[i], self.mlp_down_bank[i], - v_embed=ve) - skips.append(x) - for i in range(self.num_decoder_layers): - bi = self.num_encoder_layers + i - if skips: - x = x + self.skip_weights[i].to(dtype=x.dtype)[None, None, :] * skips.pop() - ve = self._get_ve(bi, input_ids, ve_cache) - x = self.blocks[bi](x, x0, - self.qo_bank[bi], self.kv_bank[bi], self.kv_bank[n + bi], - self.qo_bank[n + bi], self.mlp_up_bank[bi], self.mlp_down_bank[bi], - v_embed=ve) - x = self.final_norm(x) - if self.tie_embeddings: - logits_proj = F.linear(x, self.tok_emb.weight) - else: - logits_proj = self.lm_head(x) - return self.logit_softcap * torch.tanh(logits_proj / self.logit_softcap) - -# --- Sliding window evaluation --- - -def eval_val_sliding( - args: Hyperparameters, - base_model: nn.Module, - rank: int, - world_size: int, - device: torch.device, - val_tokens: Tensor, - base_bytes_lut: Tensor, - has_leading_space_lut: Tensor, - is_boundary_token_lut: Tensor, - stride: int, - batch_seqs: int = 32, - eval_seq_len: int | None = None, -) -> tuple[float, float]: - """Sliding window evaluation: each token scored with maximum context.""" - seq_len = eval_seq_len or args.train_seq_len - total_tokens = val_tokens.numel() - 1 - window_starts = [ws for ws in range(0, total_tokens, stride) - if min(ws + seq_len, total_tokens) - ws >= 1] - total_windows = len(window_starts) - my_s = (total_windows * rank) // world_size - my_e = (total_windows * (rank + 1)) // world_size - my_windows = window_starts[my_s:my_e] - loss_sum = torch.zeros((), device=device, dtype=torch.float64) - token_count = torch.zeros((), device=device, dtype=torch.float64) - byte_count = torch.zeros((), device=device, dtype=torch.float64) - base_model.eval() - compiled_logits = torch.compile(base_model.forward_logits, dynamic=False, fullgraph=True) - with torch.inference_mode(): - for bi in range(0, len(my_windows), batch_seqs): - batch_ws = my_windows[bi:bi + batch_seqs] - bsz = len(batch_ws) - x_batch = torch.zeros(bsz, seq_len, dtype=torch.int64, device=device) - y_batch = torch.zeros(bsz, seq_len, dtype=torch.int64, device=device) - wlens: list[int] = [] - for i, ws in enumerate(batch_ws): - end = min(ws + seq_len, total_tokens) - wlen = end - ws - wlens.append(wlen) - chunk = val_tokens[ws:end + 1].to(dtype=torch.int64, device=device) - x_batch[i, :wlen] = chunk[:-1] - y_batch[i, :wlen] = chunk[1:] - with torch.autocast(device_type="cuda", dtype=torch.bfloat16): - logits = compiled_logits(x_batch) - nll = F.cross_entropy( - logits.reshape(-1, logits.size(-1)).float(), - y_batch.reshape(-1), - reduction="none", - ).reshape(bsz, seq_len) - for i, ws in enumerate(batch_ws): - wlen = wlens[i] - s = 0 if ws == 0 else max(wlen - stride, 0) - scored_nll = nll[i, s:wlen].to(torch.float64) - loss_sum += scored_nll.sum() - token_count += float(wlen - s) - tgt = y_batch[i, s:wlen] - prev = x_batch[i, s:wlen] - tb = base_bytes_lut[tgt].to(torch.float64) - tb += (has_leading_space_lut[tgt] & ~is_boundary_token_lut[prev]).to(torch.float64) - byte_count += tb.sum() - if dist.is_available() and dist.is_initialized(): - dist.all_reduce(loss_sum, op=dist.ReduceOp.SUM) - dist.all_reduce(token_count, op=dist.ReduceOp.SUM) - dist.all_reduce(byte_count, op=dist.ReduceOp.SUM) - val_loss = (loss_sum / token_count).item() - bits_per_token = val_loss / math.log(2.0) - tokens_per_byte = token_count.item() / byte_count.item() - base_model.train() - return val_loss, bits_per_token * tokens_per_byte - - -def eval_val_sliding_ttt( - args: Hyperparameters, base_model: nn.Module, rank: int, world_size: int, - device: torch.device, val_tokens: Tensor, base_bytes_lut: Tensor, - has_leading_space_lut: Tensor, is_boundary_token_lut: Tensor, - stride: int, batch_seqs: int = 32, log0=print, -) -> tuple[float, float]: - """Legal score-first TTT (PR #461 recipe): score each chunk with sliding windows, - then train on it. Every token scored BEFORE any update that could use it.""" - seq_len = args.train_seq_len - total_tokens = val_tokens.numel() - 1 - ttt_chunk = args.ttt_chunk_tokens - - # Pre-compute all window starts - window_starts = [ws for ws in range(0, total_tokens, stride) - if min(ws + seq_len, total_tokens) - ws >= stride or ws == 0] - - # Assign each window to a chunk based on the first token it scores - num_chunks = (total_tokens + ttt_chunk - 1) // ttt_chunk - chunk_windows: list[list[int]] = [[] for _ in range(num_chunks)] - for ws in window_starts: - end = min(ws + seq_len, total_tokens) - wlen = end - ws - s = 0 if ws == 0 else max(wlen - stride, 0) - scored_start = ws + s - ci = min(scored_start // ttt_chunk, num_chunks - 1) - chunk_windows[ci].append(ws) - - log0(f"ttt_sliding:start chunks={num_chunks} chunk_tokens={ttt_chunk} " - f"total_windows={len(window_starts)} stride={stride} " - f"ttt_lr={args.ttt_lr} ttt_epochs={args.ttt_epochs} " - f"freeze_blocks={args.ttt_freeze_blocks}") - - loss_sum = torch.zeros((), device=device, dtype=torch.float64) - token_count = torch.zeros((), device=device, dtype=torch.float64) - byte_count = torch.zeros((), device=device, dtype=torch.float64) - - # Freeze first N blocks - frozen_block_ids = set(range(min(args.ttt_freeze_blocks, len(base_model.blocks)))) - ttt_params = [] - for name, p in base_model.named_parameters(): - freeze = False - for bi in frozen_block_ids: - if f"blocks.{bi}." in name: - freeze = True - break - if freeze: - p.requires_grad_(False) - else: - p.requires_grad_(True) - ttt_params.append(p) - - log0(f"ttt_sliding:params unfrozen={sum(p.numel() for p in ttt_params)} " - f"frozen={sum(p.numel() for p in base_model.parameters() if not p.requires_grad)}") - - optimizer = torch.optim.SGD(ttt_params, lr=args.ttt_lr, momentum=args.ttt_momentum) - t0 = time.perf_counter() - - for ci in range(num_chunks): - windows = chunk_windows[ci] - if not windows: - continue - chunk_start = ci * ttt_chunk - chunk_end = min((ci + 1) * ttt_chunk, total_tokens) - - # --- Phase 1: SCORE this chunk's windows (inference_mode) --- - my_s = (len(windows) * rank) // world_size - my_e = (len(windows) * (rank + 1)) // world_size - my_windows = windows[my_s:my_e] - - base_model.eval() - with torch.inference_mode(): - for bi in range(0, len(my_windows), batch_seqs): - batch_ws = my_windows[bi:bi + batch_seqs] - bsz = len(batch_ws) - x_batch = torch.zeros(bsz, seq_len, dtype=torch.int64, device=device) - y_batch = torch.zeros(bsz, seq_len, dtype=torch.int64, device=device) - wlens: list[int] = [] - for i, ws in enumerate(batch_ws): - end = min(ws + seq_len, total_tokens) - wlen = end - ws - wlens.append(wlen) - chunk_tok = val_tokens[ws:end + 1].to(dtype=torch.int64, device=device) - x_batch[i, :wlen] = chunk_tok[:-1] - y_batch[i, :wlen] = chunk_tok[1:] - with torch.autocast(device_type="cuda", dtype=torch.bfloat16): - logits = base_model.forward_logits(x_batch) - nll = F.cross_entropy( - logits.reshape(-1, logits.size(-1)).float(), - y_batch.reshape(-1), reduction="none", - ).reshape(bsz, seq_len) - for i, ws in enumerate(batch_ws): - wlen = wlens[i] - s = 0 if ws == 0 else max(wlen - stride, 0) - scored_nll = nll[i, s:wlen].to(torch.float64) - loss_sum += scored_nll.sum() - token_count += float(wlen - s) - tgt, prev = y_batch[i, s:wlen], x_batch[i, s:wlen] - tb = base_bytes_lut[tgt].to(torch.float64) - tb += (has_leading_space_lut[tgt] & ~is_boundary_token_lut[prev]).to(torch.float64) - byte_count += tb.sum() - - # --- Phase 2: TRAIN on this chunk (already scored = legal) --- - is_last_chunk = (ci == num_chunks - 1) - if not is_last_chunk and args.ttt_epochs > 0: - base_model.train() - chunk_seqs = (chunk_end - chunk_start) // seq_len - if chunk_seqs > 0: - cos_lr = args.ttt_lr * 0.5 * (1.0 + math.cos(math.pi * ci / max(num_chunks - 1, 1))) - for pg in optimizer.param_groups: - pg['lr'] = cos_lr - my_seq_s = (chunk_seqs * rank) // world_size - my_seq_e = (chunk_seqs * (rank + 1)) // world_size - my_chunk_seqs = my_seq_e - my_seq_s - for _ep in range(args.ttt_epochs): - for bs in range(0, my_chunk_seqs, args.ttt_batch_seqs): - be = min(bs + args.ttt_batch_seqs, my_chunk_seqs) - actual_bs = my_seq_s + bs - start_tok = chunk_start + actual_bs * seq_len - end_tok = chunk_start + (my_seq_s + be) * seq_len + 1 - if end_tok > val_tokens.numel(): - continue - local = val_tokens[start_tok:end_tok].to(device=device, dtype=torch.int64) - x = local[:-1].reshape(-1, seq_len) - y = local[1:].reshape(-1, seq_len) - optimizer.zero_grad(set_to_none=True) - with torch.autocast(device_type="cuda", dtype=torch.bfloat16): - loss = base_model(x, y) - loss.backward() - if world_size > 1: - for p in ttt_params: - if p.grad is not None: - dist.all_reduce(p.grad, op=dist.ReduceOp.AVG) - torch.nn.utils.clip_grad_norm_(ttt_params, args.ttt_grad_clip) - optimizer.step() - - if rank == 0 and (ci % 10 == 0 or ci == num_chunks - 1): - elapsed = time.perf_counter() - t0 - rl = loss_sum.item() / max(token_count.item(), 1) - rbpb = rl / math.log(2.0) * (token_count.item() / max(byte_count.item(), 1)) if token_count.item() > 0 else 0.0 - log0(f" ttt_chunk [{ci+1}/{num_chunks}] bpb={rbpb:.6f} time={elapsed:.1f}s") - - if dist.is_available() and dist.is_initialized(): - dist.all_reduce(loss_sum, op=dist.ReduceOp.SUM) - dist.all_reduce(token_count, op=dist.ReduceOp.SUM) - dist.all_reduce(byte_count, op=dist.ReduceOp.SUM) - - val_loss = (loss_sum / token_count).item() - val_bpb = val_loss / math.log(2.0) * (token_count.item() / byte_count.item()) - - for p in base_model.parameters(): - p.requires_grad_(True) - base_model.eval() - - log0(f"ttt_sliding:done val_loss={val_loss:.6f} val_bpb={val_bpb:.6f} " - f"elapsed={time.perf_counter() - t0:.1f}s") - return val_loss, val_bpb - - -# --- GPTQ-lite int6 quantization --- - -def _classify_param(name: str) -> str: - if "tok_emb" in name or "lm_head" in name: - return "embed" - if ".mlp." in name: - return "mlp" - if ".attn." in name or (".proj." in name and ".mlp." not in name): - return "attn" - return "other" -def quantize_int6_per_row(t: Tensor, clip_range: int = 31) -> tuple[Tensor, Tensor]: - t32 = t.float() - if t32.ndim == 2: - best_q, best_s, best_err = None, None, float('inf') - for pct in [0.9990, 0.9995, 0.9999, 0.99999, 1.0]: - if pct < 1.0: - row_clip = torch.quantile(t32.abs(), pct, dim=1) - else: - row_clip = t32.abs().amax(dim=1) - s = (row_clip / clip_range).clamp_min(1.0 / clip_range).to(torch.float16) - q = torch.clamp(torch.round(t32 / s.float()[:, None]), -clip_range, clip_range).to(torch.int8) - recon = q.float() * s.float()[:, None] - err = (t32 - recon).pow(2).mean().item() - if err < best_err: - best_q, best_s, best_err = q, s, err - return best_q, best_s - amax = t32.abs().max().item() - scale = torch.tensor(amax / clip_range if amax > 0 else 1.0, dtype=torch.float16) - q = torch.clamp(torch.round(t32 / scale.float()), -clip_range, clip_range).to(torch.int8) - return q, scale - -def _unbank_state_dict(sd: dict[str, Tensor], num_layers: int) -> dict[str, Tensor]: - """Convert 3D bank tensors into individual 2D tensors with standard names.""" - out: dict[str, Tensor] = {} - n = num_layers - for name, tensor in sd.items(): - if name == "qo_bank": - for i in range(n): - out[f"blocks.{i}.attn.c_q.weight"] = tensor[i] - out[f"blocks.{i}.attn.proj.weight"] = tensor[n + i] - elif name == "kv_bank": - for i in range(n): - out[f"blocks.{i}.attn.c_k.weight"] = tensor[i] - out[f"blocks.{i}.attn.c_v.weight"] = tensor[n + i] - elif name == "mlp_up_bank": - for i in range(n): - out[f"blocks.{i}.mlp.fc.weight"] = tensor[i] - elif name == "mlp_down_bank": - for i in range(n): - out[f"blocks.{i}.mlp.proj.weight"] = tensor[i] - else: - out[name] = tensor - return out - -def _rebank_state_dict(sd: dict[str, Tensor], num_layers: int, template_sd: dict[str, Tensor]) -> dict[str, Tensor]: - """Convert individual 2D tensors back into 3D bank tensors.""" - out: dict[str, Tensor] = {} - n = num_layers - # Reconstruct banks from individual weight keys - qo_slices = [None] * (2 * n) - kv_slices = [None] * (2 * n) - up_slices = [None] * n - down_slices = [None] * n - consumed = set() - for i in range(n): - qk = f"blocks.{i}.attn.c_q.weight" - if qk in sd: - qo_slices[i] = sd[qk] - consumed.add(qk) - ok = f"blocks.{i}.attn.proj.weight" - if ok in sd: - qo_slices[n + i] = sd[ok] - consumed.add(ok) - kk = f"blocks.{i}.attn.c_k.weight" - if kk in sd: - kv_slices[i] = sd[kk] - consumed.add(kk) - vk = f"blocks.{i}.attn.c_v.weight" - if vk in sd: - kv_slices[n + i] = sd[vk] - consumed.add(vk) - fk = f"blocks.{i}.mlp.fc.weight" - if fk in sd: - up_slices[i] = sd[fk] - consumed.add(fk) - dk = f"blocks.{i}.mlp.proj.weight" - if dk in sd: - down_slices[i] = sd[dk] - consumed.add(dk) - out["qo_bank"] = torch.stack(qo_slices).to(dtype=template_sd["qo_bank"].dtype) - out["kv_bank"] = torch.stack(kv_slices).to(dtype=template_sd["kv_bank"].dtype) - out["mlp_up_bank"] = torch.stack(up_slices).to(dtype=template_sd["mlp_up_bank"].dtype) - out["mlp_down_bank"] = torch.stack(down_slices).to(dtype=template_sd["mlp_down_bank"].dtype) - for name, tensor in sd.items(): - if name not in consumed: - out[name] = tensor - return out - -def mixed_quantize_int6(state_dict: dict[str, Tensor], int6_cats: set[str], clip_range: int = 31, - hessians: dict[str, Tensor] | None = None): - num_layers_total = max( - (int(k.split(".")[1]) for k in state_dict if k.startswith("blocks.")), - default=0, - ) + 1 - late_k_layers = set(range(num_layers_total - 2, num_layers_total)) - result: dict[str, Tensor] = {} - meta: dict[str, object] = {} - gptq_count, naive_count = 0, 0 - for name, tensor in state_dict.items(): - t = tensor.detach().cpu().contiguous() - cat = _classify_param(name) - if not t.is_floating_point() or t.numel() <= 65536: - result[name] = t.to(torch.float16) if t.is_floating_point() else t - meta[name] = "passthrough" - continue - if any(p in name for p in CONTROL_TENSOR_NAME_PATTERNS): - result[name] = t.float() - meta[name] = "passthrough_ctrl" - continue - if cat in int6_cats and t.ndim >= 1: - H = hessians.get(name) if hessians else None - if H is not None and t.ndim == 2: - q, s = gptq_quantize_weight(t, H.cpu(), clip_range=clip_range) - gptq_count += 1 - else: - q, s = quantize_int6_per_row(t, clip_range=clip_range) - naive_count += 1 - result[name + ".q"] = q - result[name + ".scale"] = s - meta[name] = {"type": "int6"} - else: - q, s = quantize_float_tensor(t) - result[name + ".q"] = q - result[name + ".scale"] = s - meta[name] = {"type": "int8"} - if hessians: - print(f"gptq_quantize: {gptq_count} GPTQ layers, {naive_count} naive layers", flush=True) - return result, meta -def dequantize_mixed_int6(result: dict[str, Tensor], meta: dict[str, object], - template_sd: dict[str, Tensor]) -> dict[str, Tensor]: - out: dict[str, Tensor] = {} - for name, orig in template_sd.items(): - info = meta.get(name) - if info is None: - continue - orig_dtype = orig.dtype - if info in ("passthrough", "passthrough_ctrl", "passthrough_fp16"): - t = result[name] - if t.dtype == torch.float16 and orig_dtype in (torch.float32, torch.bfloat16): - t = t.to(orig_dtype) - out[name] = t - continue - q, s = result[name + ".q"], result[name + ".scale"] - if s.ndim > 0: - out[name] = (q.float() * s.float().view(q.shape[0], *([1] * (q.ndim - 1)))).to(orig_dtype) - else: - out[name] = (q.float() * float(s.item())).to(orig_dtype) - return out - -# --- Full Hessian GPTQ --- - -def gptq_quantize_weight(W: Tensor, H: Tensor, clip_range: int = 31, - block_size: int = 128, percdamp: float = 0.01) -> tuple[Tensor, Tensor]: - """GPTQ with Cholesky error compensation and actorder (Frantar et al., ICLR 2023).""" - W_orig = W.float().clone() - rows, cols = W_orig.shape - H = H.float().clone() - dead = torch.diag(H) == 0 - H[dead, dead] = 1 - damp = percdamp * H.diag().mean() - H.diagonal().add_(damp) - perm = torch.argsort(H.diag(), descending=True) - invperm = torch.argsort(perm) - W_perm = W_orig[:, perm].clone() - W_perm[:, dead[perm]] = 0 - H = H[perm][:, perm] - try: - Hinv = torch.cholesky_inverse(torch.linalg.cholesky(H)) - Hinv = torch.linalg.cholesky(Hinv, upper=True) - except torch.linalg.LinAlgError: - return quantize_int6_per_row(W_orig, clip_range) - best_q, best_scale, best_err = None, None, float('inf') - for pct in [0.9990, 0.9995, 0.9999, 0.99999, 1.0]: - if pct < 1.0: - row_clip = torch.quantile(W_orig.abs(), pct, dim=1) - else: - row_clip = W_orig.abs().amax(dim=1) - s = (row_clip / clip_range).clamp_min(1.0 / clip_range).to(torch.float16) - sf = s.float() - Q = torch.zeros(rows, cols, dtype=torch.int8) - W_work = W_perm.clone() - for i1 in range(0, cols, block_size): - i2 = min(i1 + block_size, cols) - W_block = W_work[:, i1:i2].clone() - Hinv_block = Hinv[i1:i2, i1:i2] - Err = torch.zeros(rows, i2 - i1) - for j in range(i2 - i1): - w_col = W_block[:, j] - d = Hinv_block[j, j] - q_col = torch.clamp(torch.round(w_col / sf), -clip_range, clip_range) - Q[:, i1 + j] = q_col.to(torch.int8) - err = (w_col - q_col.float() * sf) / d - Err[:, j] = err - W_block[:, j:] -= err.unsqueeze(1) * Hinv_block[j, j:].unsqueeze(0) - if i2 < cols: - W_work[:, i2:] -= Err @ Hinv[i1:i2, i2:] - recon = Q.float() * sf[:, None] - mse = (W_perm - recon).pow(2).mean().item() - if mse < best_err: - best_q, best_scale, best_err = Q, s, mse - best_q = best_q[:, invperm] - return best_q, best_scale - -def _init_hessians(nl: int, dim: int, mlp_dim: int, device: torch.device) -> dict[str, Tensor]: - h: dict[str, Tensor] = {} - for i in range(nl): - for k in ['c_q', 'c_k', 'c_v']: - h[f'blocks.{i}.attn.{k}.weight'] = torch.zeros(dim, dim, dtype=torch.float32, device=device) - h[f'blocks.{i}.attn.proj.weight'] = torch.zeros(dim, dim, dtype=torch.float32, device=device) - h[f'blocks.{i}.mlp.fc.weight'] = torch.zeros(dim, dim, dtype=torch.float32, device=device) - h[f'blocks.{i}.mlp.proj.weight'] = torch.zeros(mlp_dim, mlp_dim, dtype=torch.float32, device=device) - return h - -def _accum_hessians(hessians: dict[str, Tensor], blocks: nn.ModuleList, dim: int, mlp_dim: int) -> None: - for i, block in enumerate(blocks): - qkv_in = block.attn._gptq_qkv_in.float().reshape(-1, dim) - h_qkv = qkv_in.t() @ qkv_in - hessians[f'blocks.{i}.attn.c_q.weight'] += h_qkv - hessians[f'blocks.{i}.attn.c_k.weight'] += h_qkv - hessians[f'blocks.{i}.attn.c_v.weight'] += h_qkv - o_in = block.attn._gptq_o_in.float().reshape(-1, dim) - hessians[f'blocks.{i}.attn.proj.weight'] += o_in.t() @ o_in - up_in = block.mlp._gptq_up_in.float().reshape(-1, dim) - hessians[f'blocks.{i}.mlp.fc.weight'] += up_in.t() @ up_in - down_in = block.mlp._gptq_down_in.float().reshape(-1, mlp_dim) - hessians[f'blocks.{i}.mlp.proj.weight'] += down_in.t() @ down_in - -def _finalize_hessians(hessians: dict[str, Tensor], num_batches: int) -> None: - for name in hessians: - hessians[name] = hessians[name].cpu() / num_batches - damp = 0.01 * torch.diag(hessians[name]).mean().clamp_min(1e-6) - hessians[name] += damp * torch.eye(hessians[name].shape[0]) - -def gptq_collect_hessians(base_model: nn.Module, train_loader, device: torch.device, - num_batches: int, batch_tokens: int, seq_len: int, - grad_accum_steps: int) -> dict[str, Tensor]: - """Collect Hessians H = X^T X from training data.""" - nl = base_model.num_layers - dim = base_model.tok_emb.weight.shape[1] - mlp_dim = base_model.mlp_up_bank.shape[1] - hessians = _init_hessians(nl, dim, mlp_dim, device) - for block in base_model.blocks: - block.attn._save_gptq = True - block.mlp._save_gptq = True - base_model.eval() - with torch.inference_mode(), torch.autocast(device_type='cuda', dtype=torch.bfloat16): - for _ in range(num_batches): - x, y = train_loader.next_batch(batch_tokens, seq_len, grad_accum_steps) - base_model(x, y) - _accum_hessians(hessians, base_model.blocks, dim, mlp_dim) - for block in base_model.blocks: - block.attn._save_gptq = False - block.mlp._save_gptq = False - _finalize_hessians(hessians, num_batches) - base_model.train() - return hessians - -# --- Training --- - -def main() -> None: - code = Path(__file__).read_text(encoding="utf-8") - args = Hyperparameters() - # zeropower_via_newtonschulz5 runs eagerly with bmm -- do NOT compile - distributed = "RANK" in os.environ and "WORLD_SIZE" in os.environ - rank = int(os.environ.get("RANK", "0")) - world_size = int(os.environ.get("WORLD_SIZE", "1")) - local_rank = int(os.environ.get("LOCAL_RANK", "0")) - if world_size <= 0: - raise ValueError(f"WORLD_SIZE must be positive, got {world_size}") - if 8 % world_size != 0: - raise ValueError(f"WORLD_SIZE={world_size} must divide 8 so grad_accum_steps stays integral") - grad_accum_steps = 8 // world_size - grad_scale = 1.0 / grad_accum_steps - if not torch.cuda.is_available(): - raise RuntimeError("CUDA is required") - device = torch.device("cuda", local_rank) - torch.cuda.set_device(device) - if distributed: - dist.init_process_group(backend="nccl", device_id=device) - dist.barrier() - master_process = rank == 0 - torch.backends.cuda.matmul.allow_tf32 = True - torch.backends.cudnn.allow_tf32 = True - from torch.backends.cuda import enable_cudnn_sdp, enable_flash_sdp, enable_math_sdp, enable_mem_efficient_sdp - enable_cudnn_sdp(False) - enable_flash_sdp(True) - enable_mem_efficient_sdp(False) - enable_math_sdp(False) - logfile = None - if master_process: - os.makedirs("logs", exist_ok=True) - logfile = f"logs/{args.run_id}.txt" - print(logfile) - def log0(msg: str, console: bool = True) -> None: - if not master_process: - return - if console: - print(msg) - if logfile is not None: - with open(logfile, "a", encoding="utf-8") as f: - print(msg, file=f) - log0(code, console=False) - log0("=" * 100, console=False) - log0(f"Running Python {sys.version}", console=False) - log0(f"Running PyTorch {torch.__version__}", console=False) - log0( - subprocess.run(["nvidia-smi"], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True, check=False).stdout, - console=False, - ) - log0("=" * 100, console=False) - random.seed(args.seed) - np.random.seed(args.seed) - torch.manual_seed(args.seed) - torch.cuda.manual_seed_all(args.seed) - (base_bytes_lut, has_leading_space_lut, is_boundary_token_lut), tokenizer_metadata = load_tokenizer_luts( - args.tokenizer_path, args.tokenizer_meta_path, args.vocab_size, device, - validate_meta=args.tokenizer_meta_validate, - ) - log0(f"tokenizer: kind={tokenizer_metadata.get('tokenizer_kind', 'unknown')} vocab={tokenizer_metadata.get('vocab_size', '?')}") - if tokenizer_metadata.get('tokenizer_kind') == 'sentencepiece': - sp = spm.SentencePieceProcessor(model_file=args.tokenizer_path) - if int(sp.vocab_size()) != args.vocab_size: - raise ValueError( - f"VOCAB_SIZE={args.vocab_size} does not match tokenizer vocab_size={int(sp.vocab_size())}" - ) - dataset_dir = Path(args.data_path).resolve() - actual_train_files = len(list(dataset_dir.glob("fineweb_train_*.bin"))) - effective_eval_seq_len = args.eval_seq_len if args.eval_seq_len > 0 else args.train_seq_len - val_seq_len = max(args.train_seq_len, effective_eval_seq_len) - val_tokens = load_validation_tokens(args.val_files, val_seq_len) - log0(f"train_loader:dataset:{dataset_dir.name} train_shards:{actual_train_files}") - log0(f"val_loader:shards pattern={args.val_files} tokens:{val_tokens.numel() - 1}") - base_model = GPT( - vocab_size=args.vocab_size, - num_layers=args.num_layers, - model_dim=args.model_dim, - num_heads=args.num_heads, - num_kv_heads=args.num_kv_heads, - mlp_mult=args.mlp_mult, - tie_embeddings=args.tie_embeddings, - tied_embed_init_std=args.tied_embed_init_std, - logit_softcap=args.logit_softcap, - rope_base=args.rope_base, - qk_gain_init=args.qk_gain_init, - bigram_vocab_size=args.bigram_vocab_size, - bigram_dim=args.bigram_dim, - xsa_last_n=args.xsa_last_n, - rope_dims=args.rope_dims, - ln_scale=args.ln_scale, - ve_enabled=args.ve_enabled, - ve_dim=args.ve_dim, - ve_layers=args.ve_layers, - neg_slope=args.negative_slope, - ).to(device).bfloat16() - # Banks stay FP32 (like CastedLinear weights), cast to BF16 in forward - base_model.qo_bank.data = base_model.qo_bank.data.float() - base_model.kv_bank.data = base_model.kv_bank.data.float() - base_model.mlp_up_bank.data = base_model.mlp_up_bank.data.float() - base_model.mlp_down_bank.data = base_model.mlp_down_bank.data.float() - for module in base_model.modules(): - if isinstance(module, CastedLinear): - module.float() - restore_low_dim_params_to_fp32(base_model) - # No DDP -- Parallel Muon handles bank grad communication via reduce-scatter, - # and non-bank grads are manually all-reduced before Adam steps. - compiled_model = torch.compile(base_model, dynamic=False, fullgraph=True) - model = compiled_model - - # Optimizer split: - # - 4 parameter banks -> Muon (batched Newton-Schulz) - # - token embedding -> Adam - # - scalars/control tensors -> Adam - # - bigram proj, VE proj -> Adam (small matrix params not worth banking) - matrix_params = [ - base_model.qo_bank, base_model.kv_bank, - base_model.mlp_up_bank, base_model.mlp_down_bank, - ] - block_named_params = list(base_model.blocks.named_parameters()) - scalar_params = [ - p - for name, p in block_named_params - if p.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) - ] - if base_model.skip_weights.numel() > 0: - scalar_params.append(base_model.skip_weights) - scalar_params.append(base_model.smear.gate) - if base_model.bigram is not None: - scalar_params.append(base_model.bigram.scale) - token_lr = args.tied_embed_lr if args.tie_embeddings else args.embed_lr - tok_params = [{"params": [base_model.tok_emb.weight], "lr": token_lr, "base_lr": token_lr}] - if base_model.bigram is not None: - tok_params.append({"params": [base_model.bigram.embed.weight], "lr": token_lr, "base_lr": token_lr}) - if base_model.bigram.proj is not None: - scalar_params.append(base_model.bigram.proj.weight) - if base_model.ve_shared is not None: - tok_params.append({"params": [base_model.ve_shared.embed.weight], "lr": token_lr, "base_lr": token_lr}) - if base_model.ve_shared.proj is not None: - scalar_params.append(base_model.ve_shared.proj.weight) - scalar_params.append(base_model.ve_shared.scale) - for s in base_model.ve_layer_scales: - scalar_params.append(s) - optimizer_tok = torch.optim.AdamW( - tok_params, - betas=(args.beta1, args.beta2), - eps=args.adam_eps, - weight_decay=args.adam_wd, - fused=True, - ) - optimizer_muon = Muon( - matrix_params, - lr=args.matrix_lr, - momentum=args.muon_momentum, - backend_steps=args.muon_backend_steps, - weight_decay=args.muon_wd, - ) - for group in optimizer_muon.param_groups: - group["base_lr"] = args.matrix_lr - optimizer_scalar = torch.optim.AdamW( - [{"params": scalar_params, "lr": args.scalar_lr, "base_lr": args.scalar_lr}], - betas=(args.beta1, args.beta2), - eps=args.adam_eps, - weight_decay=args.adam_wd, - fused=True, - ) - # Non-bank params that need manual all-reduce (replicated across GPUs) - replicated_params = list(optimizer_tok.param_groups[0]["params"]) - for pg in optimizer_tok.param_groups[1:]: - replicated_params.extend(pg["params"]) - replicated_params.extend(scalar_params) - - optimizer_head = None - if base_model.lm_head is not None: - optimizer_head = torch.optim.Adam( - [{"params": [base_model.lm_head.weight], "lr": args.head_lr, "base_lr": args.head_lr}], - betas=(args.beta1, args.beta2), - eps=args.adam_eps, - fused=True, - ) - replicated_params.append(base_model.lm_head.weight) - optimizers: list[torch.optim.Optimizer] = [optimizer_tok, optimizer_muon, optimizer_scalar] - if optimizer_head is not None: - optimizers.append(optimizer_head) - log0(f"model_params:{sum(p.numel() for p in base_model.parameters())}") - xsa_layers = [i for i, b in enumerate(base_model.blocks) if b.attn.use_xsa] - log0(f"XSA:last_{args.xsa_last_n} active_layers:{xsa_layers}") - log0(f"world_size:{world_size} grad_accum_steps:{grad_accum_steps}") - log0("sdp_backends:cudnn=False flash=True mem_efficient=False math=False") - log0(f"attention_mode:gqa num_heads:{args.num_heads} num_kv_heads:{args.num_kv_heads}") - log0( - f"tie_embeddings:{args.tie_embeddings} embed_lr:{token_lr} " - f"head_lr:{args.head_lr if base_model.lm_head is not None else 0.0} " - f"matrix_lr:{args.matrix_lr} scalar_lr:{args.scalar_lr}" - ) - log0( - f"train_batch_tokens:{args.train_batch_tokens} train_seq_len:{args.train_seq_len} " - f"iterations:{args.iterations} warmup_steps:{args.warmup_steps} " - f"max_wallclock_seconds:{args.max_wallclock_seconds:.3f}" - ) - log0(f"seed:{args.seed}") - train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) - def zero_grad_all() -> None: - for opt in optimizers: - opt.zero_grad(set_to_none=True) - max_wallclock_ms = 1000.0 * args.max_wallclock_seconds if args.max_wallclock_seconds > 0 else None - if args.use_gptq and max_wallclock_ms is not None: - max_wallclock_ms -= args.gptq_reserve_ms - log0(f"gptq:reserving {args.gptq_reserve_ms:.0f}ms from training budget, effective={max_wallclock_ms:.0f}ms") - def lr_mul(step: int, elapsed_ms: float) -> float: - if args.warmdown_iters <= 0: - return 1.0 - if max_wallclock_ms is None: - warmdown_start = max(args.iterations - args.warmdown_iters, 0) - return max((args.iterations - step) / max(args.warmdown_iters, 1), 0.0) if warmdown_start <= step < args.iterations else 1.0 - step_ms = elapsed_ms / max(step, 1) - warmdown_ms = args.warmdown_iters * step_ms - remaining_ms = max(max_wallclock_ms - elapsed_ms, 0.0) - return remaining_ms / max(warmdown_ms, 1e-9) if remaining_ms <= warmdown_ms else 1.0 - if args.warmup_steps > 0: - initial_model_state = {name: tensor.detach().cpu().clone() for name, tensor in base_model.state_dict().items()} - initial_optimizer_states = [copy.deepcopy(opt.state_dict()) for opt in optimizers] - model.train() - for warmup_step in range(args.warmup_steps): - zero_grad_all() - for micro_step in range(grad_accum_steps): - x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) - with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): - warmup_loss = model(x, y) - (warmup_loss * grad_scale).backward() - # All-reduce all grads for warmup (simple, not optimized) - if distributed: - for p in base_model.parameters(): - if p.grad is not None: - dist.all_reduce(p.grad, op=dist.ReduceOp.AVG) - for opt in optimizers: - opt.step() - zero_grad_all() - if args.warmup_steps <= 20 or (warmup_step + 1) % 10 == 0 or warmup_step + 1 == args.warmup_steps: - log0(f"warmup_step:{warmup_step + 1}/{args.warmup_steps}") - base_model.load_state_dict(initial_model_state, strict=True) - for opt, state in zip(optimizers, initial_optimizer_states, strict=True): - opt.load_state_dict(state) - zero_grad_all() - train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) - swa_state: dict[str, Tensor] | None = None - swa_count = 0 - ema_state = {name: t.detach().float().clone() for name, t in base_model.state_dict().items()} - ema_decay = 0.997 - training_time_ms = 0.0 - stop_after_step: int | None = None - torch.cuda.synchronize() - t0 = time.perf_counter() - step = 0 - while True: - last_step = step == args.iterations or (stop_after_step is not None and step >= stop_after_step) - should_validate = last_step or (args.val_loss_every > 0 and step % args.val_loss_every == 0) - if should_validate: - torch.cuda.synchronize() - training_time_ms += 1000.0 * (time.perf_counter() - t0) - val_loss, val_bpb = eval_val( - args, - model, - rank, - world_size, - device, - grad_accum_steps, - val_tokens, - base_bytes_lut, - has_leading_space_lut, - is_boundary_token_lut, - ) - log0( - f"step:{step}/{args.iterations} val_loss:{val_loss:.4f} val_bpb:{val_bpb:.4f} " - f"train_time:{training_time_ms:.0f}ms step_avg:{training_time_ms / max(step, 1):.2f}ms" - ) - torch.cuda.synchronize() - t0 = time.perf_counter() - if last_step: - if stop_after_step is not None and step < args.iterations: - log0( - f"stopping_early: wallclock_cap train_time:{training_time_ms:.0f}ms " - f"step:{step}/{args.iterations}" - ) - break - elapsed_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) - scale = lr_mul(step, elapsed_ms) - zero_grad_all() - train_loss = torch.zeros((), device=device) - for micro_step in range(grad_accum_steps): - x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) - with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): - loss = model(x, y) - train_loss += loss.detach() - (loss * grad_scale).backward() - train_loss /= grad_accum_steps - frac = min(step / args.muon_momentum_warmup_steps, 1.0) if args.muon_momentum_warmup_steps > 0 else 1.0 - muon_momentum = (1 - frac) * args.muon_momentum_warmup_start + frac * args.muon_momentum - for group in optimizer_muon.param_groups: - group["momentum"] = muon_momentum - for opt in optimizers: - for group in opt.param_groups: - group["lr"] = group["base_lr"] * scale - if args.grad_clip_norm > 0: - torch.nn.utils.clip_grad_norm_(base_model.parameters(), args.grad_clip_norm) - # === 3-phase overlapped optimizer step === - # Phase 1: Launch async reduce-scatter for banks (biggest first) - optimizer_muon.launch_reduce_scatters() - # Phase 2: All-reduce non-bank grads + step Adam (while bank RS is in-flight) - if distributed: - for p in replicated_params: - if p.grad is not None: - dist.all_reduce(p.grad, op=dist.ReduceOp.AVG) - optimizer_tok.step() - optimizer_scalar.step() - if optimizer_head is not None: - optimizer_head.step() - # Phase 3: Wait for RS, local NS5, all-gather (banks processed last) - optimizer_muon.step() - zero_grad_all() - # EMA update - with torch.no_grad(): - for name, t in base_model.state_dict().items(): - ema_state[name].mul_(ema_decay).add_(t.detach().float(), alpha=1.0 - ema_decay) - step += 1 - approx_training_time_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) - if args.swa_enabled and scale < 0.2 and step % args.swa_every == 0: - if swa_state is None: - swa_state = {name: t.detach().cpu().clone() for name, t in base_model.state_dict().items()} - swa_count = 1 - log0(f"swa:start step:{step}") - else: - for name, t in base_model.state_dict().items(): - swa_state[name] += t.detach().cpu() - swa_count += 1 - should_log_train = ( - args.train_log_every > 0 - and (step <= 10 or step % args.train_log_every == 0 or stop_after_step is not None) - ) - if should_log_train: - log0( - f"step:{step}/{args.iterations} train_loss:{train_loss.item():.4f} " - f"train_time:{approx_training_time_ms:.0f}ms step_avg:{approx_training_time_ms / step:.2f}ms" - ) - reached_cap = max_wallclock_ms is not None and approx_training_time_ms >= max_wallclock_ms - if distributed and max_wallclock_ms is not None: - reached_cap_tensor = torch.tensor(int(reached_cap), device=device) - dist.all_reduce(reached_cap_tensor, op=dist.ReduceOp.MAX) - reached_cap = bool(reached_cap_tensor.item()) - if stop_after_step is None and reached_cap: - stop_after_step = step - log0( - f"peak memory allocated: {torch.cuda.max_memory_allocated() // 1024 // 1024} MiB " - f"reserved: {torch.cuda.max_memory_reserved() // 1024 // 1024} MiB" - ) - # Apply EMA weights - log0("ema:applying EMA weights") - current_state = base_model.state_dict() - avg_state = {name: t.to(dtype=current_state[name].dtype) for name, t in ema_state.items()} - base_model.load_state_dict(avg_state, strict=True) - torch.cuda.synchronize() - t_diag = time.perf_counter() - diag_val_loss, diag_val_bpb = eval_val( - args, compiled_model, rank, world_size, device, grad_accum_steps, - val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, - ) - torch.cuda.synchronize() - log0( - f"DIAGNOSTIC post_ema val_loss:{diag_val_loss:.4f} val_bpb:{diag_val_bpb:.4f} " - f"eval_time:{1000.0 * (time.perf_counter() - t_diag):.0f}ms" - ) - export_sd = base_model.state_dict() - if master_process: - torch.save(export_sd, "final_model.pt") - model_bytes = os.path.getsize("final_model.pt") - code_bytes = len(code.encode("utf-8")) - log0(f"Serialized model: {model_bytes} bytes") - log0(f"Code size: {code_bytes} bytes") - # Unbank 3D tensors into individual 2D tensors for quantization - sd_cpu = {k: v.detach().cpu() for k, v in export_sd.items()} - unbanked_sd = _unbank_state_dict(sd_cpu, args.num_layers) - # GPTQ calibration: collect Hessians from training data - gptq_hessians = None - if args.use_gptq: - t_gptq = time.perf_counter() - log0(f"gptq:calibrating with {args.gptq_calib_samples} batches (training data)...") - calib_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) - gptq_hessians = gptq_collect_hessians( - base_model, calib_loader, device, num_batches=args.gptq_calib_samples, - batch_tokens=args.train_batch_tokens, seq_len=args.train_seq_len, - grad_accum_steps=grad_accum_steps) - del calib_loader - gptq_elapsed = time.perf_counter() - t_gptq - log0(f"gptq:calibrated {len(gptq_hessians)} layers in {gptq_elapsed:.1f}s") - torch.cuda.empty_cache() - quant_result, quant_meta = mixed_quantize_int6(unbanked_sd, {"mlp", "attn"}, clip_range=args.quant_clip_range, hessians=gptq_hessians) - quant_buf = io.BytesIO() - torch.save({"w": quant_result, "m": quant_meta}, quant_buf) - quant_raw = quant_buf.getvalue() - quant_blob = lzma.compress(quant_raw, preset=6) - if master_process: - with open("final_model.int6.ptz", "wb") as f: - f.write(quant_blob) - quant_file_bytes = len(quant_blob) - code_bytes = len(code.encode("utf-8")) - log0(f"Serialized model int6+lzma: {quant_file_bytes} bytes") - log0(f"Total submission size int6+lzma: {quant_file_bytes + code_bytes} bytes") - if distributed: - dist.barrier() - with open("final_model.int6.ptz", "rb") as f: - quant_blob_disk = f.read() - quant_state = torch.load( - io.BytesIO(lzma.decompress(quant_blob_disk)), - map_location="cpu", - ) - deq_unbanked = dequantize_mixed_int6(quant_state["w"], quant_state["m"], unbanked_sd) - # Re-bank the dequantized tensors - deq_state = _rebank_state_dict(deq_unbanked, args.num_layers, sd_cpu) - eval_model = GPT( - vocab_size=args.vocab_size, num_layers=args.num_layers, model_dim=args.model_dim, - num_heads=args.num_heads, num_kv_heads=args.num_kv_heads, mlp_mult=args.mlp_mult, - tie_embeddings=args.tie_embeddings, tied_embed_init_std=args.tied_embed_init_std, - logit_softcap=args.logit_softcap, rope_base=args.rope_base, qk_gain_init=args.qk_gain_init, - bigram_vocab_size=args.bigram_vocab_size, bigram_dim=args.bigram_dim, - xsa_last_n=args.xsa_last_n, - rope_dims=args.rope_dims, ln_scale=args.ln_scale, - ve_enabled=args.ve_enabled, ve_dim=args.ve_dim, ve_layers=args.ve_layers, - neg_slope=args.negative_slope, - ).to(device).bfloat16() - eval_model.qo_bank.data = eval_model.qo_bank.data.float() - eval_model.kv_bank.data = eval_model.kv_bank.data.float() - eval_model.mlp_up_bank.data = eval_model.mlp_up_bank.data.float() - eval_model.mlp_down_bank.data = eval_model.mlp_down_bank.data.float() - for m in eval_model.modules(): - if isinstance(m, CastedLinear): - m.float() - restore_low_dim_params_to_fp32(eval_model) - eval_model.load_state_dict(deq_state, strict=True) - compiled_eval = torch.compile(eval_model, dynamic=False, fullgraph=True) - torch.cuda.synchronize() - t_qeval = time.perf_counter() - q_val_loss, q_val_bpb = eval_val( - args, compiled_eval, rank, world_size, device, grad_accum_steps, - val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, - eval_seq_len=effective_eval_seq_len, - ) - torch.cuda.synchronize() - log0( - f"final_int6_roundtrip val_loss:{q_val_loss:.4f} val_bpb:{q_val_bpb:.4f} " - f"eval_time:{1000.0 * (time.perf_counter() - t_qeval):.0f}ms" - ) - log0(f"final_int6_roundtrip_exact val_loss:{q_val_loss:.8f} val_bpb:{q_val_bpb:.8f}") - sw_seq_len = effective_eval_seq_len - if args.eval_stride > 0 and args.eval_stride < sw_seq_len: - torch.cuda.synchronize() - t_slide = time.perf_counter() - sw_val_loss, sw_val_bpb = eval_val_sliding( - args, eval_model, rank, world_size, device, - val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, - stride=args.eval_stride, - eval_seq_len=sw_seq_len, - ) - torch.cuda.synchronize() - log0( - f"final_int6_sliding_window val_loss:{sw_val_loss:.4f} val_bpb:{sw_val_bpb:.4f} " - f"stride:{args.eval_stride} eval_time:{1000.0 * (time.perf_counter() - t_slide):.0f}ms" - ) - log0(f"final_int6_sliding_window_exact val_loss:{sw_val_loss:.8f} val_bpb:{sw_val_bpb:.8f}") - log0(f"final_int8_zlib_roundtrip_exact val_loss:{sw_val_loss:.8f} val_bpb:{sw_val_bpb:.8f}") - if args.eval_stride != 64 and 64 < sw_seq_len: - torch.cuda.synchronize() - t_slide64 = time.perf_counter() - sw64_val_loss, sw64_val_bpb = eval_val_sliding( - args, eval_model, rank, world_size, device, - val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, - stride=64, - eval_seq_len=sw_seq_len, - ) - torch.cuda.synchronize() - log0( - f"final_int6_sliding_window_s64 val_loss:{sw64_val_loss:.4f} val_bpb:{sw64_val_bpb:.4f} " - f"stride:64 eval_time:{1000.0 * (time.perf_counter() - t_slide64):.0f}ms" - ) - log0(f"final_int6_sliding_window_s64_exact val_loss:{sw64_val_loss:.8f} val_bpb:{sw64_val_bpb:.8f}") - log0(f"final_int8_zlib_roundtrip_exact val_loss:{sw64_val_loss:.8f} val_bpb:{sw64_val_bpb:.8f}") - # Legal score-first TTT (PR #461 recipe) - if args.ttt_enabled: - torch.cuda.synchronize() - t_ttt = time.perf_counter() - ttt_loss, ttt_bpb = eval_val_sliding_ttt( - args, eval_model, rank, world_size, device, - val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, - stride=args.eval_stride, log0=log0, - ) - torch.cuda.synchronize() - log0(f"legal_ttt val_loss:{ttt_loss:.4f} val_bpb:{ttt_bpb:.4f} " - f"eval_time:{1000.0 * (time.perf_counter() - t_ttt):.0f}ms") - log0(f"legal_ttt_exact val_loss:{ttt_loss:.8f} val_bpb:{ttt_bpb:.8f}") - if distributed: - dist.destroy_process_group() -if __name__ == "__main__": - main() - -==================================================================================================== -Running Python 3.12.3 (main, Mar 3 2026, 12:15:18) [GCC 13.3.0] -Running PyTorch 2.9.1+cu128 -Tue Mar 31 15:10:52 2026 -+-----------------------------------------------------------------------------------------+ -| NVIDIA-SMI 580.126.09 Driver Version: 580.126.09 CUDA Version: 13.0 | -+-----------------------------------------+------------------------+----------------------+ -| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC | -| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. | -| | | MIG M. | -|=========================================+========================+======================| -| 0 NVIDIA H100 80GB HBM3 On | 00000000:0A:00.0 Off | 0 | -| N/A 33C P0 118W / 700W | 1521MiB / 81559MiB | 0% Default | -| | | Disabled | -+-----------------------------------------+------------------------+----------------------+ -| 1 NVIDIA H100 80GB HBM3 On | 00000000:18:00.0 Off | 0 | -| N/A 30C P0 123W / 700W | 1521MiB / 81559MiB | 0% Default | -| | | Disabled | -+-----------------------------------------+------------------------+----------------------+ -| 2 NVIDIA H100 80GB HBM3 On | 00000000:3F:00.0 Off | 0 | -| N/A 28C P0 116W / 700W | 1521MiB / 81559MiB | 0% Default | -| | | Disabled | -+-----------------------------------------+------------------------+----------------------+ -| 3 NVIDIA H100 80GB HBM3 On | 00000000:48:00.0 Off | 0 | -| N/A 35C P0 123W / 700W | 1521MiB / 81559MiB | 0% Default | -| | | Disabled | -+-----------------------------------------+------------------------+----------------------+ -| 4 NVIDIA H100 80GB HBM3 On | 00000000:87:00.0 Off | 0 | -| N/A 33C P0 121W / 700W | 1521MiB / 81559MiB | 0% Default | -| | | Disabled | -+-----------------------------------------+------------------------+----------------------+ -| 5 NVIDIA H100 80GB HBM3 On | 00000000:90:00.0 Off | 0 | -| N/A 28C P0 113W / 700W | 1521MiB / 81559MiB | 0% Default | -| | | Disabled | -+-----------------------------------------+------------------------+----------------------+ -| 6 NVIDIA H100 80GB HBM3 On | 00000000:BE:00.0 Off | 0 | -| N/A 28C P0 115W / 700W | 1521MiB / 81559MiB | 0% Default | -| | | Disabled | -+-----------------------------------------+------------------------+----------------------+ -| 7 NVIDIA H100 80GB HBM3 On | 00000000:C7:00.0 Off | 0 | -| N/A 34C P0 121W / 700W | 1521MiB / 81559MiB | 0% Default | -| | | Disabled | -+-----------------------------------------+------------------------+----------------------+ - -+-----------------------------------------------------------------------------------------+ -| Processes: | -| GPU GI CI PID Type Process name GPU Memory | -| ID ID Usage | -|=========================================================================================| -| 0 N/A N/A 62800 C ...ameter-golf/.venv/bin/python3 1512MiB | -| 1 N/A N/A 62801 C ...ameter-golf/.venv/bin/python3 1512MiB | -| 2 N/A N/A 62802 C ...ameter-golf/.venv/bin/python3 1512MiB | -| 3 N/A N/A 62803 C ...ameter-golf/.venv/bin/python3 1512MiB | -| 4 N/A N/A 62804 C ...ameter-golf/.venv/bin/python3 1512MiB | -| 5 N/A N/A 62805 C ...ameter-golf/.venv/bin/python3 1512MiB | -| 6 N/A N/A 62806 C ...ameter-golf/.venv/bin/python3 1512MiB | -| 7 N/A N/A 62807 C ...ameter-golf/.venv/bin/python3 1512MiB | -+-----------------------------------------------------------------------------------------+ - -==================================================================================================== -tokenizer: kind=tokenmonster vocab=998 -train_loader:dataset:fineweb10B_scylla train_shards:194 -val_loader:shards pattern=./data/datasets/fineweb10B_scylla/fineweb_val_*.bin tokens:62363648 -model_params:27022172 -XSA:last_11 active_layers:[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10] -world_size:8 grad_accum_steps:1 -sdp_backends:cudnn=False flash=True mem_efficient=False math=False -attention_mode:gqa num_heads:8 num_kv_heads:4 -tie_embeddings:True embed_lr:0.035 head_lr:0.0 matrix_lr:0.025 scalar_lr:0.025 -train_batch_tokens:786432 train_seq_len:2048 iterations:9000 warmup_steps:20 max_wallclock_seconds:600.000 -seed:42 -gptq:reserving 9000ms from training budget, effective=591000ms -warmup_step:1/20 -warmup_step:2/20 -warmup_step:3/20 -warmup_step:4/20 -warmup_step:5/20 -warmup_step:6/20 -warmup_step:7/20 -warmup_step:8/20 -warmup_step:9/20 -warmup_step:10/20 -warmup_step:11/20 -warmup_step:12/20 -warmup_step:13/20 -warmup_step:14/20 -warmup_step:15/20 -warmup_step:16/20 -warmup_step:17/20 -warmup_step:18/20 -warmup_step:19/20 -warmup_step:20/20 -step:0/9000 val_loss:6.8915 val_bpb:3.3918 train_time:0ms step_avg:0.01ms -step:1/9000 train_loss:6.8908 train_time:122ms step_avg:121.65ms -step:2/9000 train_loss:9.2384 train_time:153ms step_avg:76.72ms -step:3/9000 train_loss:8.9442 train_time:238ms step_avg:79.36ms -step:4/9000 train_loss:8.4911 train_time:324ms step_avg:81.04ms -step:5/9000 train_loss:8.0243 train_time:408ms step_avg:81.70ms -step:6/9000 train_loss:7.4500 train_time:494ms step_avg:82.27ms -step:7/9000 train_loss:6.9807 train_time:579ms step_avg:82.65ms -step:8/9000 train_loss:6.7101 train_time:664ms step_avg:83.05ms -step:9/9000 train_loss:6.5964 train_time:749ms step_avg:83.20ms -step:10/9000 train_loss:6.3878 train_time:834ms step_avg:83.39ms -step:500/9000 train_loss:2.3151 train_time:43554ms step_avg:87.11ms -step:1000/9000 train_loss:2.2086 train_time:87306ms step_avg:87.31ms -step:1500/9000 train_loss:2.1805 train_time:131209ms step_avg:87.47ms -step:2000/9000 train_loss:2.1469 train_time:175159ms step_avg:87.58ms -step:2500/9000 train_loss:2.2068 train_time:219040ms step_avg:87.62ms -step:3000/9000 train_loss:2.1005 train_time:262973ms step_avg:87.66ms -step:3500/9000 train_loss:2.0811 train_time:306855ms step_avg:87.67ms -step:4000/9000 train_loss:2.0310 train_time:350773ms step_avg:87.69ms -step:4000/9000 val_loss:2.0749 val_bpb:1.0212 train_time:350830ms step_avg:87.71ms -step:4500/9000 train_loss:2.0054 train_time:394663ms step_avg:87.70ms -step:5000/9000 train_loss:1.9694 train_time:438544ms step_avg:87.71ms -step:5500/9000 train_loss:1.9324 train_time:482435ms step_avg:87.72ms -step:6000/9000 train_loss:1.9929 train_time:526342ms step_avg:87.72ms -swa:start step:6050 -step:6500/9000 train_loss:1.9742 train_time:570896ms step_avg:87.83ms -step:6725/9000 val_loss:1.9610 val_bpb:0.9651 train_time:591069ms step_avg:87.89ms -stopping_early: wallclock_cap train_time:591069ms step:6725/9000 -peak memory allocated: 23031 MiB reserved: 23324 MiB -ema:applying EMA weights -DIAGNOSTIC post_ema val_loss:1.9595 val_bpb:0.9644 eval_time:2111ms -Serialized model: 106201974 bytes -Code size: 101929 bytes -gptq:calibrating with 64 batches (training data)... -gptq:calibrated 66 layers in 6.7s -Serialized model int6+lzma: 15869164 bytes -Total submission size int6+lzma: 15971093 bytes -final_int6_roundtrip val_loss:1.9655 val_bpb:0.9674 eval_time:5303ms -final_int6_roundtrip_exact val_loss:1.96553681 val_bpb:0.96736085 -final_int6_sliding_window val_loss:1.9255 val_bpb:0.9476 stride:64 eval_time:76779ms -final_int6_sliding_window_exact val_loss:1.92548833 val_bpb:0.94764197 -final_int8_zlib_roundtrip_exact val_loss:1.92548833 val_bpb:0.94764197 From 79d57f8e0a80c5d379aaefa687c007536aa350f8 Mon Sep 17 00:00:00 2001 From: Alex Date: Sun, 26 Apr 2026 00:47:53 -0400 Subject: [PATCH 4/4] Remove non-record Muon TTT submission --- .../README.md | 136 -- .../submission.json | 9 - .../train_gpt.py | 1954 ----------------- .../train_seed1337.log | 275 --- .../train_seed2025.log | 275 --- .../train_seed42.log | 275 --- 6 files changed, 2924 deletions(-) delete mode 100644 records/track_10min_16mb/2026-03-28_MuonTTT_EntropyAdaptive_11L_8xH100/README.md delete mode 100644 records/track_10min_16mb/2026-03-28_MuonTTT_EntropyAdaptive_11L_8xH100/submission.json delete mode 100644 records/track_10min_16mb/2026-03-28_MuonTTT_EntropyAdaptive_11L_8xH100/train_gpt.py delete mode 100644 records/track_10min_16mb/2026-03-28_MuonTTT_EntropyAdaptive_11L_8xH100/train_seed1337.log delete mode 100644 records/track_10min_16mb/2026-03-28_MuonTTT_EntropyAdaptive_11L_8xH100/train_seed2025.log delete mode 100644 records/track_10min_16mb/2026-03-28_MuonTTT_EntropyAdaptive_11L_8xH100/train_seed42.log diff --git a/records/track_10min_16mb/2026-03-28_MuonTTT_EntropyAdaptive_11L_8xH100/README.md b/records/track_10min_16mb/2026-03-28_MuonTTT_EntropyAdaptive_11L_8xH100/README.md deleted file mode 100644 index 54cabc9c07..0000000000 --- a/records/track_10min_16mb/2026-03-28_MuonTTT_EntropyAdaptive_11L_8xH100/README.md +++ /dev/null @@ -1,136 +0,0 @@ -# 11L Muon TTT + Entropy-Adaptive Epochs - -**val_bpb: 1.1179** (3-seed mean, std 0.0002) | **~15.9 MB** | 8xH100 SXM - -## Results (8xH100 80GB SXM, PyTorch 2.9.1+cu128) - -| Seed | step_avg | steps | Pre-TTT bpb | Post-TTT bpb | TTT gain | TTT time | Artifact | -|------|----------|-------|-------------|-------------|----------|----------|----------| -| 1337 | 83.5ms | 7,189 | 1.1214 | **1.11765** | -0.0037 | 477s | 15,944,410 | -| 42 | 83.5ms | 7,177 | 1.1217 | **1.11813** | -0.0035 | 485s | 15,873,826 | -| 2025 | 83.5ms | 7,175 | 1.1217 | **1.11790** | -0.0038 | 479s | 15,879,042 | -| **Mean** | **83.5ms** | **~7,180** | **~1.1216** | **1.1179 (std 0.0002)** | **-0.0037** | **~480s** | | - -## Key Innovation 1: Muon as TTT Optimizer - -Every prior TTT submission uses SGD in the test-time training loop. This submission replaces SGD with Newton-Schulz orthogonalized gradient updates -- the same Muon principle that dominates the training leaderboard, now applied to test-time adaptation. - -### Why this works - -Muon constrains each matrix update to the space of orthogonal transformations, normalizing the gradient direction. For TTT this means: -- No gradient blowup: updates only rotate weight matrices, cannot inflate them -- Better gradient signal: Newton-Schulz whitens the gradient, removing scale correlation between rows that SGD accumulates -- Faster per-epoch convergence: each TTT epoch moves farther in the useful direction - -The result: +0.0037 TTT gain per seed vs SOTA's +0.0025 (SGD, same 3 epochs), with total TTT time remaining under 600s. - -### Implementation - -Replaces optimizer.step() in the TTT loop: - - with torch.no_grad(): - for p in ttt_params: - if p.grad is None: - continue - g = p.grad.detach().float() - if g.ndim >= 2: - g = zeropower_via_newtonschulz5(g, steps=3) - p.data.add_(g.to(p.dtype), alpha=-cos_lr) - -zeropower_via_newtonschulz5 is already present in every train_gpt.py. 3 NS steps balance orthogonalization quality vs eval wall-clock (5 steps exceeded the 600s eval budget; 3 steps complete in ~480s). - -## Key Innovation 2: Entropy-Adaptive TTT Epochs - -All prior TTT submissions use a fixed epoch count per chunk. This submission dynamically assigns 2, 3, or 4 TTT epochs per chunk based on the model's measured uncertainty on that content. - -After SCORE phase, the per-chunk NLL is globally synchronized across all DDP ranks (critical: per-rank NLL gives different epoch counts per rank -> different number of dist.all_reduce calls per chunk -> NCCL collective mismatch -> watchdog timeout at 600s). The global NLL gates epoch selection: - - cls_t = torch.tensor(chunk_loss_sum, device=device, dtype=torch.float64) - ctc_t = torch.tensor(chunk_token_count, device=device, dtype=torch.float64) - if world_size > 1: - dist.all_reduce(cls_t, op=dist.ReduceOp.SUM) - dist.all_reduce(ctc_t, op=dist.ReduceOp.SUM) - chunk_nll = (cls_t / ctc_t).item() - - if chunk_nll > 2.1: # hard content (code, math): 4 epochs - effective_epochs = 4 - elif chunk_nll < 1.75: # easy content (prose): 2 epochs - effective_epochs = 2 - else: - effective_epochs = 3 - -This concentrates adaptation budget where it helps most. Average epochs ~3.0; total TTT time unchanged vs fixed-3-epoch baseline. - -## Legal TTT Protocol - -Score-first TTT following PR #461 framework: - -1. Val tokens split into 1,893 non-overlapping 32K-token chunks -2. For each chunk: - - SCORE: Sliding window eval under torch.inference_mode() -- no gradients, no weight mutation - - TRAIN: Muon-style update on already-scored chunk. Entropy-adaptive 2/3/4 epochs, cosine LR decay, grad clip 1.0 -3. Last chunk scored but never trained on -4. Chunk N scored by model adapted only on chunks 0..N-1 - -### TTT Hyperparameters - -| Parameter | Value | -|-----------|-------| -| Chunk size | 32,768 tokens | -| Optimizer | Muon-style (Newton-Schulz NS=3 + LR step) | -| Learning rate | 0.002 (cosine decay across chunks) | -| Epochs per chunk | 2/3/4 entropy-adaptive (H_HIGH=2.1, H_LOW=1.75 nats) | -| Frozen blocks | None (all blocks adapt) | -| Gradient clip | 1.0 | - -## Training Architecture - -Full SOTA stack from PR #399 and PR #414: - -| Component | Setting | -|-----------|---------| -| Layers | 11 (512d, 8H, 4KV) | -| MLP | 3x with LeakyReLU(0.5)^2 | -| BigramHash | 1536 | -| XSA | Last 4 layers | -| RoPE | Partial (16/64 dims) | -| LN Scale | 1/sqrt(layer+1) | -| VE128 | Layers 9-10 | -| Weight avg | EMA(0.997) + Tight SWA(every 50) | -| Quantization | GPTQ-lite int6 + lzma (preset=7) | -| Optimizer | Parameter Banking + Parallel Muon | - -## Run Command - - NUM_LAYERS=11 BIGRAM_VOCAB_SIZE=1536 XSA_LAST_N=4 - EMA_ENABLED=1 EMA_DECAY=0.997 SWA_ENABLED=1 SWA_EVERY=50 - ROPE_DIMS=16 LN_SCALE=1 LATE_QAT=1 LATE_QAT_THRESHOLD=0.15 - VE_ENABLED=1 VE_DIM=128 VE_LAYERS=9,10 - TTT_ENABLED=1 TTT_LR=0.002 TTT_EPOCHS=3 TTT_CHUNK_TOKENS=32768 - TTT_FREEZE_BLOCKS=0 TTT_MOMENTUM=0.9 TTT_BATCH_SEQS=32 TTT_GRAD_CLIP=1.0 - TTT_MUON=1 TTT_NS_STEPS=3 TTT_ENTROPY_ADAPT=1 - TTT_ENTROPY_HIGH=2.1 TTT_ENTROPY_LOW=1.75 - MUON_WD=0.04 ADAM_WD=0.04 - MATRIX_LR=0.025 SCALAR_LR=0.025 TIED_EMBED_LR=0.035 - MUON_MOMENTUM=0.99 MUON_MOMENTUM_WARMUP_START=0.92 - MUON_MOMENTUM_WARMUP_STEPS=1500 WARMDOWN_ITERS=3500 - ITERATIONS=9000 MAX_WALLCLOCK_SECONDS=599 EVAL_STRIDE=64 - SEED=1337 - torchrun --standalone --nproc_per_node=8 train_gpt.py - -## Timing Budget - -| Phase | Time | -|-------|------| -| Training | ~600s (<=10 min) | -| Standard eval (int6 roundtrip + sliding window) | ~82s | -| Legal TTT (score-first + adaptation) | ~480s | -| Total eval | ~562s (< 10 min) | - -## Credits - -- LeakyReLU^2 activation: PR #493 by @parinzee, PR #518 by @sofiabod -- Optimizer (Parameter Banking + Parallel Muon): PR #399 by @abaybektursun -- TTT recipe (score-first framework): PR #461 by @Christopher-Lee-McClendon -- Base architecture: PR #414 by @signalrush -- SOTA base adapted from: @abaybektursun (val_bpb 1.1194) diff --git a/records/track_10min_16mb/2026-03-28_MuonTTT_EntropyAdaptive_11L_8xH100/submission.json b/records/track_10min_16mb/2026-03-28_MuonTTT_EntropyAdaptive_11L_8xH100/submission.json deleted file mode 100644 index 31c236ed48..0000000000 --- a/records/track_10min_16mb/2026-03-28_MuonTTT_EntropyAdaptive_11L_8xH100/submission.json +++ /dev/null @@ -1,9 +0,0 @@ -{ - "name": "Record: 11L Muon TTT + Entropy-Adaptive Epochs", - "val_bpb": 1.1179, - "bytes_total": 15944410, - "blurb": "Two novel TTT innovations on SOTA base (PR #399 + PR #414 + PR #461 stack): (1) Muon-style Newton-Schulz orthogonalized gradient updates replace SGD in the TTT loop (NS=3 steps per matrix parameter per epoch) -- more stable and effective test-time adaptation; (2) Entropy-adaptive epoch selection -- per-chunk NLL (globally synced across DDP ranks) gates between 2/3/4 TTT epochs, concentrating adaptation budget on harder content. 3-seed mean: 1.1179 (std 0.0002). All artifacts under 16MB, all eval under 10 min.", - "author": "Aamod Bhatt", - "github_id": "aamodbhatt", - "date": "2026-03-28" -} \ No newline at end of file diff --git a/records/track_10min_16mb/2026-03-28_MuonTTT_EntropyAdaptive_11L_8xH100/train_gpt.py b/records/track_10min_16mb/2026-03-28_MuonTTT_EntropyAdaptive_11L_8xH100/train_gpt.py deleted file mode 100644 index 4055f904f3..0000000000 --- a/records/track_10min_16mb/2026-03-28_MuonTTT_EntropyAdaptive_11L_8xH100/train_gpt.py +++ /dev/null @@ -1,1954 +0,0 @@ -from __future__ import annotations -import copy -import glob -import io -import lzma -import math -import os -import random -import subprocess -import sys -import time -import uuid -import zlib -from pathlib import Path -try: - import zstandard - _COMPRESSOR = "zstd" -except ImportError: - _COMPRESSOR = "zlib" -import numpy as np -import sentencepiece as spm -import torch -import torch.distributed as dist -import torch.nn.functional as F -from torch import Tensor, nn -from torch.nn.parallel import DistributedDataParallel as DDP -from flash_attn_interface import flash_attn_func as flash_attn_3_func -class Hyperparameters: - data_path = os.environ.get("DATA_PATH", "./data/datasets/fineweb10B_sp1024") - train_files = os.path.join(data_path, "fineweb_train_*.bin") - val_files = os.path.join(data_path, "fineweb_val_*.bin") - tokenizer_path = os.environ.get("TOKENIZER_PATH", "./data/tokenizers/fineweb_1024_bpe.model") - run_id = os.environ.get("RUN_ID", str(uuid.uuid4())) - seed = int(os.environ.get("SEED", 1337)) - val_batch_size = int(os.environ.get("VAL_BATCH_SIZE", 524_288)) - val_loss_every = int(os.environ.get("VAL_LOSS_EVERY", 4000)) - train_log_every = int(os.environ.get("TRAIN_LOG_EVERY", 500)) - iterations = int(os.environ.get("ITERATIONS", 20000)) - warmdown_iters = int(os.environ.get("WARMDOWN_ITERS", 3500)) - warmup_steps = int(os.environ.get("WARMUP_STEPS", 20)) - train_batch_tokens = int(os.environ.get("TRAIN_BATCH_TOKENS", 786_432)) - train_seq_len = int(os.environ.get("TRAIN_SEQ_LEN", 2048)) - eval_seq_len = int(os.environ.get("EVAL_SEQ_LEN", 2048)) - max_wallclock_seconds = float(os.environ.get("MAX_WALLCLOCK_SECONDS", 600.0)) - qk_gain_init = float(os.environ.get("QK_GAIN_INIT", 1.5)) - vocab_size = int(os.environ.get("VOCAB_SIZE", 1024)) - num_layers = int(os.environ.get("NUM_LAYERS", 11)) - num_kv_heads = int(os.environ.get("NUM_KV_HEADS", 4)) - model_dim = int(os.environ.get("MODEL_DIM", 512)) - num_heads = int(os.environ.get("NUM_HEADS", 8)) - mlp_mult = float(os.environ.get("MLP_MULT", 3.0)) - tie_embeddings = bool(int(os.environ.get("TIE_EMBEDDINGS", "1"))) - rope_base = float(os.environ.get("ROPE_BASE", 10000.0)) - logit_softcap = float(os.environ.get("LOGIT_SOFTCAP", 30.0)) - embed_lr = float(os.environ.get("EMBED_LR", 0.6)) - head_lr = float(os.environ.get("HEAD_LR", 0.008)) - tied_embed_lr = float(os.environ.get("TIED_EMBED_LR", 0.035)) - tied_embed_init_std = float(os.environ.get("TIED_EMBED_INIT_STD", 0.005)) - matrix_lr = float(os.environ.get("MATRIX_LR", 0.025)) - scalar_lr = float(os.environ.get("SCALAR_LR", 0.025)) - muon_momentum = float(os.environ.get("MUON_MOMENTUM", 0.99)) - muon_backend_steps = int(os.environ.get("MUON_BACKEND_STEPS", 5)) - muon_momentum_warmup_start = float(os.environ.get("MUON_MOMENTUM_WARMUP_START", 0.92)) - muon_momentum_warmup_steps = int(os.environ.get("MUON_MOMENTUM_WARMUP_STEPS", 1500)) - beta1 = float(os.environ.get("BETA1", 0.9)) - beta2 = float(os.environ.get("BETA2", 0.95)) - adam_eps = float(os.environ.get("ADAM_EPS", 1e-8)) - grad_clip_norm = float(os.environ.get("GRAD_CLIP_NORM", 0.3)) - eval_stride = int(os.environ.get("EVAL_STRIDE", 64)) - mtp_num_heads = int(os.environ.get("MTP_NUM_HEADS", 0)) - mtp_loss_weight = float(os.environ.get("MTP_LOSS_WEIGHT", 0.2)) - muon_beta2 = float(os.environ.get("MUON_BETA2", 0.95)) - swa_enabled = bool(int(os.environ.get("SWA_ENABLED", "1"))) - swa_every = int(os.environ.get("SWA_EVERY", 50)) - lawa_enabled = bool(int(os.environ.get("LAWA_ENABLED", "0"))) - lawa_k = int(os.environ.get("LAWA_K", 10)) - lawa_freq = int(os.environ.get("LAWA_FREQ", 100)) - muon_wd = float(os.environ.get("MUON_WD", 0.04)) - adam_wd = float(os.environ.get("ADAM_WD", 0.04)) - qat_enabled = bool(int(os.environ.get("QAT_ENABLED", "0"))) - bigram_vocab_size = int(os.environ.get("BIGRAM_VOCAB_SIZE", 2048)) - bigram_dim = int(os.environ.get("BIGRAM_DIM", 128)) - xsa_last_n = int(os.environ.get("XSA_LAST_N", 4)) - rope_dims = int(os.environ.get("ROPE_DIMS", 16)) - ln_scale = bool(int(os.environ.get("LN_SCALE", "1"))) - dtg_enabled = bool(int(os.environ.get("DTG_ENABLED", "0"))) - late_qat_threshold = float(os.environ.get("LATE_QAT_THRESHOLD", 0.15)) - ve_enabled = bool(int(os.environ.get("VE_ENABLED", "1"))) - ve_dim = int(os.environ.get("VE_DIM", 128)) - ve_layers = os.environ.get("VE_LAYERS", "9,10") - gated_attention = bool(int(os.environ.get("GATED_ATTENTION", "0"))) - value_residual = bool(int(os.environ.get("VALUE_RESIDUAL", "0"))) - ttt_enabled = bool(int(os.environ.get("TTT_ENABLED", "0"))) - ttt_lr = float(os.environ.get("TTT_LR", 0.002)) - ttt_epochs = int(os.environ.get("TTT_EPOCHS", 3)) - ttt_chunk_tokens = int(os.environ.get("TTT_CHUNK_TOKENS", 32768)) - ttt_freeze_blocks = int(os.environ.get("TTT_FREEZE_BLOCKS", 2)) - ttt_momentum = float(os.environ.get("TTT_MOMENTUM", 0.9)) - ttt_batch_seqs = int(os.environ.get("TTT_BATCH_SEQS", 32)) - ttt_grad_clip = float(os.environ.get("TTT_GRAD_CLIP", 1.0)) - ttt_muon = bool(int(os.environ.get("TTT_MUON", "1"))) # use Muon-style NS update for TTT - ttt_ns_steps = int(os.environ.get("TTT_NS_STEPS", "3")) # Newton-Schulz steps for TTT Muon - ttt_entropy_adapt = bool(int(os.environ.get("TTT_ENTROPY_ADAPT", "1"))) # entropy-adaptive epochs - ttt_entropy_high = float(os.environ.get("TTT_ENTROPY_HIGH", "2.1")) # nats – hard chunk threshold - ttt_entropy_low = float(os.environ.get("TTT_ENTROPY_LOW", "1.75")) # nats – easy chunk threshold - -# --- Batched Newton-Schulz orthogonalization --- - -def zeropower_via_newtonschulz5(G: Tensor, steps: int = 5, eps: float = 1e-7) -> Tensor: - """Batched Newton-Schulz orthogonalization. G: (B,M,N) or (M,N).""" - a, b, c = (3.4445, -4.7750, 2.0315) - was_2d = G.ndim == 2 - if was_2d: - G = G.unsqueeze(0) - X = G.bfloat16() - transposed = X.size(-2) > X.size(-1) - if transposed: - X = X.mT - X = X / (X.norm(dim=(-2, -1), keepdim=True) + eps) - for _ in range(steps): - A = X @ X.mT - B = b * A + c * (A @ A) - X = a * X + B @ X - if transposed: - X = X.mT - if was_2d: - X = X.squeeze(0) - return X - -# --- Parallel Muon optimizer --- - -class Muon(torch.optim.Optimizer): - """Parallel Muon: post-backward reduce-scatter -> local NS5 -> all-gather. - - No DDP for bank params. After backward, this optimizer: - 1. Launches async reduce-scatter for all banks (biggest first) - 2. Returns control so Adam can step on small params while RS is in-flight - 3. Waits for each RS, runs local NS5 on the shard, launches async all-gather - 4. Each all-gather overlaps with next bank's NS5 - """ - def __init__(self, params, lr: float, momentum: float, backend_steps: int, - nesterov: bool = True, weight_decay: float = 0.0): - super().__init__( - params, - dict(lr=lr, momentum=momentum, backend_steps=backend_steps, - nesterov=nesterov, weight_decay=weight_decay), - ) - self._built = False - - def _build(self): - self._distributed = dist.is_available() and dist.is_initialized() - self._world_size = dist.get_world_size() if self._distributed else 1 - self._rank = dist.get_rank() if self._distributed else 0 - ws = self._world_size - - self._bank_meta = [] - for group in self.param_groups: - for p in group["params"]: - B = p.shape[0] - padded_B = ((B + ws - 1) // ws) * ws - shard_B = padded_B // ws - tail = p.shape[1:] - dev = p.device - self._bank_meta.append({ - 'p': p, - 'B': B, - 'padded_grad': torch.zeros(padded_B, *tail, device=dev, dtype=torch.bfloat16), - 'shard': torch.zeros(shard_B, *tail, device=dev, dtype=torch.bfloat16), - 'shard_mom': torch.zeros(shard_B, *tail, device=dev, dtype=torch.bfloat16), - 'full_update': torch.zeros(padded_B, *tail, device=dev, dtype=torch.bfloat16), - 'scale': max(1, p.shape[-2] / p.shape[-1]) ** 0.5, - }) - # Sort by size descending -- launch biggest reduce-scatters first - self._bank_meta.sort(key=lambda m: -m['p'].numel()) - self._built = True - - def launch_reduce_scatters(self): - """Phase 1: launch async reduce-scatter for all banks. Call right after backward.""" - if not self._built: - self._build() - if not self._distributed: - return - self._rs_futures = [] - for m in self._bank_meta: - p = m['p'] - if p.grad is None: - self._rs_futures.append(None) - continue - pg = m['padded_grad'] - pg[:m['B']].copy_(p.grad.bfloat16()) - if pg.shape[0] > m['B']: - pg[m['B']:].zero_() - fut = dist.reduce_scatter_tensor(m['shard'], pg, op=dist.ReduceOp.AVG, async_op=True) - self._rs_futures.append(fut) - - @torch.no_grad() - def step(self, closure=None): - """Phase 3: wait for RS, local NS5, all-gather. Call AFTER Adam steps.""" - loss = None - if closure is not None: - with torch.enable_grad(): - loss = closure() - - if not self._built: - self._build() - - for group in self.param_groups: - lr = group["lr"] - momentum = group["momentum"] - backend_steps = group["backend_steps"] - nesterov = group["nesterov"] - wd = group.get("weight_decay", 0.0) - - prev_ag_handle = None - prev_m = None - - sharded = self._distributed and hasattr(self, '_rs_futures') - - for i, m in enumerate(self._bank_meta): - p = m['p'] - if p.grad is None: - continue - - if prev_ag_handle is not None: - prev_ag_handle.wait() - pp = prev_m['p'] - upd = prev_m['full_update'][:prev_m['B']] - if wd > 0.0: - pp.data.mul_(1.0 - lr * wd) - pp.add_(upd.to(dtype=pp.dtype), alpha=-lr * prev_m['scale']) - - if sharded and self._rs_futures[i] is not None: - self._rs_futures[i].wait() - g = m['shard'] - buf = m['shard_mom'] - else: - g = p.grad.bfloat16() - state = self.state[p] - if "momentum_buffer" not in state: - state["momentum_buffer"] = torch.zeros_like(g) - buf = state["momentum_buffer"] - - buf.mul_(momentum).add_(g) - if nesterov: - update = g.add(buf, alpha=momentum) - else: - update = buf - - update = zeropower_via_newtonschulz5(update, steps=backend_steps) - - if sharded: - prev_ag_handle = dist.all_gather_into_tensor( - m['full_update'], update, async_op=True) - prev_m = m - else: - if wd > 0.0: - p.data.mul_(1.0 - lr * wd) - p.add_(update.to(dtype=p.dtype), alpha=-lr * m['scale']) - - if prev_ag_handle is not None: - prev_ag_handle.wait() - pp = prev_m['p'] - upd = prev_m['full_update'][:prev_m['B']] - if wd > 0.0: - pp.data.mul_(1.0 - lr * wd) - pp.add_(upd.to(dtype=pp.dtype), alpha=-lr * prev_m['scale']) - - if hasattr(self, '_rs_futures'): - del self._rs_futures - - return loss - -# --- Tokenizer evaluation helpers --- - -def build_sentencepiece_luts( - sp: spm.SentencePieceProcessor, vocab_size: int, device: torch.device -) -> tuple[Tensor, Tensor, Tensor]: - sp_vocab_size = int(sp.vocab_size()) - table_size = max(sp_vocab_size, vocab_size) - base_bytes_np = np.zeros((table_size,), dtype=np.int16) - has_leading_space_np = np.zeros((table_size,), dtype=np.bool_) - is_boundary_token_np = np.ones((table_size,), dtype=np.bool_) - for token_id in range(sp_vocab_size): - if sp.is_control(token_id) or sp.is_unknown(token_id) or sp.is_unused(token_id): - continue - is_boundary_token_np[token_id] = False - if sp.is_byte(token_id): - base_bytes_np[token_id] = 1 - continue - piece = sp.id_to_piece(token_id) - if piece.startswith("\u2581"): - has_leading_space_np[token_id] = True - piece = piece[1:] - base_bytes_np[token_id] = len(piece.encode("utf-8")) - return ( - torch.tensor(base_bytes_np, dtype=torch.int16, device=device), - torch.tensor(has_leading_space_np, dtype=torch.bool, device=device), - torch.tensor(is_boundary_token_np, dtype=torch.bool, device=device), - ) -def load_validation_tokens(pattern: str, seq_len: int) -> Tensor: - files = [Path(p) for p in sorted(glob.glob(pattern))] - if not files: - raise FileNotFoundError(f"No files found for pattern: {pattern}") - tokens = torch.cat([load_data_shard(file) for file in files]).contiguous() - usable = ((tokens.numel() - 1) // seq_len) * seq_len - if usable <= 0: - raise ValueError(f"Validation split is too short for TRAIN_SEQ_LEN={seq_len}") - return tokens[: usable + 1] -def eval_val( - args: Hyperparameters, - model: nn.Module, - rank: int, - world_size: int, - device: torch.device, - grad_accum_steps: int, - val_tokens: Tensor, - base_bytes_lut: Tensor, - has_leading_space_lut: Tensor, - is_boundary_token_lut: Tensor, - eval_seq_len: int | None = None, -) -> tuple[float, float]: - seq_len = eval_seq_len or args.train_seq_len - local_batch_tokens = args.val_batch_size // (world_size * grad_accum_steps) - if local_batch_tokens < seq_len: - raise ValueError( - "VAL_BATCH_SIZE must provide at least one sequence per rank; " - f"got VAL_BATCH_SIZE={args.val_batch_size}, WORLD_SIZE={world_size}, " - f"GRAD_ACCUM_STEPS={grad_accum_steps}, seq_len={seq_len}" - ) - local_batch_seqs = local_batch_tokens // seq_len - total_seqs = (val_tokens.numel() - 1) // seq_len - seq_start = (total_seqs * rank) // world_size - seq_end = (total_seqs * (rank + 1)) // world_size - val_loss_sum = torch.zeros((), device=device, dtype=torch.float64) - val_token_count = torch.zeros((), device=device, dtype=torch.float64) - val_byte_count = torch.zeros((), device=device, dtype=torch.float64) - model.eval() - with torch.inference_mode(): - for batch_seq_start in range(seq_start, seq_end, local_batch_seqs): - batch_seq_end = min(batch_seq_start + local_batch_seqs, seq_end) - raw_start = batch_seq_start * seq_len - raw_end = batch_seq_end * seq_len + 1 - local = val_tokens[raw_start:raw_end].to(device=device, dtype=torch.int64, non_blocking=True) - x = local[:-1].reshape(-1, seq_len) - y = local[1:].reshape(-1, seq_len) - with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): - batch_loss = model(x, y).detach() - batch_token_count = float(y.numel()) - val_loss_sum += batch_loss.to(torch.float64) * batch_token_count - val_token_count += batch_token_count - prev_ids = x.reshape(-1) - tgt_ids = y.reshape(-1) - token_bytes = base_bytes_lut[tgt_ids].to(dtype=torch.int16) - token_bytes += (has_leading_space_lut[tgt_ids] & ~is_boundary_token_lut[prev_ids]).to(dtype=torch.int16) - val_byte_count += token_bytes.to(torch.float64).sum() - if dist.is_available() and dist.is_initialized(): - dist.all_reduce(val_loss_sum, op=dist.ReduceOp.SUM) - dist.all_reduce(val_token_count, op=dist.ReduceOp.SUM) - dist.all_reduce(val_byte_count, op=dist.ReduceOp.SUM) - val_loss = val_loss_sum / val_token_count - bits_per_token = val_loss.item() / math.log(2.0) - tokens_per_byte = val_token_count.item() / val_byte_count.item() - model.train() - return float(val_loss.item()), float(bits_per_token * tokens_per_byte) - -# --- Quantization helpers --- - -CONTROL_TENSOR_NAME_PATTERNS = tuple( - pattern - for pattern in os.environ.get( - "CONTROL_TENSOR_NAME_PATTERNS", - "attn_scale,attn_scales,mlp_scale,mlp_scales,resid_mix,resid_mixes,q_gain,skip_weight,skip_weights,smear,dtg_gate,ve_layer_scales,ve_shared.scale,attn_gate,vr_lambda", - ).split(",") - if pattern -) -INT8_KEEP_FLOAT_FP32_NAME_PATTERNS = tuple( - pattern - for pattern in os.environ.get( - "INT8_KEEP_FLOAT_FP32_NAME_PATTERNS", - ",".join(CONTROL_TENSOR_NAME_PATTERNS), - ).split(",") - if pattern -) -INT8_KEEP_FLOAT_MAX_NUMEL = 65_536 -INT8_KEEP_FLOAT_STORE_DTYPE = torch.float16 -INT8_PER_ROW_SCALE_DTYPE = torch.float16 -INT8_CLIP_PERCENTILE = 99.99984 -INT8_CLIP_Q = INT8_CLIP_PERCENTILE / 100.0 -def tensor_nbytes(t: Tensor) -> int: - return int(t.numel()) * int(t.element_size()) -def keep_float_tensor(name: str, t: Tensor, passthrough_orig_dtypes: dict[str, str]) -> Tensor: - if any(pattern in name for pattern in INT8_KEEP_FLOAT_FP32_NAME_PATTERNS): - return t.float().contiguous() - if t.dtype in {torch.float32, torch.bfloat16}: - passthrough_orig_dtypes[name] = str(t.dtype).removeprefix("torch.") - return t.to(dtype=INT8_KEEP_FLOAT_STORE_DTYPE).contiguous() - return t -def quantize_float_tensor(t: Tensor) -> tuple[Tensor, Tensor]: - t32 = t.float() - if t32.ndim == 2: - clip_abs = ( - torch.quantile(t32.abs(), INT8_CLIP_Q, dim=1) - if t32.numel() - else torch.empty((t32.shape[0],), dtype=torch.float32) - ) - clipped = torch.maximum(torch.minimum(t32, clip_abs[:, None]), -clip_abs[:, None]) - scale = (clip_abs / 127.0).clamp_min(1.0 / 127.0) - q = torch.clamp(torch.round(clipped / scale[:, None]), -127, 127).to(torch.int8).contiguous() - return q, scale.to(dtype=INT8_PER_ROW_SCALE_DTYPE).contiguous() - clip_abs = float(torch.quantile(t32.abs().flatten(), INT8_CLIP_Q).item()) if t32.numel() else 0.0 - scale = torch.tensor(clip_abs / 127.0 if clip_abs > 0 else 1.0, dtype=torch.float32) - q = torch.clamp(torch.round(torch.clamp(t32, -clip_abs, clip_abs) / scale), -127, 127).to(torch.int8).contiguous() - return q, scale -def quantize_state_dict_int8(state_dict: dict[str, Tensor]): - quantized: dict[str, Tensor] = {} - scales: dict[str, Tensor] = {} - dtypes: dict[str, str] = {} - passthrough: dict[str, Tensor] = {} - passthrough_orig_dtypes: dict[str, str] = {} - qmeta: dict[str, dict[str, object]] = {} - stats = dict.fromkeys( - ("param_count", "num_tensors", "num_float_tensors", "num_nonfloat_tensors", "baseline_tensor_bytes", "int8_payload_bytes"), - 0, - ) - for name, tensor in state_dict.items(): - t = tensor.detach().to("cpu").contiguous() - stats["param_count"] += int(t.numel()) - stats["num_tensors"] += 1 - stats["baseline_tensor_bytes"] += tensor_nbytes(t) - if not t.is_floating_point(): - stats["num_nonfloat_tensors"] += 1 - passthrough[name] = t - stats["int8_payload_bytes"] += tensor_nbytes(t) - continue - if t.numel() <= INT8_KEEP_FLOAT_MAX_NUMEL: - kept = keep_float_tensor(name, t, passthrough_orig_dtypes) - passthrough[name] = kept - stats["int8_payload_bytes"] += tensor_nbytes(kept) - continue - stats["num_float_tensors"] += 1 - q, s = quantize_float_tensor(t) - if s.ndim > 0: - qmeta[name] = {"scheme": "per_row", "axis": 0} - quantized[name] = q - scales[name] = s - dtypes[name] = str(t.dtype).removeprefix("torch.") - stats["int8_payload_bytes"] += tensor_nbytes(q) + tensor_nbytes(s) - obj: dict[str, object] = { - "__quant_format__": "int8_clean_per_row_v1", - "quantized": quantized, - "scales": scales, - "dtypes": dtypes, - "passthrough": passthrough, - } - if qmeta: - obj["qmeta"] = qmeta - if passthrough_orig_dtypes: - obj["passthrough_orig_dtypes"] = passthrough_orig_dtypes - return obj, stats -def dequantize_state_dict_int8(obj: dict[str, object]) -> dict[str, Tensor]: - out: dict[str, Tensor] = {} - qmeta = obj.get("qmeta", {}) - passthrough_orig_dtypes = obj.get("passthrough_orig_dtypes", {}) - for name, q in obj["quantized"].items(): - dtype = getattr(torch, obj["dtypes"][name]) - s = obj["scales"][name] - if qmeta.get(name, {}).get("scheme") == "per_row" or s.ndim > 0: - s = s.to(dtype=torch.float32) - out[name] = (q.float() * s.view(q.shape[0], *([1] * (q.ndim - 1)))).to(dtype=dtype).contiguous() - else: - scale = float(s.item()) - out[name] = (q.float() * scale).to(dtype=dtype).contiguous() - for name, t in obj["passthrough"].items(): - out_t = t.detach().to("cpu").contiguous() - orig_dtype = passthrough_orig_dtypes.get(name) - if isinstance(orig_dtype, str): - out_t = out_t.to(dtype=getattr(torch, orig_dtype)).contiguous() - out[name] = out_t - return out - -# --- Data loading --- - -def load_data_shard(file: Path) -> Tensor: - header_bytes = 256 * np.dtype(" None: - self.file_idx = (self.file_idx + 1) % len(self.files) - self.tokens = load_data_shard(self.files[self.file_idx]) - self.pos = 0 - def take(self, n: int) -> Tensor: - chunks: list[Tensor] = [] - remaining = n - while remaining > 0: - avail = self.tokens.numel() - self.pos - if avail <= 0: - self._advance_file() - continue - k = min(remaining, avail) - chunks.append(self.tokens[self.pos : self.pos + k]) - self.pos += k - remaining -= k - return chunks[0] if len(chunks) == 1 else torch.cat(chunks) -class DistributedTokenLoader: - def __init__(self, pattern: str, rank: int, world_size: int, device: torch.device): - self.rank = rank - self.world_size = world_size - self.device = device - self.stream = TokenStream(pattern) - def next_batch(self, global_tokens: int, seq_len: int, grad_accum_steps: int) -> tuple[Tensor, Tensor]: - local_tokens = global_tokens // (self.world_size * grad_accum_steps) - per_rank_span = local_tokens + 1 - chunk = self.stream.take(per_rank_span * self.world_size) - start = self.rank * per_rank_span - local = chunk[start : start + per_rank_span].to(dtype=torch.int64) - x = local[:-1].reshape(-1, seq_len) - y = local[1:].reshape(-1, seq_len) - return x.to(self.device, non_blocking=True), y.to(self.device, non_blocking=True) - -# --- Transformer modules --- - -class RMSNorm(nn.Module): - def __init__(self, eps: float | None = None): - super().__init__() - self.eps = eps - def forward(self, x: Tensor) -> Tensor: - return F.rms_norm(x, (x.size(-1),), eps=self.eps) -class CastedLinear(nn.Linear): - _qat_enabled: bool = False - def forward(self, x: Tensor) -> Tensor: - w = self.weight.to(x.dtype) - if CastedLinear._qat_enabled and self.training and w.ndim == 2: - with torch.no_grad(): - w32 = self.weight.float() - row_max = w32.abs().amax(dim=1) - scale = (row_max / 31.0).clamp_min(1.0 / 31.0) - w_q = (torch.clamp(torch.round(w32 / scale[:, None]), -32, 31) * scale[:, None]).to(x.dtype) - w = w + (w_q - w).detach() - bias = self.bias.to(x.dtype) if self.bias is not None else None - return F.linear(x, w, bias) -def restore_low_dim_params_to_fp32(module: nn.Module) -> None: - with torch.no_grad(): - for name, param in module.named_parameters(): - if (param.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS)) and param.dtype != torch.float32: - param.data = param.data.float() -class Rotary(nn.Module): - def __init__(self, dim: int, base: float = 10000.0, train_seq_len: int = 1024, rope_dims: int = 0): - super().__init__() - self.dim = dim - self.base = base - self.train_seq_len = train_seq_len - self.rope_dims = rope_dims if rope_dims > 0 else dim - inv_freq = 1.0 / (base ** (torch.arange(0, self.rope_dims, 2, dtype=torch.float32) / self.rope_dims)) - self.register_buffer("inv_freq", inv_freq, persistent=False) - self._seq_len_cached = 0 - self._cos_cached: Tensor | None = None - self._sin_cached: Tensor | None = None - def forward(self, seq_len: int, device: torch.device, dtype: torch.dtype) -> tuple[Tensor, Tensor]: - if ( - self._cos_cached is None - or self._sin_cached is None - or self._seq_len_cached != seq_len - or self._cos_cached.device != device - ): - rd = self.rope_dims - if seq_len > self.train_seq_len: - scale = seq_len / self.train_seq_len - new_base = self.base * (scale ** (rd / (rd - 2))) - inv_freq = 1.0 / (new_base ** (torch.arange(0, rd, 2, dtype=torch.float32, device=device) / rd)) - else: - inv_freq = self.inv_freq.to(device) - t = torch.arange(seq_len, device=device, dtype=inv_freq.dtype) - freqs = torch.outer(t, inv_freq) - self._cos_cached = freqs.cos()[None, :, None, :] - self._sin_cached = freqs.sin()[None, :, None, :] - self._seq_len_cached = seq_len - return self._cos_cached.to(dtype=dtype), self._sin_cached.to(dtype=dtype) -def apply_rotary_emb(x: Tensor, cos: Tensor, sin: Tensor, rope_dims: int = 0) -> Tensor: - if rope_dims > 0 and rope_dims < x.size(-1): - x_rope, x_pass = x[..., :rope_dims], x[..., rope_dims:] - half = rope_dims // 2 - x1, x2 = x_rope[..., :half], x_rope[..., half:] - x_rope = torch.cat((x1 * cos + x2 * sin, x1 * (-sin) + x2 * cos), dim=-1) - return torch.cat((x_rope, x_pass), dim=-1) - half = x.size(-1) // 2 - x1, x2 = x[..., :half], x[..., half:] - return torch.cat((x1 * cos + x2 * sin, x1 * (-sin) + x2 * cos), dim=-1) - -class CausalSelfAttention(nn.Module): - def __init__( - self, - dim: int, - num_heads: int, - num_kv_heads: int, - rope_base: float, - qk_gain_init: float, - gated_attention: bool = False, - value_residual: bool = False, - ): - super().__init__() - if dim % num_heads != 0: - raise ValueError("model_dim must be divisible by num_heads") - if num_heads % num_kv_heads != 0: - raise ValueError("num_heads must be divisible by num_kv_heads") - self.num_heads = num_heads - self.num_kv_heads = num_kv_heads - self.head_dim = dim // num_heads - if self.head_dim % 2 != 0: - raise ValueError("head_dim must be even for RoPE") - # No CastedLinear -- weights come from banks - self.q_gain = nn.Parameter(torch.full((num_heads,), qk_gain_init, dtype=torch.float32)) - self.rope_dims = 0 # set by GPT.__init__ for partial RoPE - self.rotary = Rotary(self.head_dim, base=rope_base, train_seq_len=1024) - self.use_xsa = False # set by GPT.__init__ for deep layers only - # Gated attention and value residual (non-banked small params) - self.gated_attention = gated_attention - if gated_attention: - self.attn_gate = nn.Linear(dim, num_heads, bias=True) - nn.init.zeros_(self.attn_gate.weight) - nn.init.constant_(self.attn_gate.bias, 4.0) - self.value_residual = value_residual - if value_residual: - self.vr_lambda = nn.Parameter(torch.tensor([0.5, 0.5], dtype=torch.float32)) - def _xsa_efficient(self, y: Tensor, v: Tensor) -> Tensor: - """Efficient XSA: subtract self-value projection via GQA-aware reshape (no repeat_interleave). - y: [B, T, H, D], v: [B, T, Hkv, D]. H must be divisible by Hkv.""" - B, T, H, D = y.shape - Hkv = v.size(-2) - group = H // Hkv - y_g = y.reshape(B, T, Hkv, group, D) # [B, T, Hkv, group, D] - vn = F.normalize(v, dim=-1).unsqueeze(-2) # [B, T, Hkv, 1, D] -- broadcast ready - proj = (y_g * vn).sum(dim=-1, keepdim=True) * vn - return (y_g - proj).reshape(B, T, H, D) - def forward(self, x: Tensor, q_w: Tensor, k_w: Tensor, v_w: Tensor, out_w: Tensor, v_embed: Tensor | None = None, v0: Tensor | None = None) -> tuple[Tensor, Tensor | None]: - bsz, seqlen, dim = x.shape - q = F.linear(x, q_w.to(x.dtype)).reshape(bsz, seqlen, self.num_heads, self.head_dim) - k = F.linear(x, k_w.to(x.dtype)).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim) - v = F.linear(x, v_w.to(x.dtype)) - if v_embed is not None: - v = v + v_embed - v = v.reshape(bsz, seqlen, self.num_kv_heads, self.head_dim) - raw_v = v if self.value_residual else None - if self.value_residual and v0 is not None: - lam = self.vr_lambda.to(dtype=v.dtype) - v = lam[0] * v0 + lam[1] * v - q = F.rms_norm(q, (q.size(-1),)) - k = F.rms_norm(k, (k.size(-1),)) - cos, sin = self.rotary(seqlen, x.device, q.dtype) - q = apply_rotary_emb(q, cos, sin, self.rope_dims) - k = apply_rotary_emb(k, cos, sin, self.rope_dims) - q = q * self.q_gain.to(dtype=q.dtype)[None, None, :, None] - y = flash_attn_3_func(q, k, v, causal=True) - if self.use_xsa: - y = self._xsa_efficient(y, v) - if self.gated_attention: - # gate shape: (bsz, seqlen, num_heads) -> (bsz, seqlen, num_heads, 1) for B,T,H,D layout - gate = torch.sigmoid(self.attn_gate(x)).unsqueeze(-1) - y = y * gate - y = y.reshape(bsz, seqlen, dim) - return F.linear(y, out_w.to(x.dtype)), raw_v - -class SmearGate(nn.Module): - def __init__(self, dim: int): - super().__init__() - self.gate = nn.Parameter(torch.zeros(dim, dtype=torch.float32)) - def forward(self, x: Tensor) -> Tensor: - g = torch.sigmoid(self.gate.to(dtype=x.dtype))[None, None, :] - x_prev = torch.cat([torch.zeros_like(x[:, :1]), x[:, :-1]], dim=1) - return (1 - g) * x + g * x_prev - -class BigramHashEmbedding(nn.Module): - def __init__(self, bigram_vocab_size: int, bigram_dim: int, model_dim: int): - super().__init__() - self.bigram_vocab_size = bigram_vocab_size - self.embed = nn.Embedding(bigram_vocab_size, bigram_dim) - nn.init.zeros_(self.embed.weight) - self.proj = CastedLinear(bigram_dim, model_dim, bias=False) if bigram_dim != model_dim else None - if self.proj is not None: - nn.init.zeros_(self.proj.weight) - self.scale = nn.Parameter(torch.tensor(0.05, dtype=torch.float32)) - def bigram_hash(self, tokens: Tensor) -> Tensor: - t = tokens.to(torch.int32) - mod = self.bigram_vocab_size - 1 - out = torch.empty_like(t) - out[..., 0] = mod - out[..., 1:] = torch.bitwise_xor(36313 * t[..., 1:], 27191 * t[..., :-1]) % mod - return out.long() - def forward(self, token_ids: Tensor) -> Tensor: - h = self.embed(self.bigram_hash(token_ids)) - if self.proj is not None: - h = self.proj(h) - return h * self.scale.to(dtype=h.dtype) - -class ValueEmbedding(nn.Module): - """Reinject token identity into attention values at specific layers. - Each table maps vocab tokens to a low-dim embedding, projected to model_dim.""" - def __init__(self, vocab_size: int, ve_dim: int, model_dim: int): - super().__init__() - self.embed = nn.Embedding(vocab_size, ve_dim) - nn.init.normal_(self.embed.weight, std=0.01) - self.proj = CastedLinear(ve_dim, model_dim, bias=False) if ve_dim != model_dim else None - if self.proj is not None: - nn.init.zeros_(self.proj.weight) - self.scale = nn.Parameter(torch.tensor(0.1, dtype=torch.float32)) - def forward(self, token_ids: Tensor) -> Tensor: - h = self.embed(token_ids) - if self.proj is not None: - h = self.proj(h) - return h * self.scale.to(dtype=h.dtype) - -class MLP(nn.Module): - def __init__(self, dim: int, mlp_mult: int): - super().__init__() - # No CastedLinear -- weights come from banks - def forward(self, x: Tensor, up_w: Tensor, down_w: Tensor) -> Tensor: - x = F.leaky_relu(F.linear(x, up_w.to(x.dtype)), negative_slope=0.5) - return F.linear(x.square(), down_w.to(x.dtype)) - -class Block(nn.Module): - def __init__( - self, - dim: int, - num_heads: int, - num_kv_heads: int, - mlp_mult: int, - rope_base: float, - qk_gain_init: float, - layer_idx: int = 0, - ln_scale: bool = False, - dtg: bool = False, - gated_attention: bool = False, - value_residual: bool = False, - ): - super().__init__() - self.attn_norm = RMSNorm() - self.mlp_norm = RMSNorm() - self.attn = CausalSelfAttention(dim, num_heads, num_kv_heads, rope_base, qk_gain_init, - gated_attention=gated_attention, value_residual=value_residual) - self.mlp = MLP(dim, mlp_mult) - self.attn_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) - self.mlp_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) - self.resid_mix = nn.Parameter(torch.stack((torch.ones(dim), torch.zeros(dim))).float()) - self.ln_scale_factor = 1.0 / math.sqrt(layer_idx + 1) if ln_scale else 1.0 - if dtg: - self.dtg_gate = nn.Linear(dim, 1, bias=True) - nn.init.zeros_(self.dtg_gate.weight) - nn.init.constant_(self.dtg_gate.bias, 2.0) - else: - self.dtg_gate = None - def forward(self, x: Tensor, x0: Tensor, q_w: Tensor, k_w: Tensor, v_w: Tensor, out_w: Tensor, up_w: Tensor, down_w: Tensor, v_embed: Tensor | None = None, v0: Tensor | None = None) -> tuple[Tensor, Tensor | None]: - mix = self.resid_mix.to(dtype=x.dtype) - x_in = mix[0][None, None, :] * x + mix[1][None, None, :] * x0 - attn_out, raw_v = self.attn(self.attn_norm(x_in) * self.ln_scale_factor, q_w, k_w, v_w, out_w, v_embed=v_embed, v0=v0) - x_out = x_in + self.attn_scale.to(dtype=x_in.dtype)[None, None, :] * attn_out - x_out = x_out + self.mlp_scale.to(dtype=x_out.dtype)[None, None, :] * self.mlp(self.mlp_norm(x_out) * self.ln_scale_factor, up_w, down_w) - if self.dtg_gate is not None: - gate = torch.sigmoid(self.dtg_gate(x_in.detach())) - x_out = x_in + gate * (x_out - x_in) - return x_out, raw_v - -class GPT(nn.Module): - def __init__( - self, - vocab_size: int, - num_layers: int, - model_dim: int, - num_heads: int, - num_kv_heads: int, - mlp_mult: int, - tie_embeddings: bool, - tied_embed_init_std: float, - logit_softcap: float, - rope_base: float, - qk_gain_init: float, - mtp_num_heads: int = 0, - mtp_loss_weight: float = 0.1, - bigram_vocab_size: int = 0, - bigram_dim: int = 128, - xsa_last_n: int = 0, - rope_dims: int = 0, - ln_scale: bool = False, - dtg: bool = False, - ve_enabled: bool = False, - ve_dim: int = 128, - ve_layers: str = "9,10", - gated_attention: bool = False, - value_residual: bool = False, - ): - super().__init__() - self._ve_target_dim = num_kv_heads * (model_dim // num_heads) # kv_dim for value projection - if logit_softcap <= 0.0: - raise ValueError(f"logit_softcap must be positive, got {logit_softcap}") - self.tie_embeddings = tie_embeddings - self.tied_embed_init_std = tied_embed_init_std - self.logit_softcap = logit_softcap - self.value_residual = value_residual - self.mtp_num_heads = mtp_num_heads - self.mtp_loss_weight = mtp_loss_weight - self.tok_emb = nn.Embedding(vocab_size, model_dim) - self.bigram = BigramHashEmbedding(bigram_vocab_size, bigram_dim, model_dim) if bigram_vocab_size > 0 else None - self.smear = SmearGate(model_dim) - self.num_encoder_layers = num_layers // 2 - self.num_decoder_layers = num_layers - self.num_encoder_layers - self.num_skip_weights = min(self.num_encoder_layers, self.num_decoder_layers) - self.skip_weights = nn.Parameter(torch.ones(self.num_skip_weights, model_dim, dtype=torch.float32)) - # Parameter banks: contiguous 3D tensors for batched optimizer - head_dim = model_dim // num_heads - kv_dim = num_kv_heads * head_dim - mlp_dim = int(mlp_mult * model_dim) - self.num_layers = num_layers - self.qo_bank = nn.Parameter(torch.empty(2 * num_layers, model_dim, model_dim)) - self.kv_bank = nn.Parameter(torch.empty(2 * num_layers, kv_dim, model_dim)) - self.mlp_up_bank = nn.Parameter(torch.empty(num_layers, mlp_dim, model_dim)) - self.mlp_down_bank = nn.Parameter(torch.empty(num_layers, model_dim, mlp_dim)) - self.blocks = nn.ModuleList( - [ - Block( - model_dim, - num_heads, - num_kv_heads, - mlp_mult, - rope_base, - qk_gain_init, - layer_idx=i, - ln_scale=ln_scale, - dtg=dtg, - gated_attention=gated_attention, - value_residual=value_residual, - ) - for i in range(num_layers) - ] - ) - if rope_dims > 0: - head_dim = model_dim // num_heads - for block in self.blocks: - block.attn.rope_dims = rope_dims - block.attn.rotary = Rotary(head_dim, base=rope_base, train_seq_len=1024, rope_dims=rope_dims) - self.ve_layer_indices = [int(x) for x in ve_layers.split(",") if x.strip()] if ve_enabled else [] - kv_dim_ve = self._ve_target_dim - if self.ve_layer_indices: - self.ve_shared = ValueEmbedding(vocab_size, ve_dim, kv_dim_ve) - self.ve_layer_scales = nn.ParameterList( - [nn.Parameter(torch.ones(1, dtype=torch.float32)) for _ in self.ve_layer_indices] - ) - else: - self.ve_shared = None - self.ve_layer_scales = nn.ParameterList() - self.value_embeds = nn.ModuleList() # keep empty for compat - self.final_norm = RMSNorm() - self.lm_head = None if tie_embeddings else CastedLinear(model_dim, vocab_size, bias=False) - if self.lm_head is not None: - self.lm_head._zero_init = True - self.mtp_heads = nn.ModuleList( - [CastedLinear(model_dim, vocab_size, bias=False) for _ in range(mtp_num_heads)] - ) - for head in self.mtp_heads: - head._zero_init = True - if xsa_last_n > 0: - for i in range(max(0, num_layers - xsa_last_n), num_layers): - self.blocks[i].attn.use_xsa = True - self._init_weights() - def _init_weights(self) -> None: - if self.tie_embeddings: - nn.init.normal_(self.tok_emb.weight, mean=0.0, std=self.tied_embed_init_std) - n = self.num_layers - proj_scale = 1.0 / math.sqrt(2 * n) - # Init banks: orthogonal, with proj layers scaled down and out/down zero-init - for i in range(n): - nn.init.orthogonal_(self.qo_bank.data[i], gain=1.0) # Q - nn.init.zeros_(self.qo_bank.data[n + i]) # Out (zero init) - nn.init.orthogonal_(self.kv_bank.data[i], gain=1.0) # K - nn.init.orthogonal_(self.kv_bank.data[n + i], gain=1.0) # V - nn.init.orthogonal_(self.mlp_up_bank.data[i], gain=1.0) # MLP up - nn.init.zeros_(self.mlp_down_bank.data[i]) # MLP down (zero init) - # Scale proj layers (out_proj and mlp_down are "proj" layers) - self.qo_bank.data[n + i].mul_(proj_scale) - self.mlp_down_bank.data[i].mul_(proj_scale) - # Init remaining nn.Linear modules (bigram proj, mtp heads, lm_head) - for name, module in self.named_modules(): - if isinstance(module, nn.Linear): - if getattr(module, "_zero_init", False): - nn.init.zeros_(module.weight) - elif module.weight.ndim == 2 and module.weight.shape[0] >= 64 and module.weight.shape[1] >= 64: - nn.init.orthogonal_(module.weight, gain=1.0) - def _get_ve(self, layer_idx: int, input_ids: Tensor, ve_cache: dict | None = None) -> Tensor | None: - """Get value embedding for a specific layer using shared table + per-layer scale.""" - if self.ve_shared is None or layer_idx not in self.ve_layer_indices: - return None - if ve_cache is not None and 've' not in ve_cache: - ve_cache['ve'] = self.ve_shared(input_ids) - ve_base = ve_cache['ve'] if ve_cache is not None else self.ve_shared(input_ids) - ve_idx = self.ve_layer_indices.index(layer_idx) - return ve_base * self.ve_layer_scales[ve_idx].to(dtype=ve_base.dtype) - def forward(self, input_ids: Tensor, target_ids: Tensor) -> Tensor: - n = self.num_layers - x = self.tok_emb(input_ids) - if self.bigram is not None: - x = x + self.bigram(input_ids) - x = F.rms_norm(x, (x.size(-1),)) - x = self.smear(x) - x0 = x - v0 = None - skips: list[Tensor] = [] - ve_cache: dict = {} - for i in range(self.num_encoder_layers): - ve = self._get_ve(i, input_ids, ve_cache) - x, raw_v = self.blocks[i](x, x0, - self.qo_bank[i], self.kv_bank[i], self.kv_bank[n + i], - self.qo_bank[n + i], self.mlp_up_bank[i], self.mlp_down_bank[i], - v_embed=ve, v0=v0) - if v0 is None and raw_v is not None: - v0 = raw_v - skips.append(x) - for i in range(self.num_decoder_layers): - bi = self.num_encoder_layers + i - if skips: - x = x + self.skip_weights[i].to(dtype=x.dtype)[None, None, :] * skips.pop() - ve = self._get_ve(bi, input_ids, ve_cache) - x, _ = self.blocks[bi](x, x0, - self.qo_bank[bi], self.kv_bank[bi], self.kv_bank[n + bi], - self.qo_bank[n + bi], self.mlp_up_bank[bi], self.mlp_down_bank[bi], - v_embed=ve, v0=v0) - x = self.final_norm(x) - x_flat = x.reshape(-1, x.size(-1)) - targets = target_ids.reshape(-1) - if self.tie_embeddings: - logits_proj = F.linear(x_flat, self.tok_emb.weight) - else: - if self.lm_head is None: - raise RuntimeError("lm_head is required when tie_embeddings=False") - logits_proj = self.lm_head(x_flat) - logits = self.logit_softcap * torch.tanh(logits_proj / self.logit_softcap) - main_loss = F.cross_entropy(logits.float(), targets, reduction="mean") - if self.training and self.mtp_num_heads > 0 and self.mtp_loss_weight > 0.0: - _, seqlen, dim = x.shape - mtp_loss_sum = x.new_zeros(()) - mtp_loss_count = 0 - for k, mtp_head in enumerate(self.mtp_heads): - valid_t = seqlen - (k + 1) - if valid_t <= 0: - continue - mtp_hidden = x[:, :valid_t, :].reshape(-1, dim) - mtp_targets = target_ids[:, k + 1 :].reshape(-1) - mtp_logits_proj = mtp_head(mtp_hidden) - mtp_logits = self.logit_softcap * torch.tanh(mtp_logits_proj / self.logit_softcap) - mtp_loss_sum = mtp_loss_sum + F.cross_entropy(mtp_logits.float(), mtp_targets, reduction="mean") - mtp_loss_count += 1 - if mtp_loss_count > 0: - main_loss = main_loss + self.mtp_loss_weight * (mtp_loss_sum / mtp_loss_count) - return main_loss - def forward_logits(self, input_ids: Tensor) -> Tensor: - """Return logits (bsz, seq_len, vocab) without computing loss.""" - n = self.num_layers - x = self.tok_emb(input_ids) - if self.bigram is not None: - x = x + self.bigram(input_ids) - x = F.rms_norm(x, (x.size(-1),)) - x = self.smear(x) - x0 = x - v0 = None - skips: list[Tensor] = [] - ve_cache: dict = {} - for i in range(self.num_encoder_layers): - ve = self._get_ve(i, input_ids, ve_cache) - x, raw_v = self.blocks[i](x, x0, - self.qo_bank[i], self.kv_bank[i], self.kv_bank[n + i], - self.qo_bank[n + i], self.mlp_up_bank[i], self.mlp_down_bank[i], - v_embed=ve, v0=v0) - if v0 is None and raw_v is not None: - v0 = raw_v - skips.append(x) - for i in range(self.num_decoder_layers): - bi = self.num_encoder_layers + i - if skips: - x = x + self.skip_weights[i].to(dtype=x.dtype)[None, None, :] * skips.pop() - ve = self._get_ve(bi, input_ids, ve_cache) - x, _ = self.blocks[bi](x, x0, - self.qo_bank[bi], self.kv_bank[bi], self.kv_bank[n + bi], - self.qo_bank[n + bi], self.mlp_up_bank[bi], self.mlp_down_bank[bi], - v_embed=ve, v0=v0) - x = self.final_norm(x) - if self.tie_embeddings: - logits_proj = F.linear(x, self.tok_emb.weight) - else: - logits_proj = self.lm_head(x) - return self.logit_softcap * torch.tanh(logits_proj / self.logit_softcap) - -# --- Sliding window evaluation --- - -def eval_val_sliding( - args: Hyperparameters, - base_model: nn.Module, - rank: int, - world_size: int, - device: torch.device, - val_tokens: Tensor, - base_bytes_lut: Tensor, - has_leading_space_lut: Tensor, - is_boundary_token_lut: Tensor, - stride: int, - batch_seqs: int = 32, - eval_seq_len: int | None = None, -) -> tuple[float, float]: - """Sliding window evaluation: each token scored with maximum context.""" - seq_len = eval_seq_len or args.train_seq_len - total_tokens = val_tokens.numel() - 1 - window_starts = [ws for ws in range(0, total_tokens, stride) - if min(ws + seq_len, total_tokens) - ws >= 1] - total_windows = len(window_starts) - my_s = (total_windows * rank) // world_size - my_e = (total_windows * (rank + 1)) // world_size - my_windows = window_starts[my_s:my_e] - loss_sum = torch.zeros((), device=device, dtype=torch.float64) - token_count = torch.zeros((), device=device, dtype=torch.float64) - byte_count = torch.zeros((), device=device, dtype=torch.float64) - base_model.eval() - compiled_logits = torch.compile(base_model.forward_logits, dynamic=False, fullgraph=True) - with torch.inference_mode(): - for bi in range(0, len(my_windows), batch_seqs): - batch_ws = my_windows[bi:bi + batch_seqs] - bsz = len(batch_ws) - x_batch = torch.zeros(bsz, seq_len, dtype=torch.int64, device=device) - y_batch = torch.zeros(bsz, seq_len, dtype=torch.int64, device=device) - wlens: list[int] = [] - for i, ws in enumerate(batch_ws): - end = min(ws + seq_len, total_tokens) - wlen = end - ws - wlens.append(wlen) - chunk = val_tokens[ws:end + 1].to(dtype=torch.int64, device=device) - x_batch[i, :wlen] = chunk[:-1] - y_batch[i, :wlen] = chunk[1:] - with torch.autocast(device_type="cuda", dtype=torch.bfloat16): - logits = compiled_logits(x_batch) - nll = F.cross_entropy( - logits.reshape(-1, logits.size(-1)).float(), - y_batch.reshape(-1), - reduction="none", - ).reshape(bsz, seq_len) - for i, ws in enumerate(batch_ws): - wlen = wlens[i] - s = 0 if ws == 0 else max(wlen - stride, 0) - scored_nll = nll[i, s:wlen].to(torch.float64) - loss_sum += scored_nll.sum() - token_count += float(wlen - s) - tgt = y_batch[i, s:wlen] - prev = x_batch[i, s:wlen] - tb = base_bytes_lut[tgt].to(torch.float64) - tb += (has_leading_space_lut[tgt] & ~is_boundary_token_lut[prev]).to(torch.float64) - byte_count += tb.sum() - if dist.is_available() and dist.is_initialized(): - dist.all_reduce(loss_sum, op=dist.ReduceOp.SUM) - dist.all_reduce(token_count, op=dist.ReduceOp.SUM) - dist.all_reduce(byte_count, op=dist.ReduceOp.SUM) - val_loss = (loss_sum / token_count).item() - bits_per_token = val_loss / math.log(2.0) - tokens_per_byte = token_count.item() / byte_count.item() - base_model.train() - return val_loss, bits_per_token * tokens_per_byte - - -def eval_val_sliding_ttt( - args: Hyperparameters, base_model: nn.Module, rank: int, world_size: int, - device: torch.device, val_tokens: Tensor, base_bytes_lut: Tensor, - has_leading_space_lut: Tensor, is_boundary_token_lut: Tensor, - stride: int, batch_seqs: int = 32, log0=print, -) -> tuple[float, float]: - """Legal score-first TTT (PR #461 recipe): score each chunk with sliding windows, - then train on it. Every token scored BEFORE any update that could use it.""" - seq_len = args.train_seq_len - total_tokens = val_tokens.numel() - 1 - ttt_chunk = args.ttt_chunk_tokens - - # Pre-compute all window starts - window_starts = [ws for ws in range(0, total_tokens, stride) - if min(ws + seq_len, total_tokens) - ws >= stride or ws == 0] - - # Assign each window to a chunk based on the first token it scores - num_chunks = (total_tokens + ttt_chunk - 1) // ttt_chunk - chunk_windows: list[list[int]] = [[] for _ in range(num_chunks)] - for ws in window_starts: - end = min(ws + seq_len, total_tokens) - wlen = end - ws - s = 0 if ws == 0 else max(wlen - stride, 0) - scored_start = ws + s - ci = min(scored_start // ttt_chunk, num_chunks - 1) - chunk_windows[ci].append(ws) - - log0(f"ttt_sliding:start chunks={num_chunks} chunk_tokens={ttt_chunk} " - f"total_windows={len(window_starts)} stride={stride} " - f"ttt_lr={args.ttt_lr} ttt_epochs={args.ttt_epochs} " - f"freeze_blocks={args.ttt_freeze_blocks}") - - loss_sum = torch.zeros((), device=device, dtype=torch.float64) - token_count = torch.zeros((), device=device, dtype=torch.float64) - byte_count = torch.zeros((), device=device, dtype=torch.float64) - - # Freeze first N blocks - frozen_block_ids = set(range(min(args.ttt_freeze_blocks, len(base_model.blocks)))) - ttt_params = [] - for name, p in base_model.named_parameters(): - freeze = False - for bi in frozen_block_ids: - if f"blocks.{bi}." in name: - freeze = True - break - if freeze: - p.requires_grad_(False) - else: - p.requires_grad_(True) - ttt_params.append(p) - - log0(f"ttt_sliding:params unfrozen={sum(p.numel() for p in ttt_params)} " - f"frozen={sum(p.numel() for p in base_model.parameters() if not p.requires_grad)}") - - # Muon TTT: no external optimizer; we apply Newton-Schulz orthogonalized updates - # For non-matrix params (1-D) we use plain LR-scaled gradient updates (like AdamW scalar track) - use_muon_ttt = args.ttt_muon - if not use_muon_ttt: - optimizer = torch.optim.SGD(ttt_params, lr=args.ttt_lr, momentum=args.ttt_momentum) - else: - optimizer = None # manual update below - t0 = time.perf_counter() - - for ci in range(num_chunks): - windows = chunk_windows[ci] - if not windows: - continue - chunk_start = ci * ttt_chunk - chunk_end = min((ci + 1) * ttt_chunk, total_tokens) - - # Track per-chunk NLL for entropy-adaptive epoch selection - chunk_loss_sum = 0.0 - chunk_token_count = 0 - # --- Phase 1: SCORE this chunk's windows (inference_mode) --- - my_s = (len(windows) * rank) // world_size - my_e = (len(windows) * (rank + 1)) // world_size - my_windows = windows[my_s:my_e] - - base_model.eval() - with torch.inference_mode(): - for bi in range(0, len(my_windows), batch_seqs): - batch_ws = my_windows[bi:bi + batch_seqs] - bsz = len(batch_ws) - x_batch = torch.zeros(bsz, seq_len, dtype=torch.int64, device=device) - y_batch = torch.zeros(bsz, seq_len, dtype=torch.int64, device=device) - wlens: list[int] = [] - for i, ws in enumerate(batch_ws): - end = min(ws + seq_len, total_tokens) - wlen = end - ws - wlens.append(wlen) - chunk_tok = val_tokens[ws:end + 1].to(dtype=torch.int64, device=device) - x_batch[i, :wlen] = chunk_tok[:-1] - y_batch[i, :wlen] = chunk_tok[1:] - with torch.autocast(device_type="cuda", dtype=torch.bfloat16): - logits = base_model.forward_logits(x_batch) - nll = F.cross_entropy( - logits.reshape(-1, logits.size(-1)).float(), - y_batch.reshape(-1), reduction="none", - ).reshape(bsz, seq_len) - for i, ws in enumerate(batch_ws): - wlen = wlens[i] - s = 0 if ws == 0 else max(wlen - stride, 0) - scored_nll = nll[i, s:wlen].to(torch.float64) - loss_sum += scored_nll.sum() - token_count += float(wlen - s) - chunk_loss_sum += scored_nll.sum().item() - chunk_token_count += float(wlen - s) - tgt, prev = y_batch[i, s:wlen], x_batch[i, s:wlen] - tb = base_bytes_lut[tgt].to(torch.float64) - tb += (has_leading_space_lut[tgt] & ~is_boundary_token_lut[prev]).to(torch.float64) - byte_count += tb.sum() - - # --- Phase 2: TRAIN on this chunk (already scored = legal) --- - is_last_chunk = (ci == num_chunks - 1) - if not is_last_chunk and args.ttt_epochs > 0: - base_model.train() - chunk_seqs = (chunk_end - chunk_start) // seq_len - if chunk_seqs > 0: - cos_lr = args.ttt_lr * 0.5 * (1.0 + math.cos(math.pi * ci / max(num_chunks - 1, 1))) - if not use_muon_ttt: - for pg in optimizer.param_groups: - pg['lr'] = cos_lr - my_seq_s = (chunk_seqs * rank) // world_size - my_seq_e = (chunk_seqs * (rank + 1)) // world_size - my_chunk_seqs = my_seq_e - my_seq_s - # --- Entropy-adaptive epoch count (synchronized across ranks) --- - # Sync chunk NLL across all ranks so every rank uses the same effective_epochs - chunk_nll = float('inf') - if args.ttt_entropy_adapt: - cls_t = torch.tensor(chunk_loss_sum, device=device, dtype=torch.float64) - ctc_t = torch.tensor(chunk_token_count, device=device, dtype=torch.float64) - if world_size > 1: - dist.all_reduce(cls_t, op=dist.ReduceOp.SUM) - dist.all_reduce(ctc_t, op=dist.ReduceOp.SUM) - if ctc_t.item() > 0: - chunk_nll = (cls_t / ctc_t).item() - effective_epochs = args.ttt_epochs - if args.ttt_entropy_adapt: - if chunk_nll > args.ttt_entropy_high: - effective_epochs = args.ttt_epochs + 1 # hard chunk → extra epoch - elif chunk_nll < args.ttt_entropy_low: - effective_epochs = max(args.ttt_epochs - 1, 1) # easy chunk → save time - # Wall-clock guard: if we're past 85% of soft eval cap, cap at baseline epochs - elapsed_now = time.perf_counter() - t0 - eval_soft_cap = float(os.environ.get("EVAL_TIME_SOFT_CAP_SECONDS", "570")) - if elapsed_now > eval_soft_cap * 0.85: - effective_epochs = min(effective_epochs, args.ttt_epochs) - - for _ep in range(effective_epochs): - for bs in range(0, my_chunk_seqs, args.ttt_batch_seqs): - be = min(bs + args.ttt_batch_seqs, my_chunk_seqs) - actual_bs = my_seq_s + bs - start_tok = chunk_start + actual_bs * seq_len - end_tok = chunk_start + (my_seq_s + be) * seq_len + 1 - if end_tok > val_tokens.numel(): - continue - local = val_tokens[start_tok:end_tok].to(device=device, dtype=torch.int64) - x = local[:-1].reshape(-1, seq_len) - y = local[1:].reshape(-1, seq_len) - if not use_muon_ttt: - optimizer.zero_grad(set_to_none=True) - else: - for p in ttt_params: - p.grad = None - with torch.autocast(device_type="cuda", dtype=torch.bfloat16): - loss = base_model(x, y) - loss.backward() - if world_size > 1: - for p in ttt_params: - if p.grad is not None: - dist.all_reduce(p.grad, op=dist.ReduceOp.AVG) - torch.nn.utils.clip_grad_norm_(ttt_params, args.ttt_grad_clip) - if not use_muon_ttt: - optimizer.step() - else: - # Muon-style update: Newton-Schulz orthogonalization for matrix params, - # plain LR-scaled gradient for vector params (norms, biases, scalars). - with torch.no_grad(): - for p in ttt_params: - if p.grad is None: - continue - g = p.grad.detach().float() - if g.ndim >= 2: - g = zeropower_via_newtonschulz5(g, steps=args.ttt_ns_steps) - p.data.add_(g.to(p.dtype), alpha=-cos_lr) - - if rank == 0 and (ci % 10 == 0 or ci == num_chunks - 1): - elapsed = time.perf_counter() - t0 - rl = loss_sum.item() / max(token_count.item(), 1) - rbpb = rl / math.log(2.0) * (token_count.item() / max(byte_count.item(), 1)) if token_count.item() > 0 else 0.0 - log0(f" ttt_chunk [{ci+1}/{num_chunks}] bpb={rbpb:.6f} time={elapsed:.1f}s") - - if dist.is_available() and dist.is_initialized(): - dist.all_reduce(loss_sum, op=dist.ReduceOp.SUM) - dist.all_reduce(token_count, op=dist.ReduceOp.SUM) - dist.all_reduce(byte_count, op=dist.ReduceOp.SUM) - - val_loss = (loss_sum / token_count).item() - val_bpb = val_loss / math.log(2.0) * (token_count.item() / byte_count.item()) - - for p in base_model.parameters(): - p.requires_grad_(True) - base_model.eval() - - log0(f"ttt_sliding:done val_loss={val_loss:.6f} val_bpb={val_bpb:.6f} " - f"elapsed={time.perf_counter() - t0:.1f}s") - return val_loss, val_bpb - - -# --- GPTQ-lite int6 quantization --- - -def _classify_param(name: str) -> str: - if "tok_emb" in name or "lm_head" in name: - return "embed" - if ".mlp." in name: - return "mlp" - if ".attn." in name or (".proj." in name and ".mlp." not in name): - return "attn" - return "other" -def quantize_int6_per_row(t: Tensor, clip_range: int = 31) -> tuple[Tensor, Tensor]: - t32 = t.float() - if t32.ndim == 2: - best_q, best_s, best_err = None, None, float('inf') - for pct in [0.9990, 0.9995, 0.9999, 0.99999, 1.0]: - if pct < 1.0: - row_clip = torch.quantile(t32.abs(), pct, dim=1) - else: - row_clip = t32.abs().amax(dim=1) - s = (row_clip / clip_range).clamp_min(1.0 / clip_range).to(torch.float16) - q = torch.clamp(torch.round(t32 / s.float()[:, None]), -clip_range, clip_range).to(torch.int8) - recon = q.float() * s.float()[:, None] - err = (t32 - recon).pow(2).mean().item() - if err < best_err: - best_q, best_s, best_err = q, s, err - return best_q, best_s - amax = t32.abs().max().item() - scale = torch.tensor(amax / clip_range if amax > 0 else 1.0, dtype=torch.float16) - q = torch.clamp(torch.round(t32 / scale.float()), -clip_range, clip_range).to(torch.int8) - return q, scale - -def _unbank_state_dict(sd: dict[str, Tensor], num_layers: int) -> dict[str, Tensor]: - """Convert 3D bank tensors into individual 2D tensors with standard names.""" - out: dict[str, Tensor] = {} - n = num_layers - for name, tensor in sd.items(): - if name == "qo_bank": - for i in range(n): - out[f"blocks.{i}.attn.c_q.weight"] = tensor[i] - out[f"blocks.{i}.attn.proj.weight"] = tensor[n + i] - elif name == "kv_bank": - for i in range(n): - out[f"blocks.{i}.attn.c_k.weight"] = tensor[i] - out[f"blocks.{i}.attn.c_v.weight"] = tensor[n + i] - elif name == "mlp_up_bank": - for i in range(n): - out[f"blocks.{i}.mlp.fc.weight"] = tensor[i] - elif name == "mlp_down_bank": - for i in range(n): - out[f"blocks.{i}.mlp.proj.weight"] = tensor[i] - else: - out[name] = tensor - return out - -def _rebank_state_dict(sd: dict[str, Tensor], num_layers: int, template_sd: dict[str, Tensor]) -> dict[str, Tensor]: - """Convert individual 2D tensors back into 3D bank tensors.""" - out: dict[str, Tensor] = {} - n = num_layers - # Reconstruct banks from individual weight keys - qo_slices = [None] * (2 * n) - kv_slices = [None] * (2 * n) - up_slices = [None] * n - down_slices = [None] * n - consumed = set() - for i in range(n): - qk = f"blocks.{i}.attn.c_q.weight" - if qk in sd: - qo_slices[i] = sd[qk] - consumed.add(qk) - ok = f"blocks.{i}.attn.proj.weight" - if ok in sd: - qo_slices[n + i] = sd[ok] - consumed.add(ok) - kk = f"blocks.{i}.attn.c_k.weight" - if kk in sd: - kv_slices[i] = sd[kk] - consumed.add(kk) - vk = f"blocks.{i}.attn.c_v.weight" - if vk in sd: - kv_slices[n + i] = sd[vk] - consumed.add(vk) - fk = f"blocks.{i}.mlp.fc.weight" - if fk in sd: - up_slices[i] = sd[fk] - consumed.add(fk) - dk = f"blocks.{i}.mlp.proj.weight" - if dk in sd: - down_slices[i] = sd[dk] - consumed.add(dk) - out["qo_bank"] = torch.stack(qo_slices).to(dtype=template_sd["qo_bank"].dtype) - out["kv_bank"] = torch.stack(kv_slices).to(dtype=template_sd["kv_bank"].dtype) - out["mlp_up_bank"] = torch.stack(up_slices).to(dtype=template_sd["mlp_up_bank"].dtype) - out["mlp_down_bank"] = torch.stack(down_slices).to(dtype=template_sd["mlp_down_bank"].dtype) - for name, tensor in sd.items(): - if name not in consumed: - out[name] = tensor - return out - -def mixed_quantize_int6(state_dict: dict[str, Tensor], int6_cats: set[str]): - num_layers_total = max( - (int(k.split(".")[1]) for k in state_dict if k.startswith("blocks.")), - default=0, - ) + 1 - late_k_layers = set(range(num_layers_total - 2, num_layers_total)) - result: dict[str, Tensor] = {} - meta: dict[str, object] = {} - for name, tensor in state_dict.items(): - t = tensor.detach().cpu().contiguous() - cat = _classify_param(name) - if not t.is_floating_point() or t.numel() <= 65536: - result[name] = t.to(torch.float16) if t.is_floating_point() else t - meta[name] = "passthrough" - continue - if any(p in name for p in CONTROL_TENSOR_NAME_PATTERNS): - result[name] = t.float() - meta[name] = "passthrough_ctrl" - continue - if cat in int6_cats and t.ndim >= 1: - q, s = quantize_int6_per_row(t) - result[name + ".q"] = q - result[name + ".scale"] = s - meta[name] = {"type": "int6"} - else: - q, s = quantize_float_tensor(t) - result[name + ".q"] = q - result[name + ".scale"] = s - meta[name] = {"type": "int8"} - return result, meta -def dequantize_mixed_int6(result: dict[str, Tensor], meta: dict[str, object], - template_sd: dict[str, Tensor]) -> dict[str, Tensor]: - out: dict[str, Tensor] = {} - for name, orig in template_sd.items(): - info = meta.get(name) - if info is None: - continue - orig_dtype = orig.dtype - if info in ("passthrough", "passthrough_ctrl", "passthrough_fp16"): - t = result[name] - if t.dtype == torch.float16 and orig_dtype in (torch.float32, torch.bfloat16): - t = t.to(orig_dtype) - out[name] = t - continue - q, s = result[name + ".q"], result[name + ".scale"] - if s.ndim > 0: - out[name] = (q.float() * s.float().view(q.shape[0], *([1] * (q.ndim - 1)))).to(orig_dtype) - else: - out[name] = (q.float() * float(s.item())).to(orig_dtype) - return out - -# --- Training --- - -def main() -> None: - code = Path(__file__).read_text(encoding="utf-8") - args = Hyperparameters() - # zeropower_via_newtonschulz5 runs eagerly with bmm -- do NOT compile - distributed = "RANK" in os.environ and "WORLD_SIZE" in os.environ - rank = int(os.environ.get("RANK", "0")) - world_size = int(os.environ.get("WORLD_SIZE", "1")) - local_rank = int(os.environ.get("LOCAL_RANK", "0")) - if world_size <= 0: - raise ValueError(f"WORLD_SIZE must be positive, got {world_size}") - if 8 % world_size != 0: - raise ValueError(f"WORLD_SIZE={world_size} must divide 8 so grad_accum_steps stays integral") - grad_accum_steps = 8 // world_size - grad_scale = 1.0 / grad_accum_steps - if not torch.cuda.is_available(): - raise RuntimeError("CUDA is required") - device = torch.device("cuda", local_rank) - torch.cuda.set_device(device) - if distributed: - dist.init_process_group(backend="nccl", device_id=device) - dist.barrier() - master_process = rank == 0 - torch.backends.cuda.matmul.allow_tf32 = True - torch.backends.cudnn.allow_tf32 = True - from torch.backends.cuda import enable_cudnn_sdp, enable_flash_sdp, enable_math_sdp, enable_mem_efficient_sdp - enable_cudnn_sdp(False) - enable_flash_sdp(True) - enable_mem_efficient_sdp(False) - enable_math_sdp(False) - logfile = None - if master_process: - os.makedirs("logs", exist_ok=True) - logfile = f"logs/{args.run_id}.txt" - print(logfile) - def log0(msg: str, console: bool = True) -> None: - if not master_process: - return - if console: - print(msg) - if logfile is not None: - with open(logfile, "a", encoding="utf-8") as f: - print(msg, file=f) - log0(code, console=False) - log0("=" * 100, console=False) - log0(f"Running Python {sys.version}", console=False) - log0(f"Running PyTorch {torch.__version__}", console=False) - log0( - subprocess.run(["nvidia-smi"], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True, check=False).stdout, - console=False, - ) - log0("=" * 100, console=False) - random.seed(args.seed) - np.random.seed(args.seed) - torch.manual_seed(args.seed) - torch.cuda.manual_seed_all(args.seed) - if not args.tokenizer_path.endswith(".model"): - raise ValueError(f"Script only setup for SentencePiece .model file: {args.tokenizer_path}") - sp = spm.SentencePieceProcessor(model_file=args.tokenizer_path) - if int(sp.vocab_size()) != args.vocab_size: - raise ValueError( - f"VOCAB_SIZE={args.vocab_size} does not match tokenizer vocab_size={int(sp.vocab_size())}" - ) - dataset_dir = Path(args.data_path).resolve() - actual_train_files = len(list(dataset_dir.glob("fineweb_train_*.bin"))) - effective_eval_seq_len = args.eval_seq_len if args.eval_seq_len > 0 else args.train_seq_len - val_seq_len = max(args.train_seq_len, effective_eval_seq_len) - val_tokens = load_validation_tokens(args.val_files, val_seq_len) - base_bytes_lut, has_leading_space_lut, is_boundary_token_lut = build_sentencepiece_luts( - sp, args.vocab_size, device - ) - log0(f"val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path={args.tokenizer_path}") - log0(f"train_loader:dataset:{dataset_dir.name} train_shards:{actual_train_files}") - log0(f"val_loader:shards pattern={args.val_files} tokens:{val_tokens.numel() - 1}") - CastedLinear._qat_enabled = args.qat_enabled - base_model = GPT( - vocab_size=args.vocab_size, - num_layers=args.num_layers, - model_dim=args.model_dim, - num_heads=args.num_heads, - num_kv_heads=args.num_kv_heads, - mlp_mult=args.mlp_mult, - tie_embeddings=args.tie_embeddings, - tied_embed_init_std=args.tied_embed_init_std, - logit_softcap=args.logit_softcap, - rope_base=args.rope_base, - qk_gain_init=args.qk_gain_init, - mtp_num_heads=args.mtp_num_heads, - mtp_loss_weight=args.mtp_loss_weight, - bigram_vocab_size=args.bigram_vocab_size, - bigram_dim=args.bigram_dim, - xsa_last_n=args.xsa_last_n, - rope_dims=args.rope_dims, - ln_scale=args.ln_scale, - dtg=args.dtg_enabled, - ve_enabled=args.ve_enabled, - ve_dim=args.ve_dim, - ve_layers=args.ve_layers, - gated_attention=args.gated_attention, - value_residual=args.value_residual, - ).to(device).bfloat16() - # Banks stay FP32 (like CastedLinear weights), cast to BF16 in forward - base_model.qo_bank.data = base_model.qo_bank.data.float() - base_model.kv_bank.data = base_model.kv_bank.data.float() - base_model.mlp_up_bank.data = base_model.mlp_up_bank.data.float() - base_model.mlp_down_bank.data = base_model.mlp_down_bank.data.float() - for module in base_model.modules(): - if isinstance(module, CastedLinear): - module.float() - restore_low_dim_params_to_fp32(base_model) - # No DDP -- Parallel Muon handles bank grad communication via reduce-scatter, - # and non-bank grads are manually all-reduced before Adam steps. - compiled_model = torch.compile(base_model, dynamic=False, fullgraph=True) - model = compiled_model - - # Optimizer split: - # - 4 parameter banks -> Muon (batched Newton-Schulz) - # - token embedding -> Adam - # - scalars/control tensors -> Adam - # - bigram proj, mtp heads, VE proj -> Adam (small matrix params not worth banking) - matrix_params = [ - base_model.qo_bank, base_model.kv_bank, - base_model.mlp_up_bank, base_model.mlp_down_bank, - ] - block_named_params = list(base_model.blocks.named_parameters()) - scalar_params = [ - p - for name, p in block_named_params - if p.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) - ] - if base_model.skip_weights.numel() > 0: - scalar_params.append(base_model.skip_weights) - scalar_params.append(base_model.smear.gate) - if base_model.bigram is not None: - scalar_params.append(base_model.bigram.scale) - token_lr = args.tied_embed_lr if args.tie_embeddings else args.embed_lr - tok_params = [{"params": [base_model.tok_emb.weight], "lr": token_lr, "base_lr": token_lr}] - if base_model.bigram is not None: - tok_params.append({"params": [base_model.bigram.embed.weight], "lr": token_lr, "base_lr": token_lr}) - if base_model.bigram.proj is not None: - scalar_params.append(base_model.bigram.proj.weight) - if base_model.ve_shared is not None: - tok_params.append({"params": [base_model.ve_shared.embed.weight], "lr": token_lr, "base_lr": token_lr}) - if base_model.ve_shared.proj is not None: - scalar_params.append(base_model.ve_shared.proj.weight) - scalar_params.append(base_model.ve_shared.scale) - for s in base_model.ve_layer_scales: - scalar_params.append(s) - optimizer_tok = torch.optim.AdamW( - tok_params, - betas=(args.beta1, args.beta2), - eps=args.adam_eps, - weight_decay=args.adam_wd, - fused=True, - ) - optimizer_muon = Muon( - matrix_params, - lr=args.matrix_lr, - momentum=args.muon_momentum, - backend_steps=args.muon_backend_steps, - weight_decay=args.muon_wd, - ) - for group in optimizer_muon.param_groups: - group["base_lr"] = args.matrix_lr - optimizer_scalar = torch.optim.AdamW( - [{"params": scalar_params, "lr": args.scalar_lr, "base_lr": args.scalar_lr}], - betas=(args.beta1, args.beta2), - eps=args.adam_eps, - weight_decay=args.adam_wd, - fused=True, - ) - # Non-bank params that need manual all-reduce (replicated across GPUs) - replicated_params = list(optimizer_tok.param_groups[0]["params"]) - for pg in optimizer_tok.param_groups[1:]: - replicated_params.extend(pg["params"]) - replicated_params.extend(scalar_params) - - optimizer_head = None - if base_model.lm_head is not None: - optimizer_head = torch.optim.Adam( - [{"params": [base_model.lm_head.weight], "lr": args.head_lr, "base_lr": args.head_lr}], - betas=(args.beta1, args.beta2), - eps=args.adam_eps, - fused=True, - ) - replicated_params.append(base_model.lm_head.weight) - optimizers: list[torch.optim.Optimizer] = [optimizer_tok, optimizer_muon, optimizer_scalar] - if optimizer_head is not None: - optimizers.append(optimizer_head) - n_params = sum(p.numel() for p in base_model.parameters()) - mtp_params = sum(p.numel() for p in base_model.mtp_heads.parameters()) - log0(f"model_params:{n_params}") - log0(f"mtp_num_heads:{args.mtp_num_heads} mtp_loss_weight:{args.mtp_loss_weight} mtp_params:{mtp_params}") - xsa_layers = [i for i, b in enumerate(base_model.blocks) if b.attn.use_xsa] - log0(f"XSA:last_{args.xsa_last_n} active_layers:{xsa_layers}") - log0(f"world_size:{world_size} grad_accum_steps:{grad_accum_steps}") - log0("sdp_backends:cudnn=False flash=True mem_efficient=False math=False") - log0(f"attention_mode:gqa num_heads:{args.num_heads} num_kv_heads:{args.num_kv_heads}") - log0( - f"tie_embeddings:{args.tie_embeddings} embed_lr:{token_lr} " - f"head_lr:{args.head_lr if base_model.lm_head is not None else 0.0} " - f"matrix_lr:{args.matrix_lr} scalar_lr:{args.scalar_lr}" - ) - log0( - f"train_batch_tokens:{args.train_batch_tokens} train_seq_len:{args.train_seq_len} " - f"iterations:{args.iterations} warmup_steps:{args.warmup_steps} " - f"max_wallclock_seconds:{args.max_wallclock_seconds:.3f}" - ) - log0(f"seed:{args.seed}") - train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) - def zero_grad_all() -> None: - for opt in optimizers: - opt.zero_grad(set_to_none=True) - max_wallclock_ms = 1000.0 * args.max_wallclock_seconds if args.max_wallclock_seconds > 0 else None - def lr_mul(step: int, elapsed_ms: float) -> float: - if args.warmdown_iters <= 0: - return 1.0 - if max_wallclock_ms is None: - warmdown_start = max(args.iterations - args.warmdown_iters, 0) - return max((args.iterations - step) / max(args.warmdown_iters, 1), 0.0) if warmdown_start <= step < args.iterations else 1.0 - step_ms = elapsed_ms / max(step, 1) - warmdown_ms = args.warmdown_iters * step_ms - remaining_ms = max(max_wallclock_ms - elapsed_ms, 0.0) - return remaining_ms / max(warmdown_ms, 1e-9) if remaining_ms <= warmdown_ms else 1.0 - if args.warmup_steps > 0: - initial_model_state = {name: tensor.detach().cpu().clone() for name, tensor in base_model.state_dict().items()} - initial_optimizer_states = [copy.deepcopy(opt.state_dict()) for opt in optimizers] - model.train() - for warmup_step in range(args.warmup_steps): - zero_grad_all() - for micro_step in range(grad_accum_steps): - x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) - with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): - warmup_loss = model(x, y) - (warmup_loss * grad_scale).backward() - # All-reduce all grads for warmup (simple, not optimized) - if distributed: - for p in base_model.parameters(): - if p.grad is not None: - dist.all_reduce(p.grad, op=dist.ReduceOp.AVG) - for opt in optimizers: - opt.step() - zero_grad_all() - if args.warmup_steps <= 20 or (warmup_step + 1) % 10 == 0 or warmup_step + 1 == args.warmup_steps: - log0(f"warmup_step:{warmup_step + 1}/{args.warmup_steps}") - base_model.load_state_dict(initial_model_state, strict=True) - for opt, state in zip(optimizers, initial_optimizer_states, strict=True): - opt.load_state_dict(state) - zero_grad_all() - train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) - swa_state: dict[str, Tensor] | None = None - swa_count = 0 - from collections import deque - lawa_queue: deque[dict[str, Tensor]] = deque(maxlen=args.lawa_k) - ema_state = {name: t.detach().float().clone() for name, t in base_model.state_dict().items()} - ema_decay = 0.997 - training_time_ms = 0.0 - stop_after_step: int | None = None - torch.cuda.synchronize() - t0 = time.perf_counter() - step = 0 - while True: - last_step = step == args.iterations or (stop_after_step is not None and step >= stop_after_step) - should_validate = last_step or (args.val_loss_every > 0 and step % args.val_loss_every == 0) - if should_validate: - torch.cuda.synchronize() - training_time_ms += 1000.0 * (time.perf_counter() - t0) - val_loss, val_bpb = eval_val( - args, - model, - rank, - world_size, - device, - grad_accum_steps, - val_tokens, - base_bytes_lut, - has_leading_space_lut, - is_boundary_token_lut, - ) - log0( - f"step:{step}/{args.iterations} val_loss:{val_loss:.4f} val_bpb:{val_bpb:.4f} " - f"train_time:{training_time_ms:.0f}ms step_avg:{training_time_ms / max(step, 1):.2f}ms" - ) - torch.cuda.synchronize() - t0 = time.perf_counter() - if last_step: - if stop_after_step is not None and step < args.iterations: - log0( - f"stopping_early: wallclock_cap train_time:{training_time_ms:.0f}ms " - f"step:{step}/{args.iterations}" - ) - break - elapsed_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) - scale = lr_mul(step, elapsed_ms) - if args.late_qat_threshold > 0 and scale < args.late_qat_threshold and not CastedLinear._qat_enabled: - CastedLinear._qat_enabled = True - log0(f"late_qat:enabled step:{step} scale:{scale:.4f}") - zero_grad_all() - train_loss = torch.zeros((), device=device) - for micro_step in range(grad_accum_steps): - x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) - with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): - loss = model(x, y) - train_loss += loss.detach() - (loss * grad_scale).backward() - train_loss /= grad_accum_steps - frac = min(step / args.muon_momentum_warmup_steps, 1.0) if args.muon_momentum_warmup_steps > 0 else 1.0 - muon_momentum = (1 - frac) * args.muon_momentum_warmup_start + frac * args.muon_momentum - for group in optimizer_muon.param_groups: - group["momentum"] = muon_momentum - for opt in optimizers: - for group in opt.param_groups: - group["lr"] = group["base_lr"] * scale - if args.grad_clip_norm > 0: - torch.nn.utils.clip_grad_norm_(base_model.parameters(), args.grad_clip_norm) - # === 3-phase overlapped optimizer step === - # Phase 1: Launch async reduce-scatter for banks (biggest first) - optimizer_muon.launch_reduce_scatters() - # Phase 2: All-reduce non-bank grads + step Adam (while bank RS is in-flight) - if distributed: - for p in replicated_params: - if p.grad is not None: - dist.all_reduce(p.grad, op=dist.ReduceOp.AVG) - optimizer_tok.step() - optimizer_scalar.step() - if optimizer_head is not None: - optimizer_head.step() - # Phase 3: Wait for RS, local NS5, all-gather (banks processed last) - optimizer_muon.step() - zero_grad_all() - # EMA update - with torch.no_grad(): - for name, t in base_model.state_dict().items(): - ema_state[name].mul_(ema_decay).add_(t.detach().float(), alpha=1.0 - ema_decay) - step += 1 - approx_training_time_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) - if args.swa_enabled and scale < 0.2 and step % args.swa_every == 0: - if swa_state is None: - swa_state = {name: t.detach().cpu().clone() for name, t in base_model.state_dict().items()} - swa_count = 1 - log0(f"swa:start step:{step}") - else: - for name, t in base_model.state_dict().items(): - swa_state[name] += t.detach().cpu() - swa_count += 1 - if args.lawa_enabled and step % args.lawa_freq == 0: - lawa_queue.append({name: t.detach().cpu().clone() for name, t in base_model.state_dict().items()}) - should_log_train = ( - args.train_log_every > 0 - and (step <= 10 or step % args.train_log_every == 0 or stop_after_step is not None) - ) - if should_log_train: - log0( - f"step:{step}/{args.iterations} train_loss:{train_loss.item():.4f} " - f"train_time:{approx_training_time_ms:.0f}ms step_avg:{approx_training_time_ms / step:.2f}ms" - ) - reached_cap = max_wallclock_ms is not None and approx_training_time_ms >= max_wallclock_ms - if distributed and max_wallclock_ms is not None: - reached_cap_tensor = torch.tensor(int(reached_cap), device=device) - dist.all_reduce(reached_cap_tensor, op=dist.ReduceOp.MAX) - reached_cap = bool(reached_cap_tensor.item()) - if stop_after_step is None and reached_cap: - stop_after_step = step - log0( - f"peak memory allocated: {torch.cuda.max_memory_allocated() // 1024 // 1024} MiB " - f"reserved: {torch.cuda.max_memory_reserved() // 1024 // 1024} MiB" - ) - # Apply weight averaging - if args.lawa_enabled and len(lawa_queue) > 1: - log0(f"lawa:applying LAWA averaging k={len(lawa_queue)}") - current_state = base_model.state_dict() - avg_state = {name: torch.zeros(t.shape, dtype=torch.float32, device='cpu') for name, t in current_state.items()} - for snap in lawa_queue: - for name in avg_state: - avg_state[name] += snap[name].float() - for name in avg_state: - avg_state[name] /= len(lawa_queue) - avg_state[name] = avg_state[name].to(dtype=current_state[name].dtype) - base_model.load_state_dict(avg_state, strict=True) - else: - log0("ema:applying EMA weights") - current_state = base_model.state_dict() - avg_state = {name: t.to(dtype=current_state[name].dtype) for name, t in ema_state.items()} - base_model.load_state_dict(avg_state, strict=True) - torch.cuda.synchronize() - t_diag = time.perf_counter() - diag_val_loss, diag_val_bpb = eval_val( - args, compiled_model, rank, world_size, device, grad_accum_steps, - val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, - ) - torch.cuda.synchronize() - log0( - f"DIAGNOSTIC post_ema val_loss:{diag_val_loss:.4f} val_bpb:{diag_val_bpb:.4f} " - f"eval_time:{1000.0 * (time.perf_counter() - t_diag):.0f}ms" - ) - full_state_dict = base_model.state_dict() - export_sd = {k: v for k, v in full_state_dict.items() if "mtp_heads" not in k} - excluded_mtp = sum(int(t.numel()) for k, t in full_state_dict.items() if "mtp_heads" in k) - if excluded_mtp > 0: - log0(f"export_excluding_mtp_params:{excluded_mtp}") - if master_process: - torch.save(export_sd, "final_model.pt") - model_bytes = os.path.getsize("final_model.pt") - code_bytes = len(code.encode("utf-8")) - log0(f"Serialized model: {model_bytes} bytes") - log0(f"Code size: {code_bytes} bytes") - # Unbank 3D tensors into individual 2D tensors for quantization - sd_cpu = {k: v.detach().cpu() for k, v in export_sd.items()} - unbanked_sd = _unbank_state_dict(sd_cpu, args.num_layers) - quant_result, quant_meta = mixed_quantize_int6(unbanked_sd, {"mlp", "attn"}) - quant_buf = io.BytesIO() - torch.save({"w": quant_result, "m": quant_meta}, quant_buf) - quant_raw = quant_buf.getvalue() - quant_blob = lzma.compress(quant_raw, preset=7) - if master_process: - with open("final_model.int6.ptz", "wb") as f: - f.write(quant_blob) - quant_file_bytes = len(quant_blob) - code_bytes = len(code.encode("utf-8")) - log0(f"Serialized model int6+lzma: {quant_file_bytes} bytes") - log0(f"Total submission size int6+lzma: {quant_file_bytes + code_bytes} bytes") - if distributed: - dist.barrier() - with open("final_model.int6.ptz", "rb") as f: - quant_blob_disk = f.read() - quant_state = torch.load( - io.BytesIO(lzma.decompress(quant_blob_disk)), - map_location="cpu", - ) - deq_unbanked = dequantize_mixed_int6(quant_state["w"], quant_state["m"], unbanked_sd) - # Re-bank the dequantized tensors - deq_state = _rebank_state_dict(deq_unbanked, args.num_layers, sd_cpu) - eval_model = GPT( - vocab_size=args.vocab_size, num_layers=args.num_layers, model_dim=args.model_dim, - num_heads=args.num_heads, num_kv_heads=args.num_kv_heads, mlp_mult=args.mlp_mult, - tie_embeddings=args.tie_embeddings, tied_embed_init_std=args.tied_embed_init_std, - logit_softcap=args.logit_softcap, rope_base=args.rope_base, qk_gain_init=args.qk_gain_init, - mtp_num_heads=0, mtp_loss_weight=0.0, - bigram_vocab_size=args.bigram_vocab_size, bigram_dim=args.bigram_dim, - xsa_last_n=args.xsa_last_n, - rope_dims=args.rope_dims, ln_scale=args.ln_scale, dtg=args.dtg_enabled, - ve_enabled=args.ve_enabled, ve_dim=args.ve_dim, ve_layers=args.ve_layers, - gated_attention=args.gated_attention, value_residual=args.value_residual, - ).to(device).bfloat16() - eval_model.qo_bank.data = eval_model.qo_bank.data.float() - eval_model.kv_bank.data = eval_model.kv_bank.data.float() - eval_model.mlp_up_bank.data = eval_model.mlp_up_bank.data.float() - eval_model.mlp_down_bank.data = eval_model.mlp_down_bank.data.float() - for m in eval_model.modules(): - if isinstance(m, CastedLinear): - m.float() - restore_low_dim_params_to_fp32(eval_model) - eval_model.load_state_dict(deq_state, strict=True) - compiled_eval = torch.compile(eval_model, dynamic=False, fullgraph=True) - torch.cuda.synchronize() - t_qeval = time.perf_counter() - q_val_loss, q_val_bpb = eval_val( - args, compiled_eval, rank, world_size, device, grad_accum_steps, - val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, - eval_seq_len=effective_eval_seq_len, - ) - torch.cuda.synchronize() - log0( - f"final_int6_roundtrip val_loss:{q_val_loss:.4f} val_bpb:{q_val_bpb:.4f} " - f"eval_time:{1000.0 * (time.perf_counter() - t_qeval):.0f}ms" - ) - log0(f"final_int6_roundtrip_exact val_loss:{q_val_loss:.8f} val_bpb:{q_val_bpb:.8f}") - sw_seq_len = effective_eval_seq_len - if args.eval_stride > 0 and args.eval_stride < sw_seq_len: - torch.cuda.synchronize() - t_slide = time.perf_counter() - sw_val_loss, sw_val_bpb = eval_val_sliding( - args, eval_model, rank, world_size, device, - val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, - stride=args.eval_stride, - eval_seq_len=sw_seq_len, - ) - torch.cuda.synchronize() - log0( - f"final_int6_sliding_window val_loss:{sw_val_loss:.4f} val_bpb:{sw_val_bpb:.4f} " - f"stride:{args.eval_stride} eval_time:{1000.0 * (time.perf_counter() - t_slide):.0f}ms" - ) - log0(f"final_int6_sliding_window_exact val_loss:{sw_val_loss:.8f} val_bpb:{sw_val_bpb:.8f}") - log0(f"final_int8_zlib_roundtrip_exact val_loss:{sw_val_loss:.8f} val_bpb:{sw_val_bpb:.8f}") - if args.eval_stride != 64 and 64 < sw_seq_len: - torch.cuda.synchronize() - t_slide64 = time.perf_counter() - sw64_val_loss, sw64_val_bpb = eval_val_sliding( - args, eval_model, rank, world_size, device, - val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, - stride=64, - eval_seq_len=sw_seq_len, - ) - torch.cuda.synchronize() - log0( - f"final_int6_sliding_window_s64 val_loss:{sw64_val_loss:.4f} val_bpb:{sw64_val_bpb:.4f} " - f"stride:64 eval_time:{1000.0 * (time.perf_counter() - t_slide64):.0f}ms" - ) - log0(f"final_int6_sliding_window_s64_exact val_loss:{sw64_val_loss:.8f} val_bpb:{sw64_val_bpb:.8f}") - log0(f"final_int8_zlib_roundtrip_exact val_loss:{sw64_val_loss:.8f} val_bpb:{sw64_val_bpb:.8f}") - # Legal score-first TTT (PR #461 recipe) - if args.ttt_enabled: - torch.cuda.synchronize() - t_ttt = time.perf_counter() - ttt_loss, ttt_bpb = eval_val_sliding_ttt( - args, eval_model, rank, world_size, device, - val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, - stride=args.eval_stride, log0=log0, - ) - torch.cuda.synchronize() - log0(f"legal_ttt val_loss:{ttt_loss:.4f} val_bpb:{ttt_bpb:.4f} " - f"eval_time:{1000.0 * (time.perf_counter() - t_ttt):.0f}ms") - log0(f"legal_ttt_exact val_loss:{ttt_loss:.8f} val_bpb:{ttt_bpb:.8f}") - if distributed: - dist.destroy_process_group() -if __name__ == "__main__": - main() diff --git a/records/track_10min_16mb/2026-03-28_MuonTTT_EntropyAdaptive_11L_8xH100/train_seed1337.log b/records/track_10min_16mb/2026-03-28_MuonTTT_EntropyAdaptive_11L_8xH100/train_seed1337.log deleted file mode 100644 index 1c9e2194ae..0000000000 --- a/records/track_10min_16mb/2026-03-28_MuonTTT_EntropyAdaptive_11L_8xH100/train_seed1337.log +++ /dev/null @@ -1,275 +0,0 @@ -W0327 23:47:52.820000 59557 torch/distributed/run.py:803] -W0327 23:47:52.820000 59557 torch/distributed/run.py:803] ***************************************** -W0327 23:47:52.820000 59557 torch/distributed/run.py:803] Setting OMP_NUM_THREADS environment variable for each process to be 1 in default, to avoid your system being overloaded, please further tune the variable for optimal performance in your application as needed. -W0327 23:47:52.820000 59557 torch/distributed/run.py:803] ***************************************** -logs/run1b_muon_ttt_ns3_seed1337.txt -val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path=./data/tokenizers/fineweb_1024_bpe.model -train_loader:dataset:fineweb10B_sp1024 train_shards:80 -val_loader:shards pattern=./data/datasets/fineweb10B_sp1024/fineweb_val_*.bin tokens:62021632 -model_params:26928220 -mtp_num_heads:0 mtp_loss_weight:0.2 mtp_params:0 -XSA:last_4 active_layers:[7, 8, 9, 10] -world_size:8 grad_accum_steps:1 -sdp_backends:cudnn=False flash=True mem_efficient=False math=False -attention_mode:gqa num_heads:8 num_kv_heads:4 -tie_embeddings:True embed_lr:0.035 head_lr:0.0 matrix_lr:0.025 scalar_lr:0.025 -train_batch_tokens:786432 train_seq_len:2048 iterations:9000 warmup_steps:20 max_wallclock_seconds:600.000 -seed:1337 -warmup_step:1/20 -warmup_step:2/20 -warmup_step:3/20 -warmup_step:4/20 -warmup_step:5/20 -warmup_step:6/20 -warmup_step:7/20 -warmup_step:8/20 -warmup_step:9/20 -warmup_step:10/20 -warmup_step:11/20 -warmup_step:12/20 -warmup_step:13/20 -warmup_step:14/20 -warmup_step:15/20 -warmup_step:16/20 -warmup_step:17/20 -warmup_step:18/20 -warmup_step:19/20 -warmup_step:20/20 -step:0/9000 val_loss:6.9304 val_bpb:4.1046 train_time:0ms step_avg:0.01ms -step:1/9000 train_loss:6.9322 train_time:132ms step_avg:131.72ms -step:2/9000 train_loss:8.6545 train_time:159ms step_avg:79.37ms -step:3/9000 train_loss:7.6926 train_time:239ms step_avg:79.61ms -step:4/9000 train_loss:7.2518 train_time:320ms step_avg:79.90ms -step:5/9000 train_loss:7.1705 train_time:401ms step_avg:80.10ms -step:6/9000 train_loss:7.1160 train_time:481ms step_avg:80.20ms -step:7/9000 train_loss:7.0271 train_time:565ms step_avg:80.77ms -step:8/9000 train_loss:6.9597 train_time:648ms step_avg:81.02ms -step:9/9000 train_loss:6.5744 train_time:730ms step_avg:81.12ms -step:10/9000 train_loss:6.2001 train_time:814ms step_avg:81.36ms -step:500/9000 train_loss:2.3935 train_time:41426ms step_avg:82.85ms -step:1000/9000 train_loss:2.2635 train_time:83144ms step_avg:83.14ms -step:1500/9000 train_loss:2.2073 train_time:125004ms step_avg:83.34ms -step:2000/9000 train_loss:2.0563 train_time:166782ms step_avg:83.39ms -step:2500/9000 train_loss:2.1574 train_time:208519ms step_avg:83.41ms -step:3000/9000 train_loss:2.1500 train_time:250231ms step_avg:83.41ms -step:3500/9000 train_loss:2.1676 train_time:291931ms step_avg:83.41ms -step:4000/9000 train_loss:1.9646 train_time:333604ms step_avg:83.40ms -step:4000/9000 val_loss:2.0567 val_bpb:1.2181 train_time:333654ms step_avg:83.41ms -step:4500/9000 train_loss:2.1166 train_time:375269ms step_avg:83.39ms -step:5000/9000 train_loss:2.0992 train_time:416920ms step_avg:83.38ms -step:5500/9000 train_loss:2.0144 train_time:458562ms step_avg:83.37ms -step:6000/9000 train_loss:1.9350 train_time:500192ms step_avg:83.37ms -swa:start step:6500 -step:6500/9000 train_loss:2.0830 train_time:541808ms step_avg:83.36ms -late_qat:enabled step:6670 scale:0.1498 -step:7000/9000 train_loss:1.7905 train_time:584063ms step_avg:83.44ms -step:7189/9000 val_loss:1.9209 val_bpb:1.1376 train_time:600068ms step_avg:83.47ms -stopping_early: wallclock_cap train_time:600068ms step:7189/9000 -peak memory allocated: 21471 MiB reserved: 22002 MiB -ema:applying EMA weights -DIAGNOSTIC post_ema val_loss:1.9191 val_bpb:1.1366 eval_time:1973ms -Serialized model: 106027446 bytes -Code size: 93038 bytes -Serialized model int6+lzma: 15851372 bytes -Total submission size int6+lzma: 15944410 bytes -final_int6_roundtrip val_loss:1.9331 val_bpb:1.1449 eval_time:6284ms -final_int6_roundtrip_exact val_loss:1.93305924 val_bpb:1.14486657 -final_int6_sliding_window val_loss:1.8934 val_bpb:1.1214 stride:64 eval_time:74154ms -final_int6_sliding_window_exact val_loss:1.89337811 val_bpb:1.12136814 -final_int8_zlib_roundtrip_exact val_loss:1.89337811 val_bpb:1.12136814 -ttt_sliding:start chunks=1893 chunk_tokens=32768 total_windows=969088 stride=64 ttt_lr=0.002 ttt_epochs=3 freeze_blocks=0 -ttt_sliding:params unfrozen=26928220 frozen=0 - ttt_chunk [1/1893] bpb=1.159361 time=0.5s - ttt_chunk [11/1893] bpb=1.146201 time=2.9s - ttt_chunk [21/1893] bpb=1.131428 time=5.4s - ttt_chunk [31/1893] bpb=1.128968 time=8.0s - ttt_chunk [41/1893] bpb=1.114804 time=10.5s - ttt_chunk [51/1893] bpb=1.109042 time=13.0s - ttt_chunk [61/1893] bpb=1.115817 time=15.6s - ttt_chunk [71/1893] bpb=1.114274 time=18.1s - ttt_chunk [81/1893] bpb=1.113383 time=20.5s - ttt_chunk [91/1893] bpb=1.114187 time=23.0s - ttt_chunk [101/1893] bpb=1.117811 time=25.5s - ttt_chunk [111/1893] bpb=1.120252 time=28.0s - ttt_chunk [121/1893] bpb=1.113678 time=30.4s - ttt_chunk [131/1893] bpb=1.113996 time=32.9s - ttt_chunk [141/1893] bpb=1.119631 time=35.4s - ttt_chunk [151/1893] bpb=1.121440 time=37.9s - ttt_chunk [161/1893] bpb=1.120931 time=40.5s - ttt_chunk [171/1893] bpb=1.125218 time=43.0s - ttt_chunk [181/1893] bpb=1.127498 time=45.5s - ttt_chunk [191/1893] bpb=1.134713 time=48.0s - ttt_chunk [201/1893] bpb=1.133455 time=50.6s - ttt_chunk [211/1893] bpb=1.131281 time=53.1s - ttt_chunk [221/1893] bpb=1.132771 time=55.6s - ttt_chunk [231/1893] bpb=1.131464 time=58.1s - ttt_chunk [241/1893] bpb=1.131787 time=60.7s - ttt_chunk [251/1893] bpb=1.131297 time=63.2s - ttt_chunk [261/1893] bpb=1.128443 time=65.6s - ttt_chunk [271/1893] bpb=1.127250 time=68.1s - ttt_chunk [281/1893] bpb=1.128627 time=70.7s - ttt_chunk [291/1893] bpb=1.130481 time=73.2s - ttt_chunk [301/1893] bpb=1.131184 time=75.7s - ttt_chunk [311/1893] bpb=1.133302 time=78.2s - ttt_chunk [321/1893] bpb=1.135227 time=80.8s - ttt_chunk [331/1893] bpb=1.135063 time=83.3s - ttt_chunk [341/1893] bpb=1.134023 time=85.8s - ttt_chunk [351/1893] bpb=1.136300 time=88.4s - ttt_chunk [361/1893] bpb=1.136477 time=90.8s - ttt_chunk [371/1893] bpb=1.135770 time=93.3s - ttt_chunk [381/1893] bpb=1.136014 time=95.9s - ttt_chunk [391/1893] bpb=1.135899 time=98.5s - ttt_chunk [401/1893] bpb=1.133868 time=101.0s - ttt_chunk [411/1893] bpb=1.132670 time=103.5s - ttt_chunk [421/1893] bpb=1.131713 time=106.0s - ttt_chunk [431/1893] bpb=1.131589 time=108.6s - ttt_chunk [441/1893] bpb=1.131957 time=111.2s - ttt_chunk [451/1893] bpb=1.132313 time=113.7s - ttt_chunk [461/1893] bpb=1.131246 time=116.3s - ttt_chunk [471/1893] bpb=1.131909 time=118.9s - ttt_chunk [481/1893] bpb=1.131550 time=121.4s - ttt_chunk [491/1893] bpb=1.130408 time=123.9s - ttt_chunk [501/1893] bpb=1.129927 time=126.4s - ttt_chunk [511/1893] bpb=1.129274 time=129.0s - ttt_chunk [521/1893] bpb=1.126928 time=131.5s - ttt_chunk [531/1893] bpb=1.128112 time=134.2s - ttt_chunk [541/1893] bpb=1.128485 time=136.7s - ttt_chunk [551/1893] bpb=1.127460 time=139.2s - ttt_chunk [561/1893] bpb=1.127993 time=141.8s - ttt_chunk [571/1893] bpb=1.126972 time=144.3s - ttt_chunk [581/1893] bpb=1.126175 time=146.7s - ttt_chunk [591/1893] bpb=1.125528 time=149.2s - ttt_chunk [601/1893] bpb=1.126022 time=151.7s - ttt_chunk [611/1893] bpb=1.125971 time=154.2s - ttt_chunk [621/1893] bpb=1.125789 time=156.8s - ttt_chunk [631/1893] bpb=1.126474 time=159.3s - ttt_chunk [641/1893] bpb=1.126179 time=161.7s - ttt_chunk [651/1893] bpb=1.126297 time=164.2s - ttt_chunk [661/1893] bpb=1.125733 time=166.6s - ttt_chunk [671/1893] bpb=1.126078 time=169.1s - ttt_chunk [681/1893] bpb=1.126827 time=171.6s - ttt_chunk [691/1893] bpb=1.127830 time=174.2s - ttt_chunk [701/1893] bpb=1.127278 time=176.7s - ttt_chunk [711/1893] bpb=1.127234 time=179.3s - ttt_chunk [721/1893] bpb=1.126901 time=181.7s - ttt_chunk [731/1893] bpb=1.126944 time=184.3s - ttt_chunk [741/1893] bpb=1.127065 time=186.9s - ttt_chunk [751/1893] bpb=1.126893 time=189.4s - ttt_chunk [761/1893] bpb=1.126812 time=192.0s - ttt_chunk [771/1893] bpb=1.126504 time=194.6s - ttt_chunk [781/1893] bpb=1.127252 time=197.1s - ttt_chunk [791/1893] bpb=1.126853 time=199.7s - ttt_chunk [801/1893] bpb=1.127143 time=202.2s - ttt_chunk [811/1893] bpb=1.126883 time=204.8s - ttt_chunk [821/1893] bpb=1.126688 time=207.4s - ttt_chunk [831/1893] bpb=1.126484 time=209.9s - ttt_chunk [841/1893] bpb=1.125830 time=212.4s - ttt_chunk [851/1893] bpb=1.125559 time=214.9s - ttt_chunk [861/1893] bpb=1.125279 time=217.4s - ttt_chunk [871/1893] bpb=1.125553 time=220.0s - ttt_chunk [881/1893] bpb=1.125716 time=222.3s - ttt_chunk [891/1893] bpb=1.125269 time=224.7s - ttt_chunk [901/1893] bpb=1.124994 time=227.1s - ttt_chunk [911/1893] bpb=1.125112 time=229.6s - ttt_chunk [921/1893] bpb=1.125599 time=232.2s - ttt_chunk [931/1893] bpb=1.125550 time=234.7s - ttt_chunk [941/1893] bpb=1.125243 time=237.2s - ttt_chunk [951/1893] bpb=1.125621 time=239.7s - ttt_chunk [961/1893] bpb=1.125706 time=242.3s - ttt_chunk [971/1893] bpb=1.126561 time=244.9s - ttt_chunk [981/1893] bpb=1.126623 time=247.4s - ttt_chunk [991/1893] bpb=1.126653 time=250.0s - ttt_chunk [1001/1893] bpb=1.126605 time=252.6s - ttt_chunk [1011/1893] bpb=1.126404 time=255.2s - ttt_chunk [1021/1893] bpb=1.126738 time=257.7s - ttt_chunk [1031/1893] bpb=1.127177 time=260.3s - ttt_chunk [1041/1893] bpb=1.126844 time=262.9s - ttt_chunk [1051/1893] bpb=1.126600 time=265.4s - ttt_chunk [1061/1893] bpb=1.126634 time=268.0s - ttt_chunk [1071/1893] bpb=1.127241 time=270.5s - ttt_chunk [1081/1893] bpb=1.127504 time=273.0s - ttt_chunk [1091/1893] bpb=1.128216 time=275.6s - ttt_chunk [1101/1893] bpb=1.128229 time=278.2s - ttt_chunk [1111/1893] bpb=1.128077 time=280.7s - ttt_chunk [1121/1893] bpb=1.127884 time=283.2s - ttt_chunk [1131/1893] bpb=1.127768 time=285.6s - ttt_chunk [1141/1893] bpb=1.127473 time=288.1s - ttt_chunk [1151/1893] bpb=1.127471 time=290.6s - ttt_chunk [1161/1893] bpb=1.127077 time=293.2s - ttt_chunk [1171/1893] bpb=1.127404 time=295.7s - ttt_chunk [1181/1893] bpb=1.126651 time=298.2s - ttt_chunk [1191/1893] bpb=1.126516 time=300.7s - ttt_chunk [1201/1893] bpb=1.126927 time=303.3s - ttt_chunk [1211/1893] bpb=1.126439 time=305.7s - ttt_chunk [1221/1893] bpb=1.126135 time=308.3s - ttt_chunk [1231/1893] bpb=1.125853 time=310.8s - ttt_chunk [1241/1893] bpb=1.125502 time=313.3s - ttt_chunk [1251/1893] bpb=1.124889 time=315.6s - ttt_chunk [1261/1893] bpb=1.124865 time=318.1s - ttt_chunk [1271/1893] bpb=1.124492 time=320.6s - ttt_chunk [1281/1893] bpb=1.124279 time=323.1s - ttt_chunk [1291/1893] bpb=1.124047 time=325.6s - ttt_chunk [1301/1893] bpb=1.123447 time=328.2s - ttt_chunk [1311/1893] bpb=1.123041 time=330.6s - ttt_chunk [1321/1893] bpb=1.122713 time=333.1s - ttt_chunk [1331/1893] bpb=1.122643 time=335.7s - ttt_chunk [1341/1893] bpb=1.122511 time=338.2s - ttt_chunk [1351/1893] bpb=1.122444 time=340.7s - ttt_chunk [1361/1893] bpb=1.122503 time=343.3s - ttt_chunk [1371/1893] bpb=1.122372 time=345.8s - ttt_chunk [1381/1893] bpb=1.122364 time=348.3s - ttt_chunk [1391/1893] bpb=1.121966 time=350.7s - ttt_chunk [1401/1893] bpb=1.121917 time=353.2s - ttt_chunk [1411/1893] bpb=1.122037 time=355.8s - ttt_chunk [1421/1893] bpb=1.122287 time=358.2s - ttt_chunk [1431/1893] bpb=1.121974 time=360.7s - ttt_chunk [1441/1893] bpb=1.122479 time=363.4s - ttt_chunk [1451/1893] bpb=1.122817 time=365.8s - ttt_chunk [1461/1893] bpb=1.122370 time=368.3s - ttt_chunk [1471/1893] bpb=1.123396 time=370.9s - ttt_chunk [1481/1893] bpb=1.122929 time=373.4s - ttt_chunk [1491/1893] bpb=1.122764 time=375.9s - ttt_chunk [1501/1893] bpb=1.122665 time=378.4s - ttt_chunk [1511/1893] bpb=1.122663 time=380.9s - ttt_chunk [1521/1893] bpb=1.122690 time=383.4s - ttt_chunk [1531/1893] bpb=1.122164 time=386.0s - ttt_chunk [1541/1893] bpb=1.122022 time=388.5s - ttt_chunk [1551/1893] bpb=1.122325 time=391.0s - ttt_chunk [1561/1893] bpb=1.122322 time=393.6s - ttt_chunk [1571/1893] bpb=1.122156 time=396.1s - ttt_chunk [1581/1893] bpb=1.122251 time=398.6s - ttt_chunk [1591/1893] bpb=1.122103 time=401.2s - ttt_chunk [1601/1893] bpb=1.122263 time=403.7s - ttt_chunk [1611/1893] bpb=1.122190 time=406.3s - ttt_chunk [1621/1893] bpb=1.121789 time=408.6s - ttt_chunk [1631/1893] bpb=1.122095 time=411.2s - ttt_chunk [1641/1893] bpb=1.122104 time=413.7s - ttt_chunk [1651/1893] bpb=1.122062 time=416.2s - ttt_chunk [1661/1893] bpb=1.121936 time=418.8s - ttt_chunk [1671/1893] bpb=1.122392 time=421.3s - ttt_chunk [1681/1893] bpb=1.122541 time=423.9s - ttt_chunk [1691/1893] bpb=1.122379 time=426.4s - ttt_chunk [1701/1893] bpb=1.122539 time=429.0s - ttt_chunk [1711/1893] bpb=1.122538 time=431.6s - ttt_chunk [1721/1893] bpb=1.122549 time=434.0s - ttt_chunk [1731/1893] bpb=1.122438 time=436.6s - ttt_chunk [1741/1893] bpb=1.122231 time=439.1s - ttt_chunk [1751/1893] bpb=1.122050 time=441.7s - ttt_chunk [1761/1893] bpb=1.122201 time=444.2s - ttt_chunk [1771/1893] bpb=1.122112 time=446.7s - ttt_chunk [1781/1893] bpb=1.122127 time=449.3s - ttt_chunk [1791/1893] bpb=1.121717 time=451.7s - ttt_chunk [1801/1893] bpb=1.121588 time=454.2s - ttt_chunk [1811/1893] bpb=1.121473 time=456.8s - ttt_chunk [1821/1893] bpb=1.121540 time=459.4s - ttt_chunk [1831/1893] bpb=1.120933 time=461.9s - ttt_chunk [1841/1893] bpb=1.120892 time=464.4s - ttt_chunk [1851/1893] bpb=1.120680 time=466.9s - ttt_chunk [1861/1893] bpb=1.120309 time=469.4s - ttt_chunk [1871/1893] bpb=1.120302 time=472.0s - ttt_chunk [1881/1893] bpb=1.119857 time=474.4s - ttt_chunk [1891/1893] bpb=1.119623 time=477.0s - ttt_chunk [1893/1893] bpb=1.119667 time=477.3s -ttt_sliding:done val_loss=1.887101 val_bpb=1.117650 elapsed=477.3s -legal_ttt val_loss:1.8871 val_bpb:1.1177 eval_time:477854ms -legal_ttt_exact val_loss:1.88710072 val_bpb:1.11765030 diff --git a/records/track_10min_16mb/2026-03-28_MuonTTT_EntropyAdaptive_11L_8xH100/train_seed2025.log b/records/track_10min_16mb/2026-03-28_MuonTTT_EntropyAdaptive_11L_8xH100/train_seed2025.log deleted file mode 100644 index 26b3b1f62e..0000000000 --- a/records/track_10min_16mb/2026-03-28_MuonTTT_EntropyAdaptive_11L_8xH100/train_seed2025.log +++ /dev/null @@ -1,275 +0,0 @@ -W0328 01:35:38.316000 93345 torch/distributed/run.py:803] -W0328 01:35:38.316000 93345 torch/distributed/run.py:803] ***************************************** -W0328 01:35:38.316000 93345 torch/distributed/run.py:803] Setting OMP_NUM_THREADS environment variable for each process to be 1 in default, to avoid your system being overloaded, please further tune the variable for optimal performance in your application as needed. -W0328 01:35:38.316000 93345 torch/distributed/run.py:803] ***************************************** -logs/51ad2124-736f-4927-8d11-83429489d7ce.txt -val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path=./data/tokenizers/fineweb_1024_bpe.model -train_loader:dataset:fineweb10B_sp1024 train_shards:80 -val_loader:shards pattern=./data/datasets/fineweb10B_sp1024/fineweb_val_*.bin tokens:62021632 -model_params:26928220 -mtp_num_heads:0 mtp_loss_weight:0.2 mtp_params:0 -XSA:last_4 active_layers:[7, 8, 9, 10] -world_size:8 grad_accum_steps:1 -sdp_backends:cudnn=False flash=True mem_efficient=False math=False -attention_mode:gqa num_heads:8 num_kv_heads:4 -tie_embeddings:True embed_lr:0.035 head_lr:0.0 matrix_lr:0.025 scalar_lr:0.025 -train_batch_tokens:786432 train_seq_len:2048 iterations:9000 warmup_steps:20 max_wallclock_seconds:599.000 -seed:2025 -warmup_step:1/20 -warmup_step:2/20 -warmup_step:3/20 -warmup_step:4/20 -warmup_step:5/20 -warmup_step:6/20 -warmup_step:7/20 -warmup_step:8/20 -warmup_step:9/20 -warmup_step:10/20 -warmup_step:11/20 -warmup_step:12/20 -warmup_step:13/20 -warmup_step:14/20 -warmup_step:15/20 -warmup_step:16/20 -warmup_step:17/20 -warmup_step:18/20 -warmup_step:19/20 -warmup_step:20/20 -step:0/9000 val_loss:6.9302 val_bpb:4.1045 train_time:0ms step_avg:0.02ms -step:1/9000 train_loss:6.9311 train_time:132ms step_avg:131.62ms -step:2/9000 train_loss:8.6818 train_time:160ms step_avg:79.78ms -step:3/9000 train_loss:7.7058 train_time:240ms step_avg:80.14ms -step:4/9000 train_loss:7.2716 train_time:321ms step_avg:80.29ms -step:5/9000 train_loss:7.1779 train_time:404ms step_avg:80.78ms -step:6/9000 train_loss:7.0950 train_time:491ms step_avg:81.91ms -step:7/9000 train_loss:7.0228 train_time:572ms step_avg:81.69ms -step:8/9000 train_loss:6.9406 train_time:653ms step_avg:81.64ms -step:9/9000 train_loss:6.6071 train_time:734ms step_avg:81.52ms -step:10/9000 train_loss:6.2037 train_time:815ms step_avg:81.49ms -step:500/9000 train_loss:2.3991 train_time:41443ms step_avg:82.89ms -step:1000/9000 train_loss:2.2642 train_time:83160ms step_avg:83.16ms -step:1500/9000 train_loss:2.2121 train_time:124925ms step_avg:83.28ms -step:2000/9000 train_loss:2.0516 train_time:166699ms step_avg:83.35ms -step:2500/9000 train_loss:2.1578 train_time:208449ms step_avg:83.38ms -step:3000/9000 train_loss:2.1498 train_time:250264ms step_avg:83.42ms -step:3500/9000 train_loss:2.1680 train_time:291976ms step_avg:83.42ms -step:4000/9000 train_loss:1.9631 train_time:333663ms step_avg:83.42ms -step:4000/9000 val_loss:2.0570 val_bpb:1.2183 train_time:333712ms step_avg:83.43ms -step:4500/9000 train_loss:2.1176 train_time:375327ms step_avg:83.41ms -step:5000/9000 train_loss:2.0992 train_time:416989ms step_avg:83.40ms -step:5500/9000 train_loss:2.0111 train_time:458636ms step_avg:83.39ms -step:6000/9000 train_loss:1.9384 train_time:500279ms step_avg:83.38ms -swa:start step:6500 -step:6500/9000 train_loss:2.0815 train_time:541933ms step_avg:83.37ms -late_qat:enabled step:6656 scale:0.1498 -step:7000/9000 train_loss:1.7874 train_time:584240ms step_avg:83.46ms -step:7175/9000 val_loss:1.9210 val_bpb:1.1377 train_time:599068ms step_avg:83.49ms -stopping_early: wallclock_cap train_time:599068ms step:7175/9000 -peak memory allocated: 21471 MiB reserved: 22002 MiB -ema:applying EMA weights -DIAGNOSTIC post_ema val_loss:1.9193 val_bpb:1.1367 eval_time:1977ms -Serialized model: 106027446 bytes -Code size: 93038 bytes -Serialized model int6+lzma: 15786004 bytes -Total submission size int6+lzma: 15879042 bytes -final_int6_roundtrip val_loss:1.9333 val_bpb:1.1450 eval_time:6327ms -final_int6_roundtrip_exact val_loss:1.93333252 val_bpb:1.14502842 -final_int6_sliding_window val_loss:1.8939 val_bpb:1.1217 stride:64 eval_time:74606ms -final_int6_sliding_window_exact val_loss:1.89387417 val_bpb:1.12166193 -final_int8_zlib_roundtrip_exact val_loss:1.89387417 val_bpb:1.12166193 -ttt_sliding:start chunks=1893 chunk_tokens=32768 total_windows=969088 stride=64 ttt_lr=0.002 ttt_epochs=3 freeze_blocks=0 -ttt_sliding:params unfrozen=26928220 frozen=0 - ttt_chunk [1/1893] bpb=1.157213 time=0.5s - ttt_chunk [11/1893] bpb=1.146150 time=3.0s - ttt_chunk [21/1893] bpb=1.131983 time=5.5s - ttt_chunk [31/1893] bpb=1.129800 time=8.0s - ttt_chunk [41/1893] bpb=1.115753 time=10.6s - ttt_chunk [51/1893] bpb=1.109750 time=13.0s - ttt_chunk [61/1893] bpb=1.115753 time=15.6s - ttt_chunk [71/1893] bpb=1.114495 time=18.1s - ttt_chunk [81/1893] bpb=1.113640 time=20.6s - ttt_chunk [91/1893] bpb=1.114745 time=23.0s - ttt_chunk [101/1893] bpb=1.118217 time=25.6s - ttt_chunk [111/1893] bpb=1.120613 time=28.1s - ttt_chunk [121/1893] bpb=1.113957 time=30.5s - ttt_chunk [131/1893] bpb=1.114078 time=33.0s - ttt_chunk [141/1893] bpb=1.119665 time=35.6s - ttt_chunk [151/1893] bpb=1.121456 time=38.1s - ttt_chunk [161/1893] bpb=1.120832 time=40.6s - ttt_chunk [171/1893] bpb=1.125423 time=43.2s - ttt_chunk [181/1893] bpb=1.127674 time=45.7s - ttt_chunk [191/1893] bpb=1.135139 time=48.2s - ttt_chunk [201/1893] bpb=1.133873 time=50.8s - ttt_chunk [211/1893] bpb=1.131733 time=53.3s - ttt_chunk [221/1893] bpb=1.133254 time=55.8s - ttt_chunk [231/1893] bpb=1.131978 time=58.3s - ttt_chunk [241/1893] bpb=1.132435 time=60.9s - ttt_chunk [251/1893] bpb=1.131925 time=63.4s - ttt_chunk [261/1893] bpb=1.129055 time=65.8s - ttt_chunk [271/1893] bpb=1.127913 time=68.3s - ttt_chunk [281/1893] bpb=1.129263 time=70.8s - ttt_chunk [291/1893] bpb=1.131080 time=73.3s - ttt_chunk [301/1893] bpb=1.131819 time=75.9s - ttt_chunk [311/1893] bpb=1.133852 time=78.4s - ttt_chunk [321/1893] bpb=1.135848 time=80.9s - ttt_chunk [331/1893] bpb=1.135752 time=83.4s - ttt_chunk [341/1893] bpb=1.134782 time=86.0s - ttt_chunk [351/1893] bpb=1.137073 time=88.5s - ttt_chunk [361/1893] bpb=1.137237 time=91.0s - ttt_chunk [371/1893] bpb=1.136629 time=93.5s - ttt_chunk [381/1893] bpb=1.136776 time=96.1s - ttt_chunk [391/1893] bpb=1.136570 time=98.7s - ttt_chunk [401/1893] bpb=1.134505 time=101.2s - ttt_chunk [411/1893] bpb=1.133340 time=103.7s - ttt_chunk [421/1893] bpb=1.132381 time=106.2s - ttt_chunk [431/1893] bpb=1.132265 time=108.8s - ttt_chunk [441/1893] bpb=1.132715 time=111.4s - ttt_chunk [451/1893] bpb=1.133043 time=113.9s - ttt_chunk [461/1893] bpb=1.131903 time=116.5s - ttt_chunk [471/1893] bpb=1.132493 time=119.1s - ttt_chunk [481/1893] bpb=1.132134 time=121.6s - ttt_chunk [491/1893] bpb=1.131040 time=124.2s - ttt_chunk [501/1893] bpb=1.130580 time=126.8s - ttt_chunk [511/1893] bpb=1.129868 time=129.3s - ttt_chunk [521/1893] bpb=1.127426 time=131.9s - ttt_chunk [531/1893] bpb=1.128584 time=134.6s - ttt_chunk [541/1893] bpb=1.128870 time=137.1s - ttt_chunk [551/1893] bpb=1.127882 time=139.7s - ttt_chunk [561/1893] bpb=1.128434 time=142.3s - ttt_chunk [571/1893] bpb=1.127393 time=144.8s - ttt_chunk [581/1893] bpb=1.126593 time=147.2s - ttt_chunk [591/1893] bpb=1.125926 time=149.7s - ttt_chunk [601/1893] bpb=1.126401 time=152.2s - ttt_chunk [611/1893] bpb=1.126301 time=154.7s - ttt_chunk [621/1893] bpb=1.126164 time=157.3s - ttt_chunk [631/1893] bpb=1.126878 time=159.8s - ttt_chunk [641/1893] bpb=1.126621 time=162.3s - ttt_chunk [651/1893] bpb=1.126749 time=164.8s - ttt_chunk [661/1893] bpb=1.126248 time=167.2s - ttt_chunk [671/1893] bpb=1.126593 time=169.7s - ttt_chunk [681/1893] bpb=1.127317 time=172.2s - ttt_chunk [691/1893] bpb=1.128289 time=174.8s - ttt_chunk [701/1893] bpb=1.127720 time=177.3s - ttt_chunk [711/1893] bpb=1.127718 time=179.8s - ttt_chunk [721/1893] bpb=1.127401 time=182.3s - ttt_chunk [731/1893] bpb=1.127458 time=184.8s - ttt_chunk [741/1893] bpb=1.127582 time=187.3s - ttt_chunk [751/1893] bpb=1.127447 time=189.8s - ttt_chunk [761/1893] bpb=1.127362 time=192.4s - ttt_chunk [771/1893] bpb=1.127045 time=195.0s - ttt_chunk [781/1893] bpb=1.127768 time=197.5s - ttt_chunk [791/1893] bpb=1.127357 time=200.0s - ttt_chunk [801/1893] bpb=1.127659 time=202.6s - ttt_chunk [811/1893] bpb=1.127398 time=205.1s - ttt_chunk [821/1893] bpb=1.127153 time=207.7s - ttt_chunk [831/1893] bpb=1.126984 time=210.2s - ttt_chunk [841/1893] bpb=1.126338 time=212.7s - ttt_chunk [851/1893] bpb=1.126043 time=215.2s - ttt_chunk [861/1893] bpb=1.125803 time=217.7s - ttt_chunk [871/1893] bpb=1.126077 time=220.3s - ttt_chunk [881/1893] bpb=1.126249 time=222.6s - ttt_chunk [891/1893] bpb=1.125808 time=225.0s - ttt_chunk [901/1893] bpb=1.125530 time=227.5s - ttt_chunk [911/1893] bpb=1.125640 time=230.0s - ttt_chunk [921/1893] bpb=1.126086 time=232.5s - ttt_chunk [931/1893] bpb=1.126042 time=235.1s - ttt_chunk [941/1893] bpb=1.125707 time=237.6s - ttt_chunk [951/1893] bpb=1.126083 time=240.1s - ttt_chunk [961/1893] bpb=1.126157 time=242.7s - ttt_chunk [971/1893] bpb=1.127000 time=245.3s - ttt_chunk [981/1893] bpb=1.127070 time=247.8s - ttt_chunk [991/1893] bpb=1.127091 time=250.4s - ttt_chunk [1001/1893] bpb=1.127044 time=252.9s - ttt_chunk [1011/1893] bpb=1.126815 time=255.5s - ttt_chunk [1021/1893] bpb=1.127150 time=258.1s - ttt_chunk [1031/1893] bpb=1.127600 time=260.7s - ttt_chunk [1041/1893] bpb=1.127254 time=263.3s - ttt_chunk [1051/1893] bpb=1.127004 time=265.8s - ttt_chunk [1061/1893] bpb=1.127072 time=268.4s - ttt_chunk [1071/1893] bpb=1.127706 time=270.9s - ttt_chunk [1081/1893] bpb=1.127954 time=273.5s - ttt_chunk [1091/1893] bpb=1.128696 time=276.1s - ttt_chunk [1101/1893] bpb=1.128701 time=278.6s - ttt_chunk [1111/1893] bpb=1.128574 time=281.1s - ttt_chunk [1121/1893] bpb=1.128361 time=283.7s - ttt_chunk [1131/1893] bpb=1.128230 time=286.1s - ttt_chunk [1141/1893] bpb=1.127928 time=288.6s - ttt_chunk [1151/1893] bpb=1.127948 time=291.2s - ttt_chunk [1161/1893] bpb=1.127576 time=293.7s - ttt_chunk [1171/1893] bpb=1.127908 time=296.2s - ttt_chunk [1181/1893] bpb=1.127175 time=298.7s - ttt_chunk [1191/1893] bpb=1.127049 time=301.3s - ttt_chunk [1201/1893] bpb=1.127446 time=303.9s - ttt_chunk [1211/1893] bpb=1.126968 time=306.3s - ttt_chunk [1221/1893] bpb=1.126658 time=308.8s - ttt_chunk [1231/1893] bpb=1.126389 time=311.4s - ttt_chunk [1241/1893] bpb=1.126011 time=313.9s - ttt_chunk [1251/1893] bpb=1.125410 time=316.3s - ttt_chunk [1261/1893] bpb=1.125383 time=318.8s - ttt_chunk [1271/1893] bpb=1.124974 time=321.3s - ttt_chunk [1281/1893] bpb=1.124793 time=323.9s - ttt_chunk [1291/1893] bpb=1.124558 time=326.4s - ttt_chunk [1301/1893] bpb=1.123981 time=329.0s - ttt_chunk [1311/1893] bpb=1.123591 time=331.5s - ttt_chunk [1321/1893] bpb=1.123236 time=334.0s - ttt_chunk [1331/1893] bpb=1.123155 time=336.5s - ttt_chunk [1341/1893] bpb=1.123018 time=339.1s - ttt_chunk [1351/1893] bpb=1.122951 time=341.6s - ttt_chunk [1361/1893] bpb=1.122989 time=344.1s - ttt_chunk [1371/1893] bpb=1.122853 time=346.7s - ttt_chunk [1381/1893] bpb=1.122825 time=349.2s - ttt_chunk [1391/1893] bpb=1.122414 time=351.7s - ttt_chunk [1401/1893] bpb=1.122356 time=354.2s - ttt_chunk [1411/1893] bpb=1.122453 time=356.7s - ttt_chunk [1421/1893] bpb=1.122688 time=359.2s - ttt_chunk [1431/1893] bpb=1.122386 time=361.7s - ttt_chunk [1441/1893] bpb=1.122874 time=364.4s - ttt_chunk [1451/1893] bpb=1.123198 time=366.9s - ttt_chunk [1461/1893] bpb=1.122727 time=369.4s - ttt_chunk [1471/1893] bpb=1.123774 time=372.0s - ttt_chunk [1481/1893] bpb=1.123313 time=374.6s - ttt_chunk [1491/1893] bpb=1.123123 time=377.1s - ttt_chunk [1501/1893] bpb=1.123038 time=379.6s - ttt_chunk [1511/1893] bpb=1.123030 time=382.1s - ttt_chunk [1521/1893] bpb=1.123039 time=384.7s - ttt_chunk [1531/1893] bpb=1.122531 time=387.2s - ttt_chunk [1541/1893] bpb=1.122379 time=389.7s - ttt_chunk [1551/1893] bpb=1.122681 time=392.3s - ttt_chunk [1561/1893] bpb=1.122675 time=394.8s - ttt_chunk [1571/1893] bpb=1.122502 time=397.4s - ttt_chunk [1581/1893] bpb=1.122613 time=399.9s - ttt_chunk [1591/1893] bpb=1.122467 time=402.5s - ttt_chunk [1601/1893] bpb=1.122628 time=405.0s - ttt_chunk [1611/1893] bpb=1.122552 time=407.6s - ttt_chunk [1621/1893] bpb=1.122132 time=410.0s - ttt_chunk [1631/1893] bpb=1.122421 time=412.5s - ttt_chunk [1641/1893] bpb=1.122421 time=415.0s - ttt_chunk [1651/1893] bpb=1.122362 time=417.6s - ttt_chunk [1661/1893] bpb=1.122235 time=420.1s - ttt_chunk [1671/1893] bpb=1.122698 time=422.7s - ttt_chunk [1681/1893] bpb=1.122854 time=425.2s - ttt_chunk [1691/1893] bpb=1.122683 time=427.8s - ttt_chunk [1701/1893] bpb=1.122819 time=430.4s - ttt_chunk [1711/1893] bpb=1.122802 time=432.9s - ttt_chunk [1721/1893] bpb=1.122804 time=435.3s - ttt_chunk [1731/1893] bpb=1.122671 time=437.9s - ttt_chunk [1741/1893] bpb=1.122453 time=440.5s - ttt_chunk [1751/1893] bpb=1.122274 time=443.0s - ttt_chunk [1761/1893] bpb=1.122422 time=445.6s - ttt_chunk [1771/1893] bpb=1.122321 time=448.1s - ttt_chunk [1781/1893] bpb=1.122341 time=450.7s - ttt_chunk [1791/1893] bpb=1.121937 time=453.1s - ttt_chunk [1801/1893] bpb=1.121812 time=455.7s - ttt_chunk [1811/1893] bpb=1.121711 time=458.2s - ttt_chunk [1821/1893] bpb=1.121754 time=460.8s - ttt_chunk [1831/1893] bpb=1.121146 time=463.4s - ttt_chunk [1841/1893] bpb=1.121066 time=465.9s - ttt_chunk [1851/1893] bpb=1.120842 time=468.3s - ttt_chunk [1861/1893] bpb=1.120475 time=470.9s - ttt_chunk [1871/1893] bpb=1.120468 time=473.4s - ttt_chunk [1881/1893] bpb=1.120013 time=475.8s - ttt_chunk [1891/1893] bpb=1.119775 time=478.4s - ttt_chunk [1893/1893] bpb=1.119819 time=478.7s -ttt_sliding:done val_loss=1.887521 val_bpb=1.117899 elapsed=478.7s -legal_ttt val_loss:1.8875 val_bpb:1.1179 eval_time:479224ms -legal_ttt_exact val_loss:1.88752121 val_bpb:1.11789934 diff --git a/records/track_10min_16mb/2026-03-28_MuonTTT_EntropyAdaptive_11L_8xH100/train_seed42.log b/records/track_10min_16mb/2026-03-28_MuonTTT_EntropyAdaptive_11L_8xH100/train_seed42.log deleted file mode 100644 index b561a8866f..0000000000 --- a/records/track_10min_16mb/2026-03-28_MuonTTT_EntropyAdaptive_11L_8xH100/train_seed42.log +++ /dev/null @@ -1,275 +0,0 @@ -W0328 01:58:00.242000 94313 torch/distributed/run.py:803] -W0328 01:58:00.242000 94313 torch/distributed/run.py:803] ***************************************** -W0328 01:58:00.242000 94313 torch/distributed/run.py:803] Setting OMP_NUM_THREADS environment variable for each process to be 1 in default, to avoid your system being overloaded, please further tune the variable for optimal performance in your application as needed. -W0328 01:58:00.242000 94313 torch/distributed/run.py:803] ***************************************** -logs/58a783cd-c3a9-42ba-8a34-26ee87798f69.txt -val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path=./data/tokenizers/fineweb_1024_bpe.model -train_loader:dataset:fineweb10B_sp1024 train_shards:80 -val_loader:shards pattern=./data/datasets/fineweb10B_sp1024/fineweb_val_*.bin tokens:62021632 -model_params:26928220 -mtp_num_heads:0 mtp_loss_weight:0.2 mtp_params:0 -XSA:last_4 active_layers:[7, 8, 9, 10] -world_size:8 grad_accum_steps:1 -sdp_backends:cudnn=False flash=True mem_efficient=False math=False -attention_mode:gqa num_heads:8 num_kv_heads:4 -tie_embeddings:True embed_lr:0.035 head_lr:0.0 matrix_lr:0.025 scalar_lr:0.025 -train_batch_tokens:786432 train_seq_len:2048 iterations:9000 warmup_steps:20 max_wallclock_seconds:599.000 -seed:42 -warmup_step:1/20 -warmup_step:2/20 -warmup_step:3/20 -warmup_step:4/20 -warmup_step:5/20 -warmup_step:6/20 -warmup_step:7/20 -warmup_step:8/20 -warmup_step:9/20 -warmup_step:10/20 -warmup_step:11/20 -warmup_step:12/20 -warmup_step:13/20 -warmup_step:14/20 -warmup_step:15/20 -warmup_step:16/20 -warmup_step:17/20 -warmup_step:18/20 -warmup_step:19/20 -warmup_step:20/20 -step:0/9000 val_loss:6.9293 val_bpb:4.1039 train_time:0ms step_avg:0.01ms -step:1/9000 train_loss:6.9308 train_time:131ms step_avg:130.99ms -step:2/9000 train_loss:8.6422 train_time:160ms step_avg:79.92ms -step:3/9000 train_loss:7.6901 train_time:240ms step_avg:80.04ms -step:4/9000 train_loss:7.2782 train_time:322ms step_avg:80.54ms -step:5/9000 train_loss:7.2222 train_time:403ms step_avg:80.58ms -step:6/9000 train_loss:7.1409 train_time:485ms step_avg:80.86ms -step:7/9000 train_loss:7.0923 train_time:568ms step_avg:81.07ms -step:8/9000 train_loss:7.0291 train_time:650ms step_avg:81.28ms -step:9/9000 train_loss:6.6343 train_time:731ms step_avg:81.26ms -step:10/9000 train_loss:6.2569 train_time:815ms step_avg:81.48ms -step:500/9000 train_loss:2.3952 train_time:41422ms step_avg:82.84ms -step:1000/9000 train_loss:2.2574 train_time:83206ms step_avg:83.21ms -step:1500/9000 train_loss:2.2105 train_time:124956ms step_avg:83.30ms -step:2000/9000 train_loss:2.0560 train_time:166685ms step_avg:83.34ms -step:2500/9000 train_loss:2.1603 train_time:208428ms step_avg:83.37ms -step:3000/9000 train_loss:2.1494 train_time:250139ms step_avg:83.38ms -step:3500/9000 train_loss:2.1733 train_time:291834ms step_avg:83.38ms -step:4000/9000 train_loss:1.9682 train_time:333534ms step_avg:83.38ms -step:4000/9000 val_loss:2.0572 val_bpb:1.2184 train_time:333585ms step_avg:83.40ms -step:4500/9000 train_loss:2.1198 train_time:375205ms step_avg:83.38ms -step:5000/9000 train_loss:2.1010 train_time:416849ms step_avg:83.37ms -step:5500/9000 train_loss:2.0140 train_time:458486ms step_avg:83.36ms -step:6000/9000 train_loss:1.9381 train_time:500134ms step_avg:83.36ms -swa:start step:6500 -step:6500/9000 train_loss:2.0817 train_time:541764ms step_avg:83.35ms -late_qat:enabled step:6658 scale:0.1500 -step:7000/9000 train_loss:1.7891 train_time:584037ms step_avg:83.43ms -step:7177/9000 val_loss:1.9217 val_bpb:1.1381 train_time:599058ms step_avg:83.47ms -stopping_early: wallclock_cap train_time:599058ms step:7177/9000 -peak memory allocated: 21471 MiB reserved: 22002 MiB -ema:applying EMA weights -DIAGNOSTIC post_ema val_loss:1.9199 val_bpb:1.1371 eval_time:1977ms -Serialized model: 106027446 bytes -Code size: 93038 bytes -Serialized model int6+lzma: 15780788 bytes -Total submission size int6+lzma: 15873826 bytes -final_int6_roundtrip val_loss:1.9337 val_bpb:1.1452 eval_time:6271ms -final_int6_roundtrip_exact val_loss:1.93365776 val_bpb:1.14522104 -final_int6_sliding_window val_loss:1.8939 val_bpb:1.1217 stride:64 eval_time:74583ms -final_int6_sliding_window_exact val_loss:1.89387243 val_bpb:1.12166090 -final_int8_zlib_roundtrip_exact val_loss:1.89387243 val_bpb:1.12166090 -ttt_sliding:start chunks=1893 chunk_tokens=32768 total_windows=969088 stride=64 ttt_lr=0.002 ttt_epochs=3 freeze_blocks=0 -ttt_sliding:params unfrozen=26928220 frozen=0 - ttt_chunk [1/1893] bpb=1.163564 time=0.5s - ttt_chunk [11/1893] bpb=1.147559 time=3.0s - ttt_chunk [21/1893] bpb=1.131890 time=5.5s - ttt_chunk [31/1893] bpb=1.129424 time=8.2s - ttt_chunk [41/1893] bpb=1.115715 time=10.7s - ttt_chunk [51/1893] bpb=1.110300 time=13.2s - ttt_chunk [61/1893] bpb=1.116705 time=15.8s - ttt_chunk [71/1893] bpb=1.115147 time=18.4s - ttt_chunk [81/1893] bpb=1.114224 time=20.8s - ttt_chunk [91/1893] bpb=1.115324 time=23.3s - ttt_chunk [101/1893] bpb=1.118834 time=25.9s - ttt_chunk [111/1893] bpb=1.121161 time=28.4s - ttt_chunk [121/1893] bpb=1.114292 time=30.9s - ttt_chunk [131/1893] bpb=1.114635 time=33.4s - ttt_chunk [141/1893] bpb=1.120073 time=36.0s - ttt_chunk [151/1893] bpb=1.121918 time=38.6s - ttt_chunk [161/1893] bpb=1.121185 time=41.2s - ttt_chunk [171/1893] bpb=1.125485 time=43.8s - ttt_chunk [181/1893] bpb=1.127617 time=46.3s - ttt_chunk [191/1893] bpb=1.134961 time=48.9s - ttt_chunk [201/1893] bpb=1.133663 time=51.5s - ttt_chunk [211/1893] bpb=1.131524 time=54.1s - ttt_chunk [221/1893] bpb=1.132942 time=56.6s - ttt_chunk [231/1893] bpb=1.131655 time=59.1s - ttt_chunk [241/1893] bpb=1.132049 time=61.7s - ttt_chunk [251/1893] bpb=1.131664 time=64.3s - ttt_chunk [261/1893] bpb=1.128766 time=66.8s - ttt_chunk [271/1893] bpb=1.127561 time=69.2s - ttt_chunk [281/1893] bpb=1.128924 time=71.8s - ttt_chunk [291/1893] bpb=1.130687 time=74.4s - ttt_chunk [301/1893] bpb=1.131425 time=76.9s - ttt_chunk [311/1893] bpb=1.133488 time=79.5s - ttt_chunk [321/1893] bpb=1.135443 time=82.0s - ttt_chunk [331/1893] bpb=1.135227 time=84.6s - ttt_chunk [341/1893] bpb=1.134219 time=87.2s - ttt_chunk [351/1893] bpb=1.136494 time=89.7s - ttt_chunk [361/1893] bpb=1.136633 time=92.2s - ttt_chunk [371/1893] bpb=1.135976 time=94.8s - ttt_chunk [381/1893] bpb=1.136160 time=97.4s - ttt_chunk [391/1893] bpb=1.135987 time=100.0s - ttt_chunk [401/1893] bpb=1.133966 time=102.6s - ttt_chunk [411/1893] bpb=1.132784 time=105.1s - ttt_chunk [421/1893] bpb=1.131921 time=107.6s - ttt_chunk [431/1893] bpb=1.131814 time=110.3s - ttt_chunk [441/1893] bpb=1.132189 time=112.9s - ttt_chunk [451/1893] bpb=1.132478 time=115.5s - ttt_chunk [461/1893] bpb=1.131385 time=118.1s - ttt_chunk [471/1893] bpb=1.132036 time=120.7s - ttt_chunk [481/1893] bpb=1.131638 time=123.3s - ttt_chunk [491/1893] bpb=1.130494 time=125.8s - ttt_chunk [501/1893] bpb=1.130040 time=128.4s - ttt_chunk [511/1893] bpb=1.129401 time=131.0s - ttt_chunk [521/1893] bpb=1.126977 time=133.6s - ttt_chunk [531/1893] bpb=1.128169 time=136.2s - ttt_chunk [541/1893] bpb=1.128509 time=138.8s - ttt_chunk [551/1893] bpb=1.127494 time=141.4s - ttt_chunk [561/1893] bpb=1.128034 time=144.0s - ttt_chunk [571/1893] bpb=1.127037 time=146.5s - ttt_chunk [581/1893] bpb=1.126260 time=149.0s - ttt_chunk [591/1893] bpb=1.125674 time=151.5s - ttt_chunk [601/1893] bpb=1.126156 time=154.1s - ttt_chunk [611/1893] bpb=1.126091 time=156.6s - ttt_chunk [621/1893] bpb=1.125943 time=159.2s - ttt_chunk [631/1893] bpb=1.126689 time=161.7s - ttt_chunk [641/1893] bpb=1.126468 time=164.2s - ttt_chunk [651/1893] bpb=1.126543 time=166.8s - ttt_chunk [661/1893] bpb=1.126027 time=169.3s - ttt_chunk [671/1893] bpb=1.126396 time=171.8s - ttt_chunk [681/1893] bpb=1.127132 time=174.4s - ttt_chunk [691/1893] bpb=1.128127 time=177.0s - ttt_chunk [701/1893] bpb=1.127586 time=179.5s - ttt_chunk [711/1893] bpb=1.127590 time=182.0s - ttt_chunk [721/1893] bpb=1.127208 time=184.5s - ttt_chunk [731/1893] bpb=1.127258 time=187.1s - ttt_chunk [741/1893] bpb=1.127373 time=189.6s - ttt_chunk [751/1893] bpb=1.127246 time=192.1s - ttt_chunk [761/1893] bpb=1.127144 time=194.7s - ttt_chunk [771/1893] bpb=1.126834 time=197.3s - ttt_chunk [781/1893] bpb=1.127591 time=199.9s - ttt_chunk [791/1893] bpb=1.127190 time=202.4s - ttt_chunk [801/1893] bpb=1.127493 time=205.0s - ttt_chunk [811/1893] bpb=1.127250 time=207.6s - ttt_chunk [821/1893] bpb=1.127009 time=210.2s - ttt_chunk [831/1893] bpb=1.126855 time=212.7s - ttt_chunk [841/1893] bpb=1.126217 time=215.3s - ttt_chunk [851/1893] bpb=1.125976 time=217.8s - ttt_chunk [861/1893] bpb=1.125732 time=220.3s - ttt_chunk [871/1893] bpb=1.126004 time=222.9s - ttt_chunk [881/1893] bpb=1.126166 time=225.3s - ttt_chunk [891/1893] bpb=1.125728 time=227.8s - ttt_chunk [901/1893] bpb=1.125443 time=230.2s - ttt_chunk [911/1893] bpb=1.125556 time=232.8s - ttt_chunk [921/1893] bpb=1.126022 time=235.3s - ttt_chunk [931/1893] bpb=1.125947 time=237.9s - ttt_chunk [941/1893] bpb=1.125640 time=240.5s - ttt_chunk [951/1893] bpb=1.126033 time=243.1s - ttt_chunk [961/1893] bpb=1.126097 time=245.7s - ttt_chunk [971/1893] bpb=1.126940 time=248.3s - ttt_chunk [981/1893] bpb=1.127032 time=250.8s - ttt_chunk [991/1893] bpb=1.127043 time=253.4s - ttt_chunk [1001/1893] bpb=1.126998 time=256.0s - ttt_chunk [1011/1893] bpb=1.126757 time=258.6s - ttt_chunk [1021/1893] bpb=1.127099 time=261.2s - ttt_chunk [1031/1893] bpb=1.127547 time=263.8s - ttt_chunk [1041/1893] bpb=1.127210 time=266.5s - ttt_chunk [1051/1893] bpb=1.126970 time=269.1s - ttt_chunk [1061/1893] bpb=1.127002 time=271.7s - ttt_chunk [1071/1893] bpb=1.127603 time=274.2s - ttt_chunk [1081/1893] bpb=1.127870 time=276.9s - ttt_chunk [1091/1893] bpb=1.128612 time=279.5s - ttt_chunk [1101/1893] bpb=1.128638 time=282.0s - ttt_chunk [1111/1893] bpb=1.128492 time=284.6s - ttt_chunk [1121/1893] bpb=1.128274 time=287.2s - ttt_chunk [1131/1893] bpb=1.128135 time=289.7s - ttt_chunk [1141/1893] bpb=1.127824 time=292.2s - ttt_chunk [1151/1893] bpb=1.127844 time=294.7s - ttt_chunk [1161/1893] bpb=1.127462 time=297.3s - ttt_chunk [1171/1893] bpb=1.127792 time=299.9s - ttt_chunk [1181/1893] bpb=1.127049 time=302.4s - ttt_chunk [1191/1893] bpb=1.126913 time=305.0s - ttt_chunk [1201/1893] bpb=1.127341 time=307.7s - ttt_chunk [1211/1893] bpb=1.126863 time=310.1s - ttt_chunk [1221/1893] bpb=1.126566 time=312.7s - ttt_chunk [1231/1893] bpb=1.126281 time=315.3s - ttt_chunk [1241/1893] bpb=1.125930 time=317.8s - ttt_chunk [1251/1893] bpb=1.125346 time=320.2s - ttt_chunk [1261/1893] bpb=1.125316 time=322.7s - ttt_chunk [1271/1893] bpb=1.124934 time=325.3s - ttt_chunk [1281/1893] bpb=1.124727 time=327.9s - ttt_chunk [1291/1893] bpb=1.124482 time=330.4s - ttt_chunk [1301/1893] bpb=1.123879 time=333.0s - ttt_chunk [1311/1893] bpb=1.123487 time=335.6s - ttt_chunk [1321/1893] bpb=1.123164 time=338.1s - ttt_chunk [1331/1893] bpb=1.123096 time=340.7s - ttt_chunk [1341/1893] bpb=1.122974 time=343.3s - ttt_chunk [1351/1893] bpb=1.122892 time=345.8s - ttt_chunk [1361/1893] bpb=1.122945 time=348.4s - ttt_chunk [1371/1893] bpb=1.122809 time=351.0s - ttt_chunk [1381/1893] bpb=1.122795 time=353.5s - ttt_chunk [1391/1893] bpb=1.122398 time=356.0s - ttt_chunk [1401/1893] bpb=1.122363 time=358.6s - ttt_chunk [1411/1893] bpb=1.122473 time=361.2s - ttt_chunk [1421/1893] bpb=1.122727 time=363.7s - ttt_chunk [1431/1893] bpb=1.122430 time=366.2s - ttt_chunk [1441/1893] bpb=1.122915 time=368.9s - ttt_chunk [1451/1893] bpb=1.123241 time=371.4s - ttt_chunk [1461/1893] bpb=1.122774 time=373.9s - ttt_chunk [1471/1893] bpb=1.123808 time=376.6s - ttt_chunk [1481/1893] bpb=1.123323 time=379.1s - ttt_chunk [1491/1893] bpb=1.123138 time=381.7s - ttt_chunk [1501/1893] bpb=1.123050 time=384.2s - ttt_chunk [1511/1893] bpb=1.123062 time=386.7s - ttt_chunk [1521/1893] bpb=1.123073 time=389.3s - ttt_chunk [1531/1893] bpb=1.122546 time=391.9s - ttt_chunk [1541/1893] bpb=1.122399 time=394.4s - ttt_chunk [1551/1893] bpb=1.122711 time=397.0s - ttt_chunk [1561/1893] bpb=1.122693 time=399.6s - ttt_chunk [1571/1893] bpb=1.122531 time=402.1s - ttt_chunk [1581/1893] bpb=1.122647 time=404.7s - ttt_chunk [1591/1893] bpb=1.122489 time=407.3s - ttt_chunk [1601/1893] bpb=1.122653 time=409.9s - ttt_chunk [1611/1893] bpb=1.122597 time=412.5s - ttt_chunk [1621/1893] bpb=1.122183 time=415.0s - ttt_chunk [1631/1893] bpb=1.122483 time=417.6s - ttt_chunk [1641/1893] bpb=1.122491 time=420.1s - ttt_chunk [1651/1893] bpb=1.122440 time=422.7s - ttt_chunk [1661/1893] bpb=1.122317 time=425.2s - ttt_chunk [1671/1893] bpb=1.122789 time=427.8s - ttt_chunk [1681/1893] bpb=1.122929 time=430.4s - ttt_chunk [1691/1893] bpb=1.122772 time=433.0s - ttt_chunk [1701/1893] bpb=1.122922 time=435.6s - ttt_chunk [1711/1893] bpb=1.122925 time=438.2s - ttt_chunk [1721/1893] bpb=1.122935 time=440.7s - ttt_chunk [1731/1893] bpb=1.122811 time=443.3s - ttt_chunk [1741/1893] bpb=1.122612 time=445.9s - ttt_chunk [1751/1893] bpb=1.122456 time=448.5s - ttt_chunk [1761/1893] bpb=1.122599 time=451.1s - ttt_chunk [1771/1893] bpb=1.122505 time=453.6s - ttt_chunk [1781/1893] bpb=1.122531 time=456.2s - ttt_chunk [1791/1893] bpb=1.122132 time=458.7s - ttt_chunk [1801/1893] bpb=1.122003 time=461.3s - ttt_chunk [1811/1893] bpb=1.121897 time=463.9s - ttt_chunk [1821/1893] bpb=1.121947 time=466.5s - ttt_chunk [1831/1893] bpb=1.121349 time=469.1s - ttt_chunk [1841/1893] bpb=1.121263 time=471.7s - ttt_chunk [1851/1893] bpb=1.121051 time=474.2s - ttt_chunk [1861/1893] bpb=1.120685 time=476.8s - ttt_chunk [1871/1893] bpb=1.120658 time=479.4s - ttt_chunk [1881/1893] bpb=1.120215 time=481.8s - ttt_chunk [1891/1893] bpb=1.119984 time=484.4s - ttt_chunk [1893/1893] bpb=1.120030 time=484.7s -ttt_sliding:done val_loss=1.887909 val_bpb=1.118129 elapsed=484.7s -legal_ttt val_loss:1.8879 val_bpb:1.1181 eval_time:485267ms -legal_ttt_exact val_loss:1.88790947 val_bpb:1.11812929