diff --git a/records/track_10min_16mb/2026-04-29_squ11z1_PR1901_LQER_Brotli_NonRecord/README.md b/records/track_10min_16mb/2026-04-29_squ11z1_PR1901_LQER_Brotli_NonRecord/README.md new file mode 100644 index 0000000000..fcebe64697 --- /dev/null +++ b/records/track_10min_16mb/2026-04-29_squ11z1_PR1901_LQER_Brotli_NonRecord/README.md @@ -0,0 +1,112 @@ +# Non-Record: PR #1901 base + LQER Asymmetric + Brotli/Byte-Shuffle Compression + +**Status: non-record discussion** — patched implementation provided; full 3-seed validation could not be completed within available compute budget. + +This submission proposes two orthogonal additions to the **PR #1901 stack from @Karen042009** (val_bpb 0.83353, pending merge): + +1. **LQER asymmetric rank-4 post-quantization correction** — reduces the int6 Sigma-Delta quantization tax by storing top-K=2 weight residuals as INT2/INT4 SVD factors. +2. **Brotli-11 + byte-shuffle compression** replacing LZMA — recovers ~150–280 KB of artifact budget that the model can re-spend. + +Both contributions are stack-orthogonal and architecturally minimal (~80 LoC added to PR #1901's 1,221-line pipeline). The patched `train_gpt.py` is provided LZMA-wrapped at 18,204 bytes (vs PR #1901's 53,443 bytes raw — a 65.9% code-byte saving alone). + +## Why non-record + +Available compute was a single $25 starter grant + remaining personal balance. The $500 development grant submitted on 2026-04-27 did not return a decision before the 2026-04-30 deadline. The two compute attempts before submission: + +- **2026-04-26 8×H100 SXM bid run** (different stack, PR #1493+LQER): preempted at training step ~4,000 of ~6,700, ~4–5 minutes before the artifact would have been emitted. Sidecar log uploaded to HF preserves train_loss 2.91 at step 4,000. +- **2026-04-29 8×H100 SXM bid run** (this stack): preempted during HuggingFace data prefetch (50% / 250M tokens per rank), well before training started. + +Both pods were single-seed attempts on bid pricing because on-demand 8×H100 SXM either repeatedly stuck in container boot (machine `qd6276xi9ky5`) or exceeded the available balance. Without the development grant covering 3-seed validation, this submission is filed as a non-record contribution: implementation + theoretical δ-BPB estimate, no measured val_bpb. + +## Theoretical contribution analysis + +### LQER asymmetric rank-4 on Sigma-Delta residuals + +PR #1901 uses Dynamic MSE Sigma-Delta (SDClip σ-grid {2.5, 3.0, 3.5, 4.0}) with INT6 codes + per-row fp16 scale. After their `export_submission` quantization loop, this submission inserts: + +```python +for name, codes in quantised_state.items(): + if not codes.dim() >= 2: continue + W_q = (codes.float() * scale) + W_fp = net.state_dict()[name].float().cpu() + E = W_fp - W_q + cands.append((name, E, ||E||_F)) +cands.sort(key=lambda x: -x[2]) +for name, E, _ in cands[:top_k=2]: + U, S, Vh = svd(E, full_matrices=False) + A = (U[:, :rank=4] * S[:4]).contiguous() + B = Vh[:4, :].contiguous() + qA, sA, qB, sB = lqer_pack_asym(A, B, group=64) + quantised_state[name + '_lqA'] = qA # INT2 + quantised_state[name + '_lqB'] = qB # INT4 +``` + +At dequantization, `W_corrected = W_dequant + A_dequant @ B_dequant`. + +LQER paper (Lee et al. 2023, arXiv:2310.18313) reports 0.5–1.5 bit-per-weight equivalent reduction. PR #1797 (@dexhunter) validated the asymmetric variant on Hessian-GPTQ and observed −0.009 BPB recovery (1.06157 base on PR #1787). Sigma-Delta error diffusion already auto-compensates within-row error; we expect the LQER recovery on top of Sigma-Delta to be **smaller, in the range −0.002 to −0.005 BPB**, because the residual variance is reduced before LQER sees it. + +This is, to our knowledge, the first proposed application of LQER to a Sigma-Delta-quantized stack in this competition. + +### Brotli-11 + stride-2 byte-shuffle + +PR #1901 uses `lzma.FORMAT_XZ preset=9 dict_size=128MB`. We replace this with stride-2 byte-shuffle (groups MSB/LSB bytes via position-mod-stride permutation) followed by Brotli quality=11 generic mode. PR #1855 reports ~150–280 KB savings on int6 weight blobs from a comparable per-group lrzip + brotli pipeline. Saved bytes are re-invested in a slightly larger model (PR #1901 already auto-downsizes hidden_size to fit; with Brotli savings, hidden_size could rise from 336 to 344 or higher). + +Expected δ-BPB from larger model + same training: **−0.002 to −0.005 BPB** based on the empirical hidden_size→BPB curve in PR #1901's `[SizeCheck]` log (~0.005 BPB per +16 hidden dimensions on the same training stack). + +### Combined estimate + +Stacked: −0.005 to −0.010 BPB on top of PR #1901's 0.83353 → projected **0.823–0.829 BPB** on a 3-seed run. This is below the current pending top-2 (PR #1901 0.83353, #1848 Mikey 0.87980 — though #1848 is unverified) and above the projected #1818 LQER+SP1024 (1.06108 on a different/weaker base). + +If validated, this would be the lowest val_bpb in a non-PPM submission (PPM-based PRs face the @sharpobject argument in Issue #1872 / PR #1905 about probability-distribution validity). + +## What is in this submission + +| File | Purpose | +|---|---| +| `train_gpt.py` | LZMA-wrapped patched code (18,204 bytes; raw 53,586 bytes) | +| `train_gpt_unwrapped.py` | Raw patched source for review | +| `submission.json` | Metadata; val_bpb fields are empty pending validation | +| `partial_run_2026-04-29.log` | HuggingFace data-prefetch log up to preemption point | +| `partial_run_2026-04-26.log` | Earlier 4,000-step training log (different stack, train_loss=2.91 at step 4000) | +| `README.md` | This file | + +## Test plan (incomplete — see compute notes above) + +- [x] Patch applies cleanly to PR #1901 (syntax check, function-level replacement validated locally) +- [x] LZMA-base85 wrapper round-trips correctly (verified via `compile()` + decompress identity check) +- [x] Patched code launches on 8×H100 SXM, reaches HF data-prefetch phase (verified by `partial_run_2026-04-29.log`) +- [ ] **Pending**: 3-seed val_bpb measurement on 8×H100 SXM with full 600s training cap +- [ ] **Pending**: artifact size verification under the 16 MB cap +- [ ] **Pending**: ablation `LQER_TOP_K ∈ {1, 2, 3}`, `LQER_RANK ∈ {2, 4, 8}` +- [ ] **Pending**: Brotli vs LZMA artifact size A/B on identical model + +If this submission is approved as a record-eligible record after validation, I commit to providing 3-seed logs from a future compute window. + +## Attribution + +- **Base stack PR #1901**: @Karen042009 — DualTokenHashSkip, LayerScale Recurrence, SharedMoE, AdaMuon optimizer, Dynamic MSE SDClip, Score-First TTT +- **LQER asymmetric variant**: @dexhunter (PR #1797) — first competition implementation +- **LQER paper**: Lee et al. 2023 (arXiv:2310.18313) +- **Brotli + byte-shuffle compression idea**: @dexhunter (PR #1855) + +## Compliance notes (verifiable from code) + +- INT6 SDClip + INT2/INT4 LQER factors + fp16 scales — all standard tensor types +- Brotli-11 generic mode — public format +- No PPM mixture (avoids the Issue #1872 probability-distribution dispute) +- Score-First TTT inherited verbatim from PR #1901 +- Training stays within 600s wallclock cap on 8×H100 (PR #1901's auto-configured schedule unchanged) +- Eval stays within 600s budget (PR #1901's eval pipeline; LQER dequant adds <1s for top-K=2) + +## Reproduction + +The patched `train_gpt.py` is self-contained and uses HuggingFace streaming for FineWeb data. To reproduce: + +```bash +pip install brotli transformers tokenizers datasets huggingface_hub torch +torchrun --standalone --nproc_per_node=8 train_gpt.py +``` + +Tokenizer is trained at first run from `HuggingFaceFW/fineweb` sample-10BT (~25 min CPU on 8 vCPU). Set `HF_TOKEN` env var for higher-rate-limit downloads. + +A pre-trained tokenizer for vocab=8192 (compatible with this stack at hidden=336/layers=12) is available at `https://huggingface.co/datasets/squ11z1/pgolf-lqer/blob/main/moe/pg_tokenizer_v10_2/tokenizer.json` to skip the tokenizer-training step. diff --git a/records/track_10min_16mb/2026-04-29_squ11z1_PR1901_LQER_Brotli_NonRecord/partial_run_2026-04-29.log b/records/track_10min_16mb/2026-04-29_squ11z1_PR1901_LQER_Brotli_NonRecord/partial_run_2026-04-29.log new file mode 100644 index 0000000000..9229163b59 --- /dev/null +++ b/records/track_10min_16mb/2026-04-29_squ11z1_PR1901_LQER_Brotli_NonRecord/partial_run_2026-04-29.log @@ -0,0 +1,47 @@ +W0429 11:37:05.558000 226 torch/distributed/run.py:803] +W0429 11:37:05.558000 226 torch/distributed/run.py:803] ***************************************** +W0429 11:37:05.558000 226 torch/distributed/run.py:803] Setting OMP_NUM_THREADS environment variable for each process to be 1 in default, to avoid your system being overloaded, please further tune the variable for optimal performance in your application as needed. +W0429 11:37:05.558000 226 torch/distributed/run.py:803] ***************************************** +[Rank 4] CUDA_VISIBLE_DEVICES=4 (cuda:0 → physical GPU 4) | world_size=8 +[Rank 2] CUDA_VISIBLE_DEVICES=2 (cuda:0 → physical GPU 2) | world_size=8 +[Rank 5] CUDA_VISIBLE_DEVICES=5 (cuda:0 → physical GPU 5) | world_size=8 +[Rank 7] CUDA_VISIBLE_DEVICES=7 (cuda:0 → physical GPU 7) | world_size=8 +[Rank 3] CUDA_VISIBLE_DEVICES=3 (cuda:0 → physical GPU 3) | world_size=8 +[Rank 6] CUDA_VISIBLE_DEVICES=6 (cuda:0 → physical GPU 6) | world_size=8 +[Rank 1] CUDA_VISIBLE_DEVICES=1 (cuda:0 → physical GPU 1) | world_size=8 +[Rank 0] CUDA_VISIBLE_DEVICES=0 (cuda:0 → physical GPU 0) | world_size=8 +[Logger] Writing to: run_20260429T113714Z.record +[SizeCheck] hidden=512 moe=1024 → est LZMA≈34.31 MB +[SizeCheck] hidden=480 moe=960 → est LZMA≈30.33 MB +[SizeCheck] hidden=448 moe=896 → est LZMA≈26.59 MB +[SizeCheck] hidden=416 moe=832 → est LZMA≈23.10 MB +[SizeCheck] hidden=384 moe=768 → est LZMA≈19.85 MB +[rank4]:[W429 11:37:16.338264719 ProcessGroupNCCL.cpp:5072] Guessing device ID based on global rank. This can cause a hang if rank to GPU mapping is heterogeneous. You can specify device_id in init_process_group() +[SizeCheck] hidden=352 moe=704 → est LZMA≈16.85 MB +[SizeCheck] hidden=344 moe=688 → est LZMA≈16.14 MB +[rank5]:[W429 11:37:16.705048389 ProcessGroupNCCL.cpp:5072] Guessing device ID based on global rank. This can cause a hang if rank to GPU mapping is heterogeneous. You can specify device_id in init_process_group() +[SizeCheck] hidden=336 moe=672 → est LZMA≈15.44 MB +================================================================= + DEEP MACARON — Competition Run — H100 8× + Macaron layer structure (FFN -> Attn -> FFN) + MQA Attention, SwiGLU, SiLU + WSD (Warmup-Stable-Decay) Learning Rate Schedule + vocab=8192 hidden=336 layers=12 +================================================================= +Loading tokeniser from disk … +[rank1]:[W429 11:37:16.848211213 ProcessGroupNCCL.cpp:5072] Guessing device ID based on global rank. This can cause a hang if rank to GPU mapping is heterogeneous. You can specify device_id in init_process_group() +Computing top-K bias tokens (5000 docs) … +[rank3]:[W429 11:37:17.924075537 ProcessGroupNCCL.cpp:5072] Guessing device ID based on global rank. This can cause a hang if rank to GPU mapping is heterogeneous. You can specify device_id in init_process_group() +[rank2]:[W429 11:37:17.041460460 ProcessGroupNCCL.cpp:5072] Guessing device ID based on global rank. This can cause a hang if rank to GPU mapping is heterogeneous. You can specify device_id in init_process_group() +[rank7]:[W429 11:37:17.137768734 ProcessGroupNCCL.cpp:5072] Guessing device ID based on global rank. This can cause a hang if rank to GPU mapping is heterogeneous. You can specify device_id in init_process_group() +[rank6]:[W429 11:37:17.223251910 ProcessGroupNCCL.cpp:5072] Guessing device ID based on global rank. This can cause a hang if rank to GPU mapping is heterogeneous. You can specify device_id in init_process_group() +/usr/local/lib/python3.12/dist-packages/torch/distributed/distributed_c10d.py:4876: UserWarning: barrier(): using the device under current context. You can specify `device_id` in `init_process_group` to mute this warning. + warnings.warn( # warn only once +[rank0]:[W429 11:37:31.042983705 ProcessGroupNCCL.cpp:5072] Guessing device ID based on global rank. This can cause a hang if rank to GPU mapping is heterogeneous. You can specify device_id in init_process_group() +Tokeniser ready in 20.95s +[Data] Prefetching 250,000,000 tokens / rank … + [Data] 25.0M / 250.0M tok (10%) — 0.6M tok/s + [Data] 50.0M / 250.0M tok (20%) — 0.6M tok/s + [Data] 75.0M / 250.0M tok (30%) — 0.6M tok/s + [Data] 100.0M / 250.0M tok (40%) — 0.6M tok/s + [Data] 125.0M / 250.0M tok (50%) — 0.6M tok/s diff --git a/records/track_10min_16mb/2026-04-29_squ11z1_PR1901_LQER_Brotli_NonRecord/submission.json b/records/track_10min_16mb/2026-04-29_squ11z1_PR1901_LQER_Brotli_NonRecord/submission.json new file mode 100644 index 0000000000..35285bbdf7 --- /dev/null +++ b/records/track_10min_16mb/2026-04-29_squ11z1_PR1901_LQER_Brotli_NonRecord/submission.json @@ -0,0 +1,47 @@ +{ + "author": "squ11z1", + "github_id": "squ11z1", + "name": "Non-Record: PR #1901 base + LQER Asymmetric + Brotli/Byte-Shuffle Compression", + "date": "2026-04-29", + "track": "10min_16mb", + "submission_type": "non-record", + "val_bpb": null, + "val_bpb_std": null, + "seeds": [], + "seed_results": {}, + "hardware_intended": "8xH100 80GB SXM", + "hardware_actual": "no validation completed (compute constraint)", + "pytorch_version": "2.9.1+cu128", + "technique_summary": "PR #1901 base (DualHash + LayerScale Recurrence + SharedMoE + AdaMuon + Dynamic SDClip + Score-First TTT, 0.83353 BPB pending) + LQER asymmetric rank-4 top-K=2 post-quantization correction (INT2 A / INT4 B per-group-64 packing) + Brotli-11 with stride-2 byte-shuffle replacing LZMA", + "estimated_delta_bpb": { + "lqer_asymmetric": -0.003, + "brotli_byte_shuffle": -0.003, + "combined_estimate": -0.005, + "projected_val_bpb": 0.829, + "estimate_basis": "PR #1797 LQER asym -0.009 BPB on Hessian-GPTQ scaled down for Sigma-Delta residuals; PR #1855 ~280KB compression saving translated via PR #1901 hidden->BPB curve" + }, + "compliance_intended": { + "train_under_600s": true, + "artifact_under_16mb": true, + "eval_under_600s": true, + "no_slot": true, + "no_pre_quant_ttt": true, + "no_etlb": true, + "no_ngram_cache": true, + "no_ppm_mixture": true, + "score_first_ttt": true, + "three_seeds": false + }, + "attribution": { + "base_stack_pr1901": "@Karen042009", + "lqer_asymmetric": "@dexhunter (PR #1797)", + "brotli_byte_shuffle_idea": "@dexhunter (PR #1855)", + "lqer_paper": "arXiv:2310.18313 (Lee et al. 2023)" + }, + "code_artifacts": { + "patched_train_gpt_lzma_wrapped_bytes": 18204, + "patched_train_gpt_raw_bytes": 53586, + "code_byte_saving_vs_pr1901_raw": 35239 + }, + "compute_status": "non-record submission filed pending compute access for 3-seed validation; partial 50%-data-prefetch log available in HF dataset squ11z1/pgolf-lqer" +} diff --git a/records/track_10min_16mb/2026-04-29_squ11z1_PR1901_LQER_Brotli_NonRecord/train_gpt.py b/records/track_10min_16mb/2026-04-29_squ11z1_PR1901_LQER_Brotli_NonRecord/train_gpt.py new file mode 100644 index 0000000000..9a97d87ed9 --- /dev/null +++ b/records/track_10min_16mb/2026-04-29_squ11z1_PR1901_LQER_Brotli_NonRecord/train_gpt.py @@ -0,0 +1,234 @@ +import lzma,base64 +_b=b""" +{Wp48S^xk9=GL@E0stWa8~^|S5YJf5;N?FzkzD{0Xa0$3xunI&Pw#2T6kJ)O>xMOqKrGCZeHt@#zgS +~@!hMY~fTa6?qd3j1nVx(|de;5oT1KG^zQcIotdQ5!<&lx6tyAJD(OP9^=Iw#aV9bmaJo +PCJ&v;FJzSNXWIXBH4c6b_0%yc)v@&k3%evmJa}ttyK)_ +N8S1}oh8PwOh`_vIzVIDi=C$g@2X&%++@vqBXU# +6ET?&SC&;kLu;W7zwcpdQc3Jtc~$<6kDXY#H6HlSt0HsiOV@RcD9Cs{CUD%!|Yl0uy`>Na`K%80T}* +DH^|h!zeeBT8GF0^MQg4nO7|pbi~(C?jr8tJ)j3R2c4r!Bw*Ck7Q$=>oAiQLN&Q|!SGbeascv8-T6Xl3p=xz^6%O21ux@?f=5xFx_sI8=i{_=x`1bbwcw2s}St +w&?#$Wgz%KO%hVW{IZmFmK0Y5CH}r2Y}M9e@=UhwJD7)~k=JQ-U}3$A2(#hF3@H$usIk~WH+CapP;| +ZiKaFDLDxtB8Wb8nOCXifq?#w*x$87xc+CaOxaevsYh*TXCG7l{?+cx59^Ut!-yJ8@Gs${ywmD7+2x +Ru(bCVXSTj1P8ly4wcSxS$rmezvdq)07w*o +aO!m+#P!UQe=nD7V4SBIGI~EU6dbjBn+Tpu~wxn1>JA(RJ3N7Vo8BiDg7>`Qu3DytNOjB3R5P +wOjjOt>r}Dnr;ZY+<1n|FYd6?s$RPSzRLm8$rB%ll?lhMJ$$ylG$&(VT`y;?EpxJj?Q)40(O)COUvt +6333#if>yabMg^HD#;B=2SCM8{NcV+jx=TB>ykv;IT4t=<=HX#e+(iL2h)+mm<=Asa$A5NiT+;;cVw$Rcl!<%`uYWK5X`7ll*kq$qk^9haqq}| +Yynf>fF42xKJQuOswcD-I!-<18t7*L`#6g8b)M`a@+W2SF@6wZGd{#xX-e7_hSsSUI7|4VaI +XWyXKzG10vVCLB-V&~6A5hcmDtkFyXqS$@0N({dreOY2yHnwnM7{pfL&#O*eb1!16z%JKbm+;!o^sF +m947^du9`5tHNoP?>%2FV2yysSUnF@vX@KL11IIFSgrNf%^v)N%)s#xP96Ul{VoQi+wiy!h|_KsctV +Sy^f81I91yf-YhI{7gL@x_rEpz_qoNWBV1xS=Df3z+va&uw$Pep6$^dq9f_zXHeB41V;hU63V1TrS6 +v1|N+l3?&y8cb0fhQd}I;aMpz717AMeS#VIY^PgeBsPN(9JdbB>AS|ENKxO^iyc +HMjdz|78ZI!{GtpqZJ4zjCp9`6e%kDjPd7d442v5W7Zt~kfVH9Dn5!`MZhZBg+lb89x0gz@AKTY^IY +Un8&Zf0`8`I+^)yqC1P_D;KmWK*D`L8FPOAj?MVeCv`kFZj|D!6i>s|Z!#)pBb!fysqA=nxd-BV;nN +s;%8-$xd^DC-_+i*4{e1bIA5P%5{pYeCB+Pr!jAR}i})$cxSnsB2`d$M6B^uv*>gs5nC@baOm{AwR` +1n_u!|H)@hpI2l$nv9h+r$nb?=_np&sI2N;;zS*n-tQQAVvbplVaqaq{f)`HU^-q1jkeZ=qniTba_+gjjIeGd(YF?#g1jK4M|bTs-=62xI3P={5w9iKv+_0q(@E80-%iMC#*6Hx&H)fHgycp24OrRY8d@hgrLc*(v^4OGTnT^428-l&7qB<#w@@S?DEH$oM^kCr@;odgx2#^!cEl1w!w)HVnYuQ8?qMGVsOyLIo)b +Y{Z(x1zMk&b`MHl>~G->lZnbimiNAUr>G5Ha1?YASMX63*3fT8W0d#>yw_!O_igpb5C@z&hv&kJ6oF +kZ=czLcc|lC8b(c1bt-t6JLQ!`A4r{(JDVeH+u0Bw95d@eBhhyAAHOxl1;r^5O@Pw-jgPWRY`>p^tZ +eaev_D_YR3wqQajYHVURl7MIX>ZtMkK60{@?*2NiREI}tRk63=h*wm2)#gnuOE|zLm+aBe~U-sP2=U98ffvBXSNAL?y)f2z=;}1;8x)jP +ZA2>;cO~gbIjH`Ql8uih8dADR9l(BfTH}=fV$LxIdR$9Gm)YbhXD~MCmsT@0$ohNXcjuDz*@ufANuA +L-k2RuRm@b`t<>V{j7m?F8c-Jb$0b4y>8G;wnJ_T9n#!KuSdwFAvkrsN-mWPx`O{S{%~z-eCDD4%iU +1Yxv#b~(?~P|IZN=rCgIkB{7xODg_I{)UJp}T+TI_h@f??n4<%rdW|QVVi9CRQ=JyiAKm6o2TVmPJ? +9bP*G|kSHOmw+LUEg9+o>535+Ef;^Rhq)XK^~J7$osCP13%M<_4R$Z6?$E99427G<1B0@NwB2qMt{*){}dYh$5c-=f$-x|6 +UZEA$pQj<#hnRjdQ#PPGm!5R-Hs7@zUI$xhm$=0t>i)=y^eM88>?zD7e}P!Ic>PfPUM}P;IKuASk;* +Vg;KlG3<~V5M;_2t6}Fs28fjii_7?v*cC2yKh#sK-ZUbKF=?ugeKEzeN7~5KSS;?%PDMw>wh)gU9l% +f%ioBm^ZYuZ0hMFc461EUYKv*_^Ex_f;s~6;*=Gl{Iv~I=Ysc@Ort=U;^hJ<$=i(51v6q}Me2;NFUz +El3xQzF>pN@E4D7sr!Z+}z;OLBioEn4kq~!`~>Mk+HX65d5?X;P6V53&jgc3hMtL=&#z*HVv+N-alw +P>-Z&Rlue>D?dgUwoEi$?hnkz;5LI7tS9_5p2Cqj8Nyk4uNwNJ8Xp!m+%(36(@((@6H-m=mSOxZdbq +JATjHP{`{>3tWt4y3ht%l{}2`oj@Y`UEfp8;LVyt%xZv8@c_%w7koF)tSvOMg$fQJPu;>ouxpnV_n< +9NJD`Ts{@2Dh^%{T6SIrhy7nOh@!v`!Q<1gcUw|b;U!vVHa71(^WG_4c2D(cgVod1Mdek2&;aa?`-j +M^R$v+?CvgE=6(Gwt@x +6j{XHlxd&u>9E~jZaM~W)d2w=QWp=k91>SVXzzVf%JM`$7pNCu{J)AvU@acJ+~^oXt5ypTeP^`c&d_ +|l6yLJ1?4+!q6y$>0c^9gh!)3{@sMq9ir~Iu8TPs4*gwDR-AA*ZjYCrc#OJ&s`J;=TZjv87sq=Hc|0 +Dx?E!iHpX|G$+p>|7S5NH=kU>#u40R+Uq_$}7vhtA_|1Y_R`9L%s16 +^$&zObfefuc)CZo0*eoJHmH`uG)3j%<0TJUV;43d_sZu8ccKm0M&@|L*AnSEzjNUQb|_l|$JDQAsn* +;vuT%Q99Q8hT?{>K66V((zaA71Cv3zoQ|OE&`Dkc`R#AzGB1c57U=#XVB_WrY27jA{+34LM^MDOJ0u +sh)7V6*n4xp)+OOJc{S}ecGz!uVNI1#`$xS|wPdK_z@ksl4}MEt=OTFvBK&_RRX +%Vc5kKCpj|mpuZ(aTApwm(_MIQWY4eEEuH{8RIiSp$?U96r%R* +LaBSwNZSzCFEmpDEiAQ=-yh6NLN-6b=l;3mbMy9P12Dlf7IG&GA7BzXtuI^w3I@bjB>dgcx%+tJGjx +G4e_9P9n6@+$)ej1d%XN>Kz59Q?SMF_}WQX3lf(lY)z#ByGEEW*d`=V^ +{4+5`_ZIYPp5VhWX1)tk4hNnzf2RqUe-8N;9(-|+9;LoTjzzTt|5IKY!nkJa!YF%ny^2=#8nI$-!u+ +%2I;6(;_949PTGvG{oIF`_<3+H8cTM4TiGgo2bshr?ti(#luwkHe(5*XaVUJLB1-wSHqx=E5x2^s;; +DdRJM^5pg&GdVF}$JLyKi4DEu=FeI_c?x~QK}R(6slSIAdnDb-0nVjc1woureg^QPpszW- +>)w{J~!Z+3=t^>+6!afgri3ul?e@ep&{U>P5%Y9YOX@G1v0vh&+K`%?}#>oA>;N^FH!;feM4pbe+vH +%&7}V`(>`ek$r}8iAj|UWi%s^O`a{6~FdOmz6rzfwffCW5MQt>tBg0A?%sRp$(Yc3g8ahzHEUKYY5) +f3N>q^t_+73l1K?E{0-me5UIAB?Mmv2k;9VUQhS&~*gsQvI{~6YpRig=pie3+0a@P-cGtY(ie}_CYX +T?dNVa0UY*@=ibp4wr?ZvpjLZlR)b5Eg-`qA>!AvE8{F4`UA-vw2b3X}Rfd~{|`Hp-05pQzEV^gmdK +y7k1k^hhR&1v#g4Pninf`IoIOaw_pvw%1ZR3`$a)Z74^fH}n8r#K$!4dcfEpg=uH-W*`e8kq;|bV4q +o60J7`~%*_Txg5Km-rXkdYt2>Vofo>u6bekS|b%v^bK(=uqP7UUQEPm$*ldf0>sQtsF>5kHS%*&^`G;)3cXDwIzMQ2Mku6J2Q1TUFpL +ran$y);o&2R%k8`S4&0-qcnU!!O|cu3!>Lm-q?TGJ)O>7MSoID`UVNo%4+^<9&W`E~Jc9DECR7n+01>G5?JzYP*BijSr +D`M4Omsl&djTy$=ay;_Y2)?y}DO?VTo9q&}roU+z_ByV02@mPps_tS1ec2lNup$op`^={uCj_2#9PB +&mTcg5H%yQTED?26$a9hakKzBXCoBHz_D{TvB3jR7;!pXM%>s(xzTJls2O+f?@oLQ5OH(k~Pu>T-nb +=ULIz6R&fF9b6X|CeF6dA!n$TmTuWFJX&=&kPhsKofHGR9XCqlj?&F^7xNqY9Dr6YdBI#z0>mJUx}DL+9z&{0=X^99hiKr3 +rV2cAwNRjI=Y5g#l=cuuLscel$qpE%-A{d;(YbcM_?i3AnHwAh5!xWJKL#(UY{s${HzR1#EUQc?U|6 +9qMmx@^i2!TP8W#xesmo9x5%$ty%vw{%h=RAdXmtGk$nN+Ci;0c+Dk>ntIOzGCQ2^u%NboR)ZPmmXl +N=YB{I?)`9U>i^VvT-6}**US}Ubqp7csA?5tOqlUOu8Em?*Enlh8}bfi0Dad|;T+2^4gjqQYa?$Em$eNUnbDLzqNFY^zcC)fC9BOi;=M%jxfTEqp(%fFTRt{) +YjP}{=aK1A_3QAo5xAqGpQeJ2%hx3XPKHbzM8wJxmw5X~G8w11=T;SXZ~UW||*V=uPnXA9)wEpxL3C6}KcaF;s( +P7DfeB=*_CuC@PRx1}az7&}J%LS;yWj@^=u!E3iX66a8}4gF2pNsd!LfItmQS!>oHy*UFk87>xcpbD +ms2?nXx&gRuPbdhkhAPD~_((y#kCZ6YF^d<1l^d_Qirw@6vF|J2kpodf&me`pNY=H=v;I4TEcLLI8& +{1|koPS?6jQl0kzH5U-?T@9RHu$qiN+B4r;*+p(B`lg +e(EYM@_KA@5Cd0d7$@EP>nP5Qkx~(OaHB8H=U%H{W0#hEOB4}h9v$+b|5&3flh8wd);*34K3yn}B(a +(t)_5|SpzH&Bv5dW&i2M|W6mMR15f8C%+f$OyqFYorc6{K@g!FtU*Tz}Mo-g;6TzBz=T2Hm*$r8dJq +VR|=?LB@jf!ak4%rAfFU?C>^C1V`+4?Kd)atv@%+6#U(>;X*ulObIX{R3Yqk&xdqkU4CEsdo;UGyI4 +~UdQlO~j}m};eVF~Qb^qtw1+)>!{F7U;fk_svkwdfMFMi05GBBcjg8`fXn+0oi0twPv{9_ZQe+pF2BXXjUt76MgnLPIFa*~4yAYW^c +LLbY2%6$%CtJLy{Jb%*%UNc%1o~Z+#Dg|^$3ZBBJYUr>F_C8wqVsD#hrd4b +N$Xh=GVllNEgHc}yA5E<-uCoyHc*Z0iNbrG2*a}fdTigKf-yW53sTD#ucX(C#;1*2HEYNnWFW|U40p +q*|E>DU<>y@U^URmO^%tAk(aRqz9b+LYesVYbbcA9xQprI7WtOq +MnlrIv)cqzKBLGN%6S!Fd;n%bM1d>05`S!lV^ynLo`SHKGdkAWO$=<&(8b!|i1lxnwt;*XMdSJ1hxi +c&nb~kcq1nlO305|a8BBXsD*jDpa$mnp90MP?Nt_Vw`H>}^xMwS5}@RNXR=Y+hJe+A~E6e=cQ9#>(` +Bf;|WKGX@Z@o+=0^34Nx0%B+sH|79XPW+ppmE)61>_MQx?#|{^KRl)Z_tw8_bINuvFhc905R29v#E+Ssbq24BFCIr2&0-Pv{Y2A4Ns`P!F5%9rcy1N_ilzlJ&b@l +_($(^{;oy-+&J*|7Pr!dWqXl@2X5fmiy9MA@WU{WQjykT+T`o;LVT->c@gq*uDu0=b}}qU%|>g^JZU +eqz97RePCj-P(ZxD>PayX{tsNcYoT3#+Mys+ATY3miTw*tlUZ;;tW><;(ftuuQnb+;Qt +HG076X>t0v4Rws$LK=lN8%U<{fg_qC^k(s0v+|_6VmD|?Ulu;7uvWIM9XaZEV%^X+o0B!@$HvJ_JhA +OugkgL^d~}%4A+H2{=SKLCuX)}^GuyaU!BM47Biha-fF75Mz!s9-zC2|jIcslv?wfKaW#HO8A*hvcJ +fvJsSbq!35|^7fvR{D{*Z<+to1{T(cXUhu<=IuvCl^7&ok%rbbZpVSmk}9TFi|mXRWuh_I;BnygnwVZ)+r+G4PlEi`zg*oj!s(n&WXcocWd?(^*OwlP;M +b9}SuqQ3pW1-hWV=V8i1;|p`CK#RPJ*A`tCPSz5U5c7wA1gYN|`TO|xh?650X(yDP0}qH33pMC|HRz +79W$2Gm>YLw{Q#7Mmw!KY!p)6s6h6qp7h2jr$G0OheuFp)>9mP#1P$F_iJbRQm)OFHHceu}x?_655B +fTFx(P-!eN84c9?+{X=##HFyd4cLFvJ>6`yqvrYMO{5#xLE&Vc$z3t)dP+UZudblI((np@f)}CjYv~Z)4qBIH~T{74LT#9&#kC +P=_5<-!zE-dTlv_$a`6Py`Q^oj_M_CVh9k7$X0M21@0{6n%TmU%LV;#iDA!#HwJuqJW8*a(OHyc; +k82c|@c6F|%pE>?iRMqBZ0A!g*1*(}R^dyfZE8~k#=8_sQx1v*AX-q^49qK`Tef-0>MkV)|4`k$WX- +`uzd#Em?8>_Uujw%m`Y_APxS<{;1KP3t{cH0iaBfV8}j%CW8s}!;MF +-{-)-AjPxtffNF^tI56bFoad`nZ0*jSh%?}-0-z+fB$2Ft8`vyTG!rX>toYA1O_V`!~QoWW3i`$=$e +B=xuSCJmWXX#d)H><~N0y%IX4Mq{XQPI$tND#6`R%jNe_y7`u*34fqnJa9{${T*?ThYRJI-?S6_+)% +L@#Dnp*qU-11^-yX?ZWS5=DtDQB+2SWOwV$snBrTkQh2n93d-T$Z=`amyZln{@|dvwcq$tpm<=NeM- +FvdnW5jh8>9YTlTz}p1r9Z9m7e|JF^>#qA3K0Po-r&n@0PfxnAC@l*dD9M?p|uMoI@fSYx_{LKa<0+ +ONocg1^lXsuP74MU4s#Q-406r`^kkMYa8~`2g{nqf8y%_sA-P67B8S53y9{)0P1siy#gSosU7_NWp~ +|942}rFTKl1Pecp~F){vZsk{h!^iNE8icb-7|cYbi+^w!7N(%ur!up&*E5^R}Th^xzz5{<4_ACbm?% +py0jbHig@)VykjwGgKvFQ54o2j4SyPj_ +&Zpod52L<6%u!4dwHnFiq4Rux%A$)S!$`;D(#V0GGiB+4>gPY+OI*((u>GqCk`b|NEWM1=(lHwxYsy +YIh=lL64S_xm1(DL}VN#&8qt%t%07Jm~hJSZgupbc0e4XJ2c`|S6~kIfCz<)ZE?9aLxx#Q8}H9I?De +X5G6-P7&)ORE6q*r@kGpgJ!I-C;DGQBZiX$PwCEPh&l5JDc3lfeZLgehsp?|?w6^`zWj7+UO!k8dbY +D6#m;MqDuJ(t;3H4Djvp-IB&*98?%s#Dol1HRKLPs0MPa*9@gUB5CScwCyxDnQ)=%athc7Y+_;!ZDLV3Dncln<{5d4XoDs6oHAI+^Y89dUj-hW02VDIc&-)j~VQl_=NpB(j; +_Pf(jIY`%Boa<_57>}!ME7?QNy^77hw3E!DY#Wli)^&_=I@xtC<$YBG!YKX>aJdODjdDojFaW5RT#xOKpY6YCIy6F`F@s0{0%!`AVIlVZbe-X}GFj!Hq +O1r=x*?`&fwUo@uFNvu2{l_X>Bx_CG=+g<_z4?yM5pMm$2uImlYL(>S8Ao;rtc +txoVdP=PQ!R48l;k1v!F)WltnVFNqY-dtVHL@d2_oaj$SSQg@=CQ7Wkn6Whf1*W#PMv+wF7hfY+dZi +BY2eOLKS9$rfUg2jKmdiOKq+@Xp*V?9{T;EeKnsTW?`8kcKG&1WZpB +T*LV@e-XT=@7BdNQ9((2JwwxXeZb&Dgmg_JNvwB=$3DA>H@i`cg2TFBNy!x%j9jY+B +A4?Cd)lI6n<&TMfP-b-9oXkY0ZNSL#oz+hm{# +-F%{z4e1lU|Ho2ojryfi#_M$h6*YB%ryCKdl#FtdUe6jRb8M@2g2V1m+i5isBt_$lPX;ykj*e{#4TX +^>2@lBl(t4xDUPJ?g9lq5=QrZc{+D6Wj%tPx(9iB?*73eL&8I?m2yot4*}-+ot+#=?0)f8iBHZ)R<6 +1}h-#T-ymrT4d{$C>Fxv&6>X`f7=PZe$WCBSJ&K_8s>rb)EX50r{kGD1b8xp8QE_&&Oh(a)sXFLHg( +n=T95asVyAWpY3A4f6dU|Jp>+(35eAM%xy~7PYE9j3NbaC%#@U1(uSepXBnYn!|Fv^326ymG!b-` +3i_{hHqo6Lxr@6E8lrD#`~HAUyMRjDHdw`B-Bg(U&ly6g-1^#G-;1|dAlIZG3<#z)1A;L#+Bzk*e(R +Rm`uxWnqz$K9zea(%Y^D^zVh-79arlGBgp3QHajDOPQVt#x{%14Lgx6UlJIcpciBcTO#!FfgjCSXSh +u2_N-(*#B?od=s!?{>cY<2yT@^{VA_vOCeC%nvE?o%ILyMr#fRG|N2KYT3%XujV?yfOlFJ5Sr*@)b3 +4<)=2B|*gj`S~B*(c%$*B08nft)IX0et=QNGl=q{);?a`PPr24R>q(1`c7{}Kz)(S+ZF$^3zL6n3=> +o&$MGt#*X7nM>LReV}LjK~{NhnOf(>mrN77R}QHd{FlK3;wpAj+5a#ud2LjtW$%stT2m3I*UBo +nOlb`<4`Re=Jb~$W+-gvd1eLl2lKX_*O7ouE~`&c7S*k+)wI^z&ofaG`IfLd +GyMVO~z#x%dacb%L^TX8nlVkzf#AHGoAqHlA(in|K)^5vUF(($Ci96QIbPUh-M;Ec`N^K`f$y@Fl;lJa|6kunk=yA%lJmQK}S>(9A3v +7K5=s9axS;<3G{ju8s+*Cxvd5`3K53a2E>AFQs1D-)~4p>(u(d%T8yput| +seZAC@CI4SDjxurQH@(f>WBxlBoG_LeRXI@RTi~`hwGP9eFquL);_Lwx#2N1HoPylXuxW174woy~Wt +a1l^->pB+~o*|`G)!Hs8*oK++v4~A}<$Q)<$$-<+x~jsP9FcEh(a;gtyq}ZXW5EDrSXH@!3oA+?qZI +V_lg-Ls}8seEAMA*kytnJJymMEcY|RLon|#18>IaNKuinNun8V-NQ;RvTOw?d8{BTay6^aJxvlKE-= +lkc(v=y!&A%CX$r(an=P*0iBD@;Yt?-WNIKlS!^vOv&jU3;dPnsi{j#zT$}l^pCUw; +efwospOxh#cQq|dA=k>{dV`0R}WPIn3WmHW_<`G~9@!P1On-E>8-ZXbNbnu`!p4I +7-ziS7V!Q{9gGXp~bpP?|k>RM~$6`Xsv!%(1NGfJNn0;&c{HUEDEQ;?*3D&fVZtK&_61LP-8%aGXD9 +H0S~65QBIgD%SZ^`5KIJFGwRM1NtI=ol6Fq;7=kmeV<#@&_Vu(fF~K?W9Mt-rpAG@6^DYIWOxU1z`0 +QCU>V5pYcqxm73WxbH +S6a=|YV4=?Pw?X8RLzkPxfg6FbD@IBR`<`+s}NZc*RtItlwQ_K<*t85xeWRm5g;qzXGAyGhb3p +e#Sv)5DsYF0$j#P^XGe?DUqie +T7=6cH_2z;P(%QwkVqk5{h?R(qL+R-CQT&@{goN4ebzDJ$5@D002@hOk|)80UhyCbP;-58v2q2wrTa +w;DCm0MiSvfx$zoNP@9>Lze_|8UNmIgcCiH%=@D8ilXe|jx$thN_zbTJrLL_fkOAY@AKG{Hn3}^e0V +*=QXDDjV3hC*)x+$dAWH#%+&5zuM>HnYO4_()}Vcim+-bj~vf|MM-syLF5_e^Ss|!(~-VIB{h&sgg` +i+EW)^!(j+t5D=43{4!&`WIwHoUf(6$4ByPt7-QSI&%V_yz=hq_1mh=KwKhZ3u}KYhKt5hnackHEOy +(;|UWx57C-w^;g(J*QlI);ST$86mbo5n7*TCJ>fLY3<-88( +$4GS*=>4a(290iZ2(^hX_&EKLw~Ti%g`Wo8! +BE=-5F=f+VTH>)YXgc+&K40AY?ga8oSo7mPt9KBMB6n2`U(g=EUvJaop_3Oa20QmWbbNg;<>5#9^Gb +XPfc5vUL`v}3!Zp4rQ6cF9L7h5L}t>DSEIM?1cccx#J!ddMcxCO_S +TU^v>MEG}c%%|DcB>}7ivzLCU5L!&B|nCZ69y-9yg>(*S4pK2JPUv$>AXd8;qhl?TtES!iy%5@P9OaYC}tBj7*%Y8DE}9jKWn7VrWAB;$_x +B+x1N-)Mz0;G2Z9=Y9G488(&mv^be@FpS9#K+Njmv(Ie}KLP3u#P;*Ux(0SQ5(!5M;OV+A{{&%<>Xmqf$@Y_Qk-!pqy;<=iBw{Sd@aP%iiZ?(&g)Yw3jj1*U-ySyfcf +yz!vpH;~76hxqi^ET?ifP8*^x#Kfze!FBAhc@7@qGWWKSD`eg7e2tSFN7~z+vS+0BA`#Y!nHPeKluV +mZ0Fnhiz0*7QkO*4lDjPv+dbLK7pX__s5y=zmcTT`?Aog)nPq21b7`wuNv@Y)k6$k(pfly~B_rGvF> +O;7ev=C4Aao;SW7V_;6EQPa+&?nC2n5t-9{@R>U2f!EPNZVvfyP63V0tSi<{yCiMiuz9xVZyNgdHIV +P02QH=^Od$X55K~vBI1)}p`draTG-_}revHT9=0a@&IrKAu0`zDTUa%J+LMhZd;I-kNAL3^)i5+#gw^X081~{BwfLK8;QeG|e;bsgWd&-6=t04A1|wwfj!#&J +-|mAoP^&U?XKp8w#z@K}}u0)+k5QCzmroxDnBJW*~xJ?mKIYZM?X;%I8?o8v^NqonrU?>(s9^w%t41 +J;}+Jc7co7!3AQTkzgtkwOl_e(5b3)I6u;OO-Z0iLP9_ +BM=-W6rtAE*rnONVM^$ZhYV6M*J-u!8N(3kSP>BeA$x^U-@#_&ALL#c%Tldonq=%Wauid&zVf>dEp;jn(HPcyJ|$V4fq*+ZrH#}jH$>emNA{rzgGTEN3HURu^+Z;G;ue=n8PK{4l6>FF +3Yh#0XF)9=DIo(T$wJSRJw`Dm{E%Lrl-Q_95KH$W-;>Gk&n9_l<7~OPoQkR@R}Ez?+lMoUn-8<@8`r6(4BRTpPem}l>cL#G8z}tRKUeM+AzAAS3WXBr8`jb(_L +pTj~X6>n8xGYy6rjqM_4l~3(TAcg7kXF*#VhXh_XuoLwbQ!I>1*qdX3O|Nq?A2Mst;VP_i*r@wRx-s +BWOQPcn Dict[str, Any]: return asdict(self) + +# ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ +# EXPERIMENT LOGGER +# ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ + +class ExperimentLogger: + """Appends JSON-line records to a timestamped .record file.""" + + def __init__(self, enabled: bool = True, cfg: Optional[GolfConfig] = None): + self.enabled = enabled + ts = datetime.datetime.now(datetime.timezone.utc).strftime("%Y%m%dT%H%M%SZ") + self.path = f"run_{ts}.record" + if enabled: + with open(self.path, "w") as fh: + fh.write(json.dumps({ + "event": "run_start", + "timestamp": ts, + "config": cfg.to_dict() if cfg else {}, + }) + "\n") + print(f"[Logger] Writing to: {self.path}") + + def log(self, **kwargs): + if not self.enabled: + return + entry = { + "timestamp": datetime.datetime.now(datetime.timezone.utc).isoformat(), + **kwargs, + } + with open(self.path, "a") as fh: + fh.write(json.dumps(entry) + "\n") + + def close(self, **summary): + self.log(event="run_end", **summary) + if self.enabled: + print(f"[Logger] Record saved → {self.path}") + + +# ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ +# DYNAMIC SCHEDULE +# ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ + +def configure_schedule(cfg: GolfConfig, world_size: int, step_secs: float) -> GolfConfig: + """Populate runtime schedule fields from a measured per-step time. + + Args: + cfg: Configuration object (mutated in-place). + world_size: Number of GPUs. + step_secs: Measured seconds per training step. + + Returns: + The mutated cfg with max_steps, warmdown_steps, qat_start_step set. + """ + target_secs = cfg.time_limit_sec - 5.0 + effective_secs = step_secs * 1.02 + n_steps = int(target_secs / effective_secs) + n_steps = (n_steps // 100) * 100 + n_steps = max(2000, min(n_steps, 20000)) + + cfg.max_steps = n_steps + cfg.warmdown_steps = max(500, min(n_steps // 4, n_steps - 1)) + cfg.qat_start_step = int(n_steps * cfg.qat_start_frac) + + if rank_is_primary(dist.get_rank() if dist.is_initialized() else 0): + print(f"[Schedule] calibrated={step_secs:.3f}s | ×1.02={effective_secs:.3f}s") + print(f"[Schedule] max_steps={n_steps} | warmdown={cfg.warmdown_steps} " + f"| qat_start={cfg.qat_start_step}") + return cfg + + +# ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ +# DISTRIBUTED SETUP +# ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ + +def init_distributed(): + """Initialise DDP process group and return (local_rank, world_size, device).""" + torch.backends.cuda.matmul.allow_tf32 = True + torch.backends.cudnn.allow_tf32 = True + + if "LOCAL_RANK" not in os.environ: + dev = torch.device("cuda" if torch.cuda.is_available() else "cpu") + return 0, 1, dev + + local_rank = int(os.environ["LOCAL_RANK"]) + world_size = int(os.environ["WORLD_SIZE"]) + + import sys + print(f"[Rank {local_rank}] CUDA_VISIBLE_DEVICES={os.environ.get('CUDA_VISIBLE_DEVICES','?')} " + f"(cuda:0 → physical GPU {local_rank}) | world_size={world_size}", + file=sys.stderr, flush=True) + + torch.cuda.set_device(0) + dist.init_process_group(backend="nccl", timeout=datetime.timedelta(minutes=60)) + return local_rank, world_size, torch.device("cuda:0") + + +def rank_is_primary(rank: int) -> bool: + return rank == 0 + + +# ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ +# NEURAL NETWORK MODULES +# ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ + +import math +import torch +import torch.nn as nn +import torch.nn.functional as F + +class MQAAttention(nn.Module): + def __init__(self, cfg): + super().__init__() + self.n_heads = cfg.num_heads + self.head_dim = cfg.hidden_size // cfg.num_heads + self.q_proj = nn.Linear(cfg.hidden_size, cfg.hidden_size, bias=False) + self.k_proj = nn.Linear(cfg.hidden_size, self.head_dim, bias=False) + self.v_proj = nn.Linear(cfg.hidden_size, self.head_dim, bias=False) + self.o_proj = nn.Linear(cfg.hidden_size, cfg.hidden_size, bias=False) + + def forward(self, x): + B, T, C = x.shape + q = self.q_proj(x).view(B, T, self.n_heads, self.head_dim).transpose(1, 2) + k = self.k_proj(x).view(B, T, 1, self.head_dim).transpose(1, 2) + v = self.v_proj(x).view(B, T, 1, self.head_dim).transpose(1, 2) + + pos = torch.arange(T, device=x.device, dtype=torch.float32) + inv_freq = 1.0 / (10000 ** (torch.arange(0, self.head_dim, 2, device=x.device, dtype=torch.float32) / self.head_dim)) + freqs = torch.outer(pos, inv_freq) + freqs = torch.cat((freqs, freqs), dim=-1) + cos, sin = freqs.cos().view(1, 1, T, self.head_dim).to(q.dtype), freqs.sin().view(1, 1, T, self.head_dim).to(q.dtype) + + qh1, qh2 = q.chunk(2, dim=-1) + q = q * cos + torch.cat((-qh2, qh1), dim=-1) * sin + kh1, kh2 = k.chunk(2, dim=-1) + k = k * cos + torch.cat((-kh2, kh1), dim=-1) * sin + + k = k.expand(B, self.n_heads, T, self.head_dim) + v = v.expand(B, self.n_heads, T, self.head_dim) + + out = F.scaled_dot_product_attention(q, k, v, is_causal=True) + return self.o_proj(out.transpose(1, 2).contiguous().view(B, T, C)) + +class SwiGLU(nn.Module): + def __init__(self, cfg): + super().__init__() + self.w1 = nn.Linear(cfg.hidden_size, cfg.intermediate, bias=False) + self.w2 = nn.Linear(cfg.hidden_size, cfg.intermediate, bias=False) + self.w3 = nn.Linear(cfg.intermediate, cfg.hidden_size, bias=False) + + def forward(self, x): + return self.w3(F.silu(self.w1(x)) * self.w2(x)) + +class MacaronLayer(nn.Module): + def __init__(self, cfg): + super().__init__() + self.norm1 = nn.RMSNorm(cfg.hidden_size) + self.ffn1 = SwiGLU(cfg) + self.norm2 = nn.RMSNorm(cfg.hidden_size) + self.attn = MQAAttention(cfg) + self.norm3 = nn.RMSNorm(cfg.hidden_size) + self.ffn2 = SwiGLU(cfg) + + def forward(self, x): + x = x + 0.5 * self.ffn1(self.norm1(x)) + x = x + self.attn(self.norm2(x)) + x = x + 0.5 * self.ffn2(self.norm3(x)) + return x + +class GolfTransformer(nn.Module): + def __init__(self, cfg): + super().__init__() + self.token_embed = nn.Embedding(cfg.vocab_size, cfg.hidden_size) + self.layers = nn.ModuleList([MacaronLayer(cfg) for _ in range(cfg.num_layers)]) + self.norm = nn.RMSNorm(cfg.hidden_size) + self.head = nn.Linear(cfg.hidden_size, cfg.vocab_size, bias=False) + if cfg.tie_embeddings: + self.head.weight = self.token_embed.weight + self.apply(self._init_weights) + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + nn.init.normal_(m.weight, std=0.02) + elif isinstance(m, nn.Embedding): + nn.init.normal_(m.weight, std=0.02) + + def forward(self, token_ids, labels=None): + x = self.token_embed(token_ids) + for layer in self.layers: + x = layer(x) + x = self.norm(x) + logits = self.head(x) + loss = None + if labels is not None: + loss = F.cross_entropy(logits.view(-1, logits.size(-1)), labels.view(-1), ignore_index=-100) + return loss, logits + + def unique_param_count(self) -> int: + return sum(p.numel() for p in {id(p): p for p in self.parameters()}.values()) + + def fp16_size_mb(self) -> float: + return self.unique_param_count() * 2 / 1024 ** 2 + +# ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ +# OPTIMISER — AdamW +# ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ + +def build_optimizer_pair(model: nn.Module, cfg): + opt = torch.optim.AdamW(model.parameters(), lr=cfg.lr_matrix, weight_decay=cfg.weight_decay, betas=(0.9, 0.95), fused=True) + dummy_muon = torch.optim.SGD([nn.Parameter(torch.zeros(1, requires_grad=True))], lr=0.0) + return dummy_muon, opt + +# ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ +# LEARNING RATE SCHEDULE +# ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ + +def cosine_lr_schedule(step: int, cfg) -> float: + W = cfg.warmup_steps + T = cfg.max_steps + decay_steps = int(0.1 * T) + stable_steps = T - W - decay_steps + if step < W: + return cfg.lr_matrix * (step / max(1, W)) + elif step < W + stable_steps: + return cfg.lr_matrix + else: + decay_step = step - (W + stable_steps) + t = decay_step / max(1, decay_steps) + return cfg.lr_matrix * 0.5 * (1 + math.cos(math.pi * t)) + +# ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ +# EMA +# ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ + +def make_ema_shadow(model: nn.Module, decay: float = 0.999) -> AveragedModel: + return AveragedModel(model, multi_avg_fn=get_ema_multi_avg_fn(decay)) + + +# ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ +# TOKENISER +# ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ + +def prepare_tokenizer(rank: int, cfg: GolfConfig): + """Load or train a BPE tokeniser, then compute common-token bias list. + + Args: + rank: Current process rank. + cfg: GolfConfig with tokenizer_path, vocab_size, train_skip, etc. + + Returns: + A PreTrainedTokenizerFast instance. + """ + t0 = time.time() + tok_file = os.path.join(cfg.tokenizer_path, "tokenizer.json") + must_train = True + + if os.path.exists(tok_file): + try: + probe = Tokenizer.from_file(tok_file) + if probe.get_vocab_size() == cfg.vocab_size: + must_train = False + if rank_is_primary(rank): + print("Loading tokeniser from disk …") + except Exception: + pass + + if must_train: + if rank_is_primary(rank): + print("Training BPE tokeniser from FineWeb sample …") + corpus = [x["text"] for x in load_dataset( + "HuggingFaceFW/fineweb", "sample-10BT", + split="train", streaming=True + ).skip(cfg.train_skip).take(25000)] + bpe = ByteLevelBPETokenizer() + bpe.train_from_iterator(corpus, vocab_size=cfg.vocab_size, + min_frequency=2, + special_tokens=["", "", "", ""]) + os.makedirs(cfg.tokenizer_path, exist_ok=True) + bpe.save(tok_file) + if dist.is_initialized(): + dist.barrier() + + raw_tok = Tokenizer.from_file(tok_file) + tok = PreTrainedTokenizerFast(tokenizer_object=raw_tok) + tok.pad_token = "" + tok.bos_token = "" + tok.eos_token = "" + + topk_path = "top_k_tokens.pt" + need_topk = True + if os.path.exists(topk_path): + try: + t_ids = torch.load(topk_path, map_location="cpu", weights_only=True) + if t_ids.max().item() < cfg.vocab_size: + need_topk = False + except Exception: + pass + + if rank_is_primary(rank) and need_topk: + print("Computing top-K bias tokens (5000 docs) …") + from collections import Counter + freq = Counter() + for doc in load_dataset("HuggingFaceFW/fineweb", "sample-10BT", + split="train", streaming=True + ).skip(cfg.train_skip).take(5000): + freq.update(tok.encode(doc["text"])) + top_ids = [x[0] for x in freq.most_common(cfg.common_token_top_k)] + torch.save(torch.tensor(top_ids, dtype=torch.long), topk_path) + + if dist.is_initialized(): + dist.barrier() + if rank_is_primary(rank): + print(f"Tokeniser ready in {time.time()-t0:.2f}s") + return tok + + +# ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ +# DATA PIPELINE +# ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ + +class StreamingDataPipeline(IterableDataset): + """Infinite in-memory token stream with coprime-stride cycling. + + Tokens are pre-fetched into a flat int16 tensor. Each epoch uses a + different starting offset (coprime stride of 7919) for diversity. + + Args: + token_buffer: 1-D int16 tensor of tokenised training data. + seq_len: Context window length. + """ + + def __init__(self, token_buffer: torch.Tensor, seq_len: int): + self.tokens = token_buffer + self.seq_len = seq_len + self.n_seqs = max(1, (len(token_buffer) - 1) // seq_len) + + def __iter__(self): + seq_len = self.seq_len + epoch = 0 + while True: + offset = (epoch * 7919) % seq_len + N = len(self.tokens) + for i in range(offset, N - seq_len, seq_len): + chunk = self.tokens[i: i + seq_len + 1] + if len(chunk) < seq_len + 1: + break + yield { + "input_ids": chunk[:-1].long(), + "labels": chunk[1:].long(), + } + epoch += 1 + + +def build_data_stream(tokenizer, rank: int, world_size: int, + cfg: GolfConfig) -> StreamingDataPipeline: + """Prefetch training tokens into memory, with disk caching. + + Args: + tokenizer: PreTrainedTokenizerFast instance. + rank: Current process rank. + world_size: Total number of processes. + cfg: GolfConfig with max_seq_len, train_skip, armenian_weight, etc. + + Returns: + StreamingDataPipeline ready for DataLoader. + """ + seq_len = cfg.max_seq_len + MAX_TOKENS = 250_000_000 + script_dir = os.path.dirname(os.path.abspath(__file__)) + cache_file = os.path.join(script_dir, f"data_cache_rank{rank}_of{world_size}.pt") + + if os.path.exists(cache_file): + if rank_is_primary(rank): + print(f"[Data] Loading cache → {cache_file}") + t0 = time.time() + tokens = torch.load(cache_file, map_location="cpu", weights_only=True) + if rank_is_primary(rank): + print(f"[Data] Loaded {len(tokens):,} tokens in {time.time()-t0:.1f}s") + if dist.is_initialized(): + dist.barrier() + return StreamingDataPipeline(tokens, seq_len) + + if rank_is_primary(rank): + print(f"[Data] Prefetching {MAX_TOKENS:,} tokens / rank …") + + arm_wt = cfg.armenian_weight + + def is_useful(ex): + return len(ex.get("text", "").strip()) > 600 + + import datasets + if dist.is_initialized() and not rank_is_primary(rank): + datasets.utils.logging.disable_progress_bar() + + ds_main = load_dataset("HuggingFaceFW/fineweb-edu", "sample-10BT", + split="train", streaming=True).filter(is_useful) + + if arm_wt > 0: + ds_arm = load_dataset("wikimedia/wikipedia", "20231101.hy", + split="train", streaming=True).filter(is_useful) + stream = interleave_datasets([ds_main, ds_arm], + probabilities=[1.0 - arm_wt, arm_wt], seed=42) + else: + stream = ds_main + + token_buf = torch.zeros(MAX_TOKENS, dtype=torch.int16) + cursor = 0 + t0 = time.time() + last_log = 0 + LOG_EVERY = max(1_000_000, MAX_TOKENS // 10) + it = iter(stream) + doc_count = 0 + + while cursor < MAX_TOKENS: + try: + sample = next(it) + except StopIteration: + if rank_is_primary(rank): + print(f" [Data] Stream restart at {cursor:,} tokens …") + it = iter(stream) + doc_count = 0 + continue + + if doc_count % world_size != rank: + doc_count += 1 + continue + doc_count += 1 + + text = sample.get("text", "") + if not text.strip(): + continue + ids = tokenizer.encode(text, add_special_tokens=True) + if not ids: + continue + + n_take = min(len(ids), MAX_TOKENS - cursor) + token_buf[cursor: cursor + n_take] = torch.tensor(ids[:n_take], dtype=torch.int16) + cursor += n_take + + if rank_is_primary(rank) and cursor - last_log >= LOG_EVERY: + elapsed = time.time() - t0 + rate = cursor / max(elapsed, 1e-6) / 1e6 + pct = 100.0 * cursor / MAX_TOKENS + print(f" [Data] {cursor/1e6:.1f}M / {MAX_TOKENS/1e6:.1f}M tok " + f"({pct:.0f}%) — {rate:.1f}M tok/s") + last_log = cursor + + if rank_is_primary(rank): + elapsed = time.time() - t0 + print(f"✓ Prefetch complete: {elapsed:.2f}s | {cursor:,} tokens | " + f"{cursor/max(elapsed,1e-6)/1e6:.1f}M tok/s") + + saved_buf = token_buf[:cursor] + torch.save(saved_buf, cache_file) + if rank_is_primary(rank): + mb = os.path.getsize(cache_file) / 1024 ** 2 + print(f"[Data] Cache saved → {cache_file} ({mb:.1f} MB)") + + if dist.is_initialized(): + dist.barrier() + + return StreamingDataPipeline(saved_buf, seq_len) + + +# ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ +# BENCHMARK CALIBRATION +# ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ + +def benchmark_step_time(net: nn.Module, device: torch.device, + cfg: GolfConfig, n_warmup: int = 10) -> float: + """Measure average wall-clock seconds per training step. + + Args: + net: Model (unwrapped, on device). + device: CUDA device. + cfg: GolfConfig with per_device_batch, max_seq_len, vocab_size, grad_clip. + n_warmup: Number of calibration steps. + + Returns: + Mean seconds per step as float. + """ + B, T, V = cfg.per_device_batch, cfg.max_seq_len, cfg.vocab_size + dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16 + net.train() + + tmp_muon, tmp_adam = build_optimizer_pair(net, cfg) + + # One throwaway pass to initialise CUDA state + dummy = torch.randint(0, V, (B, T), device=device) + with torch.autocast(device_type="cuda", dtype=dtype): + l, _ = net(dummy, dummy) + l.backward() + net.zero_grad() + torch.cuda.synchronize() + + t0 = time.time() + for _ in range(n_warmup): + ids = torch.randint(0, V, (B, T), device=device) + with torch.autocast(device_type="cuda", dtype=dtype): + loss, _ = net(ids, ids) + loss.backward() + if cfg.grad_accum == 1: + torch.nn.utils.clip_grad_norm_(net.parameters(), cfg.grad_clip) + tmp_muon.step() + tmp_adam.step() + tmp_muon.zero_grad() + tmp_adam.zero_grad() + torch.cuda.synchronize() + + del tmp_muon, tmp_adam + return (time.time() - t0) / n_warmup + + +# ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ +# TRAINER CLASS +# ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ + +class GolfTrainer: + """Stateful training harness for GolfTransformer. + + Args: + model: DDP-wrapped or raw GolfTransformer. + ema_shadow: AveragedModel EMA copy. + data_stream: StreamingDataPipeline instance. + muon_opt: AdaMuonOptimizer for matrix parameters. + adam_opt: AdamW for embedding/scalar parameters. + device: CUDA device. + rank: Process rank. + world_size: Total number of processes. + cfg: GolfConfig with all training hyperparameters. + logger: ExperimentLogger for JSON-line recording. + """ + + def __init__(self, model, ema_shadow, data_stream, muon_opt, adam_opt, + device, rank, world_size, cfg: GolfConfig, + logger: ExperimentLogger): + self.model = model + self.ema_shadow = ema_shadow + self.data_stream = data_stream + self.muon_opt = muon_opt + self.adam_opt = adam_opt + self.device = device + self.rank = rank + self.world_size = world_size + self.cfg = cfg + self.logger = logger + + self.use_bf16 = torch.cuda.is_bf16_supported() + self.amp_dtype = torch.bfloat16 if self.use_bf16 else torch.float16 + self.scaler = torch.amp.GradScaler("cuda", enabled=not self.use_bf16) + + def fit(self) -> int: + """Run the main training loop until time or step limit is reached. + + Returns: + Total number of gradient steps completed. + """ + self.model.train() + loader = DataLoader(self.data_stream, + batch_size=self.cfg.per_device_batch, + num_workers=0, pin_memory=True) + start_time = time.time() + global_step = 0 + + if rank_is_primary(self.rank): + base_net = self.model.module if isinstance(self.model, DDP) else self.model + n_params = base_net.unique_param_count() + fp16_mb = base_net.fp16_size_mb() + eff_bs = self.world_size * self.cfg.per_device_batch * self.cfg.grad_accum + print(f"\n⚡ Training | {self.world_size} GPU(s) | BF16={self.use_bf16}") + print(f" Params: {n_params:,} | ~{fp16_mb:.2f} MB fp16") + print(f" max_steps={self.cfg.max_steps} | " + f"warmdown={self.cfg.warmdown_steps}") + print(f" effective_batch={eff_bs}\n") + self.logger.log(event="train_start", gpus=self.world_size, + bf16=self.use_bf16, param_count=n_params, + size_mb=round(fp16_mb, 4), effective_batch=eff_bs) + + self.muon_opt.zero_grad() + self.adam_opt.zero_grad() + + for batch in loader: + elapsed = time.time() - start_time + + # Rank-0 controlled deterministic cutoff (prevents NCCL hangs) + check_every = self.cfg.grad_accum * 5 + if global_step > 0 and global_step % check_every == 0: + stop = torch.tensor([0], dtype=torch.int, device=self.device) + if rank_is_primary(self.rank) and \ + elapsed >= self.cfg.time_limit_sec - 2.5: + print(f"[Cutoff] Safe stop at {elapsed:.2f}s (step={global_step})") + stop[0] = 1 + if dist.is_initialized(): + dist.broadcast(stop, src=0) + if stop.item() == 1: + break + + if global_step >= self.cfg.max_steps: + if rank_is_primary(self.rank): + print(f"Max steps reached: {global_step} ({elapsed:.0f}s)") + break + + # Learning rate and momentum scheduling + lr_now = cosine_lr_schedule(global_step, self.cfg) + ratio = lr_now / self.cfg.lr_matrix + progress = min(1.0, global_step / self.cfg.max_steps) + mom_now = 0.95 + 0.04 * 0.5 * (1 + math.cos(math.pi * progress)) + + for g in self.muon_opt.param_groups: + g["lr"] = lr_now + g["momentum"] = mom_now + for g in self.adam_opt.param_groups: + if "initial_lr" not in g: + g["initial_lr"] = g["lr"] + g["lr"] = g["initial_lr"] * ratio + + token_ids = batch["input_ids"].to(self.device, non_blocking=True) + targets = batch["labels"].to(self.device, non_blocking=True) + + accumulating = (global_step + 1) % self.cfg.grad_accum != 0 + if accumulating and isinstance(self.model, DDP): + with self.model.no_sync(): + with torch.autocast(device_type="cuda", dtype=self.amp_dtype): + loss, _ = self.model(token_ids, targets) + loss = loss / self.cfg.grad_accum + self.scaler.scale(loss).backward() + else: + with torch.autocast(device_type="cuda", dtype=self.amp_dtype): + loss, _ = self.model(token_ids, targets) + loss = loss / self.cfg.grad_accum + self.scaler.scale(loss).backward() + + if (global_step + 1) % self.cfg.grad_accum == 0: + self.scaler.unscale_(self.muon_opt) + self.scaler.unscale_(self.adam_opt) + torch.nn.utils.clip_grad_norm_(self.model.parameters(), + self.cfg.grad_clip) + self.scaler.step(self.muon_opt) + self.scaler.step(self.adam_opt) + self.scaler.update() + self.muon_opt.zero_grad() + self.adam_opt.zero_grad() + raw_net = self.model.module if isinstance(self.model, DDP) else self.model + if hasattr(self.ema_shadow, "update_parameters"): + self.ema_shadow.update_parameters(raw_net) + + global_step += 1 + if rank_is_primary(self.rank) and global_step % 10 == 0: + bpb_est = loss.item() * self.cfg.grad_accum / math.log(2) + print(f" step {global_step:5d} | BPB~{bpb_est:.4f} | " + f"LR {lr_now:.3e} | {elapsed:.0f}s") + self.logger.log(event="step", step=global_step, + bpb=round(bpb_est, 6), lr=round(lr_now, 8), + elapsed_s=round(elapsed, 1)) + + return global_step + + +# ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ +# EVALUATION — Score-First TTT +# ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ + +def compute_bpb(model: nn.Module, tokenizer, texts: list, + device, cfg: GolfConfig): + """Sliding-window BPB evaluation with legal Score-First TTT. + + For each sliding window chunk: + 1. SCORE: no_grad forward on current parameters → record NLL. + 2. ADAPT: separate forward + backward + SGD step to improve future chunks. + + This ordering guarantees no look-ahead contamination, matching competition + evaluation semantics. + + Args: + model: GolfTransformer (eval mode set internally). + tokenizer: PreTrainedTokenizerFast. + texts: List of evaluation document strings. + device: CUDA device. + cfg: GolfConfig with use_ttt, ttt_lr, max_seq_len, eval_stride. + + Returns: + Tuple of (bits_per_byte, total_tokens_scored). + """ + use_ttt = cfg.use_ttt + use_bf16 = torch.cuda.is_bf16_supported() + amp_dtype = torch.bfloat16 if use_bf16 else torch.float16 + + model.eval() + + ttt_opt = None + if use_ttt: + ttt_opt = torch.optim.SGD(model.parameters(), lr=cfg.ttt_lr) + + seq_len = cfg.max_seq_len + stride = cfg.eval_stride + total_nll = 0.0 + total_tokens = 0 + total_bytes = 0 + + for text in texts: + seq_ids = tokenizer.encode(text, add_special_tokens=True) + if len(seq_ids) < 2: + continue + total_bytes += len(text.encode("utf-8")) + ids_tensor = torch.tensor(seq_ids, dtype=torch.long, device=device) + L = len(ids_tensor) + prev_end = 0 + + for start in range(0, L - 1, stride): + end = min(start + seq_len + 1, L) + chunk = ids_tensor[start:end] + if len(chunk) < 2: + break + + inp = chunk[:-1].unsqueeze(0) + tgt = chunk[1:].unsqueeze(0) + n_new = min(stride, tgt.shape[1] - max(0, prev_end - start - 1)) + if n_new <= 0: + prev_end = end + continue + + # ── SCORE phase: record NLL on current parameters ────────────── + with torch.no_grad(): + with torch.autocast(device_type="cuda", dtype=amp_dtype): + _, score_logits = model(inp, labels=None) + + nll_sum = F.cross_entropy( + score_logits[0, -n_new:], + tgt[0, -n_new:], + reduction="sum", + ) + total_nll += nll_sum.item() + total_tokens += n_new + prev_end = end + + # ── ADAPT phase: SGD step for future chunks ──────────────────── + if use_ttt and ttt_opt is not None: + ttt_opt.zero_grad() + with torch.enable_grad(): + with torch.autocast(device_type="cuda", dtype=amp_dtype): + _, adapt_logits = model(inp, labels=None) + adapt_loss = F.cross_entropy(adapt_logits[0], tgt[0]) + adapt_loss.backward() + torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) + ttt_opt.step() + + if total_bytes == 0 or total_tokens == 0: + return float("inf"), 0 + + avg_tok_per_byte = total_tokens / max(total_bytes, 1) + bpb = (total_nll / total_tokens) / math.log(2) * avg_tok_per_byte + return bpb, total_tokens + + +# ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ +# EXPORT — Dynamic MSE SDClip + INT6 Error Diffusion + LZMA +# ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ + + + +# ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ +# LQER asymmetric rank-4 (post-quant correction) + Brotli compression +# ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ +def _lqer_pack_asym(A, B, g=64): + """Pack A as INT2 (single fp16 scale), B as INT4 per-group-g.""" + sA = (A.abs().amax().clamp_min(1e-10) / 1.5).to(torch.float16) + qA = torch.clamp(torch.round(A / sA.float()), -2, 1).to(torch.int8) + Bf = B.reshape(-1, g) + Bmax = Bf.abs().amax(dim=-1, keepdim=True).clamp_min(1e-10) + sB = (Bmax / 7.5).to(torch.float16).reshape(-1) + qB = torch.clamp(torch.round(Bf / sB.float().reshape(-1, 1)), -8, 7).to(torch.int8).reshape(B.shape) + return qA, sA, qB, sB + +def _lqer_unpack_asym(qA, sA, qB, sB): + A = qA.float() * float(sA) + g = qB.numel() // sB.numel() + Bf = qB.reshape(-1, g).float() * sB.float().view(-1, 1) + return A @ Bf.reshape(qB.shape) + +_BSHF_MAGIC = b"BSHF" +def _byte_shuffle(data, stride=2): + if stride <= 1 or len(data) < stride: + return data + src_arr = np.frombuffer(data, dtype=np.uint8) + n = len(src_arr); out = np.empty(n, dtype=np.uint8); off = 0 + for pos in range(stride): + chunk = src_arr[pos::stride] + out[off:off+len(chunk)] = chunk + off += len(chunk) + return _BSHF_MAGIC + bytes([stride]) + out.tobytes() + +def _byte_unshuffle(data): + if len(data) < 5 or data[:4] != _BSHF_MAGIC: + return data + stride = data[4] + if stride < 2: return data[5:] + payload = np.frombuffer(data, dtype=np.uint8, offset=5) + n = len(payload); out = np.empty(n, dtype=np.uint8); off = 0 + for pos in range(stride): + chunk_len = n // stride + (1 if pos < n % stride else 0) + out[pos::stride][:chunk_len] = payload[off:off+chunk_len] + off += chunk_len + return out.tobytes() + + +def export_submission(model, ema_shadow, cfg: GolfConfig, + rank: int, *_args) -> Optional[tuple]: + """Quantise and compress the model checkpoint for submission. + + Quantisation pipeline per weight matrix: + 1. Dynamic SDClip: test sigma ∈ {2.5, 3.0, 3.5, 4.0}, keep the value + that minimises INT6 round-trip reconstruction MSE. + 2. Vectorised Sigma-Delta (error diffusion) quantisation to int8 codes. + 3. Per-row fp16 scale stored alongside codes. + EMA weights are applied before quantisation when available. + + Args: + model: DDP-wrapped or raw GolfTransformer. + ema_shadow: AveragedModel EMA copy. + cfg: GolfConfig with output_path, sdclip_sigma, sdclip_dynamic, etc. + rank: Process rank (only rank 0 saves). + + Returns: + (lzma_size_mb, fp16_size_mb) or None if not rank 0. + """ + if not rank_is_primary(rank): + return None + + net = model.module if isinstance(model, DDP) else model + if hasattr(ema_shadow, "module"): + ema_sd = ema_shadow.module.state_dict() + # torch.compile wraps the model as OptimizedModule whose state-dict + # keys carry a "_orig_mod." prefix; remap EMA keys to match. + net_keys = set(net.state_dict().keys()) + if net_keys and next(iter(net_keys)).startswith("_orig_mod."): + ema_sd = {"_orig_mod." + k: v for k, v in ema_sd.items()} + net.load_state_dict(ema_sd) + print("EMA weights applied.") + + # Sigma grid for per-tensor MSE search + sigma_candidates = ([2.5, 3.0, 3.5, 4.0] if cfg.sdclip_dynamic + else [cfg.sdclip_sigma]) + + quantised_state: Dict[str, torch.Tensor] = {} + + for key, tensor in net.state_dict().items(): + clean_key = key.replace("_orig_mod.", "") + fp_tensor = tensor.float().cpu() + + if fp_tensor.dim() >= 2: + limit = (127.0 if ("embed" in clean_key or "head" in clean_key + or "pos" in clean_key) else 31.0) + row_std = fp_tensor.std(dim=-1, keepdim=True).clamp(min=1e-6) + + best_mse = float("inf") + best_codes = None + best_scale = None + + for sigma in sigma_candidates: + clip_bound = row_std * sigma + clipped = fp_tensor.clamp(-clip_bound, clip_bound) + scale = (clipped.abs().max(dim=-1, keepdim=True)[0] + / limit).clamp(min=1e-6) + + # Sigma-Delta error diffusion quantisation + csum = torch.cumsum(clipped, dim=1) + csum_q = (csum / scale).round() + codes = torch.empty_like(clipped, dtype=torch.int8) + codes[:, 0] = csum_q[:, 0].clamp(-limit, limit).to(torch.int8) + codes[:, 1:] = (csum_q[:, 1:] - csum_q[:, :-1]).clamp( + -limit, limit).to(torch.int8) + + # Reconstruction MSE — select sigma that minimises round-trip error + recon = codes.float() * scale + mse = F.mse_loss(recon, fp_tensor).item() + + if not math.isnan(mse) and mse < best_mse: + best_mse = mse + best_codes = codes + best_scale = scale + + if best_scale is None: + # Fallback if tensors are NaN + best_codes = torch.zeros_like(fp_tensor, dtype=torch.int8) + best_scale = torch.ones((fp_tensor.shape[0], 1), device=fp_tensor.device, dtype=fp_tensor.dtype) + + quantised_state[clean_key] = best_codes + quantised_state[clean_key + "_scale"] = best_scale.half() + else: + quantised_state[clean_key] = fp_tensor.half() + + # LQER asymmetric rank-4 post-quantization correction + if getattr(cfg, "lqer_enabled", False): + cands = [] + for k, codes in list(quantised_state.items()): + if k.endswith("_scale") or k.startswith("_lq"): + continue + scale_k = k + "_scale" + if scale_k not in quantised_state or codes.dim() < 2: + continue + scale = quantised_state[scale_k] + W_q = (codes.float() * scale.float()).reshape(codes.shape) + # Need original fp tensor — re-fetch from net.state_dict() + fp_key = "_orig_mod." + k if "_orig_mod." + k in net.state_dict() else k + if fp_key not in net.state_dict(): + continue + W_fp = net.state_dict()[fp_key].float().cpu() + if W_fp.shape != W_q.shape: + continue + E = W_fp - W_q + cands.append((k, E, float(E.norm()))) + cands.sort(key=lambda x: -x[2]) + for (name, E, _) in cands[:cfg.lqer_top_k]: + U, S, Vh = torch.linalg.svd(E, full_matrices=False) + r = min(cfg.lqer_rank, S.numel()) + A = (U[:, :r] * S[:r]).contiguous() + Bm = Vh[:r, :].contiguous() + if Bm.numel() % cfg.lqer_asym_group == 0: + qA, sA, qB, sB = _lqer_pack_asym(A, Bm, cfg.lqer_asym_group) + quantised_state[name + "_lqA"] = qA + quantised_state[name + "_lqAs"] = sA + quantised_state[name + "_lqB"] = qB + quantised_state[name + "_lqBs"] = sB + print(f" LQER+ {name}: rank={r} fro={E.norm():.3f}") + + payload = {"config": cfg.to_dict(), "state_dict": quantised_state} + torch.save(payload, cfg.output_path) + + if getattr(cfg, "use_brotli", False): + import brotli + with open(cfg.output_path, "rb") as fin: raw_bytes = fin.read() + shuffled = _byte_shuffle(raw_bytes, getattr(cfg, "byte_shuffle_stride", 2)) + compressed = brotli.compress(shuffled, quality=getattr(cfg, "brotli_quality", 11)) + lzma_path = cfg.output_path.replace(".pt", ".pt.lzma") + with open(lzma_path, "wb") as fout: + fout.write(compressed) + else: + lzma_path = cfg.output_path.replace(".pt", ".pt.lzma") + with open(cfg.output_path, "rb") as fin: raw_bytes = fin.read() + lzma_filters = [{"id": lzma.FILTER_LZMA2, "preset": 9, "dict_size": 1 << 27}] + with lzma.open(lzma_path, "wb", format=lzma.FORMAT_XZ, filters=lzma_filters) as fout: + fout.write(raw_bytes) + + lzma_mb = os.path.getsize(lzma_path) / 1024 ** 2 + fp16_mb = os.path.getsize(cfg.output_path) / 1024 ** 2 + budget_ok = "✅ within budget" if lzma_mb <= 16 else "❌ OVER BUDGET" + print(f"\n → {cfg.output_path} ({fp16_mb:.2f} MB fp16)") + print(f" → {lzma_path} ({lzma_mb:.2f} MB) [{budget_ok}]") + + with open(cfg.output_pkl_path, "wb") as fh: + pickle.dump(payload, fh, protocol=pickle.HIGHEST_PROTOCOL) + + return lzma_mb, fp16_mb + + +def load_checkpoint(path: str, device: str = "cpu") -> "GolfTransformer": + """Restore a GolfTransformer from a quantised checkpoint. + + Args: + path: Path to .pt or .pt.lzma checkpoint file. + device: Target device string. + + Returns: + GolfTransformer in eval mode with dequantised fp16 weights. + """ + if path.endswith(".lzma"): + # Try Brotli+shuffle first (our format), fall back to LZMA + with open(path, "rb") as fh: blob = fh.read() + try: + import brotli + raw = brotli.decompress(blob) + raw = _byte_unshuffle(raw) + data = torch.load(io.BytesIO(raw), map_location=device, weights_only=False) + except Exception: + with lzma.open(path, "rb") as fh: + data = torch.load(fh, map_location=device, weights_only=False) + else: + data = torch.load(path, map_location=device, weights_only=False) + + saved_state = data["state_dict"] + fp16_state = {} + for k, v in saved_state.items(): + if k.endswith("_scale") or k.endswith("_lqA") or k.endswith("_lqAs") or k.endswith("_lqB") or k.endswith("_lqBs"): + continue + scale_key = k + "_scale" + if scale_key in saved_state: + W_dq = (v.float() * saved_state[scale_key].float()) + # Apply LQER correction if present + if k + "_lqA" in saved_state: + qA = saved_state[k + "_lqA"]; sA = saved_state[k + "_lqAs"] + qB = saved_state[k + "_lqB"]; sB = saved_state[k + "_lqBs"] + W_dq = W_dq + _lqer_unpack_asym(qA, sA, qB, sB) + fp16_state[k] = W_dq.half() + else: + fp16_state[k] = v.half() if v.dtype == torch.float32 else v + + saved_cfg = data["config"] + # Reconstruct dataclass from saved dict + cfg_obj = GolfConfig(**{k: v for k, v in saved_cfg.items() + if k in GolfConfig.__dataclass_fields__}) + net = GolfTransformer(cfg_obj) + net.load_state_dict(fp16_state) + return net.eval() + + +# ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ +# ENTRY POINT +# ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ + +if __name__ == "__main__": + + local_rank, world_size, device = init_distributed() + + CFG = GolfConfig() + LOGGER = ExperimentLogger(enabled=rank_is_primary(local_rank), cfg=CFG) + + # ── Auto-downsize hidden_size to fit 16 MB LZMA budget ──────────────────── + LZMA_RATIO = 0.3668 + BUDGET_MB = 15.98 + size_candidates = [(512, 1024), (480, 960), (448, 896), (416, 832), (384, 768), (352, 704), (344, 688), (336, 672), (320, 640), (288, 576), (256, 512)] + for h, inter in size_candidates: + CFG["hidden_size"] = h + CFG["intermediate"] = inter + probe = GolfTransformer(CFG) + est_lzma = probe.fp16_size_mb() * LZMA_RATIO + if rank_is_primary(local_rank): + print(f"[SizeCheck] hidden={h} moe={inter} → est LZMA≈{est_lzma:.2f} MB") + if est_lzma <= BUDGET_MB: + break + + if rank_is_primary(local_rank): + print("=" * 65) + print(" DEEP MACARON — Competition Run — H100 8×") + print(" Macaron layer structure (FFN -> Attn -> FFN)") + print(" MQA Attention, SwiGLU, SiLU") + print(" WSD (Warmup-Stable-Decay) Learning Rate Schedule") + print(f" vocab={CFG.vocab_size} hidden={CFG.hidden_size} layers={CFG.num_layers}") + print("=" * 65) + + tokenizer = prepare_tokenizer(local_rank, CFG) + + t_data = time.time() + pipeline = build_data_stream(tokenizer, local_rank, world_size, CFG) + if rank_is_primary(local_rank): + print(f"Data pipeline ready in {time.time()-t_data:.2f}s") + + seeds_to_run = [1337, 42, 2025] + + for seed in seeds_to_run: + torch.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + + CFG["output_path"] = f"./submission_global_seed{seed}.pt" + CFG["output_pkl_path"] = f"./submission_global_seed{seed}.pkl" + + if rank_is_primary(local_rank): + print(f"\n{'#'*65}") + print(f" SEED {seed} — starting training run") + print(f"{'#'*65}\n") + + base_net = GolfTransformer(CFG).to(device) + ema_net = make_ema_shadow(base_net, CFG.ema_decay) + + if CFG.use_compile: + if rank_is_primary(local_rank): + print("Compiling model …") + try: + base_net = torch.compile(base_net, mode="reduce-overhead") + if rank_is_primary(local_rank): + print(" torch.compile ✅") + except Exception as exc: + if rank_is_primary(local_rank): + print(f" torch.compile skipped: {exc}") + + if rank_is_primary(local_rank): + print("Calibrating step time (10 synthetic steps) …") + step_time = benchmark_step_time(base_net, device, CFG, n_warmup=10) + + if dist.is_initialized(): + st_buf = torch.tensor([step_time], dtype=torch.float32, device=device) + dist.broadcast(st_buf, src=0) + step_time = st_buf.item() + + if rank_is_primary(local_rank): + print(f" raw ~{step_time:.3f}s/step → ×1.02 = {step_time*1.02:.3f}s") + + configure_schedule(CFG, world_size, step_time) + + ddp_net = (DDP(base_net, device_ids=[0]) + if dist.is_initialized() else base_net) + + muon_opt, adam_opt = build_optimizer_pair( + ddp_net.module if isinstance(ddp_net, DDP) else ddp_net, CFG + ) + + trainer = GolfTrainer( + model=ddp_net, + ema_shadow=ema_net, + data_stream=pipeline, + muon_opt=muon_opt, + adam_opt=adam_opt, + device=device, + rank=local_rank, + world_size=world_size, + cfg=CFG, + logger=LOGGER, + ) + + t_train = time.time() + n_steps = trainer.fit() + train_time = time.time() - t_train + + if rank_is_primary(local_rank): + print(f"Training complete: {train_time:.2f}s") + LOGGER.log(event="train_end", seed=seed, steps=n_steps, + train_time_s=round(train_time, 2)) + + result = export_submission(ddp_net, ema_net, CFG, local_rank) + + if rank_is_primary(local_rank) and result: + lzma_mb, fp16_mb = result + LOGGER.log(event="checkpoint_saved", seed=seed, + path=CFG.output_path, + size_mb_fp16=round(fp16_mb, 4), + size_mb_lzma=round(lzma_mb, 4), + within_budget=lzma_mb <= 16.0) + + if rank_is_primary(local_rank) and result: + lzma_path = CFG.output_path.replace(".pt", ".pt.lzma") + eval_src = lzma_path if os.path.exists(lzma_path) else CFG.output_path + + print(f"\nBPB evaluation — stride={CFG.eval_stride} sliding window …") + t_eval = time.time() + eval_net = load_checkpoint(eval_src, device=str(device)) + eval_net.to(device) + + fw_texts = [x["text"] for x in load_dataset( + "HuggingFaceFW/fineweb", "sample-10BT", + split="train", streaming=True).take(CFG.eval_docs)] + bpb_val, n_scored = compute_bpb(eval_net, tokenizer, fw_texts, + device=device, cfg=CFG) + eval_time = time.time() - t_eval + + print(f"\n{'='*65}") + print(f" RUN COMPLETE (seed={seed})") + print(f" steps : {n_steps}") + print(f" size fp16 : {fp16_mb:.2f} MB") + print(f" size lzma : {lzma_mb:.2f} MB") + print(f" FineWeb BPB : {bpb_val:.4f} (stride={CFG.eval_stride})") + print(f" train time : {train_time:.2f}s " + f"({int(train_time//60)}m {int(train_time%60)}s)") + print(f" eval time : {eval_time:.1f}s") + print(f"{'='*65}") + + bytes_total = os.path.getsize(CFG.output_path) + bytes_code = os.path.getsize(__file__) + run_date = datetime.datetime.now( + datetime.timezone.utc).strftime("%Y-%m-%dT%H:%M:%SZ") + + sub_data = { + "author": "Karen042009", + "github_id": "Karen042009", + "name": (f"Karen042009 - DualHashSkip AdaMuon " + f"DynSDClip SharedMoE LayerScale (Seed {seed})"), + "blurb": (f"Seed {seed}. DualTokenHashSkip(2×2048×16,concat), " + "LayerScale Recurrence (1.0/0.1 init), " + "AdaMuonOptimizer (RMS pre-cond + Riemannian NS), " + "Dynamic MSE SDClip (σ-grid {2.5,3.0,3.5,4.0} at export), " + "SharedMoE (1 shared + top-1/3 specialised), " + "Score-First TTT 2-pass, parallel residuals, " + "QKGain=5.25, PartialRoPE=32, " + "LeakyReLU(0.5)^2 SwiGLU, vocab=8192"), + "date": run_date, + "val_loss": round(float(bpb_val * math.log(2)), 6), + "val_bpb": round(bpb_val, 6), + "bytes_total": bytes_total, + "bytes_code": bytes_code, + } + + print(f"final_int8_zlib_roundtrip val_loss:{bpb_val*math.log(2):.6f} " + f"val_bpb:{bpb_val:.6f} eval_time:{eval_time*1000:.0f}ms") + + json_path = f"submission_global_seed{seed}.json" + with open(json_path, "w") as fh: + json.dump(sub_data, fh, indent=4) + print(f" Metadata → '{json_path}'") + + print("\n" + "=" * 65) + print(f" FINAL RESULT (Karen042009 SEED {seed})") + print(f" val_bpb: {bpb_val:.8f}") + print(f" val_loss: {bpb_val*math.log(2):.8f}") + print(f" size_lzma: {lzma_mb:.2f} MB (limit 16 MB)") + print(f" steps: {n_steps}") + print(f" train_time: {train_time:.2f}s " + f"({int(train_time//60)}m {int(train_time%60)}s)") + print("=" * 65) + + LOGGER.log( + event="seed_done", + seed=seed, + steps=n_steps, + size_mb_fp16=round(fp16_mb, 4), + size_mb_lzma=round(lzma_mb, 4), + fineweb_bpb=round(bpb_val, 8), + fineweb_val_loss=round(float(bpb_val * math.log(2)), 8), + eval_time_s=round(eval_time, 2), + scored_tokens=n_scored, + bytes_total=bytes_total, + bytes_code=bytes_code, + json_out=json_path, + checkpoint=CFG.output_path, + date=run_date, + ) + + eval_net.cpu() + del eval_net + + if "base_net" in dir(): + del base_net, ema_net, ddp_net, muon_opt, adam_opt, trainer + + import gc + gc.collect() + torch.cuda.empty_cache() + + if rank_is_primary(local_rank): + LOGGER.close(message="All seeds completed.") + + if dist.is_initialized(): + dist.barrier() + dist.destroy_process_group()