From e037ca92030c8c861eb4ec2aa2b8f230722edb73 Mon Sep 17 00:00:00 2001 From: Pavel Liashkov Date: Sat, 11 Apr 2026 19:48:01 +0700 Subject: [PATCH 1/2] =?UTF-8?q?Record:=20SP8192=20+=20Improved=20Parallel?= =?UTF-8?q?=20Residuals=20+=20Muon=200.97=20+=20LR=200.03=20+=20Legal=20TT?= =?UTF-8?q?T=20=E2=80=94=20val=5Fbpb=201.07785=20(3-seed=20mean)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 3-seed mean: 1.07785 (std 0.00047), seeds 42/314/999 All artifacts under 16MB, training under 600s, eval under 600s Improved parallel residuals (cross-lane routing), Muon 0.97, MATRIX_LR=0.03 Score-first TTT (SGD 3ep), no SLOT, no pre-quant TTT Co-Authored-By: Claude Opus 4.6 (1M context) --- .../README.md | 77 +++++++++ .../submission.json | 37 +++++ .../train_gpt.py | 2 + .../train_seed314.log | 149 ++++++++++++++++++ .../train_seed42.log | 149 ++++++++++++++++++ .../train_seed999.log | 149 ++++++++++++++++++ 6 files changed, 563 insertions(+) create mode 100644 records/track_10min_16mb/2026-04-11_SP8192_ImprovedParResid_Muon097_LR03_LegalTTT/README.md create mode 100644 records/track_10min_16mb/2026-04-11_SP8192_ImprovedParResid_Muon097_LR03_LegalTTT/submission.json create mode 100644 records/track_10min_16mb/2026-04-11_SP8192_ImprovedParResid_Muon097_LR03_LegalTTT/train_gpt.py create mode 100644 records/track_10min_16mb/2026-04-11_SP8192_ImprovedParResid_Muon097_LR03_LegalTTT/train_seed314.log create mode 100644 records/track_10min_16mb/2026-04-11_SP8192_ImprovedParResid_Muon097_LR03_LegalTTT/train_seed42.log create mode 100644 records/track_10min_16mb/2026-04-11_SP8192_ImprovedParResid_Muon097_LR03_LegalTTT/train_seed999.log diff --git a/records/track_10min_16mb/2026-04-11_SP8192_ImprovedParResid_Muon097_LR03_LegalTTT/README.md b/records/track_10min_16mb/2026-04-11_SP8192_ImprovedParResid_Muon097_LR03_LegalTTT/README.md new file mode 100644 index 0000000000..3243127bd6 --- /dev/null +++ b/records/track_10min_16mb/2026-04-11_SP8192_ImprovedParResid_Muon097_LR03_LegalTTT/README.md @@ -0,0 +1,77 @@ +# Record: SP8192 + Improved Parallel Residuals + Muon 0.97 + LR 0.03 + Legal TTT + +**val_bpb = 1.07785** (3-seed mean, std 0.00047) | **~15.99 MB** | 8xH100 SXM + +## 3-Seed Results + +| Seed | Sliding BPP | **TTT BPP** | Artifact | +|------|-------------|-------------|----------| +| 42 | 1.07880 | **1.07718** | 15,990,780 | +| 314 | 1.07959 | **1.07810** | 15,987,449 | +| 999 | 1.07963 | **1.07826** | 15,987,550 | +| **Mean** | **1.07934** | **1.07785** | **15,988,593** | +| **Std** | **0.00039** | **0.00047** | | + +Merged SOTA (PR #1493, our previous): **1.0810 BPP**. Delta: **-0.0032 BPP**. + +## Key Techniques + +1. **Improved Parallel Residuals** (from PR #1529 @msisovic) -- cross-lane routing where attention and MLP outputs route to BOTH lanes via learned scalars. 66 new scalar params (`par_post[11,2,2]` + `par_resid[11,2]`). Final output = MLP lane (lane1). Starts at layer 7. + +2. **Muon Momentum 0.97** (from PR #1514 @dexhunter) -- reduced from 0.99. Shorter memory horizon (~33 steps) better tracks the rapidly changing loss surface during warmdown. + +3. **MATRIX_LR = 0.03** -- re-tuned for momentum 0.97 (higher LR pairs with lower momentum). Sweep: 0.022 → 1.0797, 0.03 → 1.0795, 0.04 → 1.0811. + +4. **3-Layer Depth Recurrence** (L3-5, activate at frac=0.35) -- 17 virtual layers from 11 physical. + +5. **QK-Gain 5.25** -- monotonic improvement from 4.0 to 5.25. + +6. **Legal Score-First TTT** -- SGD (lr=0.005, mom=0.9), 3 epochs per 32K-token chunk, cosine LR decay. + +7. **SP8192 + GPTQ SDClip** -- int6 matrices (k=12.85), int8 embeddings (k=20.0), Brotli-11 compression. + +8. **Tuned Hyperparameters** -- WD=0.095, EMA=0.9965, warmdown=0.72. + +## Architecture + +11L x 512d x 8H / 4KV, MLP 4x, LeakyReLU(0.5)^2, Partial RoPE (16/64 dims), layerwise LN scale, tied embeddings, logit softcap=30.0. Depth recurrence: encoder [0,1,2,3,4,5,3,4] decoder [5,3,4,5,6,7,8,9,10]. Improved parallel residuals from layer 7: attention reads from lane0, MLP reads from lane1, both outputs route to both lanes via learned `par_post` and `par_resid` scalars. Skip gates (sigmoid-gated U-Net connections). + +## Compliance (Track B) + +Per Issue #1017: +- **Condition 1 (Causality):** Sliding-window eval, prefix only +- **Condition 2 (Normalized):** Standard softmax, no n-gram/logit bias +- **Condition 3 (Score before update):** Each chunk scored under `torch.no_grad()` BEFORE SGD +- **Condition 4 (Single pass):** Each token scored once, no rescoring + +No SLOT, no pre-quant TTT, no ETLB, no n-gram cache. All artifacts < 16MB, train < 600s, eval < 600s. + +## Reproduction + +```bash +SEED=42 QK_GAIN_INIT=5.25 MUON_MOMENTUM=0.97 MATRIX_LR=0.03 \ + TTT_ENABLED=1 TTT_LR=0.005 TTT_EPOCHS=3 \ + torchrun --standalone --nproc_per_node=8 train_gpt.py +``` + +## Credits + +- **@msisovic** -- Improved parallel residuals (PR #1529, #1204) +- **@clarkkev** -- SP8192 + GPTQ + SDClip + MuonEq-R (PR #1394) +- **@dexhunter** -- Muon 0.97 (PR #1514), depth recurrence (PR #1331, #1437), TTT on SP8192 (PR #1413) +- **@abaybektursun** -- Score-first TTT framework (PR #549) +- **@X-Abhishek-X** -- Hyperparameter tuning (PR #1445, #1471) +- **@Robby955** -- Parallel residuals on SP8192 (PR #1412) + +## Acknowledgements + +Thanks to OpenAI's Advanced Competitor grant ($500 compute credit via RunPod). + +## Included Files + +- `README.md` (this file) +- `submission.json` +- `train_gpt.py` +- `train_seed42.log` +- `train_seed314.log` +- `train_seed999.log` diff --git a/records/track_10min_16mb/2026-04-11_SP8192_ImprovedParResid_Muon097_LR03_LegalTTT/submission.json b/records/track_10min_16mb/2026-04-11_SP8192_ImprovedParResid_Muon097_LR03_LegalTTT/submission.json new file mode 100644 index 0000000000..455193afcf --- /dev/null +++ b/records/track_10min_16mb/2026-04-11_SP8192_ImprovedParResid_Muon097_LR03_LegalTTT/submission.json @@ -0,0 +1,37 @@ +{ + "author": "bigbag", + "github_id": "bigbag", + "name": "SP8192 + Improved Parallel Residuals + Muon 0.97 + LR 0.03 + Legal Score-First TTT", + "date": "2026-04-11", + "track": "10min_16mb", + "val_bpb": 1.07785, + "val_bpb_std": 0.00047, + "seeds": [42, 314, 999], + "seed_results": { + "42": {"val_bpb": 1.07718, "artifact_bytes": 15990780}, + "314": {"val_bpb": 1.07810, "artifact_bytes": 15987449}, + "999": {"val_bpb": 1.07826, "artifact_bytes": 15987550} + }, + "hardware": "8xH100 80GB SXM", + "pytorch_version": "2.9.1+cu128", + "technique_summary": "SP8192 + Improved Parallel Residuals (cross-lane routing L7+) + 3-Layer Depth Recurrence (L3-5) + Muon 0.97 + LR 0.03 + QK-Gain 5.25 + EMA 0.9965 + WD 0.095 + Score-First TTT (SGD 3ep) + GPTQ SDClip + Brotli", + "compliance": { + "train_under_600s": true, + "artifact_under_16mb": true, + "eval_under_600s": true, + "no_slot": true, + "no_pre_quant_ttt": true, + "no_etlb": true, + "no_ngram_cache": true, + "score_first_ttt": true, + "three_seeds": true + }, + "attribution": { + "sp8192_gptq_sdclip": "@clarkkev (PR #1394)", + "depth_recurrence": "@dexhunter (PR #1331, #1437)", + "improved_parallel_residuals": "@msisovic (PR #1529, #1204)", + "legal_ttt_framework": "@abaybektursun (PR #549), @dexhunter (PR #1413)", + "muon_097": "@dexhunter (PR #1514)", + "hyperparameter_tuning": "@X-Abhishek-X (PR #1445)" + } +} diff --git a/records/track_10min_16mb/2026-04-11_SP8192_ImprovedParResid_Muon097_LR03_LegalTTT/train_gpt.py b/records/track_10min_16mb/2026-04-11_SP8192_ImprovedParResid_Muon097_LR03_LegalTTT/train_gpt.py new file mode 100644 index 0000000000..64e8385dbb --- /dev/null +++ b/records/track_10min_16mb/2026-04-11_SP8192_ImprovedParResid_Muon097_LR03_LegalTTT/train_gpt.py @@ -0,0 +1,2 @@ +import lzma as L,base64 as B +exec(L.decompress(B.b85decode(";KqP8Ook~)Km^%c^ys%R{D_%yAk9-_tV7^coUOo3$w>`(`ci)t`2F7>r>Ltx>>S2CRw|7ov>Wn1e~_!RLQ=%V9g?)G3yPsu%SBy!lj1PaC-x%dDmCDOZ^r^!)+WWz}ejKXTJ#^U6Ra!};QocHHXQC+4UM!QQ!-N5Xd|%~a(9)bTYIO+>B~8~@lqmri%^qEkQUy074Rh6w7V_#^s9J-3BNA`G;qyR$LYcI?e+loZVWi~B$n=TKFp{%SeHYp{oNWh;U@Ahk8M2$OU%K8B$lb*dRQXd-GR_@*KAZdRdwSd#v=LSq1v@Puul=a7WXDmh1^kBj}Y2XlER!D2E{&{%lV(hz$#n5%+%sk&Q}>{y0xpRgiQQBJeVV0hy8UD3ntyo@(Pv+K7^zVRDt4bah(r8kfsZThb+H1)~K-lIr4`|V#-2R>G7pP*N!fwWd&Dq8C)y=NrG_U_Oz6Q?+@ok1?(VJ5?ZT~&}C4Ks38WRB>3i=I!}H-8qq=&yKJ;tbpwwn~lAseD^q1C*u5T;lKQtF;?zv@u0f36%6SXU~txi3v5iSPK*`fNE9531KaQDL`zTPF$MX4U(-3sY-&?>QJe)giBQzpor7H)AZ#4=Hn#`AoAL7tT){&bw(fgz|eQRt`#6-<>;m*+&$!nf|od6&lVKYYHuOoNgZU_L>E@!O%__mlt=);Hwdc43+CM?sh5y+my3XSVYMO8F1pXuq$fvTU<$mpDjr>Lm){DeV)>4AKAhA?jxjH<-3yYQ#5qz+4c`Utifny+Ydmr4?c_z60#9@FU+U1&O$Lfg$WrX7gCj50O1t`1A`k04LVr;^*~{|@(TS5>#TAjL(B`umc8bVA$bS|F?^2A7E}z7IIgZlY(8Ex#K+nLh0vzlKK=74U!g+sX4T?e3_^_7XB1A(HB{pYd{vHYcak_P3DZ2LAB20wAP+C_9p7R|0}wA=p~JFi&xD8H}n(LxCc5rcmwF`!s(tSf_Fr`xA*>{``5_ONTCOibyqqb(@(_fio!M9VR)_%H?t4D^*@6P)Sk^W?gZLQ-%kL^)bh4p0}C)5<>6$_CEzK2?3S{Vq7wdD-aAHTuLsdDK-PSvVu8SoUPW&0w7c1H@K8yV0|N%%Fz83LT*P|DZM8Pk{Wuc2u{`JNA)Bz@$w9a-6Ii?B=hpdtRUo$AhZzDil)Fw)PPMB172a?5{YekB3=`R=+UGKm)e%$9VA34V1BE+XUPVOxan0xT|m)x;}CG8t1PY~CgH&P9`!IvP4#K!r7k%|Gj}<>VGZo4}`1M0!Hi@g7HptJ?R}qGJvl5g``IXSQber|>J>e#Y>!lfCg@M~~VFbh>q7K15z!7e&5Ql5h{lxvx+iO3KJSpbfYlun(~kWx(gE%%%uYa!)#Vi^x8=u!V?nawOmy&5}K{)0mMO2NG>Hk@u@WdeR3Gz!lIWYD2{Rrau)U2OBJBP`H9#Q#0g&syRS|>MN4dVfFN83CH4H4@knF$tacvtKT-btG5W1)eMwIGGCoXnY0*vSuDJo}e*Gse!>{1sH;SaxT-IV8D3pC_(bL8dUsVVQ?BuL@v;$7WkY;s9$^6bh@QfbNzhM_09#%6TZFz@a$%I~G?ih{4|y5!CcoiucYg8aGgTu1t(X@bdkGLIbkXy=)B1+hrl^XbQzxmQ}uhj5oHvu6907dw;lms&ZuXNeDhd_wW4WjutVd?HVVa4SYZkeQUR|ED9flkj4Fdh4Crl_MLwI@wEX{pihe^y8bi-u+ea^C$~8K@Uu!3=`~VD8TE|E+rk`TtL$?%=nTReOW>Whvuja0{##+=3N57|p0WqI`;kH)zY9z-KWbhulW|iCBdt28_7344AL)iXP_2Fi07gZvSDDxsnHSX1`WtvswAJ<)`l7&Y*Gmvh$ALANv{7iv@v6n|0fiRTifHr|IuQFCYXjb`H9#Ndd@9gE@K}oIP(HLmVM9gezXOTA$&zW8k?=E5P)0uhb{`mlnWH+I^wfzfCKBv8^4-(KpYtKP)Ew>)*aqyBdJJ%iHe5ACRI@e(ZMFI&4DMq11o=TcVej?V5X*x{cQKzt{(w+ouMH#B;~}!OxP3+;QR2I4W4s)YSaroB9p|xRZ+&M`4h*0OltPVBnG5M?It(NAZr7Z!~jY2Hm7#X~?@_a2!u+j80at6_Q(CPk0pj-uiV4nNcBQydzq}qYc;Hied|N?SOnLk_Kp~Y68YyFRNh|$C?I-O2Jq5d)Zlk04>OIIjezAswE7N#S$HFT&4ucxTaj$2wnGysPer;01Qz9urB66PHRjFR>~VrY|@2*V8f4MFBL(0Wd{-*ZClXffzd929%G%ULX|vzN|18a{Qmhru_A-XkjE|%HU~+*a$_i&q6-^<@6=BvJJrVx2v*{%dGTMJH%gEFXX(CH7M0dRIYvBbDH)?mpi0g;`)4$nW4!#BcU{Nf)@*wC&PPsS_o8a)TK|eJ)gtWr1H_O|n|p}U;nTS?;VgJ?A$&?)oK`o11#%e&^ci~x6|5pDq!sCx?;p-i(w7Zkt?QVp)cypVPnZv5F6f0||9txY%3TILMHCe#*Vm{m@nDjo$3JldqIkmhRnDt*EP!qfvl~r+oC^QWDCV<(*$DCggneARoBFD@IUwPr(KU?ZPMwo#Dr{huUJHPyV(etRG2|1`jkYj5K)q0oz@Z6EiW!c(Y#l44I(|9FS3~UjEei>piytPtFK(3>HR1DQN6;eGR$KFm>M#Ox?_QZ*fB%Q=~49sh#6Ft34eo*PBVwK5z6ou=uyrz<(*O{v6ZjS=moOu6l%-AsI`wkAiLC-y8Dz`GeL@3Yb~PnHc9Ukb^llWmD0N#lpA9v-X6)(mR{)q<^xY)B+9kczmd$2mwZlSpi~zH^)9lTsifb4t%Z^VNk|ExNhjHrJWEN+euZw(tx-DH!(JKqEShJbvsxlV^E~eKJ)O;efIUvs*2UG84H-kE(I*j8EAJ^%pOOFS_B1p6>I;O*AxFQT!Zd5FDW01)&osxt|3UYeDyvAEo1iQ@LcK?X&x=?+cy3pua_h)0MB#nJ#&VAz$}FgGld&6Dd_NPcmlXd!-)VB{p{F;m4h8JGy-C;-{H?fW~fcPN^E||2#-DhDE90TW2W4HKsDs+|Eh@ZB0p3_rp<5I@tJgUsy(!$OZljiL@zSy2Gq!UNgiof2lx8m<^+jWY9K2vmSao%XrICDe`u$2fh8pNeSR~;#*_LJEW%L{q_Wn<`Nn7C!ju_BluVzA4=i#Alocb7;VJb#Xyf4vVPB0LFLp$Hum%&nDP*AAOJWJwfMpOpO)z^Ec^OWj7i1{u@SB{jltcFk`gRvjP~QZDA{{RykvW1$Vy~aRLk>~LW2exy?@BBx%T^gBEwpvRQJ>pY$JMF5|L;MthvhSL^&tKvqVCS&b>o2h_PhT{i`sWf1P%$i0OeD_&>0wrjc`YXV%i)#_cd%LblW!u7vL~z8uW?}7+-%Z7AiNdJ%T`4!=SKi}bU00$wO#DL*1*Ry6x%4|$Kqd#_iJq48E;|sD=I;lmN+3M9E?ETRv#o+G)dcefdajB{?SSm_V7BM#n!63CH_CC__}w|L6d~HOZ^h5a`8-pPm!)0-0Uy?CYfz3(O0mvZdD+=#;xsqAy2p0o7(*LmmwzUm;U>%epB>o3llJ3F5pb+v_n%%xD@T`vzug&9G08)y78|FG=+f4E``bvOk*Sv?{;egOs5*to1(0wJD+E#z^}sB3IkIvrURxe`x_RDk`O`WgG1n0T#^1I3}_m7MB56Z6tu%BmWwg__rQXA0L`>``cjw#R{@R;;w0o-G`+EsFsSUoKjeAHGA4SDd(7Y6uJowxQZ!1yXQi;9#+j3G@cB>&7$+JlmRnIFoRxP>%_p{haKR|*LVdp+VPxua4x-7=KnNCEmxJ=nSOjCg(N$`BmRpzl&GLDSMR!W6Q9@Tk;jG|dY7!V6{KFaQ-;pu>EYBpvy>kc~P635{3U$=)+*c=#_t5C&@OA`2qZeBReffZBryZQxGult=RvjcPfB*r`a)s9?s%a#2~s+$<&!sf_}2pL7#|Bt>t&HFPwG0^_;bNf;}Qb$A#2^@w3ypFkTgTogls#6fFI!X2rIe-b0r0zjS`B=yBFy9epIi&hf6nJ%)KxXN_QxxCBx1}ZeWU3;r%Y^kx&+NXVQcu>p14r>hZUEx!Voe$v0dj5gp-zq-8yG~~@Oz`}0i?&86!^|g?TrBV&}fKW{t^JU^Loc*cCt;EC}gRc7k-eCV~5ZI}8j(DO9pl-_U%WhnT@o!H@$QH(0&$v~E}UEotq@{Q%OZ%$;%^P88RbVDomXp?X{_YFqe9lBg)H^d-HseH)q75qIW&qD>lW${;-VBB)(Yw0+u07@rh*c_Y0tZKw`*E2_!JDcLUDz@OtZO@kA%Jbh_hg=L{UO#`m74S1g!hG2@k;kHr+CB!L+nrhGED#c7(Xd69bMx26&+ua`4Y#b~33$>Rg0_+0iiA+y#fuTgl@i=~Eir}nYN-EG{j?FlMzF+N4Mq=`H=E{x{BY#(Sq7E~6>Xme#TXYPRLc+~vZYj=HvkgZ7=Nx~)~gfA==haeX*-RZv~^Up%Y{_QUNRZJ%vDu;p)@`(@0im$#6TrAE&lP#BEA!$TDHz-HDv4CpFnA%=4>Wy)nCUKaXXY__y&NK-R@0#@tklq?$mLS6~4hT9TI{3@FjmL)Xberb{#SKN#zA-NzqH^($enKb`Ii5rkw0SyKsTP!DNRl3J97jh=A2AFZVH)hj}*3=*U(BfG5cr7AtHxWn9VFH!ARn2VuvlGDz=Pr@aFrabA^9LqwjgY-^`RXyRtF*m`=-SvAo?7p($6hWl40xlrlbfdTz7_;^hUv&ivJ3R6+FXht-{pCNYS}EKe=@R_fR$#?2{532HEzTxL*s5C_rTSC`i-Yo}~okpzGBMSo}AlOr}J+QVQvQpUnH%+*k!VXb@t`62R8#wtLUb4muMCHHUxuTXBEXSQysedBqE1z0vH%A~4DGh-;E=!bb+dnoWswS3hyjyrK0r(=Q=K?aY4j#A!RssCPl0w4MS@IZ!4L7b_|7nTi+@ztZ*>c&wT%xJ02z{ESft7NQbOzwd>ZuU(lJo2coBfX?l(&QF%WK@KP4hkDXBKPh+##D6{;tXE%c$I~OxdgY)5;>HxXlKBDJK>E9&i<-^rhX$>?aa0;<}4nj$KCT1iZ*`lq5ZNPh)37fiG&IJcI_#y$9evdL&2R0jjD5{q6<*o8u%GbeEx6(*v>rg-@z_lfhM(!i!hU#iavvX#U($$;wYw``O3luE(nH0icO!d!sg78Ep%`hYr!Ec3h=xvwwkCLWtso|J@F61%SFk47`^059k?^jzf0K?`k6Ks$nL;aEHxs$6#&dkRpFR<8aJLlk5yg=PXK`=>n*%KmVCZk%8MC3uPm1+Rt=nAB=XrIg(&)po!>BeHAsUeZMBBaGg}bZCC#w{@#hBzLni{^xqP^1sYz2$|mBa4=B$gTE8xsPqOTb^&p}_RW9vj|zV~yI^oF&hxY6H+P-~#)8F^qV*M=GyxvISR~WoTTzb^qK}UPBw+FV9b%~)KwZcQVpal02%&cp@rpx-CH4`Pa#tITe1r=mg_23x-Ji9op6-hSH}P)^N<4I?P~n~mWHi?Pq1Lfk{7TGUfkh+2~T)j#3|{E&!9zQaSYiqR?)crCFx6L(i)WTR1!{Mm&>6mA&%eJ7+#5+X-A3ZA0FEleX<^2uSRf!=t-0T6O|wxPi%a9D%o#D^Ug91EyEqsl|w4k?GDH&AiP*Ge$d_2;6A!G$=-<_zQ^{vwHRai^3B93P#Z<8OE+=xF^B`}VLY$zDC8OoGJc0E8OJ|t^bi6jJ*7w4@o>%%(pj%j1FX%=+eck>{9nM;LDHPt){=)TenkAa_rJ;{6VLU%TdlJ=baD#mCpU28%XzL%t7NSYU9MC9RXWE-n}89g`YcP+xp7pv5H#`*{xcJss32SouK2aETz&rcDQwko!r$*DwEunB<+WPx_c+rzB&y9juamq5Ryhk$V}FPb>9SbIWP&Zu$udz)BvPpMova=s!pkAV%#p$<-t)dx^zug&cAJ=eTz9J1@YiY}w>oW(J0j;Zh4wF#rxvVu@r@j5Pd2vN*ybvO4&IheLi*%FY71Qf5}W)(vyTFYu~B4$gXUamVHkC%eBDXNjNd@b7oJypLA)*wc>CGYhm&M)y7d$8>=tHpTmbkZB&@(iX|zC$=DJ)AKnBJ7410obKeq8qBY5T(_?%y)#EIDKqJB3{fyD$fK^$v-NXVy;8a3~P{9yzA2YXK4MoQ{`Jq~#+oE9k)-17g1T5>0;yI`Hb~oks;FIn_x-?(bGNe!t=@qF{^X=LgsM!QS<8Dy76M^7Q2CoNsd3m@vuJ$dj2F+anQrp|s-tSFC2(^!H<^XIJAWn_m5#%qfRdcXR@bl`kRW`#kfo>bvy(q12N_3aF`vwpLzpEU^2`L%CA*@i7jX{AhpT}SBoui&MpaH@F`20;0m-?fL*0irm-f|UF#H+yAKUvK-mBG*u=rdEyGZ3aOx_puuy*EHfC2?W;ack3-bQS-nKz!}VjB-F2^DuMah8)AsHYObk-p^H7RbOGal|8h=IeaCV+!P;U~hP`KbaDAR;NLv_m%Zd_vCLhJlpRyx(atjYhH|5}MX*ee;D8B)m^*0*wOVeJ(93<4SX^?16k}9&FKr&9Xsh9g8wzkuxzR~`tu}=A*c#RCt)lS|@HhF@RWocUAJ!clmgJP#%^$P;0RxwrPb}@?`3uUN4dWCwU_GsCWP`ia*Lpb(B9&3w0%BREBI9^K>=(tKZPhVkPM{;nOQa;0N!+Z|MkqrDiN-$DPRlE<`zPkKmJMmDF?|#riP*kBduGv&e77=lySngzf{$^-fA(8Zz)y*u#tTC@wn#rz0D&(?DaX|-pF)tU~RCZ6&yV2yD2ED%~fIOSvqN|tFAZ#I>7;gD5pR#%PyO}$RLRz8h@p+O-xj%uFASSM44I5VnYb7OmeS>fK69?%QY`FguYuIlf0{BKl1^}Y;&xab!fVY@cxSSfkn}9Qh|HVmwgDjVSFaOLy1u!g~FMK**p}W@pMtLrYOx*GF@+SHJILF&~X`CCaflg`$mn>R?g)UF-Lf0fa+56|(DgxwLEO~}Md*=E4l#IHTMG8em;|O0vjunAGJPj~PKQ1cwm?a7!Y{Fl|d$EoDJk<(Gg78qW-zp8uOLmPSi_0~F=k6$O7t~(=9RP3;S*xhIB8&gnD-T>Bm4O1u2hiaAE{#5(Kbl{kbPR(Rp*=4?M>CADDFEB5n<(|t!R2-oW>YYjznMB|YS{ILQ-blviYG0@7K?K~FcXsS;bE3Yr`#I2B~NbyJ@54z`kLm$oqtG-yw>HH2KdN5=!Sdpz$rV9S|PGOMe4S|srwv4)PItIwU%0ET6rDY%_!L3z#SXM7(FB5&vbK`nbvfV>eK~r_Ef?fmxUQf@KKGcCwuU06F!QV56f3#3&kMUim>opJR4A+MxQV{n}Iklh67`kM(oC&M4QfN%&})6K}uqTf-R^i(r$u2&mOF)o=G*Uu`QGH(2W=5Vsh#uih8&50gl{56iMyR>Fcv~-xGumPaHh9P17t3MErKrj1x=X#LoN5s<6R<$nFiJfaXfxgroij0Eu6PP?kO$;ZtHczL-2t9laYS1Ns;Id9jn=+4^Cls2;|gk(8alaUxmJq*b@@V1R#Hh4+?f54*?~fbVE4qf!=pBAZh1P0HSfQ*MCBOU;?#Ttgj#npIQ{!6kIsp8Ad47EEa>etSFzIXz*Atw)qqu;0;8QBPd8d!MHc(&x3YMRija*K!hB(hn^mrL0zJP|a>2MJtoT$w}^>;QpDxJf9(9VbYfv?&Sho?M8118PjPd^dMCLiWn0?7r2*_gQ$-9R!PCdY*_)WCOyH;g{Hp=;+~DAaXQIMRovNS2$ucZvB>v`zVMiB5!#9Jv(Ppmte0|ap8WI+yenA;iFA9Fs#ua7I(mrYt{CwTZy;4lT&Go%*Vs+NgYG$8uvL-}iPV2`i}hQN4s4iFB@dG@-`DK~Ewd77)-CK>IvDoG$OH6QTw7wAQw;CK1n@2NA#BR#bygqP|vNvKt4i2U%NSq<`2|7<7U4|LFVss)Y>W_pPj0Fjmx^p)IEda9fz3EEdsDdq&S?>&6|E$KA*6@)QUJZ4-k8-ItDTeP^|3I{dhKWWXm0dSE=<*@1+cIX7oV0R<<75`op{>SaiC8lh(e6+KZQ6BOcn|i^RLnDTd8^}bHlWc`{Oqbnp9WbN>wC#1Lk~M#no98OtB~G%#StKU_*cHPjy|@&F@5@4L^!~e%c^l)HVo2aN#?R#{wwSWitk>Ep&oan&U7&^o6U%(sgIrMdEL}+0I;Y0C&Tn3`udRLxbewG5G|}P}a@C>}mWE*3W(R0@I7TBzt{Y~LL);*`cTwUhsRFp$hq_uvd9d6LOR1&tiaCY3E#g8v5~{d_baI%44Av17puVL#ym<2mU3o&(j)13Qp6>^AsR4{4&?RM;aLFPb39avfs_#sdxpu)~{rT3XbrH1)upfW^Rc_<5@bT_m&H;sLv)~7$r#R4K>S6#qKNOLsaS5tEq1NgQLJ-+*CWxh6Hw(7pN50BKC)xV@eZd*oV5i@o8zE7n_}Ah-E-{d)VhQqMC-ZqmEgS~GhO+bL?I?0xt-^*qgSG-wQk!cmM%Zc?r)*p)4h+G2Rwom1Wg!W)yq$@}c-XTrVVmE5V2MyjK0U5+)abe6%be~Ev*fVK-ElsG2s#q%mc;fXrl)~=_0{&z!w*l)UatEk&cmf7U;M6NNTtVy4(blUGu*uI*M;nb(@I?E2(z0&UFd{G2L*o`VkrDlHDJd}zGfgmUR@T^%oGIJlHJ00jMAU*R1FMmL+Ddhty)auBp|g976IpXi6<2Mj8!+k1i^?~*P>IYV9#u6{A&s&jB%(ORz%3LX`>UbPVUs9aZ5w0Th*H+QF)a`p9lQ$V9HA%wq}WyBouDL{%kTr)%5-R=X(B8BAEpnKT8uB!d|ROH9M;O(r=HI{|lkRz7bi{Nc-Kk51#TVcw=-Zr*3Z_Ek@f{JqHo{NIFdfyIkE~s%7v%J~~=WmwGFc>}IUDO#X}Y1xU2SpN_(wYjevq5Y@8%Rkk|XN7V4`b~Zv^V#k6jWuhsdK%p4*fmlFgsT3Zz9prIQ*v-#-Mv`9fxd2>x8ZWv_q|u3Fc;RSoIrmY+JopkWC`uS|Lw4Xco9*Npv73Bs_Z*D>30Mc7?SqFJZR||BKP|i_AwF5Jvx1uU{rt!1Dox>)|u@w)+u-%PR|nqY9QZ>91uY@CY>>mo}GmZlYTlv6(V-oH6+PUX2Ys%_TJD!QQyNAXHnD8Hj8=8IS7?TO}%OZPp~_>x9887VM#k4PU+l}TG34ylZ~bzGq702nHq=st3PPd2FMlhVI*e|)(gi!D%;u-qJ{B3mS@UUBXkRref%ohYwQ+)6YuSIj!Cc2%cjvc5n|c%Lxy{+RmP2`Wm*+Rjx}Qv%O#TH;$N$WYRp+TikOIt{m)X8!=IU5TQ;tqqI>M(&A?FI44Vx Date: Sun, 12 Apr 2026 03:42:32 +0000 Subject: [PATCH 2/2] =?UTF-8?q?feat(record):=20VarLen=20TTT=20+=20Warmdown?= =?UTF-8?q?=200.75=20+=20Chunk=2048=20=E2=80=94=20val=5Fbpb=201.07406=20(3?= =?UTF-8?q?-seed=20mean)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit PR #1530 v2 base + warmdown_frac=0.75 + TTT_CHUNK_SIZE=48 + Muon 0.97. 3-seed mean: 1.07406 (std 0.00132), 2.77441 nats. Delta vs merged SOTA (#1493): -0.01491 nats (clears 0.005 bar by 3.0x). All artifacts < 16 MB, train < 600s, eval < 225s. --- .../README.md | 114 + .../submission.json | 39 + .../train_gpt.py | 2843 +++++++++++++++++ .../train_seed0.log | 279 ++ .../train_seed1337.log | 265 ++ .../train_seed42.log | 259 ++ 6 files changed, 3799 insertions(+) create mode 100644 records/track_10min_16mb/2026-04-12_VarLen_TTT_Warmdown75_Chunk48/README.md create mode 100644 records/track_10min_16mb/2026-04-12_VarLen_TTT_Warmdown75_Chunk48/submission.json create mode 100644 records/track_10min_16mb/2026-04-12_VarLen_TTT_Warmdown75_Chunk48/train_gpt.py create mode 100644 records/track_10min_16mb/2026-04-12_VarLen_TTT_Warmdown75_Chunk48/train_seed0.log create mode 100644 records/track_10min_16mb/2026-04-12_VarLen_TTT_Warmdown75_Chunk48/train_seed1337.log create mode 100644 records/track_10min_16mb/2026-04-12_VarLen_TTT_Warmdown75_Chunk48/train_seed42.log diff --git a/records/track_10min_16mb/2026-04-12_VarLen_TTT_Warmdown75_Chunk48/README.md b/records/track_10min_16mb/2026-04-12_VarLen_TTT_Warmdown75_Chunk48/README.md new file mode 100644 index 0000000000..84fc32ff11 --- /dev/null +++ b/records/track_10min_16mb/2026-04-12_VarLen_TTT_Warmdown75_Chunk48/README.md @@ -0,0 +1,114 @@ +# Record: VarLen Attention + Triton Fused MLP + Doc-TTT + Warmdown 0.75 + Chunk 48 + +**val_bpb = 1.07406** (3-seed mean, std 0.00132) | **2.77441 nats** | **~15.99 MB** | 8xH100 SXM, 600s + +## Results (8xH100 80GB SXM, PyTorch 2.9.1+cu128) + +### Core Results + +| Seed | Steps | ms/step | Pre-TTT BPB | Post-TTT BPB | TTT Gain | TTT Time | Artifact | +|------|-------|---------|-------------|--------------|----------|----------|----------| +| 42 | 4918 | 119.4 | 1.08400 | **1.07352** | -0.01048 | 213s | 15,994,146 | +| 0 | 4900 | 119.8 | 1.08363 | **1.07310** | -0.01053 | 221s | 15,997,570 | +| 1337 | 4908 | 119.6 | 1.08619 | **1.07556** | -0.01063 | 219s | 15,988,610 | +| **Mean** | **4909** | **119.6** | **1.08461** | **1.07406** | **-0.01055** | **218s** | **15,993,442** | +| **Std** | | | | **0.00132** | | | | + +### Supplemental Diagnostics + +| Seed | Post-EMA BPB | Quantized BPB | Post-TTT BPB | val_loss (nats) | Code size | Total submission | Train time | Eval time | +|------|--------------|---------------|--------------|-----------------|-----------|------------------|------------|-----------| +| 42 | 1.07134 | 1.08400 | 1.07352 | 2.77301 | 2843 lines | 15,994,146 | 587.1s | 213s | +| 0 | 1.07160 | 1.08363 | 1.07310 | 2.77193 | 2843 lines | 15,997,570 | 587.1s | 221s | +| 1337 | 1.07339 | 1.08619 | 1.07556 | 2.77829 | 2843 lines | 15,988,610 | 587.1s | 219s | + +Merged SOTA (PR #1493 @bigbag): **1.0810 BPB** (2.78932 nats). Delta: **-0.01491 nats** (clears 0.005 bar by **3.0x**). + +## Key Innovation + +**Warmdown fraction and TTT chunk size tuning** on top of PR #1530's VarLen + Triton fused MLP + doc-TTT stack: + +- **warmdown_frac = 0.75** (up from 0.72 default) -- extends the cosine decay phase by 3%, allowing the model to settle into a slightly lower-loss basin before quantization. This alone gives ~0.001 BPB improvement. +- **TTT_CHUNK_SIZE = 48** (up from 32 default) -- larger document chunks provide more context per TTT gradient step, improving LoRA adaptation quality at a small compute cost. Combined with warmdown tuning, yields ~0.002 BPB total gain. +- **Muon momentum 0.97** -- shorter memory horizon (~33 effective steps) tracks the rapidly changing loss surface better during the extended warmdown phase. + +### Changes from PR #1530 v2 baseline + +| Parameter | PR #1530 v2 | This submission | +|-----------|-------------|-----------------| +| warmdown_frac | 0.72 | **0.75** | +| TTT_CHUNK_SIZE | 32 | **48** | +| MUON_MOMENTUM | 0.95 | **0.97** | + +## Architecture + +11L x 512d x 8H / 4KV, MLP 4x with Triton fused kernel (LeakyReLU(0.5)^2), Partial RoPE (16/64 dims), layerwise LN scale, tied embeddings, logit softcap=30.0. VarLen attention via Flash Attention 3 `flash_attn_varlen_func` for document-aware batching. Triple depth recurrence: layers 3-5 looped 3x (17 virtual layers from 11 physical, activates at frac=0.35). Parameter banking with batched Newton-Schulz orthogonalization. Parallel residuals from layer 8 with mean lane fusion. Skip gates (sigmoid-gated U-Net connections). + +**Optimizer**: Muon (momentum=0.97, 5-step Newton-Schulz, row-normalized) for matrix params + Adam (beta1=0.9, beta2=0.95) for scalars/embeddings. Split LR: matrix=0.022, embed=0.6, head=0.008, scalar=0.02. EMA decay=0.9965. Gradient clipping at 0.3. + +**Quantization**: Full Hessian GPTQ with int6 matrices (clip_sigmas=12.85), int8 embeddings (clip_sigmas=20.0), Brotli-11 compression. + +**TTT**: Doc-independent LoRA (rank=96) on K, MLP, and O projections. Adam optimizer (lr=0.0001, beta2=0.999), weight decay=0.5, chunk_size=48. Score-first: each chunk scored under `torch.no_grad()` before gradient update. + +## Rule Compliance + +Per Issue #1017: +- **Condition 1 (Causality):** VarLen attention with per-document `cu_seqlens` ensures strict causal masking within documents. No cross-document information leakage. +- **Condition 2 (Normalized):** Standard softmax over full vocabulary. No n-gram bias, no logit manipulation. +- **Condition 3 (Score before update):** Each TTT chunk scored under `torch.no_grad()` BEFORE any LoRA gradient update. Score-first ordering verified. +- **Condition 4 (Single pass):** Each token scored exactly once. No rescoring, no multi-pass evaluation. + +No SLOT, no pre-quant TTT on val data, no ETLB, no n-gram cache, no hashed n-gram. All artifacts < 16 MB, train < 600s, eval < 600s. Compile warmup uses random tokens (not val data). + +## Requirements + +- Python 3.10+ +- PyTorch 2.9.1+cu128 +- flash-attn-interface (Flash Attention 3) +- sentencepiece +- triton +- brotli +- numpy + +## Run Command + +```bash +# 3-seed verification loop (defaults baked into train_gpt.py) +for SEED in 42 0 1337; do + SEED=$SEED torchrun --standalone --nproc_per_node=8 train_gpt.py \ + 2>&1 | tee train_seed${SEED}.log +done +``` + +## Lineage + +PR #1530 v2 (@samacqua) -> warmdown/chunk/momentum tuning (this work) + +Built on: +- PR #1530 (@samacqua) -- VarLen attention, Triton fused MLP, doc-independent LoRA TTT, triple depth recurrence, parameter banking +- PR #1523 (@EthanYangTW) -- triple recurrence (NUM_LOOPS=2), parameter banking, fused MLP TMA +- PR #1514 (@dexhunter) -- Muon momentum 0.97 +- PR #1493 (@bigbag) -- merged SOTA baseline +- PR #1394 (@clarkkev) -- SP8192 + GPTQ + SDClip + MuonEq-R foundation + +## Credits + +- **@samacqua** -- VarLen attention, Triton fused MLP, doc-independent LoRA TTT, triple depth recurrence, parameter banking (PR #1530) +- **@EthanYangTW** -- Triple recurrence, parameter banking, fused MLP TMA (PR #1523) +- **@dexhunter** -- Muon momentum 0.97 (PR #1514), warmdown/chunk/momentum tuning (this work) +- **@bigbag** -- Merged SOTA baseline (PR #1493) +- **@clarkkev** -- SP8192 + GPTQ + SDClip + MuonEq-R (PR #1394) +- **@abaybektursun** -- Score-first TTT framework (PR #549) + +## Acknowledgements + +Thanks to OpenAI's Advanced Competitor grant ($500 compute credit via RunPod). + +## Included Files + +- `README.md` (this file) +- `submission.json` +- `train_gpt.py` +- `train_seed42.log` +- `train_seed0.log` +- `train_seed1337.log` diff --git a/records/track_10min_16mb/2026-04-12_VarLen_TTT_Warmdown75_Chunk48/submission.json b/records/track_10min_16mb/2026-04-12_VarLen_TTT_Warmdown75_Chunk48/submission.json new file mode 100644 index 0000000000..908de30158 --- /dev/null +++ b/records/track_10min_16mb/2026-04-12_VarLen_TTT_Warmdown75_Chunk48/submission.json @@ -0,0 +1,39 @@ +{ + "track": "track_10min_16mb", + "submission_date": "2026-04-12", + "val_bpb": 1.07406, + "val_bpb_std": 0.00132, + "val_loss": 2.77441, + "seeds": [42, 0, 1337], + "num_seeds": 3, + "hardware": "8xH100 SXM 80GB", + "train_time_seconds": 587, + "eval_time_seconds": 218, + "artifact_bytes_mean": 15993442, + "framework": "PyTorch 2.9.1+cu128", + "tokenizer": "sp8192", + "ttt_enabled": true, + "per_seed": { + "42": { + "val_bpb": 1.07352, + "val_loss": 2.77301, + "artifact_bytes": 15994146, + "steps": 4918, + "eval_time_seconds": 213 + }, + "0": { + "val_bpb": 1.07310, + "val_loss": 2.77193, + "artifact_bytes": 15997570, + "steps": 4900, + "eval_time_seconds": 221 + }, + "1337": { + "val_bpb": 1.07556, + "val_loss": 2.77829, + "artifact_bytes": 15988610, + "steps": 4908, + "eval_time_seconds": 219 + } + } +} diff --git a/records/track_10min_16mb/2026-04-12_VarLen_TTT_Warmdown75_Chunk48/train_gpt.py b/records/track_10min_16mb/2026-04-12_VarLen_TTT_Warmdown75_Chunk48/train_gpt.py new file mode 100644 index 0000000000..60bb9f85cd --- /dev/null +++ b/records/track_10min_16mb/2026-04-12_VarLen_TTT_Warmdown75_Chunk48/train_gpt.py @@ -0,0 +1,2843 @@ +import base64, collections, copy, fcntl, glob, io, json, lzma, math, os +from pathlib import Path +import random, re, subprocess, sys, time, uuid, numpy as np, sentencepiece as spm, torch, torch.distributed as dist, torch.nn.functional as F +from torch import nn +from flash_attn_interface import ( + flash_attn_func as flash_attn_3_func, + flash_attn_varlen_func, +) +from concurrent.futures import ThreadPoolExecutor +import triton +import triton.language as tl +from triton.tools.tensor_descriptor import TensorDescriptor + + +class Hyperparameters: + data_dir = os.environ.get("DATA_DIR", "./data/") + seed = int(os.environ.get("SEED", 1337)) + run_id = os.environ.get("RUN_ID", str(uuid.uuid4())) + iterations = int(os.environ.get("ITERATIONS", 20000)) + warmdown_frac = float(os.environ.get("WARMDOWN_FRAC", 0.75)) + warmup_steps = int(os.environ.get("WARMUP_STEPS", 20)) + train_batch_tokens = int(os.environ.get("TRAIN_BATCH_TOKENS", 786432)) + train_seq_len = int(os.environ.get("TRAIN_SEQ_LEN", 2048)) + train_log_every = int(os.environ.get("TRAIN_LOG_EVERY", 500)) + max_wallclock_seconds = float(os.environ.get("MAX_WALLCLOCK_SECONDS", 6e2)) + val_batch_tokens = int(os.environ.get("VAL_BATCH_TOKENS", 524288)) + eval_seq_len = int(os.environ.get("EVAL_SEQ_LEN", 2048)) + val_loss_every = int(os.environ.get("VAL_LOSS_EVERY", 4000)) + sliding_window_enabled = bool(int(os.environ.get("SLIDING_WINDOW_ENABLED", "0"))) + vocab_size = int(os.environ.get("VOCAB_SIZE", 8192)) + num_layers = int(os.environ.get("NUM_LAYERS", 11)) + xsa_last_n = int(os.environ.get("XSA_LAST_N", 11)) + model_dim = int(os.environ.get("MODEL_DIM", 512)) + embedding_dim = int(os.environ.get("EMBEDDING_DIM", 512)) + num_kv_heads = int(os.environ.get("NUM_KV_HEADS", 4)) + num_heads = int(os.environ.get("NUM_HEADS", 8)) + mlp_mult = float(os.environ.get("MLP_MULT", 4.0)) + skip_gates_enabled = bool(int(os.environ.get("SKIP_GATES_ENABLED", "1"))) + tie_embeddings = bool(int(os.environ.get("TIE_EMBEDDINGS", "1"))) + logit_softcap = float(os.environ.get("LOGIT_SOFTCAP", 3e1)) + rope_base = float(os.environ.get("ROPE_BASE", 1e4)) + rope_dims = int(os.environ.get("ROPE_DIMS", 16)) + rope_train_seq_len = int(os.environ.get("ROPE_TRAIN_SEQ_LEN", 2048)) + rope_yarn = bool(int(os.environ.get("ROPE_YARN", "0"))) + ln_scale = bool(int(os.environ.get("LN_SCALE", "1"))) + qk_gain_init = float(os.environ.get("QK_GAIN_INIT", 5.0)) + num_loops = int(os.environ.get("NUM_LOOPS", 2)) + loop_start = int(os.environ.get("LOOP_START", 3)) + loop_end = int(os.environ.get("LOOP_END", 5)) + enable_looping_at = float(os.environ.get("ENABLE_LOOPING_AT", 0.35)) + parallel_start_layer = int(os.environ.get("PARALLEL_START_LAYER", 8)) + parallel_final_lane = os.environ.get("PARALLEL_FINAL_LANE", "mean") + min_lr = float(os.environ.get("MIN_LR", 0.0)) + embed_lr = float(os.environ.get("EMBED_LR", 0.6)) + head_lr = float(os.environ.get("HEAD_LR", 0.008)) + tied_embed_lr = float(os.environ.get("TIED_EMBED_LR", 0.03)) + tied_embed_init_std = float(os.environ.get("TIED_EMBED_INIT_STD", 0.005)) + matrix_lr = float(os.environ.get("MATRIX_LR", 0.022)) + scalar_lr = float(os.environ.get("SCALAR_LR", 0.02)) + muon_momentum = float(os.environ.get("MUON_MOMENTUM", 0.97)) + muon_backend_steps = int(os.environ.get("MUON_BACKEND_STEPS", 5)) + muon_momentum_warmup_start = float( + os.environ.get("MUON_MOMENTUM_WARMUP_START", 0.92) + ) + muon_momentum_warmup_steps = int(os.environ.get("MUON_MOMENTUM_WARMUP_STEPS", 1500)) + muon_row_normalize = bool(int(os.environ.get("MUON_ROW_NORMALIZE", "1"))) + beta1 = float(os.environ.get("BETA1", 0.9)) + beta2 = float(os.environ.get("BETA2", 0.95)) + adam_eps = float(os.environ.get("ADAM_EPS", 1e-08)) + grad_clip_norm = float(os.environ.get("GRAD_CLIP_NORM", 0.3)) + eval_stride = int(os.environ.get("EVAL_STRIDE", 64)) + muon_beta2 = float(os.environ.get("MUON_BETA2", 0.95)) + adam_wd = float(os.environ.get("ADAM_WD", 0.02)) + muon_wd = float(os.environ.get("MUON_WD", 0.095)) + embed_wd = float(os.environ.get("EMBED_WD", 0.085)) + ema_decay = float(os.environ.get("EMA_DECAY", 0.9965)) + ttt_enabled = bool(int(os.environ.get("TTT_ENABLED", "1"))) + ttt_lora_rank = int(os.environ.get("TTT_LORA_RANK", 96)) + ttt_lora_lr = float(os.environ.get("TTT_LORA_LR", 0.0001)) + ttt_chunk_size = int(os.environ.get("TTT_CHUNK_SIZE", 48)) + ttt_eval_seq_len = int(os.environ.get("TTT_EVAL_SEQ_LEN", 2048)) + ttt_batch_size = int(os.environ.get("TTT_BATCH_SIZE", 64)) + ttt_grad_steps = int(os.environ.get("TTT_GRAD_STEPS", 1)) + ttt_weight_decay = float(os.environ.get("TTT_WEIGHT_DECAY", 0.5)) + ttt_beta1 = float(os.environ.get("TTT_BETA1", 0)) + ttt_beta2 = float(os.environ.get("TTT_BETA2", 0.999)) + ttt_k_lora = bool(int(os.environ.get("TTT_K_LORA", "1"))) + ttt_mlp_lora = bool(int(os.environ.get("TTT_MLP_LORA", "1"))) + ttt_o_lora = bool(int(os.environ.get("TTT_O_LORA", "1"))) + ttt_optimizer = os.environ.get("TTT_OPTIMIZER", "adam") + ttt_eval_batches = os.environ.get("TTT_EVAL_BATCHES", "") + ttt_output_dir = os.environ.get("TTT_OUTPUT_DIR", "") + val_doc_fraction = float(os.environ.get("VAL_DOC_FRACTION", 1.0)) + compressor = os.environ.get("COMPRESSOR", "brotli") + gptq_calibration_batches = int(os.environ.get("GPTQ_CALIBRATION_BATCHES", 64)) + gptq_reserve_seconds = float(os.environ.get("GPTQ_RESERVE_SECONDS", 13.0)) + matrix_bits = int(os.environ.get("MATRIX_BITS", 6)) + embed_bits = int(os.environ.get("EMBED_BITS", 8)) + matrix_clip_sigmas = float(os.environ.get("MATRIX_CLIP_SIGMAS", 12.85)) + embed_clip_sigmas = float(os.environ.get("EMBED_CLIP_SIGMAS", 2e1)) + distributed = "RANK" in os.environ and "WORLD_SIZE" in os.environ + rank = int(os.environ.get("RANK", "0")) + world_size = int(os.environ.get("WORLD_SIZE", "1")) + local_rank = int(os.environ.get("LOCAL_RANK", "0")) + is_main_process = rank == 0 + grad_accum_steps = 8 // world_size + datasets_dir = os.path.join(data_dir, "datasets", f"fineweb10B_sp{vocab_size}") + train_files = os.path.join(datasets_dir, "fineweb_train_*.bin") + val_files = os.path.join(datasets_dir, "fineweb_val_*.bin") + tokenizer_path = os.path.join( + data_dir, "tokenizers", f"fineweb_{vocab_size}_bpe.model" + ) + artifact_dir = os.environ.get("ARTIFACT_DIR", "") + eval_only_path = os.environ.get("EVAL_ONLY_PATH", "") + logfile = ( + os.path.join(artifact_dir, f"{run_id}.txt") + if artifact_dir + else f"logs/{run_id}.txt" + ) + model_path = ( + os.path.join(artifact_dir, "final_model.pt") + if artifact_dir + else "final_model.pt" + ) + quantized_model_path = ( + os.path.join(artifact_dir, "final_model.int6.ptz") + if artifact_dir + else "final_model.int6.ptz" + ) + + +_logger_hparams = None + + +def set_logging_hparams(h): + global _logger_hparams + _logger_hparams = h + + +def log(msg, console=True): + if _logger_hparams is None: + print(msg) + return + if _logger_hparams.is_main_process: + if console: + print(msg) + if _logger_hparams.logfile is not None: + with open(_logger_hparams.logfile, "a", encoding="utf-8") as f: + print(msg, file=f) + + +class ValidationData: + def __init__(self, h, device): + self.sp = spm.SentencePieceProcessor(model_file=h.tokenizer_path) + if int(self.sp.vocab_size()) != h.vocab_size: + raise ValueError( + f"VOCAB_SIZE={h.vocab_size} does not match tokenizer vocab_size={int(self.sp.vocab_size())}" + ) + self.val_tokens = load_validation_tokens(h.val_files, h.eval_seq_len) + ( + self.base_bytes_lut, + self.has_leading_space_lut, + self.is_boundary_token_lut, + ) = build_sentencepiece_luts(self.sp, h.vocab_size, device) + + +def build_sentencepiece_luts(sp, vocab_size, device): + sp_vocab_size = int(sp.vocab_size()) + assert ( + sp.piece_to_id("▁") != sp.unk_id() + ), "Tokenizer must have '▁' (space) as its own token for correct BPB byte counting" + table_size = max(sp_vocab_size, vocab_size) + base_bytes_np = np.zeros((table_size,), dtype=np.int16) + has_leading_space_np = np.zeros((table_size,), dtype=np.bool_) + is_boundary_token_np = np.ones((table_size,), dtype=np.bool_) + for token_id in range(sp_vocab_size): + if sp.is_control(token_id) or sp.is_unknown(token_id) or sp.is_unused(token_id): + continue + is_boundary_token_np[token_id] = False + if sp.is_byte(token_id): + base_bytes_np[token_id] = 1 + continue + piece = sp.id_to_piece(token_id) + if piece.startswith("▁"): + has_leading_space_np[token_id] = True + piece = piece[1:] + base_bytes_np[token_id] = len(piece.encode("utf-8")) + return ( + torch.tensor(base_bytes_np, dtype=torch.int16, device=device), + torch.tensor(has_leading_space_np, dtype=torch.bool, device=device), + torch.tensor(is_boundary_token_np, dtype=torch.bool, device=device), + ) + + +def load_validation_tokens(pattern, seq_len): + files = [Path(p) for p in sorted(glob.glob(pattern))] + if not files: + raise FileNotFoundError(f"No files found for pattern: {pattern}") + tokens = torch.cat([load_data_shard(file) for file in files]).contiguous() + usable = (tokens.numel() - 1) // seq_len * seq_len + if usable <= 0: + raise ValueError(f"Validation split is too short for TRAIN_SEQ_LEN={seq_len}") + return tokens[: usable + 1] + + +def load_data_shard(file): + header_bytes = 256 * np.dtype(" 0: + pos = start + while pos < end: + seg_starts.append(pos) + pos += max_doc_len + else: + seg_starts.append(start) + boundaries = seg_starts + [total_len] + padded_len = get_next_multiple_of_n(len(boundaries), bucket_size) + cu = torch.full((padded_len,), total_len, dtype=torch.int32, device=device) + cu[: len(boundaries)] = torch.tensor(boundaries, dtype=torch.int32, device=device) + seg_ends = seg_starts[1:] + [total_len] + max_seqlen = max(end - start for start, end in zip(seg_starts, seg_ends)) + return cu, max_seqlen + +class DocumentPackingLoader: + _shard_pool = ThreadPoolExecutor(1) + + def __init__(self, h, device, cu_bucket_size=64): + self.rank = h.rank + self.world_size = h.world_size + self.device = device + self.cu_bucket_size = cu_bucket_size + self.max_seq_len = h.train_seq_len + all_files = [Path(p) for p in sorted(glob.glob(h.train_files))] + if not all_files: + raise FileNotFoundError(f"No files found for pattern: {h.train_files}") + self.files = all_files + self.file_iter = iter(self.files) + self._init_shard(load_data_shard(next(self.file_iter))) + self._next_shard = self._submit_next_shard() + self._batch_pool = ThreadPoolExecutor(1) + self._next_batch = None + + def _init_shard(self, tokens): + global BOS_ID + self.tokens = tokens + self.shard_size = tokens.numel() + if BOS_ID is None: + BOS_ID = 1 + self.bos_idx = ( + (tokens == BOS_ID).nonzero(as_tuple=True)[0].to(torch.int64).cpu().numpy() + ) + if self.bos_idx.size == 0: + self.bos_idx = np.array([0], dtype=np.int64) + self.cursor = int(self.bos_idx[0]) + + def _submit_next_shard(self): + try: + path = next(self.file_iter) + return self._shard_pool.submit(load_data_shard, path) + except StopIteration: + return None + + def _advance_shard(self): + if self._next_shard is None: + self.file_iter = iter(self.files) + self._next_shard = self._shard_pool.submit( + load_data_shard, next(self.file_iter) + ) + self._init_shard(self._next_shard.result()) + self._next_shard = self._submit_next_shard() + + def _local_doc_starts(self, local_start, total_len): + lo = np.searchsorted(self.bos_idx, local_start, side="left") + hi = np.searchsorted(self.bos_idx, local_start + total_len, side="left") + return (self.bos_idx[lo:hi] - local_start).tolist() + + def _prepare_batch(self, num_tokens_local, max_seq_len): + per_rank_span = num_tokens_local + 1 + global_span = per_rank_span * self.world_size + while self.cursor + global_span > self.shard_size: + self._advance_shard() + local_start = self.cursor + self.rank * per_rank_span + buf = self.tokens[local_start : local_start + per_rank_span] + inputs = buf[:-1].to(dtype=torch.int64).pin_memory() + targets = buf[1:].to(dtype=torch.int64).pin_memory() + starts = self._local_doc_starts(local_start, inputs.numel()) + cu_seqlens, max_seqlen = _build_cu_seqlens( + starts, inputs.numel(), inputs.device, max_seq_len, self.cu_bucket_size + ) + cu_seqlens = cu_seqlens.pin_memory() + self.cursor += global_span + return inputs, targets, cu_seqlens, max_seqlen + + def next_batch(self, global_tokens, grad_accum_steps): + num_tokens_local = global_tokens // (self.world_size * grad_accum_steps) + if self._next_batch is not None: + inputs, targets, cu_seqlens, max_seqlen = self._next_batch.result() + else: + inputs, targets, cu_seqlens, max_seqlen = self._prepare_batch( + num_tokens_local, self.max_seq_len + ) + self._next_batch = self._batch_pool.submit( + self._prepare_batch, num_tokens_local, self.max_seq_len + ) + return ( + inputs[None].to(self.device, non_blocking=True), + targets[None].to(self.device, non_blocking=True), + cu_seqlens.to(self.device, non_blocking=True), + max_seqlen, + ) + + +class ShuffledSequenceLoader: + def __init__(self, h, device): + self.world_size = h.world_size + self.seq_len = h.train_seq_len + self.device = device + all_files = [Path(p) for p in sorted(glob.glob(h.train_files))] + if not all_files: + raise FileNotFoundError(f"No files found for pattern: {h.train_files}") + self.files = all_files[h.rank :: h.world_size] + self.rng = np.random.Generator(np.random.PCG64(h.rank)) + self.num_tokens = [_read_num_tokens(f) for f in self.files] + self.start_inds = [[] for _ in self.files] + for si in range(len(self.files)): + self._reset_shard(si) + + def _reset_shard(self, si): + max_phase = min( + self.seq_len - 1, max(0, self.num_tokens[si] - self.seq_len - 1) + ) + phase = int(self.rng.integers(max_phase + 1)) if max_phase > 0 else 0 + num_sequences = (self.num_tokens[si] - 1 - phase) // self.seq_len + sequence_order = self.rng.permutation(num_sequences) + self.start_inds[si] = (phase + sequence_order * self.seq_len).tolist() + + def next_batch(self, global_tokens, grad_accum_steps): + device_tokens = global_tokens // (self.world_size * grad_accum_steps) + device_batch_size = device_tokens // self.seq_len + remaining = np.array([len(s) for s in self.start_inds], dtype=np.float64) + x = torch.empty((device_batch_size, self.seq_len), dtype=torch.int64) + y = torch.empty((device_batch_size, self.seq_len), dtype=torch.int64) + for bi in range(device_batch_size): + total = remaining.sum() + if total <= 0: + for si in range(len(self.files)): + self._reset_shard(si) + remaining = np.array( + [len(s) for s in self.start_inds], dtype=np.float64 + ) + total = remaining.sum() + probs = remaining / total + si = int(self.rng.choice(len(self.files), p=probs)) + start_ind = self.start_inds[si].pop() + remaining[si] -= 1 + mm = _get_shard_memmap(self.files[si]) + window = torch.as_tensor( + np.array(mm[start_ind : start_ind + self.seq_len + 1], dtype=np.int64) + ) + x[bi] = window[:-1] + y[bi] = window[1:] + return x.to(self.device, non_blocking=True), y.to( + self.device, non_blocking=True + ) + + +class RMSNorm(nn.Module): + def __init__(self, eps=None): + super().__init__() + self.eps = eps + + def forward(self, x): + return F.rms_norm(x, (x.size(-1),), eps=self.eps) + + +class CastedLinear(nn.Linear): + def forward(self, x): + w = self.weight.to(x.dtype) + bias = self.bias.to(x.dtype) if self.bias is not None else None + return F.linear(x, w, bias) + + +@triton.jit +def linear_leaky_relu_square_kernel( + a_desc, + b_desc, + c_desc, + aux_desc, + M, + N, + K, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + NUM_SMS: tl.constexpr, + FORWARD: tl.constexpr, +): + dtype = tl.bfloat16 + start_pid = tl.program_id(axis=0) + num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) + num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) + k_tiles = tl.cdiv(K, BLOCK_SIZE_K) + num_tiles = num_pid_m * num_pid_n + tile_id_c = start_pid - NUM_SMS + for tile_id in tl.range(start_pid, num_tiles, NUM_SMS, flatten=True): + pid_m = tile_id // num_pid_n + pid_n = tile_id % num_pid_n + offs_am = pid_m * BLOCK_SIZE_M + offs_bn = pid_n * BLOCK_SIZE_N + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + for ki in range(k_tiles): + offs_k = ki * BLOCK_SIZE_K + a = a_desc.load([offs_am, offs_k]) + b = b_desc.load([offs_bn, offs_k]) + accumulator = tl.dot(a, b.T, accumulator) + tile_id_c += NUM_SMS + offs_am_c = offs_am + offs_bn_c = offs_bn + acc = tl.reshape(accumulator, (BLOCK_SIZE_M, 2, BLOCK_SIZE_N // 2)) + acc = tl.permute(acc, (0, 2, 1)) + acc0, acc1 = tl.split(acc) + c0 = acc0.to(dtype) + c1 = acc1.to(dtype) + if not FORWARD: + pre0 = aux_desc.load([offs_am_c, offs_bn_c]) + pre1 = aux_desc.load([offs_am_c, offs_bn_c + BLOCK_SIZE_N // 2]) + c0 = c0 * tl.where(pre0 > 0, 2.0 * pre0, 0.5 * pre0) + c1 = c1 * tl.where(pre1 > 0, 2.0 * pre1, 0.5 * pre1) + c_desc.store([offs_am_c, offs_bn_c], c0) + c_desc.store([offs_am_c, offs_bn_c + BLOCK_SIZE_N // 2], c1) + if FORWARD: + aux0 = tl.where(c0 > 0, c0, 0.5 * c0) + aux1 = tl.where(c1 > 0, c1, 0.5 * c1) + aux_desc.store([offs_am_c, offs_bn_c], aux0 * aux0) + aux_desc.store([offs_am_c, offs_bn_c + BLOCK_SIZE_N // 2], aux1 * aux1) + + +def linear_leaky_relu_square(a, b, aux=None): + M, K = a.shape + N, K2 = b.shape + assert K == K2 + c = torch.empty((M, N), device=a.device, dtype=a.dtype) + forward = aux is None + if aux is None: + aux = torch.empty((M, N), device=a.device, dtype=a.dtype) + num_sms = torch.cuda.get_device_properties(a.device).multi_processor_count + BLOCK_SIZE_M, BLOCK_SIZE_N, BLOCK_SIZE_K = 128, 256, 64 + num_stages = 4 if forward else 3 + a_desc = TensorDescriptor.from_tensor(a, [BLOCK_SIZE_M, BLOCK_SIZE_K]) + b_desc = TensorDescriptor.from_tensor(b, [BLOCK_SIZE_N, BLOCK_SIZE_K]) + c_desc = TensorDescriptor.from_tensor(c, [BLOCK_SIZE_M, BLOCK_SIZE_N // 2]) + aux_desc = TensorDescriptor.from_tensor(aux, [BLOCK_SIZE_M, BLOCK_SIZE_N // 2]) + grid = lambda _meta: ( + min(num_sms, triton.cdiv(M, BLOCK_SIZE_M) * triton.cdiv(N, BLOCK_SIZE_N)), + ) + linear_leaky_relu_square_kernel[grid]( + a_desc, + b_desc, + c_desc, + aux_desc, + M, + N, + K, + BLOCK_SIZE_M=BLOCK_SIZE_M, + BLOCK_SIZE_N=BLOCK_SIZE_N, + BLOCK_SIZE_K=BLOCK_SIZE_K, + NUM_SMS=num_sms, + FORWARD=forward, + num_stages=num_stages, + num_warps=8, + ) + if forward: + return c, aux + return c + + +class FusedLinearLeakyReLUSquareFunction(torch.autograd.Function): + @staticmethod + def forward(ctx, x, w1, w2): + x_flat = x.reshape(-1, x.shape[-1]) + pre, post = linear_leaky_relu_square(x_flat, w1) + out = F.linear(post, w2) + ctx.save_for_backward(x, w1, w2, pre, post) + return out.view(*x.shape[:-1], out.shape[-1]) + + @staticmethod + def backward(ctx, grad_output): + x, w1, w2, pre, post = ctx.saved_tensors + x_flat = x.reshape(-1, x.shape[-1]) + grad_output_flat = grad_output.reshape(-1, grad_output.shape[-1]) + dw2 = grad_output_flat.T @ post + dpre = linear_leaky_relu_square(grad_output_flat, w2.T.contiguous(), aux=pre) + dw1 = dpre.T @ x_flat + dx = dpre @ w1 + return dx.view_as(x), dw1, dw2 + + +FusedLeakyReLUSquareMLP = FusedLinearLeakyReLUSquareFunction.apply + + +class Rotary(nn.Module): + def __init__(self, dim, base=1e4, train_seq_len=1024, rope_dims=0, yarn=True): + super().__init__() + self.dim = dim + self.base = base + self.train_seq_len = train_seq_len + self.yarn = yarn + self.rope_dims = rope_dims if rope_dims > 0 else dim + inv_freq = 1.0 / base ** ( + torch.arange(0, self.rope_dims, 2, dtype=torch.float32) / self.rope_dims + ) + self.register_buffer("inv_freq", inv_freq, persistent=False) + self._seq_len_cached = 0 + self._cos_cached = None + self._sin_cached = None + + def forward(self, seq_len, device, dtype): + if ( + self._cos_cached is None + or self._sin_cached is None + or self._seq_len_cached < seq_len + or self._cos_cached.device != device + ): + rd = self.rope_dims + if self.yarn and seq_len > self.train_seq_len: + scale = seq_len / self.train_seq_len + new_base = self.base * scale ** (rd / (rd - 2)) + inv_freq = 1.0 / new_base ** ( + torch.arange(0, rd, 2, dtype=torch.float32, device=device) / rd + ) + else: + inv_freq = self.inv_freq.float().to(device) + t = torch.arange(seq_len, device=device, dtype=torch.float32) + freqs = torch.outer(t, inv_freq) + self._cos_cached = freqs.cos()[None, :, None, :] + self._sin_cached = freqs.sin()[None, :, None, :] + self._seq_len_cached = seq_len + return self._cos_cached[:, :seq_len].to(dtype=dtype), self._sin_cached[:, :seq_len].to(dtype=dtype) + + +def apply_rotary_emb(x, cos, sin, rope_dims=0): + if rope_dims > 0 and rope_dims < x.size(-1): + x_rope, x_pass = x[..., :rope_dims], x[..., rope_dims:] + half = rope_dims // 2 + x1, x2 = x_rope[..., :half], x_rope[..., half:] + x_rope = torch.cat((x1 * cos + x2 * sin, x1 * -sin + x2 * cos), dim=-1) + return torch.cat((x_rope, x_pass), dim=-1) + half = x.size(-1) // 2 + x1, x2 = x[..., :half], x[..., half:] + return torch.cat((x1 * cos + x2 * sin, x1 * -sin + x2 * cos), dim=-1) + + +class CausalSelfAttention(nn.Module): + def __init__( + self, dim, num_heads, num_kv_heads, rope_base, qk_gain_init, train_seq_len, yarn=True + ): + super().__init__() + if dim % num_heads != 0: + raise ValueError("model_dim must be divisible by num_heads") + if num_heads % num_kv_heads != 0: + raise ValueError("num_heads must be divisible by num_kv_heads") + self.num_heads = num_heads + self.num_kv_heads = num_kv_heads + self.head_dim = dim // num_heads + if self.head_dim % 2 != 0: + raise ValueError("head_dim must be even for RoPE") + self.q_gain = nn.Parameter( + torch.full((num_heads,), qk_gain_init, dtype=torch.float32) + ) + self.rope_dims = 0 + self.rotary = Rotary(self.head_dim, base=rope_base, train_seq_len=train_seq_len, yarn=yarn) + self.use_xsa = False + + def _xsa_efficient(self, y, v): + B, T, H, D = y.shape + Hkv = v.size(-2) + group = H // Hkv + y_g = y.reshape(B, T, Hkv, group, D) + vn = F.normalize(v, dim=-1).unsqueeze(-2) + proj = (y_g * vn).sum(dim=-1, keepdim=True) * vn + return (y_g - proj).reshape(B, T, H, D) + + def forward(self, x, q_w, k_w, v_w, out_w, cu_seqlens=None, max_seqlen=0): + bsz, seqlen, dim = x.shape + q = F.linear(x, q_w.to(x.dtype)).reshape(bsz, seqlen, self.num_heads, self.head_dim) + k = F.linear(x, k_w.to(x.dtype)).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim) + v = F.linear(x, v_w.to(x.dtype)).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim) + q = F.rms_norm(q, (q.size(-1),)) + k = F.rms_norm(k, (k.size(-1),)) + cos, sin = self.rotary(seqlen, x.device, q.dtype) + q = apply_rotary_emb(q, cos, sin, self.rope_dims) + k = apply_rotary_emb(k, cos, sin, self.rope_dims) + q = q * self.q_gain.to(dtype=q.dtype)[None, None, :, None] + if cu_seqlens is not None: + y = flash_attn_varlen_func( + q[0], + k[0], + v[0], + cu_seqlens_q=cu_seqlens, + cu_seqlens_k=cu_seqlens, + max_seqlen_q=max_seqlen, + max_seqlen_k=max_seqlen, + causal=True, + window_size=(-1, -1), + )[None] + else: + y = flash_attn_3_func(q, k, v, causal=True) + if self.use_xsa: + y = self._xsa_efficient(y, v) + y = y.reshape(bsz, seqlen, dim) + self._last_proj_input = y.detach() if getattr(self, "_calib", False) else None + return F.linear(y, out_w.to(x.dtype)) + + +class MLP(nn.Module): + def __init__(self, dim, mlp_mult): + super().__init__() + self.use_fused = True + + def forward(self, x, up_w, down_w): + if self.training and self.use_fused: + return FusedLeakyReLUSquareMLP(x, up_w.to(x.dtype), down_w.to(x.dtype)) + hidden = F.leaky_relu(F.linear(x, up_w.to(x.dtype)), negative_slope=0.5).square() + self._last_down_input = hidden.detach() if getattr(self, "_calib", False) else None + return F.linear(hidden, down_w.to(x.dtype)) + + +class Block(nn.Module): + def __init__( + self, + dim, + num_heads, + num_kv_heads, + mlp_mult, + rope_base, + qk_gain_init, + train_seq_len, + layer_idx=0, + ln_scale=False, + yarn=True, + ): + super().__init__() + self.attn_norm = RMSNorm() + self.mlp_norm = RMSNorm() + self.attn = CausalSelfAttention( + dim, num_heads, num_kv_heads, rope_base, qk_gain_init, train_seq_len, yarn=yarn + ) + self.mlp = MLP(dim, mlp_mult) + self.attn_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + self.mlp_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + self.resid_mix = nn.Parameter( + torch.stack((torch.ones(dim), torch.zeros(dim))).float() + ) + self.ln_scale_factor = 1.0 / math.sqrt(layer_idx + 1) if ln_scale else 1.0 + + def forward(self, x, x0, q_w, k_w, v_w, out_w, up_w, down_w, cu_seqlens=None, max_seqlen=0): + mix = self.resid_mix.to(dtype=x.dtype) + x_in = mix[0][None, None, :] * x + mix[1][None, None, :] * x0 + attn_out = self.attn( + self.attn_norm(x_in) * self.ln_scale_factor, + q_w, k_w, v_w, out_w, + cu_seqlens=cu_seqlens, + max_seqlen=max_seqlen, + ) + x_out = x_in + self.attn_scale.to(dtype=x_in.dtype)[None, None, :] * attn_out + x_out = x_out + self.mlp_scale.to(dtype=x_out.dtype)[ + None, None, : + ] * self.mlp(self.mlp_norm(x_out) * self.ln_scale_factor, up_w, down_w) + return x_out + +class GPT(nn.Module): + def __init__(self, h): + super().__init__() + if h.logit_softcap <= 0.0: + raise ValueError(f"logit_softcap must be positive, got {h.logit_softcap}") + self.tie_embeddings = h.tie_embeddings + self.tied_embed_init_std = h.tied_embed_init_std + self.logit_softcap = h.logit_softcap + self.tok_emb = nn.Embedding(h.vocab_size, h.embedding_dim) + if h.embedding_dim != h.model_dim: + self.embed_proj = CastedLinear(h.embedding_dim, h.model_dim, bias=False) + self.head_proj = CastedLinear(h.model_dim, h.embedding_dim, bias=False) + else: + self.embed_proj = None + self.head_proj = None + self.num_layers = h.num_layers + head_dim = h.model_dim // h.num_heads + kv_dim = h.num_kv_heads * head_dim + hidden_dim = int(h.mlp_mult * h.model_dim) + self.qo_bank = nn.Parameter(torch.empty(2 * h.num_layers, h.model_dim, h.model_dim)) + self.kv_bank = nn.Parameter(torch.empty(2 * h.num_layers, kv_dim, h.model_dim)) + self.mlp_up_bank = nn.Parameter(torch.empty(h.num_layers, hidden_dim, h.model_dim)) + self.mlp_down_bank = nn.Parameter(torch.empty(h.num_layers, h.model_dim, hidden_dim)) + self.num_encoder_layers = h.num_layers // 2 + self.num_decoder_layers = h.num_layers - self.num_encoder_layers + self.blocks = nn.ModuleList( + [ + Block( + h.model_dim, + h.num_heads, + h.num_kv_heads, + h.mlp_mult, + h.rope_base, + h.qk_gain_init, + h.train_seq_len, + layer_idx=i, + ln_scale=h.ln_scale, + yarn=h.rope_yarn, + ) + for i in range(h.num_layers) + ] + ) + if h.rope_dims > 0: + head_dim = h.model_dim // h.num_heads + for block in self.blocks: + block.attn.rope_dims = h.rope_dims + block.attn.rotary = Rotary( + head_dim, + base=h.rope_base, + train_seq_len=h.train_seq_len, + rope_dims=h.rope_dims, + yarn=h.rope_yarn, + ) + self.final_norm = RMSNorm() + self.lm_head = ( + None + if h.tie_embeddings + else CastedLinear(h.embedding_dim, h.vocab_size, bias=False) + ) + if self.lm_head is not None: + self.lm_head._zero_init = True + if h.xsa_last_n > 0: + for i in range(max(0, h.num_layers - h.xsa_last_n), h.num_layers): + self.blocks[i].attn.use_xsa = True + self.looping_active = False + if h.num_loops > 0: + loop_seg = list(range(h.loop_start, h.loop_end + 1)) + all_indices = list(range(h.loop_start)) + for _ in range(h.num_loops + 1): + all_indices.extend(loop_seg) + all_indices.extend(range(h.loop_end + 1, h.num_layers)) + num_enc = len(all_indices) // 2 + self.encoder_indices = all_indices[:num_enc] + self.decoder_indices = all_indices[num_enc:] + else: + self.encoder_indices = list(range(self.num_encoder_layers)) + self.decoder_indices = list(range(self.num_encoder_layers, h.num_layers)) + self.num_skip_weights = min( + len(self.encoder_indices), len(self.decoder_indices) + ) + self.skip_weights = nn.Parameter( + torch.ones(self.num_skip_weights, h.model_dim, dtype=torch.float32) + ) + self.skip_gates = ( + nn.Parameter( + torch.zeros(self.num_skip_weights, h.model_dim, dtype=torch.float32) + ) + if h.skip_gates_enabled + else None + ) + self.parallel_start_layer = h.parallel_start_layer + self.parallel_final_lane = h.parallel_final_lane.lower() + self.parallel_post_lambdas = nn.Parameter( + torch.ones(h.num_layers, 2, 2, dtype=torch.float32) + ) + self.parallel_resid_lambdas = nn.Parameter( + torch.full((h.num_layers, 2), 1.1, dtype=torch.float32) + ) + self._init_weights() + + def _init_weights(self): + if self.tie_embeddings: + nn.init.normal_(self.tok_emb.weight, mean=0.0, std=self.tied_embed_init_std) + n = self.num_layers + proj_scale = 1.0 / math.sqrt(2 * n) + for i in range(n): + nn.init.orthogonal_(self.qo_bank.data[i], gain=1.0) + nn.init.zeros_(self.qo_bank.data[n + i]) + self.qo_bank.data[n + i].mul_(proj_scale) + nn.init.orthogonal_(self.kv_bank.data[i], gain=1.0) + nn.init.orthogonal_(self.kv_bank.data[n + i], gain=1.0) + nn.init.orthogonal_(self.mlp_up_bank.data[i], gain=1.0) + nn.init.zeros_(self.mlp_down_bank.data[i]) + self.mlp_down_bank.data[i].mul_(proj_scale) + for name, module in self.named_modules(): + if isinstance(module, nn.Linear): + if getattr(module, "_zero_init", False): + nn.init.zeros_(module.weight) + elif ( + module.weight.ndim == 2 + and module.weight.shape[0] >= 64 + and module.weight.shape[1] >= 64 + ): + nn.init.orthogonal_(module.weight, gain=1.0) + + def _bank_weights(self, i): + n = self.num_layers + return ( + self.qo_bank[i], + self.kv_bank[i], + self.kv_bank[n + i], + self.qo_bank[n + i], + self.mlp_up_bank[i], + self.mlp_down_bank[i], + ) + + def _parallel_block( + self, block_idx, lane0, lane1, x0, + q_w, k_w, v_w, out_w, up_w, down_w, + cu_seqlens=None, max_seqlen=0, + ): + block = self.blocks[block_idx] + mix = block.resid_mix.to(dtype=lane0.dtype) + attn_read = mix[0][None, None, :] * lane0 + mix[1][None, None, :] * x0 + attn_out = block.attn( + block.attn_norm(attn_read) * block.ln_scale_factor, + q_w, k_w, v_w, out_w, + cu_seqlens=cu_seqlens, max_seqlen=max_seqlen, + ) + attn_out = block.attn_scale.to(dtype=attn_out.dtype)[None, None, :] * attn_out + mlp_read = lane1 + mlp_out = block.mlp_scale.to(dtype=lane1.dtype)[None, None, :] * block.mlp( + block.mlp_norm(mlp_read) * block.ln_scale_factor, up_w, down_w + ) + attn_resid = self.parallel_resid_lambdas[block_idx, 0].to(dtype=lane0.dtype) + attn_post = self.parallel_post_lambdas[block_idx, 0].to(dtype=lane0.dtype) + mlp_resid = self.parallel_resid_lambdas[block_idx, 1].to(dtype=lane0.dtype) + mlp_post = self.parallel_post_lambdas[block_idx, 1].to(dtype=lane0.dtype) + lane0 = attn_resid * lane0 + attn_post[0] * attn_out + mlp_post[0] * mlp_out + lane1 = mlp_resid * lane1 + attn_post[1] * attn_out + mlp_post[1] * mlp_out + return lane0, lane1 + + def _final_parallel_hidden(self, lane0, lane1): + if self.parallel_final_lane == "mlp": + return lane1 + if self.parallel_final_lane == "attn": + return lane0 + return 0.5 * (lane0 + lane1) + + def forward_logits(self, input_ids, cu_seqlens=None, max_seqlen=0): + x = self.tok_emb(input_ids) + x = F.rms_norm(x, (x.size(-1),)) + if self.embed_proj is not None: + x = self.embed_proj(x) + x0 = x + skips = [] + enc_iter = ( + self.encoder_indices + if self.looping_active + else range(self.num_encoder_layers) + ) + dec_iter = ( + self.decoder_indices + if self.looping_active + else range( + self.num_encoder_layers, + self.num_encoder_layers + self.num_decoder_layers, + ) + ) + for i in enc_iter: + q_w, k_w, v_w, out_w, up_w, down_w = self._bank_weights(i) + x = self.blocks[i](x, x0, q_w, k_w, v_w, out_w, up_w, down_w, cu_seqlens=cu_seqlens, max_seqlen=max_seqlen) + skips.append(x) + psl = self.parallel_start_layer + lane0 = None + lane1 = None + for skip_idx, i in enumerate(dec_iter): + q_w, k_w, v_w, out_w, up_w, down_w = self._bank_weights(i) + if i >= psl and psl > 0: + if lane0 is None: + lane0 = x + lane1 = x + if skip_idx < self.num_skip_weights and skips: + skip = skips.pop() + w = self.skip_weights[skip_idx].to(dtype=lane0.dtype)[None, None, :] + if self.skip_gates is not None: + g = torch.sigmoid(self.skip_gates[skip_idx].to(dtype=lane0.dtype))[None, None, :] + lane0 = torch.lerp(w * skip, lane0, g) + else: + lane0 = lane0 + w * skip + lane0, lane1 = self._parallel_block( + i, lane0, lane1, x0, q_w, k_w, v_w, out_w, up_w, down_w, + cu_seqlens=cu_seqlens, max_seqlen=max_seqlen, + ) + else: + if skip_idx < self.num_skip_weights and skips: + scaled_skip = ( + self.skip_weights[skip_idx].to(dtype=x.dtype)[None, None, :] + * skips.pop() + ) + if self.skip_gates is not None: + g = torch.sigmoid(self.skip_gates[skip_idx].to(dtype=x.dtype))[None, None, :] + x = torch.lerp(scaled_skip, x, g) + else: + x = x + scaled_skip + x = self.blocks[i](x, x0, q_w, k_w, v_w, out_w, up_w, down_w, cu_seqlens=cu_seqlens, max_seqlen=max_seqlen) + if lane0 is not None: + x = self._final_parallel_hidden(lane0, lane1) + x = self.final_norm(x) + if self.head_proj is not None: + x = self.head_proj(x) + if self.tie_embeddings: + logits_proj = F.linear(x, self.tok_emb.weight) + else: + logits_proj = self.lm_head(x) + return self.logit_softcap * torch.tanh(logits_proj / self.logit_softcap) + + def forward(self, input_ids, target_ids, cu_seqlens=None, max_seqlen=0): + logits = self.forward_logits( + input_ids, cu_seqlens=cu_seqlens, max_seqlen=max_seqlen + ) + return F.cross_entropy( + logits.reshape(-1, logits.size(-1)).float(), + target_ids.reshape(-1), + reduction="mean", + ) + + def forward_ttt(self, input_ids, target_ids, lora): + x = self.tok_emb(input_ids) + x = F.rms_norm(x, (x.size(-1),)) + if self.embed_proj is not None: + x = self.embed_proj(x) + x0 = x + skips = [] + enc_iter = ( + self.encoder_indices + if self.looping_active + else list(range(self.num_encoder_layers)) + ) + dec_iter = ( + self.decoder_indices + if self.looping_active + else list( + range( + self.num_encoder_layers, + self.num_encoder_layers + self.num_decoder_layers, + ) + ) + ) + slot = 0 + for i in enc_iter: + q_w, k_w, v_w, out_w, up_w, down_w = self._bank_weights(i) + x = self._block_with_lora(self.blocks[i], x, x0, lora, slot, q_w, k_w, v_w, out_w, up_w, down_w) + slot += 1 + skips.append(x) + psl = self.parallel_start_layer + lane0 = None + lane1 = None + for skip_idx, i in enumerate(dec_iter): + q_w, k_w, v_w, out_w, up_w, down_w = self._bank_weights(i) + if i >= psl and psl > 0: + if lane0 is None: + lane0 = x + lane1 = x + if skip_idx < self.num_skip_weights and skips: + skip = skips.pop() + w = self.skip_weights[skip_idx].to(dtype=lane0.dtype)[None, None, :] + if self.skip_gates is not None: + g = torch.sigmoid(self.skip_gates[skip_idx].to(dtype=lane0.dtype))[None, None, :] + lane0 = torch.lerp(w * skip, lane0, g) + else: + lane0 = lane0 + w * skip + lane0, lane1 = self._parallel_block_with_lora( + i, lane0, lane1, x0, lora, slot, + q_w, k_w, v_w, out_w, up_w, down_w, + ) + else: + if skip_idx < self.num_skip_weights and skips: + scaled_skip = ( + self.skip_weights[skip_idx].to(dtype=x.dtype)[None, None, :] + * skips.pop() + ) + if self.skip_gates is not None: + g = torch.sigmoid(self.skip_gates[skip_idx].to(dtype=x.dtype))[None, None, :] + x = torch.lerp(scaled_skip, x, g) + else: + x = x + scaled_skip + x = self._block_with_lora(self.blocks[i], x, x0, lora, slot, q_w, k_w, v_w, out_w, up_w, down_w) + slot += 1 + if lane0 is not None: + x = self._final_parallel_hidden(lane0, lane1) + x = self.final_norm(x) + if self.head_proj is not None: + x = self.head_proj(x) + if self.tie_embeddings: + logits = F.linear(x, self.tok_emb.weight) + else: + logits = self.lm_head(x) + logits = logits + lora.lm_head_lora(x) + logits = self.logit_softcap * torch.tanh(logits / self.logit_softcap) + bsz, sl, V = logits.shape + return F.cross_entropy( + logits.float().reshape(-1, V), target_ids.reshape(-1), reduction="none" + ).reshape(bsz, sl) + + def _block_with_lora(self, block, x, x0, lora, slot, q_w, k_w, v_w, out_w, up_w, down_w): + mix = block.resid_mix.to(dtype=x.dtype) + x_in = mix[0][None, None, :] * x + mix[1][None, None, :] * x0 + n = block.attn_norm(x_in) * block.ln_scale_factor + attn = block.attn + bsz, seqlen, dim = n.shape + q = (F.linear(n, q_w.to(n.dtype)) + lora.q_loras[slot](n)).reshape( + bsz, seqlen, attn.num_heads, attn.head_dim + ) + k = F.linear(n, k_w.to(n.dtype)) + if lora.k_loras is not None: + k = k + lora.k_loras[slot](n) + k = k.reshape(bsz, seqlen, attn.num_kv_heads, attn.head_dim) + v = (F.linear(n, v_w.to(n.dtype)) + lora.v_loras[slot](n)).reshape( + bsz, seqlen, attn.num_kv_heads, attn.head_dim + ) + q = F.rms_norm(q, (q.size(-1),)) + k = F.rms_norm(k, (k.size(-1),)) + cos, sin = attn.rotary(seqlen, n.device, q.dtype) + q = apply_rotary_emb(q, cos, sin, attn.rope_dims) + k = apply_rotary_emb(k, cos, sin, attn.rope_dims) + q = q * attn.q_gain.to(dtype=q.dtype)[None, None, :, None] + y = flash_attn_3_func(q, k, v, causal=True) + if attn.use_xsa: + y = attn._xsa_efficient(y, v) + y = y.reshape(bsz, seqlen, dim) + attn_out = F.linear(y, out_w.to(n.dtype)) + if lora.o_loras is not None: + attn_out = attn_out + lora.o_loras[slot](n) + x_out = x_in + block.attn_scale.to(dtype=x_in.dtype)[None, None, :] * attn_out + mlp_n = block.mlp_norm(x_out) * block.ln_scale_factor + mlp_out = block.mlp(mlp_n, up_w, down_w) + if lora.mlp_loras is not None: + mlp_out = mlp_out + lora.mlp_loras[slot](mlp_n) + x_out = x_out + block.mlp_scale.to(dtype=x_out.dtype)[None, None, :] * mlp_out + return x_out + + def _parallel_block_with_lora( + self, block_idx, lane0, lane1, x0, lora, slot, + q_w, k_w, v_w, out_w, up_w, down_w, + ): + block = self.blocks[block_idx] + mix = block.resid_mix.to(dtype=lane0.dtype) + attn_read = mix[0][None, None, :] * lane0 + mix[1][None, None, :] * x0 + n = block.attn_norm(attn_read) * block.ln_scale_factor + attn = block.attn + bsz, seqlen, dim = n.shape + q = (F.linear(n, q_w.to(n.dtype)) + lora.q_loras[slot](n)).reshape( + bsz, seqlen, attn.num_heads, attn.head_dim + ) + k = F.linear(n, k_w.to(n.dtype)) + if lora.k_loras is not None: + k = k + lora.k_loras[slot](n) + k = k.reshape(bsz, seqlen, attn.num_kv_heads, attn.head_dim) + v = (F.linear(n, v_w.to(n.dtype)) + lora.v_loras[slot](n)).reshape( + bsz, seqlen, attn.num_kv_heads, attn.head_dim + ) + q = F.rms_norm(q, (q.size(-1),)) + k = F.rms_norm(k, (k.size(-1),)) + cos, sin = attn.rotary(seqlen, n.device, q.dtype) + q = apply_rotary_emb(q, cos, sin, attn.rope_dims) + k = apply_rotary_emb(k, cos, sin, attn.rope_dims) + q = q * attn.q_gain.to(dtype=q.dtype)[None, None, :, None] + y = flash_attn_3_func(q, k, v, causal=True) + if attn.use_xsa: + y = attn._xsa_efficient(y, v) + y = y.reshape(bsz, seqlen, dim) + attn_out = F.linear(y, out_w.to(n.dtype)) + if lora.o_loras is not None: + attn_out = attn_out + lora.o_loras[slot](n) + attn_out = block.attn_scale.to(dtype=attn_out.dtype)[None, None, :] * attn_out + mlp_read = lane1 + mlp_n = block.mlp_norm(mlp_read) * block.ln_scale_factor + mlp_out = block.mlp(mlp_n, up_w, down_w) + if lora.mlp_loras is not None: + mlp_out = mlp_out + lora.mlp_loras[slot](mlp_n) + mlp_out = block.mlp_scale.to(dtype=lane1.dtype)[None, None, :] * mlp_out + attn_resid = self.parallel_resid_lambdas[block_idx, 0].to(dtype=lane0.dtype) + attn_post = self.parallel_post_lambdas[block_idx, 0].to(dtype=lane0.dtype) + mlp_resid = self.parallel_resid_lambdas[block_idx, 1].to(dtype=lane0.dtype) + mlp_post = self.parallel_post_lambdas[block_idx, 1].to(dtype=lane0.dtype) + lane0 = attn_resid * lane0 + attn_post[0] * attn_out + mlp_post[0] * mlp_out + lane1 = mlp_resid * lane1 + attn_post[1] * attn_out + mlp_post[1] * mlp_out + return lane0, lane1 + + +class BatchedLinearLoRA(nn.Module): + def __init__(self, bsz, in_features, out_features, rank): + super().__init__() + self._bound = 1.0 / math.sqrt(in_features) + self.A = nn.Parameter( + torch.empty(bsz, rank, in_features).uniform_(-self._bound, self._bound) + ) + self.B = nn.Parameter(torch.zeros(bsz, out_features, rank)) + + def reset(self): + with torch.no_grad(): + self.A.uniform_(-self._bound, self._bound) + self.B.zero_() + + def forward(self, x): + return (x @ self.A.transpose(1, 2)) @ self.B.transpose(1, 2) + + +class BatchedTTTLoRA(nn.Module): + def __init__(self, bsz, model, rank, k_lora=True, mlp_lora=True, o_lora=True): + super().__init__() + self.bsz = bsz + dim = model.qo_bank.shape[-1] + vocab = model.tok_emb.num_embeddings + if getattr(model, "looping_active", False): + num_slots = len(model.encoder_indices) + len(model.decoder_indices) + else: + num_slots = len(model.blocks) + kv_dim = model.blocks[0].attn.num_kv_heads * ( + dim // model.blocks[0].attn.num_heads + ) + embed_dim = model.tok_emb.embedding_dim + self.lm_head_lora = BatchedLinearLoRA(bsz, embed_dim, vocab, rank) + self.q_loras = nn.ModuleList( + [BatchedLinearLoRA(bsz, dim, dim, rank) for _ in range(num_slots)] + ) + self.v_loras = nn.ModuleList( + [BatchedLinearLoRA(bsz, dim, kv_dim, rank) for _ in range(num_slots)] + ) + self.k_loras = ( + nn.ModuleList( + [BatchedLinearLoRA(bsz, dim, kv_dim, rank) for _ in range(num_slots)] + ) + if k_lora + else None + ) + self.mlp_loras = ( + nn.ModuleList( + [BatchedLinearLoRA(bsz, dim, dim, rank) for _ in range(num_slots)] + ) + if mlp_lora + else None + ) + self.o_loras = ( + nn.ModuleList( + [BatchedLinearLoRA(bsz, dim, dim, rank) for _ in range(num_slots)] + ) + if o_lora + else None + ) + + def reset(self): + with torch.no_grad(): + self.lm_head_lora.reset() + for loras in [self.q_loras, self.v_loras, self.k_loras, + self.mlp_loras, self.o_loras]: + if loras is not None: + for lora in loras: + lora.reset() + + +def classify_param(name): + if "tok_emb" in name or "lm_head" in name: + return "embed" + if ".mlp." in name: + return "mlp" + if ".attn." in name or ".proj." in name and ".mlp." not in name: + return "attn" + return "other" + + +@torch.compile +def zeropower_via_newtonschulz5(G, steps=10, eps=1e-07): + a, b, c = 3.4445, -4.775, 2.0315 + was_2d = G.ndim == 2 + if was_2d: + G = G.unsqueeze(0) + X = G.bfloat16() + transposed = X.size(-2) > X.size(-1) + if transposed: + X = X.mT + X = X / (X.norm(dim=(-2, -1), keepdim=True) + eps) + for _ in range(steps): + A = X @ X.mT + B = b * A + c * (A @ A) + X = a * X + B @ X + if transposed: + X = X.mT + if was_2d: + X = X.squeeze(0) + return X + + +class Muon(torch.optim.Optimizer): + def __init__( + self, + params, + lr, + momentum, + backend_steps, + nesterov=True, + weight_decay=0.0, + row_normalize=False, + ): + super().__init__( + params, + dict( + lr=lr, + momentum=momentum, + backend_steps=backend_steps, + nesterov=nesterov, + weight_decay=weight_decay, + row_normalize=row_normalize, + ), + ) + self._built = False + + def _build(self): + self._distributed = dist.is_available() and dist.is_initialized() + self._world_size = dist.get_world_size() if self._distributed else 1 + self._rank = dist.get_rank() if self._distributed else 0 + ws = self._world_size + self._bank_meta = [] + for group in self.param_groups: + for p in group["params"]: + B = p.shape[0] + padded_B = ((B + ws - 1) // ws) * ws + shard_B = padded_B // ws + tail = p.shape[1:] + dev = p.device + self._bank_meta.append({ + "p": p, + "B": B, + "padded_grad": torch.zeros(padded_B, *tail, device=dev, dtype=torch.bfloat16), + "shard": torch.zeros(shard_B, *tail, device=dev, dtype=torch.bfloat16), + "shard_mom": torch.zeros(shard_B, *tail, device=dev, dtype=torch.bfloat16), + "full_update": torch.zeros(padded_B, *tail, device=dev, dtype=torch.bfloat16), + "scale": max(1, p.shape[-2] / p.shape[-1]) ** 0.5, + }) + self._bank_meta.sort(key=lambda m: -m["p"].numel()) + self._built = True + + def launch_reduce_scatters(self): + if not self._built: + self._build() + if not self._distributed: + return + self._rs_futures = [] + for m in self._bank_meta: + p = m["p"] + if p.grad is None: + self._rs_futures.append(None) + continue + pg = m["padded_grad"] + pg[: m["B"]].copy_(p.grad.bfloat16()) + if pg.shape[0] > m["B"]: + pg[m["B"] :].zero_() + fut = dist.reduce_scatter_tensor( + m["shard"], pg, op=dist.ReduceOp.AVG, async_op=True + ) + self._rs_futures.append(fut) + + @torch.no_grad() + def step(self, closure=None): + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + if not self._built: + self._build() + for group in self.param_groups: + lr = group["lr"] + momentum = group["momentum"] + backend_steps = group["backend_steps"] + nesterov = group["nesterov"] + wd = group.get("weight_decay", 0.0) + row_normalize = group.get("row_normalize", False) + prev_ag_handle = None + prev_m = None + sharded = self._distributed and hasattr(self, "_rs_futures") + for idx, m in enumerate(self._bank_meta): + p = m["p"] + if p.grad is None: + continue + if prev_ag_handle is not None: + prev_ag_handle.wait() + pp = prev_m["p"] + upd = prev_m["full_update"][: prev_m["B"]] + if wd > 0.0: + pp.data.mul_(1.0 - lr * wd) + pp.add_(upd.to(dtype=pp.dtype), alpha=-lr * prev_m["scale"]) + if sharded and self._rs_futures[idx] is not None: + self._rs_futures[idx].wait() + g = m["shard"] + buf = m["shard_mom"] + else: + g = p.grad.bfloat16() + state = self.state[p] + if "momentum_buffer" not in state: + state["momentum_buffer"] = torch.zeros_like(g) + buf = state["momentum_buffer"] + buf.mul_(momentum).add_(g) + if nesterov: + update = g.add(buf, alpha=momentum) + else: + update = buf + if row_normalize: + rn = update.float().norm(dim=-1, keepdim=True).clamp_min(1e-07) + update = update / rn.to(update.dtype) + update = zeropower_via_newtonschulz5(update, steps=backend_steps) + if sharded: + prev_ag_handle = dist.all_gather_into_tensor( + m["full_update"], update, async_op=True + ) + prev_m = m + else: + if wd > 0.0: + p.data.mul_(1.0 - lr * wd) + p.add_(update.to(dtype=p.dtype), alpha=-lr * m["scale"]) + if prev_ag_handle is not None: + prev_ag_handle.wait() + pp = prev_m["p"] + upd = prev_m["full_update"][: prev_m["B"]] + if wd > 0.0: + pp.data.mul_(1.0 - lr * wd) + pp.add_(upd.to(dtype=pp.dtype), alpha=-lr * prev_m["scale"]) + if hasattr(self, "_rs_futures"): + del self._rs_futures + return loss + + +CONTROL_TENSOR_NAME_PATTERNS = tuple( + pattern + for pattern in os.environ.get( + "CONTROL_TENSOR_NAME_PATTERNS", + "attn_scale,attn_scales,mlp_scale,mlp_scales,resid_mix,resid_mixes,q_gain,skip_weight,skip_weights,skip_gates,parallel_post_lambdas,parallel_resid_lambdas", + ).split(",") + if pattern +) + + +PACKED_REPLICATED_GRAD_MAX_NUMEL = 1 << 15 + + +class Optimizers: + def __init__(self, h, base_model): + matrix_params = [ + base_model.qo_bank, + base_model.kv_bank, + base_model.mlp_up_bank, + base_model.mlp_down_bank, + ] + block_named_params = list(base_model.blocks.named_parameters()) + scalar_params = [ + p + for (name, p) in block_named_params + if p.ndim < 2 + or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) + ] + if base_model.skip_weights.numel() > 0: + scalar_params.append(base_model.skip_weights) + if base_model.skip_gates is not None and base_model.skip_gates.numel() > 0: + scalar_params.append(base_model.skip_gates) + if base_model.parallel_post_lambdas is not None: + scalar_params.append(base_model.parallel_post_lambdas) + if base_model.parallel_resid_lambdas is not None: + scalar_params.append(base_model.parallel_resid_lambdas) + token_lr = h.tied_embed_lr if h.tie_embeddings else h.embed_lr + tok_params = [ + {"params": [base_model.tok_emb.weight], "lr": token_lr, "base_lr": token_lr} + ] + self.optimizer_tok = torch.optim.AdamW( + tok_params, + betas=(h.beta1, h.beta2), + eps=h.adam_eps, + weight_decay=h.embed_wd, + fused=True, + ) + self.optimizer_muon = Muon( + matrix_params, + lr=h.matrix_lr, + momentum=h.muon_momentum, + backend_steps=h.muon_backend_steps, + weight_decay=h.muon_wd, + row_normalize=h.muon_row_normalize, + ) + for group in self.optimizer_muon.param_groups: + group["base_lr"] = h.matrix_lr + self.optimizer_scalar = torch.optim.AdamW( + [{"params": scalar_params, "lr": h.scalar_lr, "base_lr": h.scalar_lr}], + betas=(h.beta1, h.beta2), + eps=h.adam_eps, + weight_decay=h.adam_wd, + fused=True, + ) + self.optimizers = [ + self.optimizer_tok, + self.optimizer_muon, + self.optimizer_scalar, + ] + if base_model.lm_head is not None: + self.optimizer_head = torch.optim.Adam( + [ + { + "params": [base_model.lm_head.weight], + "lr": h.head_lr, + "base_lr": h.head_lr, + } + ], + betas=(h.beta1, h.beta2), + eps=h.adam_eps, + fused=True, + ) + self.optimizers.insert(1, self.optimizer_head) + else: + self.optimizer_head = None + self.replicated_params = list(tok_params[0]["params"]) + self.replicated_params.extend(scalar_params) + if base_model.lm_head is not None: + self.replicated_params.append(base_model.lm_head.weight) + self.replicated_large_params = [] + self.replicated_packed_params = [] + for p in self.replicated_params: + if p.numel() <= PACKED_REPLICATED_GRAD_MAX_NUMEL: + self.replicated_packed_params.append(p) + else: + self.replicated_large_params.append(p) + + def __iter__(self): + return iter(self.optimizers) + + def zero_grad_all(self): + for opt in self.optimizers: + opt.zero_grad(set_to_none=True) + + def _all_reduce_packed_grads(self): + grads_by_key = collections.defaultdict(list) + for p in self.replicated_packed_params: + if p.grad is not None: + grads_by_key[(p.grad.device, p.grad.dtype)].append(p.grad) + for grads in grads_by_key.values(): + flat = torch.empty( + sum(g.numel() for g in grads), + device=grads[0].device, + dtype=grads[0].dtype, + ) + offset = 0 + for g in grads: + n = g.numel() + flat[offset : offset + n].copy_(g.contiguous().view(-1)) + offset += n + dist.all_reduce(flat, op=dist.ReduceOp.AVG) + offset = 0 + for g in grads: + n = g.numel() + g.copy_(flat[offset : offset + n].view_as(g)) + offset += n + + def step(self, distributed=False): + self.optimizer_muon.launch_reduce_scatters() + if distributed: + reduce_handles = [ + dist.all_reduce(p.grad, op=dist.ReduceOp.AVG, async_op=True) + for p in self.replicated_large_params + if p.grad is not None + ] + self._all_reduce_packed_grads() + for handle in reduce_handles: + handle.wait() + self.optimizer_tok.step() + self.optimizer_scalar.step() + if self.optimizer_head is not None: + self.optimizer_head.step() + self.optimizer_muon.step() + self.zero_grad_all() + + +def restore_fp32_params(model): + for module in model.modules(): + if isinstance(module, CastedLinear): + module.float() + for name, param in model.named_parameters(): + if ( + param.ndim < 2 + or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) + ) and param.dtype != torch.float32: + param.data = param.data.float() + if hasattr(model, "qo_bank"): + model.qo_bank.data = model.qo_bank.data.float() + model.kv_bank.data = model.kv_bank.data.float() + model.mlp_up_bank.data = model.mlp_up_bank.data.float() + model.mlp_down_bank.data = model.mlp_down_bank.data.float() + + +def collect_hessians(model, train_loader, h, device, n_calibration_batches=64): + hessians = {} + hooks = [] + for i, block in enumerate(model.blocks): + block.attn._calib = True + block.mlp._calib = True + block.mlp.use_fused = False + + def make_attn_hook(layer_idx): + def hook_fn(module, inp, out): + x = inp[0].detach().float() + if x.ndim == 3: + x = x.reshape(-1, x.shape[-1]) + for suffix in ["c_q", "c_k", "c_v"]: + name = f"blocks.{layer_idx}.attn.{suffix}.weight" + if name not in hessians: + hessians[name] = torch.zeros( + x.shape[1], x.shape[1], dtype=torch.float32, device=device + ) + hessians[name].addmm_(x.T, x) + y = module._last_proj_input + if y is not None: + y = y.float() + if y.ndim == 3: + y = y.reshape(-1, y.shape[-1]) + name = f"blocks.{layer_idx}.attn.proj.weight" + if name not in hessians: + hessians[name] = torch.zeros( + y.shape[1], y.shape[1], dtype=torch.float32, device=device + ) + hessians[name].addmm_(y.T, y) + return hook_fn + + def make_mlp_hook(layer_idx): + def hook_fn(module, inp, out): + x = inp[0].detach().float() + if x.ndim == 3: + x = x.reshape(-1, x.shape[-1]) + name = f"blocks.{layer_idx}.mlp.fc.weight" + if name not in hessians: + hessians[name] = torch.zeros( + x.shape[1], x.shape[1], dtype=torch.float32, device=device + ) + hessians[name].addmm_(x.T, x) + h_act = module._last_down_input + if h_act is not None: + h_act = h_act.float() + if h_act.ndim == 3: + h_act = h_act.reshape(-1, h_act.shape[-1]) + name = f"blocks.{layer_idx}.mlp.proj.weight" + if name not in hessians: + hessians[name] = torch.zeros( + h_act.shape[1], h_act.shape[1], dtype=torch.float32, device=device + ) + hessians[name].addmm_(h_act.T, h_act) + return hook_fn + + for i, block in enumerate(model.blocks): + hooks.append(block.attn.register_forward_hook(make_attn_hook(i))) + hooks.append(block.mlp.register_forward_hook(make_mlp_hook(i))) + if model.tie_embeddings: + hook_module = ( + model.head_proj if model.head_proj is not None else model.final_norm + ) + + def make_output_hook(name): + def hook_fn(module, inp, out): + x = out.detach().float() + if x.ndim == 3: + x = x.reshape(-1, x.shape[-1]) + if name not in hessians: + hessians[name] = torch.zeros( + x.shape[1], x.shape[1], dtype=torch.float32, device=device + ) + hessians[name].addmm_(x.T, x) + return hook_fn + + hooks.append( + hook_module.register_forward_hook(make_output_hook("tok_emb.weight")) + ) + model.eval() + with torch.no_grad(): + for _ in range(n_calibration_batches): + x, _ = train_loader.next_batch(h.train_batch_tokens, h.grad_accum_steps) + model.forward_logits(x) + for hook in hooks: + hook.remove() + for i, block in enumerate(model.blocks): + block.attn._calib = False + block.mlp._calib = False + block.mlp.use_fused = True + for name in hessians: + hessians[name] = hessians[name].cpu() / n_calibration_batches + return hessians + + +def gptq_quantize_weight(w, H, clip_sigmas=3.0, clip_range=63, block_size=128): + W_orig = w.float().clone() + rows, cols = W_orig.shape + H = H.float().clone() + dead = torch.diag(H) == 0 + H[dead, dead] = 1 + damp = 0.01 * H.diag().mean() + H.diagonal().add_(damp) + perm = torch.argsort(H.diag(), descending=True) + invperm = torch.argsort(perm) + W_perm = W_orig[:, perm].clone() + W_perm[:, dead[perm]] = 0 + H = H[perm][:, perm] + Hinv = torch.cholesky_inverse(torch.linalg.cholesky(H)) + Hinv = torch.linalg.cholesky(Hinv, upper=True) + row_std = W_orig.std(dim=1) + s = (clip_sigmas * row_std / clip_range).clamp_min(1e-10).to(torch.float16) + sf = s.float() + Q = torch.zeros(rows, cols, dtype=torch.int8) + W_work = W_perm.clone() + for i1 in range(0, cols, block_size): + i2 = min(i1 + block_size, cols) + W_block = W_work[:, i1:i2].clone() + Hinv_block = Hinv[i1:i2, i1:i2] + Err = torch.zeros(rows, i2 - i1) + for j in range(i2 - i1): + w_col = W_block[:, j] + d = Hinv_block[j, j] + q_col = torch.clamp(torch.round(w_col / sf), -clip_range, clip_range) + Q[:, i1 + j] = q_col.to(torch.int8) + err = (w_col - q_col.float() * sf) / d + Err[:, j] = err + W_block[:, j:] -= err.unsqueeze(1) * Hinv_block[j, j:].unsqueeze(0) + if i2 < cols: + W_work[:, i2:] -= Err @ Hinv[i1:i2, i2:] + return Q[:, invperm], s + + +def gptq_mixed_quantize(state_dict, hessians, h): + result = {} + meta = {} + for (name, tensor) in state_dict.items(): + t = tensor.detach().cpu().contiguous() + if not t.is_floating_point() or t.numel() <= 65536: + result[name] = t.to(torch.float16) if t.is_floating_point() else t + meta[name] = "passthrough (float16)" + continue + cs = h.embed_clip_sigmas if "tok_emb" in name else h.matrix_clip_sigmas + bits = h.embed_bits if "tok_emb" in name else h.matrix_bits + q, s = gptq_quantize_weight( + t, hessians[name], clip_sigmas=cs, clip_range=2 ** (bits - 1) - 1 + ) + result[name + ".q"] = q + result[name + ".scale"] = s + meta[name] = f"gptq (int{bits})" + categories = collections.defaultdict(set) + for (name, cat) in meta.items(): + short = re.sub("\\.\\d+$", "", re.sub("blocks\\.\\d+", "blocks", name)) + categories[cat].add(short) + log("Quantized weights:") + for cat in sorted(categories): + log(f" {cat}: {', '.join(sorted(categories[cat]))}") + return result, meta + + +def dequantize_mixed(result, meta, template_sd): + out = {} + for (name, orig) in template_sd.items(): + info = meta.get(name) + if info is None: + continue + orig_dtype = orig.dtype + if "passthrough" in info: + t = result[name] + if t.dtype == torch.float16 and orig_dtype in ( + torch.float32, + torch.bfloat16, + ): + t = t.to(orig_dtype) + out[name] = t + continue + q, s = result[name + ".q"], result[name + ".scale"] + if s.ndim > 0: + out[name] = ( + q.float() * s.float().view(q.shape[0], *[1] * (q.ndim - 1)) + ).to(orig_dtype) + else: + out[name] = (q.float() * float(s.item())).to(orig_dtype) + return out + + +_BSHF_MAGIC = b"BSHF" + + +def _byte_shuffle(data, stride=2): + if stride <= 1 or len(data) < stride: + return data + src = np.frombuffer(data, dtype=np.uint8) + n = len(src) + out = np.empty(n, dtype=np.uint8) + dest_off = 0 + for pos in range(stride): + chunk = src[pos::stride] + out[dest_off : dest_off + len(chunk)] = chunk + dest_off += len(chunk) + return _BSHF_MAGIC + bytes([stride]) + out.tobytes() + + +def _byte_unshuffle(data): + if len(data) < 5 or data[:4] != _BSHF_MAGIC: + return data + stride = data[4] + if stride < 2: + return data[5:] + payload = np.frombuffer(data, dtype=np.uint8, offset=5) + n = len(payload) + out = np.empty(n, dtype=np.uint8) + src_off = 0 + for pos in range(stride): + chunk_len = n // stride + (1 if pos < n % stride else 0) + out[pos::stride][:chunk_len] = payload[src_off : src_off + chunk_len] + src_off += chunk_len + return out.tobytes() + + +def _compress(data, compressor): + data = _byte_shuffle(data) + if compressor == "lzma": + return lzma.compress(data, preset=6) + elif compressor == "brotli": + import brotli + + return brotli.compress(data, quality=11) + raise ValueError(f"Unknown compressor: {compressor!r}") + + +def _decompress(data, compressor): + if compressor == "lzma": + raw = lzma.decompress(data) + elif compressor == "brotli": + import brotli + + raw = brotli.decompress(data) + else: + raise ValueError(f"Unknown compressor: {compressor!r}") + raw = _byte_unshuffle(raw) + return raw + + +def _unbank_state_dict(state_dict, num_layers): + sd = {} + n = num_layers + for k, v in state_dict.items(): + t = v.detach().cpu() + if k == "qo_bank": + for i in range(n): + sd[f"blocks.{i}.attn.c_q.weight"] = t[i] + sd[f"blocks.{i}.attn.proj.weight"] = t[n + i] + elif k == "kv_bank": + for i in range(n): + sd[f"blocks.{i}.attn.c_k.weight"] = t[i] + sd[f"blocks.{i}.attn.c_v.weight"] = t[n + i] + elif k == "mlp_up_bank": + for i in range(n): + sd[f"blocks.{i}.mlp.fc.weight"] = t[i] + elif k == "mlp_down_bank": + for i in range(n): + sd[f"blocks.{i}.mlp.proj.weight"] = t[i] + else: + sd[k] = t + return sd + + +def _rebank_state_dict(flat_sd, num_layers, model_dim, kv_dim, hidden_dim): + sd = {} + n = num_layers + sd["qo_bank"] = torch.zeros(2 * n, model_dim, model_dim) + sd["kv_bank"] = torch.zeros(2 * n, kv_dim, model_dim) + sd["mlp_up_bank"] = torch.zeros(n, hidden_dim, model_dim) + sd["mlp_down_bank"] = torch.zeros(n, model_dim, hidden_dim) + for i in range(n): + sd["qo_bank"][i] = flat_sd[f"blocks.{i}.attn.c_q.weight"] + sd["qo_bank"][n + i] = flat_sd[f"blocks.{i}.attn.proj.weight"] + sd["kv_bank"][i] = flat_sd[f"blocks.{i}.attn.c_k.weight"] + sd["kv_bank"][n + i] = flat_sd[f"blocks.{i}.attn.c_v.weight"] + sd["mlp_up_bank"][i] = flat_sd[f"blocks.{i}.mlp.fc.weight"] + sd["mlp_down_bank"][i] = flat_sd[f"blocks.{i}.mlp.proj.weight"] + for k, v in flat_sd.items(): + if not ( + k.startswith("blocks.") + and any( + p in k + for p in [ + ".attn.c_q.", ".attn.c_k.", ".attn.c_v.", + ".attn.proj.", ".mlp.fc.", ".mlp.proj.", + ] + ) + ): + sd[k] = v + return sd + + +def _compressed_code_size(code): + code_raw = code.encode("utf-8") + minified = subprocess.run( + ["pyminify", "--no-rename-locals", "--no-hoist-literals", "--remove-literal-statements", "-"], + input=code_raw, capture_output=True, check=True, + ).stdout + compressed = lzma.compress(minified) + encoded = base64.b85encode(compressed) + wrapper = b'import lzma as L,base64 as B\nexec(L.decompress(B.b85decode("' + encoded + b'")))\n' + return len(code_raw), len(wrapper) + + +def serialize(h, base_model, code): + code_bytes_uncompressed, code_bytes = _compressed_code_size(code) + if h.is_main_process: + torch.save(base_model.state_dict(), h.model_path) + model_bytes = os.path.getsize(h.model_path) + log(f"Serialized model: {model_bytes} bytes") + log(f"Code size (uncompressed): {code_bytes_uncompressed} bytes") + log(f"Code size (compressed): {code_bytes} bytes") + sd_cpu = _unbank_state_dict(base_model.state_dict(), h.num_layers) + device = torch.device("cuda", h.local_rank) + log("GPTQ:collecting Hessians from calibration data...") + t0 = time.perf_counter() + calib_loader = ShuffledSequenceLoader(h, device) + hessians = collect_hessians( + base_model, + calib_loader, + h, + device, + n_calibration_batches=h.gptq_calibration_batches, + ) + log(f"GPTQ:collected {len(hessians)} Hessians in {time.perf_counter()-t0:.1f}s") + quant_result, quant_meta = gptq_mixed_quantize(sd_cpu, hessians, h) + quant_buf = io.BytesIO() + torch.save({"w": quant_result, "m": quant_meta}, quant_buf) + quant_raw = quant_buf.getvalue() + quant_blob = _compress(quant_raw, h.compressor) + quant_file_bytes = len(quant_blob) + bytes_total = quant_file_bytes + code_bytes + if h.is_main_process: + with open(h.quantized_model_path, "wb") as f: + f.write(quant_blob) + log(f"Serialized model quantized+{h.compressor}: {quant_file_bytes} bytes") + log(f"Total submission size quantized+{h.compressor}: {bytes_total} bytes") + return bytes_total, quant_file_bytes + + +def deserialize(h, device): + eval_model = GPT(h).to(device).bfloat16() + restore_fp32_params(eval_model) + flat_template = _unbank_state_dict(eval_model.state_dict(), h.num_layers) + with open(h.quantized_model_path, "rb") as f: + quant_blob_disk = f.read() + quant_state = torch.load( + io.BytesIO(_decompress(quant_blob_disk, h.compressor)), map_location="cpu" + ) + deq_flat = dequantize_mixed(quant_state["w"], quant_state["m"], flat_template) + head_dim = h.model_dim // h.num_heads + kv_dim = h.num_kv_heads * head_dim + hidden_dim = int(h.mlp_mult * h.model_dim) + deq_state = _rebank_state_dict(deq_flat, h.num_layers, h.model_dim, kv_dim, hidden_dim) + eval_model.load_state_dict(deq_state, strict=True) + return eval_model + + +def _loss_bpb(loss_sum, token_count, byte_count): + val_loss = (loss_sum / token_count).item() + val_bpb = val_loss / math.log(2.0) * (token_count.item() / byte_count.item()) + return val_loss, val_bpb + + +def eval_val(h, device, val_data, model, forward_logits_fn=None): + seq_len = h.eval_seq_len + local_batch_tokens = h.val_batch_tokens // (h.world_size * h.grad_accum_steps) + if local_batch_tokens < seq_len: + raise ValueError( + f"VAL_BATCH_SIZE must provide at least one sequence per rank; got VAL_BATCH_SIZE={h.val_batch_tokens}, WORLD_SIZE={h.world_size}, GRAD_ACCUM_STEPS={h.grad_accum_steps}, seq_len={seq_len}" + ) + local_batch_seqs = local_batch_tokens // seq_len + total_seqs = (val_data.val_tokens.numel() - 1) // seq_len + seq_start = total_seqs * h.rank // h.world_size + seq_end = total_seqs * (h.rank + 1) // h.world_size + + # TODO: Don't truncate this. + seq_end = seq_start + ((seq_end - seq_start) // local_batch_seqs) * local_batch_seqs + + val_loss_sum = torch.zeros((), device=device, dtype=torch.float64) + val_token_count = torch.zeros((), device=device, dtype=torch.float64) + val_byte_count = torch.zeros((), device=device, dtype=torch.float64) + run_forward_logits = ( + (model.module.forward_logits if hasattr(model, "module") else model.forward_logits) + if forward_logits_fn is None + else forward_logits_fn + ) + model.eval() + global BOS_ID + if BOS_ID is None: + BOS_ID = 1 + with torch.no_grad(): + for batch_seq_start in range(seq_start, seq_end, local_batch_seqs): + batch_seq_end = min(batch_seq_start + local_batch_seqs, seq_end) + raw_start = batch_seq_start * seq_len + raw_end = batch_seq_end * seq_len + 1 + local = val_data.val_tokens[raw_start:raw_end].to( + device=device, dtype=torch.int64, non_blocking=True + ) + x = local[:-1] + y = local[1:] + bos_pos = (x == BOS_ID).nonzero(as_tuple=True)[0].tolist() + cu_seqlens, max_seqlen = _build_cu_seqlens( + bos_pos, x.numel(), x.device, h.eval_seq_len, 64 + ) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + logits = run_forward_logits( + x[None], cu_seqlens=cu_seqlens, max_seqlen=max_seqlen + ).detach() + per_token_loss = F.cross_entropy( + logits.reshape(-1, logits.size(-1)).float(), + y.reshape(-1), + reduction="none", + ) + val_loss_sum += per_token_loss.to(torch.float64).sum() + val_token_count += float(y.numel()) + prev_ids = x + tgt_ids = y + token_bytes = val_data.base_bytes_lut[tgt_ids].to(dtype=torch.int16) + token_bytes += ( + val_data.has_leading_space_lut[tgt_ids] + & ~val_data.is_boundary_token_lut[prev_ids] + ).to(dtype=torch.int16) + val_byte_count += token_bytes.to(torch.float64).sum() + if dist.is_available() and dist.is_initialized(): + dist.all_reduce(val_loss_sum, op=dist.ReduceOp.SUM) + dist.all_reduce(val_token_count, op=dist.ReduceOp.SUM) + dist.all_reduce(val_byte_count, op=dist.ReduceOp.SUM) + model.train() + return _loss_bpb(val_loss_sum, val_token_count, val_byte_count) + + +def eval_val_sliding(h, device, val_data, base_model, forward_logits_fn=None, batch_seqs=32): + global BOS_ID + if BOS_ID is None: + BOS_ID = 1 + base_model.eval() + run_forward_logits = base_model.forward_logits if forward_logits_fn is None else forward_logits_fn + seq_len = h.eval_seq_len + stride = h.eval_stride + total_tokens = val_data.val_tokens.numel() - 1 + context_size = seq_len - stride + window_starts = [ws for ws in range(0, total_tokens, stride) + if ws + context_size < total_tokens] + total_windows = len(window_starts) + my_s = (total_windows * h.rank) // h.world_size + my_e = (total_windows * (h.rank + 1)) // h.world_size + my_windows = window_starts[my_s:my_e] + loss_sum = torch.zeros((), device=device, dtype=torch.float64) + token_count = torch.zeros((), device=device, dtype=torch.float64) + byte_count = torch.zeros((), device=device, dtype=torch.float64) + total_batches = (len(my_windows) + batch_seqs - 1) // batch_seqs + is_master = h.rank == 0 + cu_bucket = 64 + t_sw_start = time.perf_counter() + with torch.no_grad(): + for bi in range(0, len(my_windows), batch_seqs): + batch_idx = bi // batch_seqs + if is_master and (batch_idx % 50 == 0 or batch_idx == total_batches - 1): + elapsed = time.perf_counter() - t_sw_start + rl = float(loss_sum.item() / token_count.item()) if token_count.item() > 0 else 0.0 + rb = float((rl / math.log(2.0)) * token_count.item() / byte_count.item()) if byte_count.item() > 0 else 0.0 + log(f"sliding_progress: batch {batch_idx+1}/{total_batches} " + f"tokens:{int(token_count.item())} running_loss:{rl:.4f} running_bpb:{rb:.4f} " + f"elapsed:{elapsed:.1f}s") + batch_ws = my_windows[bi:bi + batch_seqs] + x_parts = [] + y_parts = [] + cu_starts = [] + score_ranges = [] + offset = 0 + for ws in batch_ws: + end = min(ws + seq_len, total_tokens) + wlen = end - ws + chunk_cpu = val_data.val_tokens[ws:end + 1] + bos_pos = (chunk_cpu[:-1] == BOS_ID).nonzero(as_tuple=True)[0].tolist() + if not bos_pos or bos_pos[0] != 0: + bos_pos = [0] + bos_pos + cu_starts.extend(offset + pos for pos in bos_pos) + chunk = chunk_cpu.to(dtype=torch.int64, device=device) + x_parts.append(chunk[:-1]) + y_parts.append(chunk[1:]) + score_ranges.append((offset, wlen, ws)) + offset += wlen + x_cat = torch.cat(x_parts, dim=0)[None] + y_cat = torch.cat(y_parts, dim=0) + boundaries = cu_starts + [offset] + padded_len = get_next_multiple_of_n(len(boundaries), cu_bucket) + cu_seqlens = torch.full((padded_len,), offset, dtype=torch.int32, device=device) + cu_seqlens[:len(boundaries)] = torch.tensor(boundaries, dtype=torch.int32, device=device) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16): + logits = run_forward_logits(x_cat, cu_seqlens=cu_seqlens, max_seqlen=seq_len) + flat_nll = F.cross_entropy( + logits.reshape(-1, logits.size(-1)).float(), + y_cat, + reduction="none", + ) + flat_x = x_cat.reshape(-1) + for off, wlen, ws in score_ranges: + s = 0 if ws == 0 else context_size + lo = off + s + hi = off + wlen + scored_nll = flat_nll[lo:hi].to(torch.float64) + loss_sum += scored_nll.sum() + token_count += float(hi - lo) + tgt = y_cat[lo:hi] + prev = flat_x[lo:hi] + tb = val_data.base_bytes_lut[tgt].to(torch.float64) + tb += (val_data.has_leading_space_lut[tgt] & ~val_data.is_boundary_token_lut[prev]).to(torch.float64) + byte_count += tb.sum() + if dist.is_available() and dist.is_initialized(): + dist.all_reduce(loss_sum, op=dist.ReduceOp.SUM) + dist.all_reduce(token_count, op=dist.ReduceOp.SUM) + dist.all_reduce(byte_count, op=dist.ReduceOp.SUM) + base_model.train() + return _loss_bpb(loss_sum, token_count, byte_count) + + +def _find_docs(all_tokens): + bos_positions = (all_tokens == BOS_ID).nonzero(as_tuple=True)[0].numpy() + docs = [] + for i in range(len(bos_positions)): + start = int(bos_positions[i]) + end = ( + int(bos_positions[i + 1]) + if i + 1 < len(bos_positions) + else all_tokens.numel() + ) + if i + 1 < len(bos_positions): + end += 1 + assert end - start >= 2 + docs.append((start, end - start)) + return docs + + +def _build_ttt_global_batches(doc_entries, h, ascending=False): + batch_size = h.ttt_batch_size + global_doc_entries = sorted(doc_entries, key=lambda x: x[1][1]) + global_batches = [ + global_doc_entries[i : i + batch_size] + for i in range(0, len(global_doc_entries), batch_size) + ] + indexed = list(enumerate(global_batches)) + if not ascending: + indexed.sort(key=lambda ib: -max(dl for _, (_, dl) in ib[1])) + return indexed + + +def _init_batch_counter(path): + with open(path, "wb") as f: + f.write((0).to_bytes(4, "little")) + + +def _claim_next_batch(counter_path, queue_len): + try: + with open(counter_path, "r+b") as f: + fcntl.flock(f, fcntl.LOCK_EX) + idx = int.from_bytes(f.read(4), "little") + f.seek(0) + f.write((idx + 1).to_bytes(4, "little")) + f.flush() + except FileNotFoundError: + return queue_len + return idx + + +def _compute_chunk_window(ci, pred_len, num_chunks, chunk_size, eval_seq_len): + chunk_end = pred_len if ci == num_chunks - 1 else (ci + 1) * chunk_size + win_start = max(0, chunk_end - eval_seq_len) + win_len = chunk_end - win_start + chunk_start = ci * chunk_size + chunk_offset = chunk_start - win_start + chunk_len = chunk_end - chunk_start + return win_start, win_len, chunk_offset, chunk_len + + +def _accumulate_bpb( + ptl, + x, + y, + chunk_offsets, + chunk_lens, + pos_idx, + base_bytes_lut, + has_leading_space_lut, + is_boundary_token_lut, + loss_sum, + byte_sum, + token_count, +): + pos = pos_idx[: x.size(1)].unsqueeze(0) + mask = ( + (chunk_lens.unsqueeze(1) > 0) + & (pos >= chunk_offsets.unsqueeze(1)) + & (pos < (chunk_offsets + chunk_lens).unsqueeze(1)) + ) + mask_f64 = mask.to(torch.float64) + tok_bytes = base_bytes_lut[y].to(torch.float64) + tok_bytes += (has_leading_space_lut[y] & ~is_boundary_token_lut[x]).to( + torch.float64 + ) + loss_sum += (ptl.to(torch.float64) * mask_f64).sum() + byte_sum += (tok_bytes * mask_f64).sum() + token_count += chunk_lens.to(torch.float64).sum() + +def eval_val_ttt_lora(h, base_model, device, val_data, forward_ttt_train): + global BOS_ID + if BOS_ID is None: + BOS_ID = 1 + base_model.eval() + for p in base_model.parameters(): + p.requires_grad_(False) + all_tokens = val_data.val_tokens + all_tokens_idx = all_tokens.to(torch.int32) + docs = _find_docs(all_tokens) + doc_entries = list(enumerate(docs)) + if h.val_doc_fraction < 1.0: + sample_n = max(1, int(round(len(docs) * h.val_doc_fraction))) + sampled_indices = sorted( + random.Random(h.seed).sample(range(len(docs)), sample_n) + ) + doc_entries = [(i, docs[i]) for i in sampled_indices] + log( + f"ttt_lora:docs:{len(doc_entries)} rank:{h.ttt_lora_rank} lr:{h.ttt_lora_lr} chunk:{h.ttt_chunk_size}" + ) + if os.environ.get("TTT_DEBUG_BYPASS") and h.rank == 0: + test_doc = doc_entries[0][1] + ds, dl = test_doc + log(f"DEBUG: test doc start={ds} len={dl}") + toks = all_tokens_idx[ds : ds + dl].to(device=device, dtype=torch.int64) + x_d = toks[:-1].unsqueeze(0) + y_d = toks[1:].unsqueeze(0) + with torch.no_grad(), torch.autocast(device_type="cuda", dtype=torch.bfloat16): + logits_d = base_model.forward_logits(x_d) + ptl_d = F.cross_entropy( + logits_d.float().reshape(-1, logits_d.size(-1)), + y_d.reshape(-1), reduction="none", + ) + direct_loss = ptl_d.mean().item() + direct_bpb = direct_loss / math.log(2.0) + log(f"DEBUG: direct forward_logits loss={direct_loss:.6f} bpb={direct_bpb:.6f} ntokens={y_d.numel()}") + toks_first5 = toks[:5].tolist() + ptl_first5 = ptl_d[:5].tolist() + log(f"DEBUG: first 5 tokens={toks_first5} ptl={[f'{v:.4f}' for v in ptl_first5]}") + chunk_size, eval_seq_len = h.ttt_chunk_size, h.ttt_eval_seq_len + eval_batch_set = None + if h.ttt_eval_batches: + eval_batch_set = set(int(x) for x in h.ttt_eval_batches.split(",") if x.strip()) + use_ascending = eval_batch_set is not None + global_batches_sorted = _build_ttt_global_batches(doc_entries, h, ascending=use_ascending) + queue_len = len(global_batches_sorted) + counter_path = f"/tmp/ttt_counter_{h.run_id}" + if h.rank == 0: + _init_batch_counter(counter_path) + if dist.is_available() and dist.is_initialized(): + path_list = [counter_path] + dist.broadcast_object_list(path_list, src=0) + counter_path = path_list[0] + dist.barrier() + loss_sum = torch.zeros((), device=device, dtype=torch.float64) + byte_sum = torch.zeros((), device=device, dtype=torch.float64) + token_count = torch.zeros((), device=device, dtype=torch.float64) + t_start = time.perf_counter() + reusable_lora = BatchedTTTLoRA( + h.ttt_batch_size, base_model, h.ttt_lora_rank, + k_lora=h.ttt_k_lora, mlp_lora=h.ttt_mlp_lora, o_lora=h.ttt_o_lora, + ).to(device) + + def _build_opt(lora): + if h.ttt_optimizer == "sgd": + return torch.optim.SGD( + lora.parameters(), lr=h.ttt_lora_lr, + momentum=h.ttt_beta1, weight_decay=h.ttt_weight_decay, + ) + return torch.optim.AdamW( + lora.parameters(), lr=h.ttt_lora_lr, + betas=(h.ttt_beta1, h.ttt_beta2), + eps=1e-10, weight_decay=h.ttt_weight_decay, fused=True, + ) + + reusable_opt = _build_opt(reusable_lora) + progress_f = None + if h.ttt_output_dir and h.rank == 0: + os.makedirs(h.ttt_output_dir, exist_ok=True) + progress_f = open(os.path.join(h.ttt_output_dir, "progress.jsonl"), "w") + try: + while True: + queue_idx = _claim_next_batch(counter_path, queue_len) + if queue_idx >= queue_len: + break + orig_batch_idx, batch_entries = global_batches_sorted[queue_idx] + batch = [doc for _, doc in batch_entries] + bsz = len(batch) + prev_loss = loss_sum.item() + prev_bytes = byte_sum.item() + prev_tokens = token_count.item() + if bsz == reusable_lora.bsz: + reusable_lora.reset() + for s in reusable_opt.state.values(): + for k, v in s.items(): + if isinstance(v, torch.Tensor): + v.zero_() + elif k == "step": + s[k] = 0 + cur_lora = reusable_lora + cur_opt = reusable_opt + else: + cur_lora = BatchedTTTLoRA( + bsz, base_model, h.ttt_lora_rank, + k_lora=h.ttt_k_lora, mlp_lora=h.ttt_mlp_lora, o_lora=h.ttt_o_lora, + ).to(device) + cur_opt = _build_opt(cur_lora) + pred_lens = [doc_len - 1 for _, doc_len in batch] + num_chunks = [(pl + chunk_size - 1) // chunk_size for pl in pred_lens] + max_nc = max(num_chunks) + num_chunks_t = torch.tensor(num_chunks, dtype=torch.int64, device=device) + for ci in range(max_nc): + active = [ci < nc for nc in num_chunks] + needs_train = any(ci < nc - 1 for nc in num_chunks) + tok_starts = torch.zeros(bsz, dtype=torch.int64) + tok_wls = torch.zeros(bsz, dtype=torch.int64) + chunk_offsets_cpu = torch.zeros(bsz, dtype=torch.int64) + chunk_lens_cpu = torch.zeros(bsz, dtype=torch.int64) + for b in range(bsz): + if not active[b]: + continue + doc_start, doc_len = batch[b] + win_start, win_len, chunk_offset, chunk_len = _compute_chunk_window( + ci, pred_lens[b], num_chunks[b], chunk_size, eval_seq_len + ) + tok_starts[b] = doc_start + win_start + tok_wls[b] = win_len + chunk_offsets_cpu[b] = chunk_offset + chunk_lens_cpu[b] = chunk_len + _, context_size, chunk_offset, _ = _compute_chunk_window( + ci, (ci + 1) * chunk_size, ci + 1, chunk_size, eval_seq_len + ) + col_idx = torch.arange(context_size + 1) + idx = tok_starts.unsqueeze(1) + col_idx.unsqueeze(0) + idx.clamp_(max=all_tokens.numel() - 1) + gathered_gpu = all_tokens_idx[idx].to( + device=device, dtype=torch.int64, non_blocking=True + ) + valid = (col_idx[:context_size].unsqueeze(0) < tok_wls.unsqueeze(1)).to( + device, non_blocking=True + ) + chunk_offsets = chunk_offsets_cpu.to(device, non_blocking=True) + chunk_lens = chunk_lens_cpu.to(device, non_blocking=True) + x = torch.where(valid, gathered_gpu[:, :context_size], 0) + y = torch.where(valid, gathered_gpu[:, 1 : context_size + 1], 0) + ctx_pos = torch.arange(context_size, device=device, dtype=torch.int64) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16): + per_tok_loss = forward_ttt_train(x, y, lora=cur_lora) + with torch.no_grad(): + _accumulate_bpb( + per_tok_loss, + x, + y, + chunk_offsets, + chunk_lens, + ctx_pos, + val_data.base_bytes_lut, + val_data.has_leading_space_lut, + val_data.is_boundary_token_lut, + loss_sum, + byte_sum, + token_count, + ) + if needs_train: + activate_chunk_mask = (num_chunks_t - 1 > ci).float() + for gi in range(h.ttt_grad_steps): + if gi > 0: + with torch.autocast(device_type="cuda", dtype=torch.bfloat16): + per_tok_loss = forward_ttt_train(x, y, lora=cur_lora) + per_doc = per_tok_loss[ + :, chunk_offset : chunk_offset + chunk_size + ].mean(dim=-1) + cur_opt.zero_grad(set_to_none=True) + (per_doc * activate_chunk_mask).sum().backward() + cur_opt.step() + else: + del per_tok_loss + batch_num = orig_batch_idx + 1 + doc_lens = [dl for _, dl in batch] + should_report = False + if eval_batch_set is not None: + should_report = batch_num in eval_batch_set + else: + # should_report = local_batch_count % 10 == 0 + should_report = True + if should_report: + cur_tokens = token_count.item() + cur_loss_val = loss_sum.item() + cur_bytes_val = byte_sum.item() + dt = cur_tokens - prev_tokens + if dt > 0: + b_loss = (cur_loss_val - prev_loss) / dt + b_bpb = b_loss / math.log(2.0) * (dt / (cur_bytes_val - prev_bytes)) + else: + b_loss = b_bpb = 0.0 + r_loss = cur_loss_val / max(cur_tokens, 1) + r_bpb = r_loss / math.log(2.0) * (cur_tokens / max(cur_bytes_val, 1)) + elapsed = time.perf_counter() - t_start + log( + f"ttt_progress: batch {batch_num}/{queue_len} batch_loss:{b_loss:.4f} " + f"batch_bpb:{b_bpb:.4f} running_loss:{r_loss:.4f} running_bpb:{r_bpb:.4f} " + f"doc_len:{min(doc_lens)}-{max(doc_lens)}" + ) + if progress_f is not None: + progress_f.write( + json.dumps({ + "batch": batch_num, "total_batches": queue_len, + "batch_loss": round(b_loss, 8), "batch_bpb": round(b_bpb, 8), + "running_loss": round(r_loss, 8), "running_bpb": round(r_bpb, 8), + "doc_len_min": min(doc_lens), "doc_len_max": max(doc_lens), + "chunk_size": chunk_size, + "elapsed_s": round(elapsed, 3), + "batch_t_s": round(elapsed, 3), + }) + "\n" + ) + progress_f.flush() + del cur_lora, cur_opt + finally: + if progress_f is not None: + progress_f.close() + if dist.is_available() and dist.is_initialized(): + dist.all_reduce(loss_sum, op=dist.ReduceOp.SUM) + dist.all_reduce(byte_sum, op=dist.ReduceOp.SUM) + dist.all_reduce(token_count, op=dist.ReduceOp.SUM) + for p in base_model.parameters(): + p.requires_grad_(True) + base_model.train() + val_loss = (loss_sum / token_count).item() + val_bpb = val_loss / math.log(2.0) * (token_count.item() / byte_sum.item()) + return val_loss, val_bpb + + +def timed_eval(label, fn, *args, **kwargs): + torch.cuda.synchronize() + t0 = time.perf_counter() + val_loss, val_bpb = fn(*args, **kwargs) + torch.cuda.synchronize() + elapsed_ms = 1e3 * (time.perf_counter() - t0) + log( + f"{label} val_loss:{val_loss:.8f} val_bpb:{val_bpb:.8f} eval_time:{elapsed_ms:.0f}ms" + ) + return val_loss, val_bpb + + +def train_model(h, device, val_data): + base_model = GPT(h).to(device).bfloat16() + restore_fp32_params(base_model) + compiled_model = torch.compile(base_model, dynamic=False, fullgraph=True) + compiled_forward_logits = torch.compile( + base_model.forward_logits, dynamic=False, fullgraph=True + ) + model = compiled_model + log(f"model_params:{sum(p.numel()for p in base_model.parameters())}") + optimizers = Optimizers(h, base_model) + train_loader = DocumentPackingLoader(h, device) + max_wallclock_ms = ( + 1e3 * h.max_wallclock_seconds if h.max_wallclock_seconds > 0 else None + ) + if max_wallclock_ms is not None: + max_wallclock_ms -= h.gptq_reserve_seconds * 1e3 + log( + f"gptq:reserving {h.gptq_reserve_seconds:.0f}s, effective={max_wallclock_ms:.0f}ms" + ) + + def training_frac(step, elapsed_ms): + if max_wallclock_ms is None: + return step / max(h.iterations, 1) + return elapsed_ms / max(max_wallclock_ms, 1e-09) + + def lr_mul(frac): + if h.warmdown_frac <= 0: + return 1.0 + if frac >= 1.0 - h.warmdown_frac: + return max((1.0 - frac) / h.warmdown_frac, h.min_lr) + return 1.0 + + def step_fn(step, lr_scale): + optimizers.zero_grad_all() + train_loss = torch.zeros((), device=device) + for micro_step in range(h.grad_accum_steps): + x, y, cu_seqlens, _max_seqlen = train_loader.next_batch( + h.train_batch_tokens, h.grad_accum_steps + ) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + loss = model(x, y, cu_seqlens=cu_seqlens, max_seqlen=h.train_seq_len) + train_loss += loss.detach() + (loss / h.grad_accum_steps).backward() + train_loss /= h.grad_accum_steps + frac = ( + min(step / h.muon_momentum_warmup_steps, 1.0) + if h.muon_momentum_warmup_steps > 0 + else 1.0 + ) + muon_momentum = ( + 1 - frac + ) * h.muon_momentum_warmup_start + frac * h.muon_momentum + for group in optimizers.optimizer_muon.param_groups: + group["momentum"] = muon_momentum + for opt in optimizers: + for group in opt.param_groups: + group["lr"] = group["base_lr"] * lr_scale + if h.grad_clip_norm > 0: + torch.nn.utils.clip_grad_norm_(base_model.parameters(), h.grad_clip_norm) + optimizers.step(distributed=h.distributed) + return train_loss + + if h.warmup_steps > 0: + initial_model_state = { + name: tensor.detach().cpu().clone() + for (name, tensor) in base_model.state_dict().items() + } + initial_optimizer_states = [ + copy.deepcopy(opt.state_dict()) for opt in optimizers + ] + model.train() + num_tokens_local = h.train_batch_tokens // h.world_size + for blk in base_model.blocks: + blk.attn.rotary(num_tokens_local, device, torch.bfloat16) + cu_bucket_size = train_loader.cu_bucket_size + warmup_cu_buckets = tuple(cu_bucket_size * i for i in range(1, 5)) + warmup_cu_iters = 3 + x, y, cu_seqlens, _ = train_loader.next_batch( + h.train_batch_tokens, h.grad_accum_steps + ) + log(f"warmup_cu_buckets:{','.join(str(b) for b in warmup_cu_buckets)} iters_each:{warmup_cu_iters}") + def _run_cu_bucket_warmup(): + for bucket_len in warmup_cu_buckets: + boundaries = list(range(0, x.size(1), max(h.train_seq_len, 1))) + if boundaries[-1] != x.size(1): + boundaries.append(x.size(1)) + cu = torch.full((bucket_len,), x.size(1), dtype=torch.int32, device=device) + cu[: len(boundaries)] = torch.tensor(boundaries, dtype=torch.int32, device=device) + for _ in range(warmup_cu_iters): + optimizers.zero_grad_all() + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + wloss = model(x, y, cu_seqlens=cu, max_seqlen=h.train_seq_len) + (wloss / h.grad_accum_steps).backward() + optimizers.zero_grad_all() + _run_cu_bucket_warmup() + if h.num_loops > 0: + base_model.looping_active = True + _run_cu_bucket_warmup() + base_model.looping_active = False + for warmup_step in range(h.warmup_steps): + step_fn(warmup_step, 1.0) + if ( + warmup_step <= 5 + or (warmup_step + 1) % 10 == 0 + or warmup_step + 1 == h.warmup_steps + ): + log(f"warmup_step: {warmup_step+1}/{h.warmup_steps}") + if h.num_loops > 0: + base_model.looping_active = True + log( + f"loop_warmup:enabled encoder:{base_model.encoder_indices} decoder:{base_model.decoder_indices}" + ) + for warmup_step in range(h.warmup_steps): + step_fn(warmup_step, 1.0) + if ( + warmup_step <= 5 + or (warmup_step + 1) % 10 == 0 + or warmup_step + 1 == h.warmup_steps + ): + log(f"loop_warmup_step: {warmup_step+1}/{h.warmup_steps}") + base_model.looping_active = False + base_model.load_state_dict(initial_model_state, strict=True) + for (opt, state) in zip(optimizers, initial_optimizer_states, strict=True): + opt.load_state_dict(state) + optimizers.zero_grad_all() + train_loader = DocumentPackingLoader(h, device) + ema_state = { + name: t.detach().float().clone() + for (name, t) in base_model.state_dict().items() + } + ema_decay = h.ema_decay + training_time_ms = 0.0 + stop_after_step = None + torch.cuda.synchronize() + t0 = time.perf_counter() + step = 0 + while True: + last_step = ( + step == h.iterations + or stop_after_step is not None + and step >= stop_after_step + ) + should_validate = ( + last_step or h.val_loss_every > 0 and step % h.val_loss_every == 0 + ) + if should_validate: + torch.cuda.synchronize() + training_time_ms += 1e3 * (time.perf_counter() - t0) + val_loss, val_bpb = eval_val( + h, device, val_data, model, compiled_forward_logits + ) + log( + f"{step}/{h.iterations} val_loss: {val_loss:.4f} val_bpb: {val_bpb:.4f}" + ) + torch.cuda.synchronize() + t0 = time.perf_counter() + if last_step: + if stop_after_step is not None and step < h.iterations: + log( + f"stopping_early: wallclock_cap train_time: {training_time_ms:.0f}ms step: {step}/{h.iterations}" + ) + break + elapsed_ms = training_time_ms + 1e3 * (time.perf_counter() - t0) + frac = training_frac(step, elapsed_ms) + scale = lr_mul(frac) + if ( + h.num_loops > 0 + and not base_model.looping_active + and frac >= h.enable_looping_at + ): + base_model.looping_active = True + log( + f"layer_loop:enabled step:{step} frac:{frac:.3f} encoder:{base_model.encoder_indices} decoder:{base_model.decoder_indices}" + ) + train_loss = step_fn(step, scale) + with torch.no_grad(): + for (name, t) in base_model.state_dict().items(): + ema_state[name].mul_(ema_decay).add_( + t.detach().float(), alpha=1.0 - ema_decay + ) + step += 1 + approx_training_time_ms = training_time_ms + 1e3 * (time.perf_counter() - t0) + should_log_train = h.train_log_every > 0 and ( + step <= 5 or step % h.train_log_every == 0 or stop_after_step is not None + ) + if should_log_train: + tok_per_sec = step * h.train_batch_tokens / (approx_training_time_ms / 1e3) + log( + f"{step}/{h.iterations} train_loss: {train_loss.item():.4f} train_time: {approx_training_time_ms/60000:.1f}m tok/s: {tok_per_sec:.0f}" + ) + reached_cap = ( + max_wallclock_ms is not None and approx_training_time_ms >= max_wallclock_ms + ) + if h.distributed and max_wallclock_ms is not None: + reached_cap_tensor = torch.tensor(int(reached_cap), device=device) + dist.all_reduce(reached_cap_tensor, op=dist.ReduceOp.MAX) + reached_cap = bool(reached_cap_tensor.item()) + if stop_after_step is None and reached_cap: + stop_after_step = step + log( + f"peak memory allocated: {torch.cuda.max_memory_allocated()//1024//1024} MiB reserved: {torch.cuda.max_memory_reserved()//1024//1024} MiB" + ) + log("ema:applying EMA weights") + current_state = base_model.state_dict() + avg_state = { + name: t.to(dtype=current_state[name].dtype) for (name, t) in ema_state.items() + } + base_model.load_state_dict(avg_state, strict=True) + return base_model, compiled_model, compiled_forward_logits + + +def train_and_eval(h, device): + random.seed(h.seed) + np.random.seed(h.seed) + torch.manual_seed(h.seed) + torch.cuda.manual_seed_all(h.seed) + if h.artifact_dir and h.is_main_process: + os.makedirs(h.artifact_dir, exist_ok=True) + val_data = ValidationData(h, device) + if h.eval_only_path: + log(f"eval_only:loading checkpoint from {h.eval_only_path}") + base_model = GPT(h).to(device).bfloat16() + restore_fp32_params(base_model) + base_model.load_state_dict(torch.load(h.eval_only_path, map_location=device)) + if h.num_loops > 0: + base_model.looping_active = True + compiled_model = torch.compile(base_model, dynamic=False, fullgraph=True) + compiled_forward_logits = torch.compile( + base_model.forward_logits, dynamic=False, fullgraph=True + ) + else: + log( + f"train_shards: {len(list(Path(h.datasets_dir).resolve().glob('fineweb_train_*.bin')))}" + ) + log(f"val_tokens: {val_data.val_tokens.numel()-1}") + base_model, compiled_model, compiled_forward_logits = train_model( + h, device, val_data + ) + _skip_training = bool(h.eval_only_path) + torch._dynamo.reset() + timed_eval( + "diagnostic pre-quantization post-ema", + eval_val, + h, + device, + val_data, + compiled_model, + compiled_forward_logits, + ) + if not _skip_training: + serialize(h, base_model, Path(__file__).read_text(encoding="utf-8")) + else: + log("eval_only: skipping serialize (already have quantized model)") + if not os.path.exists(h.quantized_model_path): + log("eval_only: no quantized model found, running serialize anyway") + serialize(h, base_model, Path(__file__).read_text(encoding="utf-8")) + if h.distributed: + dist.barrier() + eval_model = deserialize(h, device) + if h.num_loops > 0: + eval_model.looping_active = True + compiled_model = torch.compile(eval_model, dynamic=False, fullgraph=True) + compiled_forward_logits = torch.compile( + eval_model.forward_logits, dynamic=False, fullgraph=True + ) + timed_eval( + "diagnostic quantized", + eval_val, + h, + device, + val_data, + compiled_model, + compiled_forward_logits, + ) + if h.sliding_window_enabled: + timed_eval( + "diagnostic quantized_sliding_window", + eval_val_sliding, + h, + device, + val_data, + eval_model, + forward_logits_fn=compiled_forward_logits, + ) + if h.ttt_enabled: + del eval_model, compiled_model + torch._dynamo.reset() + torch.cuda.empty_cache() + ttt_model = deserialize(h, device) + if h.num_loops > 0: + ttt_model.looping_active = True + for p in ttt_model.parameters(): + p.requires_grad_(False) + + if h.rope_yarn: + _yarn_seqlen = h.train_batch_tokens // h.grad_accum_steps + for block in ttt_model.blocks: + block.attn.rotary(_yarn_seqlen, device, torch.bfloat16) + else: + for block in ttt_model.blocks: + block.attn.rotary._cos_cached = None + block.attn.rotary._sin_cached = None + block.attn.rotary._seq_len_cached = 0 + block.attn.rotary(h.ttt_eval_seq_len, device, torch.bfloat16) + + def _fwd_ttt_inner(input_ids, target_ids, lora): + return ttt_model.forward_ttt(input_ids, target_ids, lora=lora) + + _fwd_ttt_compiled_inner = None + + def _fwd_ttt(input_ids, target_ids, lora): + nonlocal _fwd_ttt_compiled_inner + if _fwd_ttt_compiled_inner is None: + _fwd_ttt_compiled_inner = torch.compile(_fwd_ttt_inner, dynamic=True) + return _fwd_ttt_compiled_inner(input_ids, target_ids, lora=lora) + + _ttt_debug_bypass = bool(os.environ.get("TTT_DEBUG_BYPASS")) + if _ttt_debug_bypass: + def _fwd_ttt_bypass(input_ids, target_ids, lora): + logits = ttt_model.forward_logits(input_ids) + dummy = lora.q_loras[0].B.sum() * 0 + logits = logits + dummy + bsz, sl, V = logits.shape + return F.cross_entropy( + logits.float().reshape(-1, V), target_ids.reshape(-1), reduction="none" + ).reshape(bsz, sl) + fwd_ttt_compiled = _fwd_ttt_bypass + log("ttt_lora:DEBUG BYPASS active - using forward_logits directly (no compile warmup)") + else: + fwd_ttt_compiled = _fwd_ttt + log(f"ttt_lora:warming up compile (random tokens, no val data)") + global BOS_ID + if BOS_ID is None: + BOS_ID = 1 + t_warmup = time.perf_counter() + warmup_bszes = [h.ttt_batch_size] + for bsz in warmup_bszes: + wl = BatchedTTTLoRA( + bsz, ttt_model, h.ttt_lora_rank, + k_lora=h.ttt_k_lora, mlp_lora=h.ttt_mlp_lora, o_lora=h.ttt_o_lora, + ).to(device) + wo = torch.optim.AdamW( + wl.parameters(), + lr=h.ttt_lora_lr, + betas=(h.ttt_beta1, h.ttt_beta2), + eps=1e-10, + weight_decay=h.ttt_weight_decay, + fused=True, + ) + for ctx_len in (h.ttt_chunk_size, h.ttt_eval_seq_len): + xw = torch.randint(0, h.vocab_size, (bsz, ctx_len), device=device, dtype=torch.int64) + yw = torch.randint(0, h.vocab_size, (bsz, ctx_len), device=device, dtype=torch.int64) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16): + ptl = fwd_ttt_compiled(xw, yw, lora=wl) + ptl[:, : min(h.ttt_chunk_size, ctx_len)].mean(dim=-1).sum().backward() + wo.step() + wo.zero_grad(set_to_none=True) + del wl, wo + torch.cuda.empty_cache() + compile_elapsed = time.perf_counter() - t_warmup + log(f"ttt_lora:compile warmup done ({compile_elapsed:.1f}s)") + log("\nbeginning TTT eval timer") + torch.cuda.synchronize() + t_ttt = time.perf_counter() + ttt_val_loss, ttt_val_bpb = eval_val_ttt_lora( + h, ttt_model, device, val_data, forward_ttt_train=fwd_ttt_compiled + ) + torch.cuda.synchronize() + ttt_eval_elapsed = time.perf_counter() - t_ttt + log( + f"quantized_ttt_lora val_loss:{ttt_val_loss:.8f} val_bpb:{ttt_val_bpb:.8f} eval_time:{1e3*ttt_eval_elapsed:.0f}ms" + ) + log(f"total_eval_time:{ttt_eval_elapsed:.1f}s") + del ttt_model + + +def main(): + world_size = int(os.environ.get("WORLD_SIZE", "1")) + local_rank = int(os.environ.get("LOCAL_RANK", "0")) + distributed = "RANK" in os.environ and "WORLD_SIZE" in os.environ + if not torch.cuda.is_available(): + raise RuntimeError("CUDA is required") + if world_size <= 0: + raise ValueError(f"WORLD_SIZE must be positive, got {world_size}") + if 8 % world_size != 0: + raise ValueError( + f"WORLD_SIZE={world_size} must divide 8 so grad_accum_steps stays integral" + ) + device = torch.device("cuda", local_rank) + torch.cuda.set_device(device) + if distributed: + dist.init_process_group(backend="nccl", device_id=device) + dist.barrier() + torch.backends.cuda.matmul.allow_tf32 = True + torch.backends.cudnn.allow_tf32 = True + torch.set_float32_matmul_precision("high") + from torch.backends.cuda import ( + enable_cudnn_sdp, + enable_flash_sdp, + enable_math_sdp, + enable_mem_efficient_sdp, + ) + + enable_cudnn_sdp(False) + enable_flash_sdp(True) + enable_mem_efficient_sdp(False) + enable_math_sdp(False) + torch._dynamo.config.optimize_ddp = False + torch._dynamo.config.cache_size_limit = 16 + h = Hyperparameters() + set_logging_hparams(h) + if h.is_main_process: + os.makedirs(h.artifact_dir if h.artifact_dir else "logs", exist_ok=True) + log(100 * "=", console=False) + log("Hyperparameters:", console=True) + for (k, v) in sorted(vars(type(h)).items()): + if not k.startswith("_"): + log(f" {k}: {v}", console=True) + log("=" * 100, console=False) + log("Source code:", console=False) + log("=" * 100, console=False) + with open(__file__, "r", encoding="utf-8") as _src: + log(_src.read(), console=False) + log("=" * 100, console=False) + log(f"Running Python {sys.version}", console=False) + log(f"Running PyTorch {torch.__version__}", console=False) + log( + subprocess.run( + ["nvidia-smi"], + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + text=True, + check=False, + ).stdout, + console=False, + ) + log("=" * 100, console=False) + train_and_eval(h, device) + if distributed: + dist.destroy_process_group() + + +if __name__ == "__main__": + main() diff --git a/records/track_10min_16mb/2026-04-12_VarLen_TTT_Warmdown75_Chunk48/train_seed0.log b/records/track_10min_16mb/2026-04-12_VarLen_TTT_Warmdown75_Chunk48/train_seed0.log new file mode 100644 index 0000000000..5ec99c30d5 --- /dev/null +++ b/records/track_10min_16mb/2026-04-12_VarLen_TTT_Warmdown75_Chunk48/train_seed0.log @@ -0,0 +1,279 @@ + +***************************************** +Setting OMP_NUM_THREADS environment variable for each process to be 1 in default, to avoid your system being overloaded, please further tune the variable for optimal performance in your application as needed. +***************************************** +Hyperparameters: + adam_eps: 1e-08 + adam_wd: 0.02 + artifact_dir: + beta1: 0.9 + beta2: 0.95 + compressor: brotli + data_dir: /home/dex/parameter-golf-with-cc/data + datasets_dir: /home/dex/parameter-golf-with-cc/data/datasets/fineweb10B_sp8192 + distributed: True + ema_decay: 0.9965 + embed_bits: 8 + embed_clip_sigmas: 20.0 + embed_lr: 0.6 + embed_wd: 0.085 + embedding_dim: 512 + enable_looping_at: 0.35 + eval_only_path: + eval_seq_len: 2048 + eval_stride: 64 + gptq_calibration_batches: 64 + gptq_reserve_seconds: 13.0 + grad_accum_steps: 1 + grad_clip_norm: 0.3 + head_lr: 0.008 + is_main_process: True + iterations: 20000 + ln_scale: True + local_rank: 0 + logfile: logs/PR1530_v2_wd75_c48_s0.txt + logit_softcap: 30.0 + loop_end: 5 + loop_start: 3 + matrix_bits: 6 + matrix_clip_sigmas: 12.85 + matrix_lr: 0.022 + max_wallclock_seconds: 600.0 + min_lr: 0.0 + mlp_mult: 4.0 + model_dim: 512 + model_path: final_model.pt + muon_backend_steps: 5 + muon_beta2: 0.95 + muon_momentum: 0.97 + muon_momentum_warmup_start: 0.92 + muon_momentum_warmup_steps: 1500 + muon_row_normalize: True + muon_wd: 0.095 + num_heads: 8 + num_kv_heads: 4 + num_layers: 11 + num_loops: 2 + parallel_final_lane: mean + parallel_start_layer: 8 + qk_gain_init: 5.0 + quantized_model_path: final_model.int6.ptz + rank: 0 + rope_base: 10000.0 + rope_dims: 16 + rope_train_seq_len: 2048 + rope_yarn: False + run_id: PR1530_v2_wd75_c48_s0 + scalar_lr: 0.02 + seed: 0 + skip_gates_enabled: True + sliding_window_enabled: False + tie_embeddings: True + tied_embed_init_std: 0.005 + tied_embed_lr: 0.03 + tokenizer_path: /home/dex/parameter-golf-with-cc/data/tokenizers/fineweb_8192_bpe.model + train_batch_tokens: 786432 + train_files: /home/dex/parameter-golf-with-cc/data/datasets/fineweb10B_sp8192/fineweb_train_*.bin + train_log_every: 500 + train_seq_len: 2048 + ttt_batch_size: 64 + ttt_beta1: 0.0 + ttt_beta2: 0.999 + ttt_chunk_size: 48 + ttt_enabled: True + ttt_eval_batches: + ttt_eval_seq_len: 2048 + ttt_grad_steps: 1 + ttt_k_lora: True + ttt_lora_lr: 0.0001 + ttt_lora_rank: 96 + ttt_mlp_lora: True + ttt_o_lora: True + ttt_optimizer: adam + ttt_output_dir: + ttt_weight_decay: 0.5 + val_batch_tokens: 524288 + val_doc_fraction: 1.0 + val_files: /home/dex/parameter-golf-with-cc/data/datasets/fineweb10B_sp8192/fineweb_val_*.bin + val_loss_every: 4000 + vocab_size: 8192 + warmdown_frac: 0.75 + warmup_steps: 20 + world_size: 8 + xsa_last_n: 11 +train_shards: 80 +val_tokens: 40540160 +model_params:35944602 +gptq:reserving 13s, effective=587000ms +warmup_cu_buckets:64,128,192,256 iters_each:3 +warmup_step: 1/20 +warmup_step: 2/20 +warmup_step: 3/20 +warmup_step: 4/20 +warmup_step: 5/20 +warmup_step: 6/20 +warmup_step: 10/20 +warmup_step: 20/20 +loop_warmup:enabled encoder:[0, 1, 2, 3, 4, 5, 3, 4] decoder:[5, 3, 4, 5, 6, 7, 8, 9, 10] +loop_warmup_step: 1/20 +loop_warmup_step: 2/20 +loop_warmup_step: 3/20 +loop_warmup_step: 4/20 +loop_warmup_step: 5/20 +loop_warmup_step: 6/20 +loop_warmup_step: 10/20 +loop_warmup_step: 20/20 +0/20000 val_loss: 9.0086 val_bpb: 3.4874 +1/20000 train_loss: 9.0089 train_time: 0.0m tok/s: 16620838 +2/20000 train_loss: 12.2848 train_time: 0.0m tok/s: 12889214 +3/20000 train_loss: 11.2645 train_time: 0.0m tok/s: 11084615 +4/20000 train_loss: 9.6515 train_time: 0.0m tok/s: 10317575 +5/20000 train_loss: 8.2550 train_time: 0.0m tok/s: 9911593 +500/20000 train_loss: 3.2500 train_time: 0.8m tok/s: 8302001 +1000/20000 train_loss: 3.0092 train_time: 1.6m tok/s: 8285532 +1500/20000 train_loss: 3.0180 train_time: 2.4m tok/s: 8270932 +2000/20000 train_loss: 2.9694 train_time: 3.2m tok/s: 8272258 +layer_loop:enabled step:2160 frac:0.350 encoder:[0, 1, 2, 3, 4, 5, 3, 4] decoder:[5, 3, 4, 5, 6, 7, 8, 9, 10] +2500/20000 train_loss: 3.0582 train_time: 4.2m tok/s: 7777629 +3000/20000 train_loss: 2.8957 train_time: 5.4m tok/s: 7314251 +3500/20000 train_loss: 2.9659 train_time: 6.5m tok/s: 7019563 +4000/20000 train_loss: 2.8973 train_time: 7.7m tok/s: 6811688 +4000/20000 val_loss: 2.8751 val_bpb: 1.1130 +4500/20000 train_loss: 2.8569 train_time: 8.9m tok/s: 6660038 +4900/20000 val_loss: 2.7695 val_bpb: 1.0721 +stopping_early: wallclock_cap train_time: 587096ms step: 4900/20000 +peak memory allocated: 40019 MiB reserved: 44090 MiB +ema:applying EMA weights +diagnostic pre-quantization post-ema val_loss:2.76815256 val_bpb:1.07160358 eval_time:6036ms +Serialized model: 135409136 bytes +Code size (uncompressed): 116419 bytes +Code size (compressed): 26790 bytes +GPTQ:collecting Hessians from calibration data... +GPTQ:collected 67 Hessians in 12.6s +Quantized weights: + gptq (int6): blocks.attn.c_k.weight, blocks.attn.c_q.weight, blocks.attn.c_v.weight, blocks.attn.proj.weight, blocks.mlp.fc.weight, blocks.mlp.proj.weight + gptq (int8): tok_emb.weight + passthrough (float16): blocks.attn.q_gain, blocks.attn_scale, blocks.mlp_scale, blocks.resid_mix, parallel_post_lambdas, parallel_resid_lambdas, skip_gates, skip_weights +Serialized model quantized+brotli: 15970780 bytes +Total submission size quantized+brotli: 15997570 bytes +diagnostic quantized val_loss:2.79920961 val_bpb:1.08362635 eval_time:9294ms +ttt_lora:warming up compile (random tokens, no val data) +ttt_lora:compile warmup done (90.2s) + +beginning TTT eval timer +ttt_lora:docs:50000 rank:96 lr:0.0001 chunk:48 +ttt_progress: batch 776/782 batch_loss:2.7183 batch_bpb:1.0876 running_loss:2.7183 running_bpb:1.0876 doc_len:6364-7180 +ttt_progress: batch 773/782 batch_loss:2.6535 batch_bpb:1.0767 running_loss:2.6897 running_bpb:1.0828 doc_len:5203-5550 +ttt_progress: batch 768/782 batch_loss:2.7042 batch_bpb:1.0852 running_loss:2.6934 running_bpb:1.0834 doc_len:4128-4306 +ttt_progress: batch 762/782 batch_loss:2.8257 batch_bpb:1.0755 running_loss:2.7167 running_bpb:1.0820 doc_len:3431-3533 +ttt_progress: batch 756/782 batch_loss:2.7821 batch_bpb:1.0783 running_loss:2.7254 running_bpb:1.0815 doc_len:2973-3032 +ttt_progress: batch 750/782 batch_loss:2.8270 batch_bpb:1.0665 running_loss:2.7360 running_bpb:1.0799 doc_len:2638-2688 +ttt_progress: batch 745/782 batch_loss:2.7869 batch_bpb:1.0894 running_loss:2.7405 running_bpb:1.0807 doc_len:2421-2458 +ttt_progress: batch 738/782 batch_loss:2.7465 batch_bpb:1.0537 running_loss:2.7409 running_bpb:1.0787 doc_len:2194-2227 +ttt_progress: batch 731/782 batch_loss:2.7737 batch_bpb:1.0586 running_loss:2.7430 running_bpb:1.0774 doc_len:2017-2041 +ttt_progress: batch 723/782 batch_loss:2.7810 batch_bpb:1.0610 running_loss:2.7451 running_bpb:1.0764 doc_len:1861-1885 +ttt_progress: batch 716/782 batch_loss:2.8115 batch_bpb:1.0375 running_loss:2.7483 running_bpb:1.0744 doc_len:1739-1754 +ttt_progress: batch 709/782 batch_loss:2.7843 batch_bpb:1.0579 running_loss:2.7499 running_bpb:1.0737 doc_len:1649-1661 +ttt_progress: batch 705/782 batch_loss:2.7839 batch_bpb:1.0723 running_loss:2.7513 running_bpb:1.0736 doc_len:1606-1617 +ttt_progress: batch 699/782 batch_loss:2.8181 batch_bpb:1.0429 running_loss:2.7539 running_bpb:1.0724 doc_len:1543-1552 +ttt_progress: batch 689/782 batch_loss:2.7831 batch_bpb:1.0648 running_loss:2.7549 running_bpb:1.0721 doc_len:1450-1458 +ttt_progress: batch 684/782 batch_loss:2.7940 batch_bpb:1.0741 running_loss:2.7562 running_bpb:1.0722 doc_len:1407-1414 +ttt_progress: batch 675/782 batch_loss:2.8366 batch_bpb:1.0648 running_loss:2.7586 running_bpb:1.0719 doc_len:1341-1347 +ttt_progress: batch 667/782 batch_loss:2.8174 batch_bpb:1.1037 running_loss:2.7602 running_bpb:1.0728 doc_len:1288-1295 +ttt_progress: batch 666/782 batch_loss:2.8213 batch_bpb:1.0603 running_loss:2.7619 running_bpb:1.0725 doc_len:1282-1288 +ttt_progress: batch 659/782 batch_loss:2.7132 batch_bpb:1.0217 running_loss:2.7606 running_bpb:1.0711 doc_len:1239-1245 +ttt_progress: batch 651/782 batch_loss:2.7208 batch_bpb:1.0450 running_loss:2.7597 running_bpb:1.0705 doc_len:1193-1198 +ttt_progress: batch 645/782 batch_loss:2.7953 batch_bpb:1.0937 running_loss:2.7605 running_bpb:1.0710 doc_len:1160-1166 +ttt_progress: batch 635/782 batch_loss:2.7404 batch_bpb:1.0608 running_loss:2.7601 running_bpb:1.0708 doc_len:1111-1116 +ttt_progress: batch 626/782 batch_loss:2.8106 batch_bpb:1.0443 running_loss:2.7611 running_bpb:1.0703 doc_len:1068-1073 +ttt_progress: batch 618/782 batch_loss:2.7370 batch_bpb:1.0493 running_loss:2.7606 running_bpb:1.0698 doc_len:1031-1037 +ttt_progress: batch 612/782 batch_loss:2.8251 batch_bpb:1.0435 running_loss:2.7618 running_bpb:1.0693 doc_len:1007-1012 +ttt_progress: batch 605/782 batch_loss:2.7450 batch_bpb:1.0588 running_loss:2.7615 running_bpb:1.0692 doc_len:978-982 +ttt_progress: batch 598/782 batch_loss:2.7992 batch_bpb:1.0662 running_loss:2.7621 running_bpb:1.0691 doc_len:950-954 +ttt_progress: batch 595/782 batch_loss:2.7277 batch_bpb:1.0546 running_loss:2.7616 running_bpb:1.0689 doc_len:940-943 +ttt_progress: batch 588/782 batch_loss:2.7343 batch_bpb:1.0432 running_loss:2.7612 running_bpb:1.0685 doc_len:917-921 +ttt_progress: batch 579/782 batch_loss:2.6416 batch_bpb:1.0068 running_loss:2.7594 running_bpb:1.0675 doc_len:888-891 +ttt_progress: batch 572/782 batch_loss:2.9391 batch_bpb:1.1185 running_loss:2.7619 running_bpb:1.0683 doc_len:865-868 +ttt_progress: batch 563/782 batch_loss:2.8033 batch_bpb:1.0634 running_loss:2.7625 running_bpb:1.0682 doc_len:837-840 +ttt_progress: batch 556/782 batch_loss:2.8257 batch_bpb:1.0803 running_loss:2.7633 running_bpb:1.0684 doc_len:815-818 +ttt_progress: batch 549/782 batch_loss:2.7619 batch_bpb:1.0626 running_loss:2.7633 running_bpb:1.0683 doc_len:795-798 +ttt_progress: batch 542/782 batch_loss:2.8301 batch_bpb:1.0721 running_loss:2.7641 running_bpb:1.0683 doc_len:777-779 +ttt_progress: batch 535/782 batch_loss:2.7831 batch_bpb:1.0552 running_loss:2.7644 running_bpb:1.0682 doc_len:759-762 +ttt_progress: batch 528/782 batch_loss:2.7513 batch_bpb:1.0307 running_loss:2.7642 running_bpb:1.0677 doc_len:742-745 +ttt_progress: batch 521/782 batch_loss:2.7661 batch_bpb:1.0496 running_loss:2.7642 running_bpb:1.0675 doc_len:725-727 +ttt_progress: batch 514/782 batch_loss:2.9054 batch_bpb:1.0959 running_loss:2.7657 running_bpb:1.0679 doc_len:707-710 +ttt_progress: batch 507/782 batch_loss:2.7585 batch_bpb:1.0414 running_loss:2.7656 running_bpb:1.0676 doc_len:690-693 +ttt_progress: batch 500/782 batch_loss:2.8358 batch_bpb:1.0832 running_loss:2.7663 running_bpb:1.0677 doc_len:675-677 +ttt_progress: batch 493/782 batch_loss:2.8411 batch_bpb:1.1141 running_loss:2.7671 running_bpb:1.0682 doc_len:659-661 +ttt_progress: batch 486/782 batch_loss:2.7977 batch_bpb:1.0619 running_loss:2.7673 running_bpb:1.0681 doc_len:645-646 +ttt_progress: batch 479/782 batch_loss:2.7103 batch_bpb:1.0344 running_loss:2.7668 running_bpb:1.0678 doc_len:630-632 +ttt_progress: batch 472/782 batch_loss:2.8018 batch_bpb:1.0710 running_loss:2.7671 running_bpb:1.0678 doc_len:616-618 +ttt_progress: batch 466/782 batch_loss:2.8020 batch_bpb:1.0653 running_loss:2.7674 running_bpb:1.0678 doc_len:604-606 +ttt_progress: batch 459/782 batch_loss:2.7402 batch_bpb:1.0398 running_loss:2.7672 running_bpb:1.0676 doc_len:591-593 +ttt_progress: batch 451/782 batch_loss:2.7744 batch_bpb:1.0628 running_loss:2.7673 running_bpb:1.0675 doc_len:576-579 +ttt_progress: batch 444/782 batch_loss:2.6744 batch_bpb:1.0133 running_loss:2.7665 running_bpb:1.0671 doc_len:564-566 +ttt_progress: batch 437/782 batch_loss:2.8731 batch_bpb:1.0602 running_loss:2.7673 running_bpb:1.0671 doc_len:551-553 +ttt_progress: batch 430/782 batch_loss:2.7585 batch_bpb:1.0472 running_loss:2.7673 running_bpb:1.0669 doc_len:539-540 +ttt_progress: batch 423/782 batch_loss:2.7382 batch_bpb:1.0285 running_loss:2.7671 running_bpb:1.0666 doc_len:526-528 +ttt_progress: batch 419/782 batch_loss:2.7936 batch_bpb:1.0382 running_loss:2.7673 running_bpb:1.0664 doc_len:519-521 +ttt_progress: batch 412/782 batch_loss:2.7113 batch_bpb:1.0530 running_loss:2.7669 running_bpb:1.0663 doc_len:508-510 +ttt_progress: batch 405/782 batch_loss:2.8215 batch_bpb:1.0665 running_loss:2.7672 running_bpb:1.0663 doc_len:497-498 +ttt_progress: batch 398/782 batch_loss:2.8788 batch_bpb:1.0934 running_loss:2.7679 running_bpb:1.0665 doc_len:486-487 +ttt_progress: batch 391/782 batch_loss:2.8180 batch_bpb:1.0975 running_loss:2.7683 running_bpb:1.0667 doc_len:475-476 +ttt_progress: batch 384/782 batch_loss:2.8534 batch_bpb:1.0947 running_loss:2.7688 running_bpb:1.0669 doc_len:464-466 +ttt_progress: batch 377/782 batch_loss:2.8012 batch_bpb:1.0861 running_loss:2.7690 running_bpb:1.0670 doc_len:454-455 +ttt_progress: batch 370/782 batch_loss:2.6797 batch_bpb:1.0426 running_loss:2.7684 running_bpb:1.0668 doc_len:444-446 +ttt_progress: batch 363/782 batch_loss:2.7416 batch_bpb:1.0933 running_loss:2.7683 running_bpb:1.0670 doc_len:434-436 +ttt_progress: batch 354/782 batch_loss:2.7964 batch_bpb:1.0849 running_loss:2.7684 running_bpb:1.0671 doc_len:422-423 +ttt_progress: batch 347/782 batch_loss:2.8564 batch_bpb:1.0888 running_loss:2.7689 running_bpb:1.0672 doc_len:413-414 +ttt_progress: batch 340/782 batch_loss:2.8209 batch_bpb:1.0912 running_loss:2.7692 running_bpb:1.0673 doc_len:403-404 +ttt_progress: batch 333/782 batch_loss:2.9049 batch_bpb:1.1313 running_loss:2.7698 running_bpb:1.0676 doc_len:394-395 +ttt_progress: batch 326/782 batch_loss:2.8608 batch_bpb:1.1307 running_loss:2.7703 running_bpb:1.0679 doc_len:385-387 +ttt_progress: batch 318/782 batch_loss:2.8156 batch_bpb:1.0680 running_loss:2.7705 running_bpb:1.0679 doc_len:374-376 +ttt_progress: batch 311/782 batch_loss:2.8528 batch_bpb:1.0929 running_loss:2.7708 running_bpb:1.0680 doc_len:365-367 +ttt_progress: batch 304/782 batch_loss:2.9095 batch_bpb:1.1331 running_loss:2.7714 running_bpb:1.0683 doc_len:357-358 +ttt_progress: batch 297/782 batch_loss:2.7950 batch_bpb:1.0589 running_loss:2.7715 running_bpb:1.0683 doc_len:348-349 +ttt_progress: batch 290/782 batch_loss:2.8584 batch_bpb:1.0829 running_loss:2.7719 running_bpb:1.0683 doc_len:340-341 +ttt_progress: batch 283/782 batch_loss:2.7984 batch_bpb:1.0735 running_loss:2.7720 running_bpb:1.0683 doc_len:332-333 +ttt_progress: batch 277/782 batch_loss:2.8098 batch_bpb:1.1066 running_loss:2.7721 running_bpb:1.0685 doc_len:325-326 +ttt_progress: batch 270/782 batch_loss:2.7779 batch_bpb:1.0902 running_loss:2.7722 running_bpb:1.0686 doc_len:318-319 +ttt_progress: batch 263/782 batch_loss:2.8343 batch_bpb:1.1039 running_loss:2.7724 running_bpb:1.0687 doc_len:310-311 +ttt_progress: batch 256/782 batch_loss:2.8866 batch_bpb:1.1316 running_loss:2.7728 running_bpb:1.0689 doc_len:301-302 +ttt_progress: batch 249/782 batch_loss:2.9033 batch_bpb:1.1564 running_loss:2.7733 running_bpb:1.0692 doc_len:294-295 +ttt_progress: batch 242/782 batch_loss:2.8987 batch_bpb:1.1083 running_loss:2.7737 running_bpb:1.0694 doc_len:287-288 +ttt_progress: batch 235/782 batch_loss:2.9307 batch_bpb:1.1140 running_loss:2.7742 running_bpb:1.0695 doc_len:280-281 +ttt_progress: batch 228/782 batch_loss:2.8714 batch_bpb:1.1363 running_loss:2.7745 running_bpb:1.0697 doc_len:273-274 +ttt_progress: batch 221/782 batch_loss:2.8423 batch_bpb:1.1407 running_loss:2.7747 running_bpb:1.0699 doc_len:266-267 +ttt_progress: batch 214/782 batch_loss:2.9320 batch_bpb:1.1279 running_loss:2.7752 running_bpb:1.0701 doc_len:259-260 +ttt_progress: batch 208/782 batch_loss:2.8251 batch_bpb:1.1155 running_loss:2.7753 running_bpb:1.0702 doc_len:254-254 +ttt_progress: batch 199/782 batch_loss:2.9351 batch_bpb:1.1249 running_loss:2.7758 running_bpb:1.0704 doc_len:246-247 +ttt_progress: batch 193/782 batch_loss:2.8785 batch_bpb:1.1598 running_loss:2.7761 running_bpb:1.0706 doc_len:240-241 +ttt_progress: batch 189/782 batch_loss:2.9648 batch_bpb:1.2033 running_loss:2.7766 running_bpb:1.0710 doc_len:237-237 +ttt_progress: batch 181/782 batch_loss:2.8776 batch_bpb:1.1563 running_loss:2.7768 running_bpb:1.0712 doc_len:230-230 +ttt_progress: batch 174/782 batch_loss:2.9701 batch_bpb:1.1531 running_loss:2.7773 running_bpb:1.0714 doc_len:224-224 +ttt_progress: batch 167/782 batch_loss:2.9676 batch_bpb:1.1863 running_loss:2.7778 running_bpb:1.0717 doc_len:218-218 +ttt_progress: batch 160/782 batch_loss:2.8660 batch_bpb:1.1264 running_loss:2.7780 running_bpb:1.0718 doc_len:212-212 +ttt_progress: batch 152/782 batch_loss:2.9008 batch_bpb:1.1318 running_loss:2.7783 running_bpb:1.0719 doc_len:205-206 +ttt_progress: batch 144/782 batch_loss:2.8253 batch_bpb:1.1238 running_loss:2.7784 running_bpb:1.0720 doc_len:199-200 +ttt_progress: batch 139/782 batch_loss:3.0010 batch_bpb:1.1614 running_loss:2.7789 running_bpb:1.0722 doc_len:195-195 +ttt_progress: batch 132/782 batch_loss:2.9586 batch_bpb:1.1386 running_loss:2.7793 running_bpb:1.0724 doc_len:189-189 +ttt_progress: batch 124/782 batch_loss:2.8793 batch_bpb:1.1518 running_loss:2.7795 running_bpb:1.0725 doc_len:183-184 +ttt_progress: batch 116/782 batch_loss:2.9951 batch_bpb:1.1844 running_loss:2.7799 running_bpb:1.0727 doc_len:177-178 +ttt_progress: batch 109/782 batch_loss:3.0750 batch_bpb:1.2118 running_loss:2.7805 running_bpb:1.0730 doc_len:172-173 +ttt_progress: batch 102/782 batch_loss:2.8056 batch_bpb:1.1297 running_loss:2.7805 running_bpb:1.0731 doc_len:167-168 +ttt_progress: batch 96/782 batch_loss:2.9505 batch_bpb:1.1532 running_loss:2.7808 running_bpb:1.0733 doc_len:162-163 +ttt_progress: batch 89/782 batch_loss:3.0140 batch_bpb:1.2020 running_loss:2.7812 running_bpb:1.0735 doc_len:157-158 +ttt_progress: batch 83/782 batch_loss:3.0262 batch_bpb:1.2093 running_loss:2.7817 running_bpb:1.0737 doc_len:152-153 +ttt_progress: batch 77/782 batch_loss:3.0259 batch_bpb:1.1692 running_loss:2.7821 running_bpb:1.0738 doc_len:148-148 +ttt_progress: batch 67/782 batch_loss:3.0691 batch_bpb:1.2405 running_loss:2.7825 running_bpb:1.0741 doc_len:140-141 +ttt_progress: batch 60/782 batch_loss:3.0546 batch_bpb:1.2258 running_loss:2.7829 running_bpb:1.0743 doc_len:134-135 +ttt_progress: batch 54/782 batch_loss:3.0860 batch_bpb:1.2636 running_loss:2.7833 running_bpb:1.0746 doc_len:130-130 +ttt_progress: batch 45/782 batch_loss:3.0913 batch_bpb:1.2367 running_loss:2.7838 running_bpb:1.0748 doc_len:122-123 +ttt_progress: batch 38/782 batch_loss:3.0322 batch_bpb:1.2103 running_loss:2.7841 running_bpb:1.0749 doc_len:117-118 +ttt_progress: batch 31/782 batch_loss:3.1833 batch_bpb:1.2615 running_loss:2.7846 running_bpb:1.0752 doc_len:111-112 +ttt_progress: batch 24/782 batch_loss:3.0488 batch_bpb:1.2062 running_loss:2.7849 running_bpb:1.0753 doc_len:105-106 +ttt_progress: batch 17/782 batch_loss:3.1453 batch_bpb:1.2467 running_loss:2.7852 running_bpb:1.0755 doc_len:98-99 +ttt_progress: batch 9/782 batch_loss:3.2255 batch_bpb:1.2782 running_loss:2.7857 running_bpb:1.0757 doc_len:87-89 +ttt_progress: batch 2/782 batch_loss:3.1535 batch_bpb:1.1696 running_loss:2.7860 running_bpb:1.0758 doc_len:70-75 +quantized_ttt_lora val_loss:2.77192541 val_bpb:1.07309871 eval_time:221256ms +total_eval_time:221.3s diff --git a/records/track_10min_16mb/2026-04-12_VarLen_TTT_Warmdown75_Chunk48/train_seed1337.log b/records/track_10min_16mb/2026-04-12_VarLen_TTT_Warmdown75_Chunk48/train_seed1337.log new file mode 100644 index 0000000000..5390869be1 --- /dev/null +++ b/records/track_10min_16mb/2026-04-12_VarLen_TTT_Warmdown75_Chunk48/train_seed1337.log @@ -0,0 +1,265 @@ + +***************************************** +Setting OMP_NUM_THREADS environment variable for each process to be 1 in default, to avoid your system being overloaded, please further tune the variable for optimal performance in your application as needed. +***************************************** +Hyperparameters: + adam_eps: 1e-08 + adam_wd: 0.02 + artifact_dir: + beta1: 0.9 + beta2: 0.95 + compressor: brotli + data_dir: /home/dex/parameter-golf-with-cc/data + datasets_dir: /home/dex/parameter-golf-with-cc/data/datasets/fineweb10B_sp8192 + distributed: True + ema_decay: 0.9965 + embed_bits: 8 + embed_clip_sigmas: 20.0 + embed_lr: 0.6 + embed_wd: 0.085 + embedding_dim: 512 + enable_looping_at: 0.35 + eval_only_path: + eval_seq_len: 2048 + eval_stride: 64 + gptq_calibration_batches: 64 + gptq_reserve_seconds: 13.0 + grad_accum_steps: 1 + grad_clip_norm: 0.3 + head_lr: 0.008 + is_main_process: True + iterations: 20000 + ln_scale: True + local_rank: 0 + logfile: logs/PR1530_v2_wd75_c48_s1337.txt + logit_softcap: 30.0 + loop_end: 5 + loop_start: 3 + matrix_bits: 6 + matrix_clip_sigmas: 12.85 + matrix_lr: 0.022 + max_wallclock_seconds: 600.0 + min_lr: 0.0 + mlp_mult: 4.0 + model_dim: 512 + model_path: final_model.pt + muon_backend_steps: 5 + muon_beta2: 0.95 + muon_momentum: 0.97 + muon_momentum_warmup_start: 0.92 + muon_momentum_warmup_steps: 1500 + muon_row_normalize: True + muon_wd: 0.095 + num_heads: 8 + num_kv_heads: 4 + num_layers: 11 + num_loops: 2 + parallel_final_lane: mean + parallel_start_layer: 8 + qk_gain_init: 5.0 + quantized_model_path: final_model.int6.ptz + rank: 0 + rope_base: 10000.0 + rope_dims: 16 + rope_train_seq_len: 2048 + rope_yarn: False + run_id: PR1530_v2_wd75_c48_s1337 + scalar_lr: 0.02 + seed: 1337 + skip_gates_enabled: True + sliding_window_enabled: False + tie_embeddings: True + tied_embed_init_std: 0.005 + tied_embed_lr: 0.03 + tokenizer_path: /home/dex/parameter-golf-with-cc/data/tokenizers/fineweb_8192_bpe.model + train_batch_tokens: 786432 + train_files: /home/dex/parameter-golf-with-cc/data/datasets/fineweb10B_sp8192/fineweb_train_*.bin + train_log_every: 500 + train_seq_len: 2048 + ttt_batch_size: 64 + ttt_beta1: 0.0 + ttt_beta2: 0.999 + ttt_chunk_size: 48 + ttt_enabled: True + ttt_eval_batches: + ttt_eval_seq_len: 2048 + ttt_grad_steps: 1 + ttt_k_lora: True + ttt_lora_lr: 0.0001 + ttt_lora_rank: 96 + ttt_mlp_lora: True + ttt_o_lora: True + ttt_optimizer: adam + ttt_output_dir: + ttt_weight_decay: 0.5 + val_batch_tokens: 524288 + val_doc_fraction: 1.0 + val_files: /home/dex/parameter-golf-with-cc/data/datasets/fineweb10B_sp8192/fineweb_val_*.bin + val_loss_every: 4000 + vocab_size: 8192 + warmdown_frac: 0.75 + warmup_steps: 20 + world_size: 8 + xsa_last_n: 11 +train_shards: 80 +val_tokens: 40540160 +model_params:35944602 +gptq:reserving 13s, effective=587000ms +warmup_cu_buckets:64,128,192,256 iters_each:3 +warmup_step: 1/20 +warmup_step: 2/20 +warmup_step: 3/20 +warmup_step: 4/20 +warmup_step: 5/20 +warmup_step: 6/20 +warmup_step: 10/20 +warmup_step: 20/20 +loop_warmup:enabled encoder:[0, 1, 2, 3, 4, 5, 3, 4] decoder:[5, 3, 4, 5, 6, 7, 8, 9, 10] +loop_warmup_step: 1/20 +loop_warmup_step: 2/20 +loop_warmup_step: 3/20 +loop_warmup_step: 4/20 +loop_warmup_step: 5/20 +loop_warmup_step: 6/20 +loop_warmup_step: 10/20 +loop_warmup_step: 20/20 +0/20000 val_loss: 9.0095 val_bpb: 3.4877 +1/20000 train_loss: 9.0094 train_time: 0.0m tok/s: 16546769 +2/20000 train_loss: 12.2233 train_time: 0.0m tok/s: 12872122 +3/20000 train_loss: 11.2583 train_time: 0.0m tok/s: 11054269 +4/20000 train_loss: 9.6375 train_time: 0.0m tok/s: 10270787 +5/20000 train_loss: 8.2508 train_time: 0.0m tok/s: 9874358 +500/20000 train_loss: 3.2652 train_time: 0.8m tok/s: 8314872 +1000/20000 train_loss: 3.0229 train_time: 1.6m tok/s: 8286827 +1500/20000 train_loss: 3.0305 train_time: 2.4m tok/s: 8281417 +2000/20000 train_loss: 2.9830 train_time: 3.2m tok/s: 8282930 +layer_loop:enabled step:2162 frac:0.350 encoder:[0, 1, 2, 3, 4, 5, 3, 4] decoder:[5, 3, 4, 5, 6, 7, 8, 9, 10] +2500/20000 train_loss: 3.0701 train_time: 4.2m tok/s: 7790541 +3000/20000 train_loss: 2.9064 train_time: 5.4m tok/s: 7327769 +3500/20000 train_loss: 2.9745 train_time: 6.5m tok/s: 7031097 +4000/20000 train_loss: 2.9048 train_time: 7.7m tok/s: 6824180 +4000/20000 val_loss: 2.8800 val_bpb: 1.1149 +4500/20000 train_loss: 2.8605 train_time: 8.8m tok/s: 6671988 +4908/20000 val_loss: 2.7740 val_bpb: 1.0739 +stopping_early: wallclock_cap train_time: 587103ms step: 4908/20000 +peak memory allocated: 40019 MiB reserved: 44090 MiB +ema:applying EMA weights +diagnostic pre-quantization post-ema val_loss:2.77277761 val_bpb:1.07339403 eval_time:6129ms +Serialized model: 135409136 bytes +Code size (uncompressed): 116419 bytes +Code size (compressed): 26790 bytes +GPTQ:collecting Hessians from calibration data... +GPTQ:collected 67 Hessians in 12.6s +Quantized weights: + gptq (int6): blocks.attn.c_k.weight, blocks.attn.c_q.weight, blocks.attn.c_v.weight, blocks.attn.proj.weight, blocks.mlp.fc.weight, blocks.mlp.proj.weight + gptq (int8): tok_emb.weight + passthrough (float16): blocks.attn.q_gain, blocks.attn_scale, blocks.mlp_scale, blocks.resid_mix, parallel_post_lambdas, parallel_resid_lambdas, skip_gates, skip_weights +Serialized model quantized+brotli: 15961820 bytes +Total submission size quantized+brotli: 15988610 bytes +diagnostic quantized val_loss:2.80583112 val_bpb:1.08618966 eval_time:9186ms +ttt_lora:warming up compile (random tokens, no val data) +ttt_lora:compile warmup done (84.6s) + +beginning TTT eval timer +ttt_lora:docs:50000 rank:96 lr:0.0001 chunk:48 +ttt_progress: batch 775/782 batch_loss:2.6934 batch_bpb:1.0663 running_loss:2.6934 running_bpb:1.0663 doc_len:5853-6355 +ttt_progress: batch 774/782 batch_loss:2.7336 batch_bpb:1.0796 running_loss:2.7128 running_bpb:1.0727 doc_len:5552-5852 +ttt_progress: batch 769/782 batch_loss:2.7750 batch_bpb:1.0982 running_loss:2.7297 running_bpb:1.0796 doc_len:4307-4479 +ttt_progress: batch 763/782 batch_loss:2.7988 batch_bpb:1.1044 running_loss:2.7422 running_bpb:1.0841 doc_len:3536-3637 +ttt_progress: batch 757/782 batch_loss:2.6461 batch_bpb:1.0226 running_loss:2.7293 running_bpb:1.0757 doc_len:3033-3108 +ttt_progress: batch 752/782 batch_loss:2.7708 batch_bpb:1.0630 running_loss:2.7338 running_bpb:1.0743 doc_len:2740-2793 +ttt_progress: batch 744/782 batch_loss:2.6603 batch_bpb:1.0598 running_loss:2.7275 running_bpb:1.0731 doc_len:2388-2419 +ttt_progress: batch 738/782 batch_loss:2.7529 batch_bpb:1.0562 running_loss:2.7293 running_bpb:1.0718 doc_len:2194-2227 +ttt_progress: batch 733/782 batch_loss:2.7652 batch_bpb:1.0552 running_loss:2.7316 running_bpb:1.0707 doc_len:2062-2090 +ttt_progress: batch 728/782 batch_loss:2.7680 batch_bpb:1.0720 running_loss:2.7337 running_bpb:1.0708 doc_len:1960-1977 +ttt_progress: batch 720/782 batch_loss:2.8259 batch_bpb:1.0794 running_loss:2.7384 running_bpb:1.0712 doc_len:1816-1832 +ttt_progress: batch 713/782 batch_loss:2.8420 batch_bpb:1.0485 running_loss:2.7430 running_bpb:1.0701 doc_len:1697-1711 +ttt_progress: batch 706/782 batch_loss:2.7227 batch_bpb:1.0466 running_loss:2.7422 running_bpb:1.0692 doc_len:1617-1627 +ttt_progress: batch 700/782 batch_loss:2.6828 batch_bpb:1.0471 running_loss:2.7399 running_bpb:1.0683 doc_len:1552-1562 +ttt_progress: batch 691/782 batch_loss:2.7060 batch_bpb:1.0449 running_loss:2.7388 running_bpb:1.0675 doc_len:1467-1476 +ttt_progress: batch 685/782 batch_loss:2.7784 batch_bpb:1.0649 running_loss:2.7401 running_bpb:1.0674 doc_len:1414-1422 +ttt_progress: batch 677/782 batch_loss:2.8688 batch_bpb:1.1121 running_loss:2.7439 running_bpb:1.0688 doc_len:1353-1360 +ttt_progress: batch 669/782 batch_loss:2.7867 batch_bpb:1.0567 running_loss:2.7451 running_bpb:1.0684 doc_len:1301-1308 +ttt_progress: batch 661/782 batch_loss:2.7312 batch_bpb:1.0240 running_loss:2.7447 running_bpb:1.0672 doc_len:1251-1258 +ttt_progress: batch 655/782 batch_loss:2.6901 batch_bpb:1.0233 running_loss:2.7434 running_bpb:1.0661 doc_len:1215-1220 +ttt_progress: batch 645/782 batch_loss:2.8010 batch_bpb:1.0959 running_loss:2.7447 running_bpb:1.0668 doc_len:1160-1166 +ttt_progress: batch 637/782 batch_loss:2.8084 batch_bpb:1.0820 running_loss:2.7461 running_bpb:1.0671 doc_len:1120-1123 +ttt_progress: batch 629/782 batch_loss:2.7338 batch_bpb:1.0474 running_loss:2.7459 running_bpb:1.0667 doc_len:1082-1086 +ttt_progress: batch 621/782 batch_loss:2.8404 batch_bpb:1.0878 running_loss:2.7477 running_bpb:1.0671 doc_len:1046-1050 +ttt_progress: batch 613/782 batch_loss:2.8250 batch_bpb:1.0633 running_loss:2.7492 running_bpb:1.0671 doc_len:1012-1016 +ttt_progress: batch 603/782 batch_loss:2.8368 batch_bpb:1.0866 running_loss:2.7507 running_bpb:1.0674 doc_len:971-974 +ttt_progress: batch 594/782 batch_loss:2.9045 batch_bpb:1.1031 running_loss:2.7532 running_bpb:1.0680 doc_len:937-940 +ttt_progress: batch 586/782 batch_loss:2.7292 batch_bpb:1.0155 running_loss:2.7529 running_bpb:1.0671 doc_len:911-914 +ttt_progress: batch 577/782 batch_loss:2.7637 batch_bpb:1.0453 running_loss:2.7530 running_bpb:1.0668 doc_len:880-884 +ttt_progress: batch 569/782 batch_loss:2.7655 batch_bpb:1.0566 running_loss:2.7532 running_bpb:1.0666 doc_len:855-858 +ttt_progress: batch 562/782 batch_loss:2.7117 batch_bpb:1.0249 running_loss:2.7526 running_bpb:1.0661 doc_len:834-837 +ttt_progress: batch 554/782 batch_loss:2.7424 batch_bpb:1.0323 running_loss:2.7525 running_bpb:1.0656 doc_len:809-812 +ttt_progress: batch 549/782 batch_loss:2.7645 batch_bpb:1.0636 running_loss:2.7526 running_bpb:1.0656 doc_len:795-798 +ttt_progress: batch 540/782 batch_loss:2.7023 batch_bpb:1.0195 running_loss:2.7520 running_bpb:1.0650 doc_len:771-774 +ttt_progress: batch 529/782 batch_loss:2.7812 batch_bpb:1.0593 running_loss:2.7524 running_bpb:1.0649 doc_len:745-747 +ttt_progress: batch 521/782 batch_loss:2.7696 batch_bpb:1.0509 running_loss:2.7526 running_bpb:1.0647 doc_len:725-727 +ttt_progress: batch 513/782 batch_loss:2.7356 batch_bpb:1.0126 running_loss:2.7524 running_bpb:1.0641 doc_len:705-707 +ttt_progress: batch 505/782 batch_loss:2.7793 batch_bpb:1.0619 running_loss:2.7527 running_bpb:1.0641 doc_len:686-688 +ttt_progress: batch 497/782 batch_loss:2.8420 batch_bpb:1.0836 running_loss:2.7536 running_bpb:1.0643 doc_len:668-671 +ttt_progress: batch 491/782 batch_loss:2.7415 batch_bpb:1.0329 running_loss:2.7535 running_bpb:1.0640 doc_len:655-657 +ttt_progress: batch 483/782 batch_loss:2.7479 batch_bpb:1.0508 running_loss:2.7534 running_bpb:1.0639 doc_len:639-641 +ttt_progress: batch 475/782 batch_loss:2.7299 batch_bpb:1.0234 running_loss:2.7532 running_bpb:1.0635 doc_len:622-623 +ttt_progress: batch 469/782 batch_loss:2.8001 batch_bpb:1.1134 running_loss:2.7536 running_bpb:1.0639 doc_len:610-611 +ttt_progress: batch 461/782 batch_loss:2.7723 batch_bpb:1.0572 running_loss:2.7538 running_bpb:1.0639 doc_len:595-597 +ttt_progress: batch 451/782 batch_loss:2.7804 batch_bpb:1.0651 running_loss:2.7540 running_bpb:1.0639 doc_len:576-579 +ttt_progress: batch 441/782 batch_loss:2.7148 batch_bpb:1.0451 running_loss:2.7537 running_bpb:1.0637 doc_len:559-560 +ttt_progress: batch 433/782 batch_loss:2.7839 batch_bpb:1.0685 running_loss:2.7539 running_bpb:1.0638 doc_len:544-545 +ttt_progress: batch 425/782 batch_loss:2.7592 batch_bpb:1.0497 running_loss:2.7539 running_bpb:1.0637 doc_len:530-532 +ttt_progress: batch 416/782 batch_loss:2.7639 batch_bpb:1.0375 running_loss:2.7540 running_bpb:1.0635 doc_len:514-516 +ttt_progress: batch 408/782 batch_loss:2.8432 batch_bpb:1.0875 running_loss:2.7546 running_bpb:1.0636 doc_len:501-503 +ttt_progress: batch 400/782 batch_loss:2.8030 batch_bpb:1.0692 running_loss:2.7550 running_bpb:1.0637 doc_len:489-490 +ttt_progress: batch 392/782 batch_loss:2.8074 batch_bpb:1.0840 running_loss:2.7553 running_bpb:1.0638 doc_len:476-478 +ttt_progress: batch 386/782 batch_loss:2.7389 batch_bpb:1.0700 running_loss:2.7552 running_bpb:1.0638 doc_len:467-468 +ttt_progress: batch 377/782 batch_loss:2.8059 batch_bpb:1.0879 running_loss:2.7555 running_bpb:1.0640 doc_len:454-455 +ttt_progress: batch 369/782 batch_loss:2.9279 batch_bpb:1.0868 running_loss:2.7565 running_bpb:1.0641 doc_len:443-444 +ttt_progress: batch 361/782 batch_loss:2.8066 batch_bpb:1.0731 running_loss:2.7568 running_bpb:1.0642 doc_len:432-433 +ttt_progress: batch 357/782 batch_loss:2.8685 batch_bpb:1.0854 running_loss:2.7575 running_bpb:1.0643 doc_len:426-427 +ttt_progress: batch 348/782 batch_loss:2.8198 batch_bpb:1.0716 running_loss:2.7578 running_bpb:1.0643 doc_len:414-415 +ttt_progress: batch 340/782 batch_loss:2.8326 batch_bpb:1.0957 running_loss:2.7582 running_bpb:1.0645 doc_len:403-404 +ttt_progress: batch 329/782 batch_loss:2.8376 batch_bpb:1.1068 running_loss:2.7586 running_bpb:1.0647 doc_len:389-390 +ttt_progress: batch 320/782 batch_loss:2.7681 batch_bpb:1.0798 running_loss:2.7586 running_bpb:1.0648 doc_len:377-378 +ttt_progress: batch 312/782 batch_loss:2.7386 batch_bpb:1.0690 running_loss:2.7585 running_bpb:1.0648 doc_len:367-368 +ttt_progress: batch 304/782 batch_loss:2.9183 batch_bpb:1.1366 running_loss:2.7593 running_bpb:1.0651 doc_len:357-358 +ttt_progress: batch 296/782 batch_loss:2.8187 batch_bpb:1.0901 running_loss:2.7595 running_bpb:1.0653 doc_len:347-348 +ttt_progress: batch 288/782 batch_loss:2.8146 batch_bpb:1.1051 running_loss:2.7598 running_bpb:1.0654 doc_len:337-339 +ttt_progress: batch 279/782 batch_loss:2.8511 batch_bpb:1.0897 running_loss:2.7602 running_bpb:1.0655 doc_len:327-329 +ttt_progress: batch 272/782 batch_loss:2.8721 batch_bpb:1.1142 running_loss:2.7606 running_bpb:1.0657 doc_len:320-321 +ttt_progress: batch 264/782 batch_loss:2.9056 batch_bpb:1.1501 running_loss:2.7612 running_bpb:1.0660 doc_len:311-312 +ttt_progress: batch 256/782 batch_loss:2.9006 batch_bpb:1.1370 running_loss:2.7617 running_bpb:1.0663 doc_len:301-302 +ttt_progress: batch 248/782 batch_loss:2.8984 batch_bpb:1.1061 running_loss:2.7622 running_bpb:1.0665 doc_len:293-294 +ttt_progress: batch 240/782 batch_loss:2.9151 batch_bpb:1.1570 running_loss:2.7628 running_bpb:1.0668 doc_len:285-286 +ttt_progress: batch 232/782 batch_loss:2.9272 batch_bpb:1.1322 running_loss:2.7633 running_bpb:1.0670 doc_len:277-278 +ttt_progress: batch 224/782 batch_loss:2.8298 batch_bpb:1.1116 running_loss:2.7635 running_bpb:1.0671 doc_len:269-270 +ttt_progress: batch 216/782 batch_loss:2.9420 batch_bpb:1.1195 running_loss:2.7641 running_bpb:1.0673 doc_len:261-262 +ttt_progress: batch 207/782 batch_loss:2.8471 batch_bpb:1.1200 running_loss:2.7644 running_bpb:1.0675 doc_len:253-254 +ttt_progress: batch 199/782 batch_loss:2.9497 batch_bpb:1.1305 running_loss:2.7649 running_bpb:1.0677 doc_len:246-247 +ttt_progress: batch 192/782 batch_loss:2.9118 batch_bpb:1.1478 running_loss:2.7654 running_bpb:1.0679 doc_len:239-240 +ttt_progress: batch 184/782 batch_loss:2.9155 batch_bpb:1.1577 running_loss:2.7658 running_bpb:1.0681 doc_len:232-233 +ttt_progress: batch 176/782 batch_loss:2.8230 batch_bpb:1.1075 running_loss:2.7659 running_bpb:1.0682 doc_len:225-226 +ttt_progress: batch 168/782 batch_loss:2.9210 batch_bpb:1.1447 running_loss:2.7664 running_bpb:1.0684 doc_len:218-219 +ttt_progress: batch 159/782 batch_loss:2.9940 batch_bpb:1.1795 running_loss:2.7669 running_bpb:1.0687 doc_len:211-212 +ttt_progress: batch 151/782 batch_loss:2.8044 batch_bpb:1.1052 running_loss:2.7670 running_bpb:1.0688 doc_len:204-205 +ttt_progress: batch 143/782 batch_loss:3.0217 batch_bpb:1.1969 running_loss:2.7676 running_bpb:1.0691 doc_len:198-199 +ttt_progress: batch 135/782 batch_loss:2.9327 batch_bpb:1.1425 running_loss:2.7680 running_bpb:1.0693 doc_len:191-192 +ttt_progress: batch 125/782 batch_loss:3.0153 batch_bpb:1.1949 running_loss:2.7686 running_bpb:1.0695 doc_len:184-185 +ttt_progress: batch 116/782 batch_loss:3.0068 batch_bpb:1.1890 running_loss:2.7691 running_bpb:1.0698 doc_len:177-178 +ttt_progress: batch 108/782 batch_loss:2.8741 batch_bpb:1.1038 running_loss:2.7693 running_bpb:1.0699 doc_len:171-172 +ttt_progress: batch 99/782 batch_loss:2.9839 batch_bpb:1.1864 running_loss:2.7697 running_bpb:1.0701 doc_len:164-165 +ttt_progress: batch 89/782 batch_loss:3.0287 batch_bpb:1.2079 running_loss:2.7702 running_bpb:1.0703 doc_len:157-158 +ttt_progress: batch 82/782 batch_loss:2.9864 batch_bpb:1.2021 running_loss:2.7706 running_bpb:1.0706 doc_len:151-152 +ttt_progress: batch 74/782 batch_loss:3.1392 batch_bpb:1.2831 running_loss:2.7712 running_bpb:1.0709 doc_len:145-146 +ttt_progress: batch 64/782 batch_loss:3.0140 batch_bpb:1.2493 running_loss:2.7716 running_bpb:1.0712 doc_len:138-139 +ttt_progress: batch 57/782 batch_loss:3.0548 batch_bpb:1.2314 running_loss:2.7720 running_bpb:1.0714 doc_len:132-133 +ttt_progress: batch 48/782 batch_loss:3.0090 batch_bpb:1.1764 running_loss:2.7724 running_bpb:1.0716 doc_len:125-126 +ttt_progress: batch 39/782 batch_loss:3.1419 batch_bpb:1.2417 running_loss:2.7729 running_bpb:1.0718 doc_len:118-119 +ttt_progress: batch 31/782 batch_loss:3.1998 batch_bpb:1.2680 running_loss:2.7734 running_bpb:1.0720 doc_len:111-112 +ttt_progress: batch 22/782 batch_loss:3.1551 batch_bpb:1.2299 running_loss:2.7739 running_bpb:1.0722 doc_len:103-104 +ttt_progress: batch 13/782 batch_loss:3.1537 batch_bpb:1.2693 running_loss:2.7743 running_bpb:1.0724 doc_len:93-94 +ttt_progress: batch 6/782 batch_loss:3.2907 batch_bpb:1.2838 running_loss:2.7748 running_bpb:1.0726 doc_len:82-84 +quantized_ttt_lora val_loss:2.77829123 val_bpb:1.07556312 eval_time:219160ms +total_eval_time:219.2s diff --git a/records/track_10min_16mb/2026-04-12_VarLen_TTT_Warmdown75_Chunk48/train_seed42.log b/records/track_10min_16mb/2026-04-12_VarLen_TTT_Warmdown75_Chunk48/train_seed42.log new file mode 100644 index 0000000000..f2d449f6ad --- /dev/null +++ b/records/track_10min_16mb/2026-04-12_VarLen_TTT_Warmdown75_Chunk48/train_seed42.log @@ -0,0 +1,259 @@ + +***************************************** +Setting OMP_NUM_THREADS environment variable for each process to be 1 in default, to avoid your system being overloaded, please further tune the variable for optimal performance in your application as needed. +***************************************** +Hyperparameters: + adam_eps: 1e-08 + adam_wd: 0.02 + artifact_dir: + beta1: 0.9 + beta2: 0.95 + compressor: brotli + data_dir: /home/dex/parameter-golf-with-cc/data + datasets_dir: /home/dex/parameter-golf-with-cc/data/datasets/fineweb10B_sp8192 + distributed: True + ema_decay: 0.9965 + embed_bits: 8 + embed_clip_sigmas: 20.0 + embed_lr: 0.6 + embed_wd: 0.085 + embedding_dim: 512 + enable_looping_at: 0.35 + eval_only_path: + eval_seq_len: 2048 + eval_stride: 64 + gptq_calibration_batches: 64 + gptq_reserve_seconds: 13.0 + grad_accum_steps: 1 + grad_clip_norm: 0.3 + head_lr: 0.008 + is_main_process: True + iterations: 20000 + ln_scale: True + local_rank: 0 + logfile: logs/PR1530_v2_warmdown75_chunk48_s42.txt + logit_softcap: 30.0 + loop_end: 5 + loop_start: 3 + matrix_bits: 6 + matrix_clip_sigmas: 12.85 + matrix_lr: 0.022 + max_wallclock_seconds: 600.0 + min_lr: 0.0 + mlp_mult: 4.0 + model_dim: 512 + model_path: final_model.pt + muon_backend_steps: 5 + muon_beta2: 0.95 + muon_momentum: 0.97 + muon_momentum_warmup_start: 0.92 + muon_momentum_warmup_steps: 1500 + muon_row_normalize: True + muon_wd: 0.095 + num_heads: 8 + num_kv_heads: 4 + num_layers: 11 + num_loops: 2 + parallel_final_lane: mean + parallel_start_layer: 8 + qk_gain_init: 5.0 + quantized_model_path: final_model.int6.ptz + rank: 0 + rope_base: 10000.0 + rope_dims: 16 + rope_train_seq_len: 2048 + rope_yarn: False + run_id: PR1530_v2_warmdown75_chunk48_s42 + scalar_lr: 0.02 + seed: 42 + skip_gates_enabled: True + sliding_window_enabled: False + tie_embeddings: True + tied_embed_init_std: 0.005 + tied_embed_lr: 0.03 + tokenizer_path: /home/dex/parameter-golf-with-cc/data/tokenizers/fineweb_8192_bpe.model + train_batch_tokens: 786432 + train_files: /home/dex/parameter-golf-with-cc/data/datasets/fineweb10B_sp8192/fineweb_train_*.bin + train_log_every: 500 + train_seq_len: 2048 + ttt_batch_size: 64 + ttt_beta1: 0.0 + ttt_beta2: 0.999 + ttt_chunk_size: 48 + ttt_enabled: True + ttt_eval_batches: + ttt_eval_seq_len: 2048 + ttt_grad_steps: 1 + ttt_k_lora: True + ttt_lora_lr: 0.0001 + ttt_lora_rank: 96 + ttt_mlp_lora: True + ttt_o_lora: True + ttt_optimizer: adam + ttt_output_dir: + ttt_weight_decay: 0.5 + val_batch_tokens: 524288 + val_doc_fraction: 1.0 + val_files: /home/dex/parameter-golf-with-cc/data/datasets/fineweb10B_sp8192/fineweb_val_*.bin + val_loss_every: 4000 + vocab_size: 8192 + warmdown_frac: 0.75 + warmup_steps: 20 + world_size: 8 + xsa_last_n: 11 +train_shards: 80 +val_tokens: 40540160 +model_params:35944602 +gptq:reserving 13s, effective=587000ms +warmup_cu_buckets:64,128,192,256 iters_each:3 +warmup_step: 1/20 +warmup_step: 2/20 +warmup_step: 3/20 +warmup_step: 4/20 +warmup_step: 5/20 +warmup_step: 6/20 +warmup_step: 10/20 +warmup_step: 20/20 +loop_warmup:enabled encoder:[0, 1, 2, 3, 4, 5, 3, 4] decoder:[5, 3, 4, 5, 6, 7, 8, 9, 10] +loop_warmup_step: 1/20 +loop_warmup_step: 2/20 +loop_warmup_step: 3/20 +loop_warmup_step: 4/20 +loop_warmup_step: 5/20 +loop_warmup_step: 6/20 +loop_warmup_step: 10/20 +loop_warmup_step: 20/20 +0/20000 val_loss: 9.0078 val_bpb: 3.4871 +1/20000 train_loss: 9.0072 train_time: 0.0m tok/s: 16639965 +2/20000 train_loss: 12.3636 train_time: 0.0m tok/s: 12963687 +3/20000 train_loss: 11.3531 train_time: 0.0m tok/s: 11104862 +4/20000 train_loss: 9.6974 train_time: 0.0m tok/s: 10318334 +5/20000 train_loss: 8.2919 train_time: 0.0m tok/s: 9915096 +500/20000 train_loss: 3.2548 train_time: 0.8m tok/s: 8332183 +1000/20000 train_loss: 3.0138 train_time: 1.6m tok/s: 8320641 +1500/20000 train_loss: 3.0213 train_time: 2.4m tok/s: 8313830 +2000/20000 train_loss: 2.9738 train_time: 3.2m tok/s: 8310934 +layer_loop:enabled step:2169 frac:0.350 encoder:[0, 1, 2, 3, 4, 5, 3, 4] decoder:[5, 3, 4, 5, 6, 7, 8, 9, 10] +2500/20000 train_loss: 3.0599 train_time: 4.2m tok/s: 7820808 +3000/20000 train_loss: 2.8979 train_time: 5.3m tok/s: 7352679 +3500/20000 train_loss: 2.9655 train_time: 6.5m tok/s: 7052329 +4000/20000 train_loss: 2.9012 train_time: 7.7m tok/s: 6843373 +4000/20000 val_loss: 2.8758 val_bpb: 1.1133 +4500/20000 train_loss: 2.8528 train_time: 8.8m tok/s: 6688886 +4918/20000 val_loss: 2.7687 val_bpb: 1.0718 +stopping_early: wallclock_cap train_time: 587091ms step: 4918/20000 +peak memory allocated: 40019 MiB reserved: 44090 MiB +ema:applying EMA weights +diagnostic pre-quantization post-ema val_loss:2.76748025 val_bpb:1.07134332 eval_time:5748ms +Serialized model: 135409136 bytes +Code size (uncompressed): 116419 bytes +Code size (compressed): 26790 bytes +GPTQ:collecting Hessians from calibration data... +GPTQ:collected 67 Hessians in 12.6s +Quantized weights: + gptq (int6): blocks.attn.c_k.weight, blocks.attn.c_q.weight, blocks.attn.c_v.weight, blocks.attn.proj.weight, blocks.mlp.fc.weight, blocks.mlp.proj.weight + gptq (int8): tok_emb.weight + passthrough (float16): blocks.attn.q_gain, blocks.attn_scale, blocks.mlp_scale, blocks.resid_mix, parallel_post_lambdas, parallel_resid_lambdas, skip_gates, skip_weights +Serialized model quantized+brotli: 15967356 bytes +Total submission size quantized+brotli: 15994146 bytes +diagnostic quantized val_loss:2.80017166 val_bpb:1.08399877 eval_time:8823ms +ttt_lora:warming up compile (random tokens, no val data) +ttt_lora:compile warmup done (84.7s) + +beginning TTT eval timer +ttt_lora:docs:50000 rank:96 lr:0.0001 chunk:48 +ttt_progress: batch 781/782 batch_loss:2.5641 batch_bpb:1.0585 running_loss:2.5641 running_bpb:1.0585 doc_len:14510-25988 +ttt_progress: batch 722/782 batch_loss:2.7704 batch_bpb:1.0593 running_loss:2.5826 running_bpb:1.0586 doc_len:1846-1861 +ttt_progress: batch 715/782 batch_loss:2.6478 batch_bpb:1.0404 running_loss:2.5877 running_bpb:1.0571 doc_len:1725-1739 +ttt_progress: batch 708/782 batch_loss:2.7217 batch_bpb:1.0459 running_loss:2.5968 running_bpb:1.0563 doc_len:1639-1649 +ttt_progress: batch 701/782 batch_loss:2.7526 batch_bpb:1.0468 running_loss:2.6064 running_bpb:1.0557 doc_len:1562-1572 +ttt_progress: batch 694/782 batch_loss:2.7641 batch_bpb:1.0668 running_loss:2.6151 running_bpb:1.0563 doc_len:1494-1504 +ttt_progress: batch 687/782 batch_loss:2.7127 batch_bpb:1.0476 running_loss:2.6200 running_bpb:1.0559 doc_len:1432-1441 +ttt_progress: batch 681/782 batch_loss:2.8193 batch_bpb:1.0704 running_loss:2.6293 running_bpb:1.0566 doc_len:1383-1393 +ttt_progress: batch 678/782 batch_loss:2.7853 batch_bpb:1.0484 running_loss:2.6361 running_bpb:1.0562 doc_len:1361-1368 +ttt_progress: batch 669/782 batch_loss:2.7812 batch_bpb:1.0546 running_loss:2.6419 running_bpb:1.0561 doc_len:1301-1308 +ttt_progress: batch 659/782 batch_loss:2.7184 batch_bpb:1.0236 running_loss:2.6447 running_bpb:1.0549 doc_len:1239-1245 +ttt_progress: batch 658/782 batch_loss:2.8122 batch_bpb:1.0764 running_loss:2.6506 running_bpb:1.0557 doc_len:1234-1239 +ttt_progress: batch 651/782 batch_loss:2.7194 batch_bpb:1.0445 running_loss:2.6529 running_bpb:1.0553 doc_len:1193-1198 +ttt_progress: batch 643/782 batch_loss:2.7971 batch_bpb:1.0664 running_loss:2.6573 running_bpb:1.0556 doc_len:1150-1155 +ttt_progress: batch 635/782 batch_loss:2.7386 batch_bpb:1.0601 running_loss:2.6597 running_bpb:1.0558 doc_len:1111-1116 +ttt_progress: batch 627/782 batch_loss:2.7353 batch_bpb:1.0354 running_loss:2.6617 running_bpb:1.0552 doc_len:1073-1077 +ttt_progress: batch 618/782 batch_loss:2.7375 batch_bpb:1.0495 running_loss:2.6636 running_bpb:1.0550 doc_len:1031-1037 +ttt_progress: batch 610/782 batch_loss:2.8365 batch_bpb:1.0649 running_loss:2.6678 running_bpb:1.0553 doc_len:999-1004 +ttt_progress: batch 603/782 batch_loss:2.8364 batch_bpb:1.0865 running_loss:2.6716 running_bpb:1.0560 doc_len:971-974 +ttt_progress: batch 595/782 batch_loss:2.7325 batch_bpb:1.0565 running_loss:2.6730 running_bpb:1.0560 doc_len:940-943 +ttt_progress: batch 579/782 batch_loss:2.6389 batch_bpb:1.0057 running_loss:2.6723 running_bpb:1.0550 doc_len:888-891 +ttt_progress: batch 571/782 batch_loss:2.7142 batch_bpb:1.0354 running_loss:2.6731 running_bpb:1.0546 doc_len:862-865 +ttt_progress: batch 563/782 batch_loss:2.8009 batch_bpb:1.0625 running_loss:2.6754 running_bpb:1.0548 doc_len:837-840 +ttt_progress: batch 555/782 batch_loss:2.7677 batch_bpb:1.0562 running_loss:2.6770 running_bpb:1.0548 doc_len:812-815 +ttt_progress: batch 547/782 batch_loss:2.7330 batch_bpb:1.0322 running_loss:2.6779 running_bpb:1.0544 doc_len:790-793 +ttt_progress: batch 539/782 batch_loss:2.7275 batch_bpb:1.0443 running_loss:2.6787 running_bpb:1.0542 doc_len:769-771 +ttt_progress: batch 531/782 batch_loss:2.7752 batch_bpb:1.0526 running_loss:2.6802 running_bpb:1.0542 doc_len:750-752 +ttt_progress: batch 523/782 batch_loss:2.8170 batch_bpb:1.0579 running_loss:2.6822 running_bpb:1.0543 doc_len:730-732 +ttt_progress: batch 515/782 batch_loss:2.7854 batch_bpb:1.0736 running_loss:2.6836 running_bpb:1.0545 doc_len:710-713 +ttt_progress: batch 507/782 batch_loss:2.7586 batch_bpb:1.0414 running_loss:2.6846 running_bpb:1.0544 doc_len:690-693 +ttt_progress: batch 499/782 batch_loss:2.7829 batch_bpb:1.0502 running_loss:2.6859 running_bpb:1.0543 doc_len:673-675 +ttt_progress: batch 491/782 batch_loss:2.7391 batch_bpb:1.0320 running_loss:2.6866 running_bpb:1.0540 doc_len:655-657 +ttt_progress: batch 483/782 batch_loss:2.7442 batch_bpb:1.0494 running_loss:2.6873 running_bpb:1.0540 doc_len:639-641 +ttt_progress: batch 475/782 batch_loss:2.7312 batch_bpb:1.0240 running_loss:2.6878 running_bpb:1.0536 doc_len:622-623 +ttt_progress: batch 472/782 batch_loss:2.8009 batch_bpb:1.0707 running_loss:2.6890 running_bpb:1.0538 doc_len:616-618 +ttt_progress: batch 464/782 batch_loss:2.7228 batch_bpb:1.0790 running_loss:2.6894 running_bpb:1.0541 doc_len:600-602 +ttt_progress: batch 456/782 batch_loss:2.8168 batch_bpb:1.0698 running_loss:2.6908 running_bpb:1.0542 doc_len:586-587 +ttt_progress: batch 447/782 batch_loss:2.8270 batch_bpb:1.0871 running_loss:2.6921 running_bpb:1.0546 doc_len:569-571 +ttt_progress: batch 440/782 batch_loss:2.8676 batch_bpb:1.0948 running_loss:2.6938 running_bpb:1.0550 doc_len:556-559 +ttt_progress: batch 432/782 batch_loss:2.7656 batch_bpb:1.0522 running_loss:2.6945 running_bpb:1.0549 doc_len:542-544 +ttt_progress: batch 423/782 batch_loss:2.7391 batch_bpb:1.0288 running_loss:2.6949 running_bpb:1.0547 doc_len:526-528 +ttt_progress: batch 417/782 batch_loss:2.8110 batch_bpb:1.0541 running_loss:2.6960 running_bpb:1.0547 doc_len:516-517 +ttt_progress: batch 409/782 batch_loss:2.7069 batch_bpb:1.0458 running_loss:2.6960 running_bpb:1.0546 doc_len:503-505 +ttt_progress: batch 400/782 batch_loss:2.7949 batch_bpb:1.0661 running_loss:2.6969 running_bpb:1.0547 doc_len:489-490 +ttt_progress: batch 392/782 batch_loss:2.8008 batch_bpb:1.0815 running_loss:2.6977 running_bpb:1.0549 doc_len:476-478 +ttt_progress: batch 384/782 batch_loss:2.8444 batch_bpb:1.0913 running_loss:2.6988 running_bpb:1.0552 doc_len:464-466 +ttt_progress: batch 376/782 batch_loss:2.7214 batch_bpb:1.0451 running_loss:2.6990 running_bpb:1.0551 doc_len:453-454 +ttt_progress: batch 368/782 batch_loss:2.8527 batch_bpb:1.0884 running_loss:2.7001 running_bpb:1.0554 doc_len:441-443 +ttt_progress: batch 360/782 batch_loss:2.8345 batch_bpb:1.0809 running_loss:2.7010 running_bpb:1.0556 doc_len:430-432 +ttt_progress: batch 352/782 batch_loss:2.7573 batch_bpb:1.0961 running_loss:2.7014 running_bpb:1.0558 doc_len:419-420 +ttt_progress: batch 344/782 batch_loss:2.8870 batch_bpb:1.1066 running_loss:2.7026 running_bpb:1.0562 doc_len:408-410 +ttt_progress: batch 336/782 batch_loss:2.9574 batch_bpb:1.1687 running_loss:2.7042 running_bpb:1.0569 doc_len:398-399 +ttt_progress: batch 328/782 batch_loss:2.7912 batch_bpb:1.0823 running_loss:2.7048 running_bpb:1.0570 doc_len:388-389 +ttt_progress: batch 318/782 batch_loss:2.8185 batch_bpb:1.0691 running_loss:2.7054 running_bpb:1.0571 doc_len:374-376 +ttt_progress: batch 310/782 batch_loss:2.7991 batch_bpb:1.0844 running_loss:2.7060 running_bpb:1.0573 doc_len:364-365 +ttt_progress: batch 302/782 batch_loss:2.8318 batch_bpb:1.0983 running_loss:2.7066 running_bpb:1.0575 doc_len:354-355 +ttt_progress: batch 293/782 batch_loss:2.7622 batch_bpb:1.0671 running_loss:2.7069 running_bpb:1.0575 doc_len:343-345 +ttt_progress: batch 286/782 batch_loss:2.8915 batch_bpb:1.0984 running_loss:2.7079 running_bpb:1.0578 doc_len:335-336 +ttt_progress: batch 279/782 batch_loss:2.8551 batch_bpb:1.0913 running_loss:2.7086 running_bpb:1.0579 doc_len:327-329 +ttt_progress: batch 272/782 batch_loss:2.8649 batch_bpb:1.1114 running_loss:2.7094 running_bpb:1.0582 doc_len:320-321 +ttt_progress: batch 264/782 batch_loss:2.8967 batch_bpb:1.1465 running_loss:2.7103 running_bpb:1.0586 doc_len:311-312 +ttt_progress: batch 256/782 batch_loss:2.8935 batch_bpb:1.1343 running_loss:2.7111 running_bpb:1.0589 doc_len:301-302 +ttt_progress: batch 248/782 batch_loss:2.8908 batch_bpb:1.1032 running_loss:2.7119 running_bpb:1.0591 doc_len:293-294 +ttt_progress: batch 240/782 batch_loss:2.9097 batch_bpb:1.1548 running_loss:2.7127 running_bpb:1.0595 doc_len:285-286 +ttt_progress: batch 232/782 batch_loss:2.9304 batch_bpb:1.1334 running_loss:2.7136 running_bpb:1.0598 doc_len:277-278 +ttt_progress: batch 224/782 batch_loss:2.8305 batch_bpb:1.1119 running_loss:2.7141 running_bpb:1.0601 doc_len:269-270 +ttt_progress: batch 216/782 batch_loss:2.9350 batch_bpb:1.1169 running_loss:2.7149 running_bpb:1.0603 doc_len:261-262 +ttt_progress: batch 207/782 batch_loss:2.8403 batch_bpb:1.1173 running_loss:2.7154 running_bpb:1.0605 doc_len:253-254 +ttt_progress: batch 199/782 batch_loss:2.9506 batch_bpb:1.1309 running_loss:2.7162 running_bpb:1.0607 doc_len:246-247 +ttt_progress: batch 191/782 batch_loss:2.9518 batch_bpb:1.1527 running_loss:2.7171 running_bpb:1.0611 doc_len:238-239 +ttt_progress: batch 183/782 batch_loss:2.8727 batch_bpb:1.1467 running_loss:2.7176 running_bpb:1.0613 doc_len:231-232 +ttt_progress: batch 175/782 batch_loss:2.8501 batch_bpb:1.1172 running_loss:2.7180 running_bpb:1.0615 doc_len:225-225 +ttt_progress: batch 166/782 batch_loss:2.9669 batch_bpb:1.1438 running_loss:2.7188 running_bpb:1.0618 doc_len:217-218 +ttt_progress: batch 160/782 batch_loss:2.8741 batch_bpb:1.1296 running_loss:2.7193 running_bpb:1.0620 doc_len:212-212 +ttt_progress: batch 151/782 batch_loss:2.7970 batch_bpb:1.1023 running_loss:2.7195 running_bpb:1.0621 doc_len:204-205 +ttt_progress: batch 145/782 batch_loss:2.9028 batch_bpb:1.1389 running_loss:2.7200 running_bpb:1.0623 doc_len:200-200 +ttt_progress: batch 137/782 batch_loss:2.9398 batch_bpb:1.1846 running_loss:2.7206 running_bpb:1.0627 doc_len:193-194 +ttt_progress: batch 126/782 batch_loss:2.9348 batch_bpb:1.1924 running_loss:2.7212 running_bpb:1.0630 doc_len:185-185 +ttt_progress: batch 116/782 batch_loss:2.9994 batch_bpb:1.1861 running_loss:2.7219 running_bpb:1.0633 doc_len:177-178 +ttt_progress: batch 110/782 batch_loss:3.0235 batch_bpb:1.1739 running_loss:2.7226 running_bpb:1.0636 doc_len:173-173 +ttt_progress: batch 103/782 batch_loss:2.9007 batch_bpb:1.1227 running_loss:2.7230 running_bpb:1.0637 doc_len:168-168 +ttt_progress: batch 94/782 batch_loss:2.9854 batch_bpb:1.1774 running_loss:2.7236 running_bpb:1.0640 doc_len:160-161 +ttt_progress: batch 86/782 batch_loss:3.0413 batch_bpb:1.2658 running_loss:2.7243 running_bpb:1.0644 doc_len:154-155 +ttt_progress: batch 78/782 batch_loss:2.9122 batch_bpb:1.1299 running_loss:2.7247 running_bpb:1.0645 doc_len:148-149 +ttt_progress: batch 70/782 batch_loss:3.0649 batch_bpb:1.1647 running_loss:2.7254 running_bpb:1.0647 doc_len:142-143 +ttt_progress: batch 62/782 batch_loss:2.9988 batch_bpb:1.2130 running_loss:2.7259 running_bpb:1.0650 doc_len:136-137 +ttt_progress: batch 53/782 batch_loss:3.1212 batch_bpb:1.2305 running_loss:2.7266 running_bpb:1.0653 doc_len:129-130 +ttt_progress: batch 47/782 batch_loss:2.9503 batch_bpb:1.1790 running_loss:2.7270 running_bpb:1.0655 doc_len:124-125 +ttt_progress: batch 40/782 batch_loss:2.9955 batch_bpb:1.2049 running_loss:2.7274 running_bpb:1.0657 doc_len:119-119 +ttt_progress: batch 32/782 batch_loss:3.0262 batch_bpb:1.2097 running_loss:2.7279 running_bpb:1.0659 doc_len:112-113 +ttt_progress: batch 24/782 batch_loss:3.0660 batch_bpb:1.2130 running_loss:2.7284 running_bpb:1.0661 doc_len:105-106 +ttt_progress: batch 17/782 batch_loss:3.1504 batch_bpb:1.2488 running_loss:2.7290 running_bpb:1.0664 doc_len:98-99 +ttt_progress: batch 7/782 batch_loss:3.2082 batch_bpb:1.2310 running_loss:2.7295 running_bpb:1.0666 doc_len:84-86 +quantized_ttt_lora val_loss:2.77300534 val_bpb:1.07351678 eval_time:213423ms +total_eval_time:213.4s