diff --git a/records/track_10min_16mb/2026-04-16_SP8192_4LayerRecurrence/README.md b/records/track_10min_16mb/2026-04-16_SP8192_4LayerRecurrence/README.md new file mode 100644 index 0000000000..c76e0a317c --- /dev/null +++ b/records/track_10min_16mb/2026-04-16_SP8192_4LayerRecurrence/README.md @@ -0,0 +1,157 @@ +# SP8192 + 4-Layer Depth Recurrence + Parallel Residuals + QK-Gain 5.25 + Legal TTT + +**val_bpb = pending** (results to be collected on 8xA100s) | **~16MB** | 8xA100 + +## Summary + +Extends the current SOTA (PR #1509, val_bpb=1.0810) by widening the depth recurrence from 3 looped layers to 4 (`LOOP_END=6`). All other hyperparameters and techniques are carried forward unchanged. + +The only code change from SOTA: + +```python +# Before (SOTA, PR #1509) +loop_end=int(os.environ.get('LOOP_END', 5)) +qk_gain_init=float(os.environ.get('QK_GAIN_INIT', 5.)) + +# After (this PR) +loop_end=int(os.environ.get('LOOP_END', 6)) +qk_gain_init=float(os.environ.get('QK_GAIN_INIT', 5.25)) # absorbs the known-good default +``` + +## Architecture + +### Virtual Layer Sequence + +**SOTA (loop_end=5)** — 17 virtual layers, 8 U-Net skips: +``` +Encoder [0,1,2,3,4,5,3,4] → Decoder [5,3,4,5,6,7,8,9,10] + ───────────────── ────────────────────────── + pre └──loop──┘ 2nd 2nd └──loop──┘ post (par) +``` + +**This PR (loop_end=6)** — 19 virtual layers, 9 U-Net skips: +``` +Encoder [0,1,2,3,4,5,6,3,4] → Decoder [5,6,3,4,5,6,7,8,9,10] + ───────────────────── ──────────────────────────── + pre └────loop────┘ 2nd 2nd └────loop────┘ post (par) +``` + +Layer 6 is promoted from the non-recurring post-loop section into the recurrence core. It now executes 3 times (like layers 3, 4, 5) instead of once. + +### U-Net Skip Connections + +With 9 encoder and 10 decoder steps, there are 9 skip connections (encoder[i] feeds decoder[8-i]): + +| Skip | Encoder step | Decoder step | +|------|-------------|-------------| +| 0 | L0 (1st pass) | L10 | +| 1 | L1 (1st pass) | L9 | +| 2 | L2 (1st pass) | L8 | +| 3 | L3 (1st pass) | L7 (parallel) | +| 4 | L4 (1st pass) | L6 (3rd loop pass) | +| 5 | L5 (1st pass) | L5 (3rd loop pass) | +| 6 | **L6 (1st pass)** | **L4 (3rd loop pass)** | +| 7 | L3 (2nd pass) | L3 (3rd loop pass) | +| 8 | L4 (2nd pass) | L6 (2nd loop pass, parallel) | + +Skip 6 is new in this PR. It connects the first pass through L6 (shallow context) to the third pass through L4 (deep re-processing), providing a residual shortcut that was absent in the 3-layer config. + +### Parallel Residuals + +Unchanged: layers 7, 8, 9, 10 use GPT-J-style parallel attention+MLP. These all appear in the decoder's post-loop section and are not part of the recurrence. + +## Motivation: Compute Budget Equivalence + +The 4-layer loop is slower per step (19 vs 17 virtual forward passes), but the total layer-step budget is identical: + +| Config | Virtual layers | Est. steps | Layer-steps | +|--------|---------------|-----------|-------------| +| SOTA | 17 | ~4,550 | ~77,350 | +| This PR | 19 | ~4,071 | ~77,349 | + +The prior depth-recurrence progression shows monotonic improvement with more virtual depth: + +| Submission | Looped layers | Virtual layers | val_bpb (no TTT) | +|-----------|--------------|---------------|-----------------| +| PR #1260 (2-layer loop) | [4,5] | ~13 | 1.0979 | +| PR #1394 (3-layer loop) | [3,4,5] | 17 | 1.0856 | +| This PR (4-layer loop) | [3,4,5,6] | 19 | pending | + +Each expansion has improved BPB without increasing the compute budget. The hypothesis is that depth-per-step is more sample-efficient than breadth-of-passes at the same depth. + +## Local Verification + +`test_architecture.py` in this directory validates the 4-layer config on CPU (no CUDA, no flash_attn needed): + +``` +$ python test_architecture.py +============================================================ +Config: SOTA (loop_end=5) + encoder: [0, 1, 2, 3, 4, 5, 3, 4] + decoder: [5, 3, 4, 5, 6, 7, 8, 9, 10] + virtual_layers=17, skips=8 + forward pass: OK shape=torch.Size([2, 32, 256]) + gradient flow: OK loss=5.5465 + all looped blocks have clean gradients: OK + +============================================================ +Config: This PR (loop_end=6) + encoder: [0, 1, 2, 3, 4, 5, 6, 3, 4] + decoder: [5, 6, 3, 4, 5, 6, 7, 8, 9, 10] + virtual_layers=19, skips=9 + forward pass: OK shape=torch.Size([2, 32, 256]) + gradient flow: OK loss=5.5435 + all looped blocks have clean gradients: OK + +All architecture tests PASSED. +``` + +## Full Technique Stack (carried from SOTA) + +1. **SP8192** tokenizer (kevclark/parameter-golf HuggingFace dataset) +2. **4-Layer Depth Recurrence** — layers [3,4,5,6], 3 total passes, activated at `frac=0.35` +3. **Parallel Residuals** — from layer 7, GPT-J style (attention and MLP read same input) +4. **QK-Gain 5.25** — learnable per-head query scaling (now the script default) +5. **MuonEq-R** — row-normalized Muon with Newton-Schulz 5 steps +6. **Legal Score-First TTT** — SGD lr=0.005, momentum=0.9, 3 epochs/32K chunk, cosine LR decay +7. **Full-Hessian GPTQ SDClip** — int6 matrices (k=12.85), int8 embeddings (k=20.0) +8. **Byte-shuffle + Brotli-11** compression +9. **LZMA code wrapper** — ~16.6KB self-extracting code + +## Reproduction + +```bash +pip install brotli sentencepiece +pip install flash_attn_3 --no-deps --find-links https://windreamer.github.io/flash-attention3-wheels/cu128_torch291/ +MATCHED_FINEWEB_REPO_ID=kevclark/parameter-golf python3 data/cached_challenge_fineweb.py --variant sp8192 + +SEED=42 TTT_ENABLED=1 TTT_LR=0.005 TTT_EPOCHS=3 \ + torchrun --standalone --nproc_per_node=8 train_gpt.py +``` + +No additional env vars needed. `QK_GAIN_INIT=5.25` and `LOOP_END=6` are now script defaults. + +## Expected Behavior + +- Training: ~4071 steps in ~588s (vs 4550 for SOTA), 11.8% more compute per step +- Artifact: ~15.99 MB (essentially unchanged — one extra skip_weight/skip_gate row adds ~4KB uncompressed) +- Eval: identical sliding window + TTT pipeline, same timing budget + +## Compliance + +Identical to SOTA (PR #1509). Score-first TTT, no SLOT, no n-gram cache, no pre-quant TTT, no ETLB. + +## Files + +- `train_gpt.py` — LZMA-compressed production script (2-line diff from SOTA) +- `train_gpt_human.py` — human-readable version of the same code +- `test_architecture.py` — CPU smoke test, no dependencies beyond PyTorch + +## Credits + +- **@clarkkev** — SP8192 + GPTQ + SDClip + MuonEq-R + depth recurrence base (PR #1394) +- **@dexhunter** — 3-layer depth recurrence (PR #1331, #1437), legal TTT on SP8192 (PR #1413) +- **@Robby955** — Parallel residuals on SP8192 (PR #1412) +- **@msisovic** — Parallel residuals concept (PR #1204) +- **@X-Abhishek-X** — Hyperparameter tuning: WD=0.095, MLR=0.022, EMA=0.9965 (PR #1445, #1471) +- **@abaybektursun** — Score-first TTT framework (PR #549) diff --git a/records/track_10min_16mb/2026-04-16_SP8192_4LayerRecurrence/submission.json b/records/track_10min_16mb/2026-04-16_SP8192_4LayerRecurrence/submission.json new file mode 100644 index 0000000000..e963db1757 --- /dev/null +++ b/records/track_10min_16mb/2026-04-16_SP8192_4LayerRecurrence/submission.json @@ -0,0 +1,7 @@ +{ + "name": "SP8192 + 4-Layer Depth Recurrence + Parallel Residuals + QK-Gain 5.25 + Legal TTT", + "github_id": "tashapais", + "val_bpb": null, + "date": "2026-04-16", + "notes": "Extends SOTA 3-layer depth recurrence to 4 layers (loop_end=6). Results pending on 8xA100s." +} diff --git a/records/track_10min_16mb/2026-04-16_SP8192_4LayerRecurrence/test_architecture.py b/records/track_10min_16mb/2026-04-16_SP8192_4LayerRecurrence/test_architecture.py new file mode 100644 index 0000000000..989b00919e --- /dev/null +++ b/records/track_10min_16mb/2026-04-16_SP8192_4LayerRecurrence/test_architecture.py @@ -0,0 +1,183 @@ +""" +Smoke-test for the 4-layer depth recurrence architecture. + +Runs on CPU with standard PyTorch (no flash_attn, no CUDA, no data). +Verifies that: + 1. Virtual layer sequences match the expected encoder/decoder paths + 2. Forward pass produces correct output shapes + 3. Gradients flow cleanly through all blocks (including looped blocks) + 4. SOTA (loop_end=5) and this submission (loop_end=6) both pass + +Usage: + python test_architecture.py +""" +import math +import torch +import torch.nn as nn +import torch.nn.functional as F + + +class RMSNorm(nn.Module): + def forward(self, x): + return F.rms_norm(x, (x.size(-1),)) + + +class CastedLinear(nn.Linear): + def forward(self, x): + return F.linear(x, self.weight.to(x.dtype)) + + +class MLP(nn.Module): + def __init__(self, dim, mlp_mult): + super().__init__() + hidden = int(mlp_mult * dim) + self.fc = CastedLinear(dim, hidden, bias=False) + self.proj = CastedLinear(hidden, dim, bias=False) + + def forward(self, x): + return self.proj(F.leaky_relu(self.fc(x), negative_slope=0.5).square()) + + +class Attention(nn.Module): + def __init__(self, dim, num_heads, num_kv_heads): + super().__init__() + self.num_heads = num_heads + self.num_kv_heads = num_kv_heads + self.head_dim = dim // num_heads + kv_dim = num_kv_heads * self.head_dim + self.c_q = CastedLinear(dim, dim, bias=False) + self.c_k = CastedLinear(dim, kv_dim, bias=False) + self.c_v = CastedLinear(dim, kv_dim, bias=False) + self.proj = CastedLinear(dim, dim, bias=False) + self.q_gain = nn.Parameter(torch.ones(num_heads)) + + def forward(self, x): + B, T, D = x.shape + hd = self.head_dim + q = F.rms_norm(self.c_q(x).reshape(B, T, self.num_heads, hd), (hd,)) + k = F.rms_norm(self.c_k(x).reshape(B, T, self.num_kv_heads, hd), (hd,)) + v = self.c_v(x).reshape(B, T, self.num_kv_heads, hd) + q = q * self.q_gain[None, None, :, None] + grp = self.num_heads // self.num_kv_heads + k = k.repeat_interleave(grp, dim=2) + v = v.repeat_interleave(grp, dim=2) + y = F.scaled_dot_product_attention( + q.permute(0, 2, 1, 3), k.permute(0, 2, 1, 3), v.permute(0, 2, 1, 3), + is_causal=True, + ).permute(0, 2, 1, 3).reshape(B, T, D) + return self.proj(y) + + +class Block(nn.Module): + def __init__(self, dim, num_heads, num_kv_heads, mlp_mult, layer_idx): + super().__init__() + self.attn_norm = RMSNorm() + self.mlp_norm = RMSNorm() + self.attn = Attention(dim, num_heads, num_kv_heads) + self.mlp = MLP(dim, mlp_mult) + self.attn_scale = nn.Parameter(torch.ones(dim)) + self.mlp_scale = nn.Parameter(torch.ones(dim)) + self.resid_mix = nn.Parameter(torch.stack([torch.ones(dim), torch.zeros(dim)]).float()) + self.ln_scale_factor = 1.0 / math.sqrt(layer_idx + 1) + self.parallel = False + + def forward(self, x, x0): + mix = self.resid_mix.to(dtype=x.dtype) + x_in = mix[0] * x + mix[1] * x0 + attn_out = self.attn(self.attn_norm(x_in) * self.ln_scale_factor) + if self.parallel: + mlp_out = self.mlp(self.mlp_norm(x_in) * self.ln_scale_factor) + return x_in + self.attn_scale.to(x_in.dtype) * attn_out + self.mlp_scale.to(x_in.dtype) * mlp_out + x2 = x_in + self.attn_scale.to(x_in.dtype) * attn_out + return x2 + self.mlp_scale.to(x2.dtype) * self.mlp(self.mlp_norm(x2) * self.ln_scale_factor) + + +class GPTRecurrent(nn.Module): + def __init__(self, vocab_size=256, num_layers=11, dim=64, + num_heads=8, num_kv_heads=4, mlp_mult=4, + loop_start=3, loop_end=5, num_loops=2, + parallel_residual_start=7): + super().__init__() + self.blocks = nn.ModuleList([ + Block(dim, num_heads, num_kv_heads, mlp_mult, i) + for i in range(num_layers) + ]) + for i in range(num_layers): + self.blocks[i].parallel = (i >= parallel_residual_start) + + loop_seg = list(range(loop_start, loop_end + 1)) + all_idx = list(range(loop_start)) + for _ in range(num_loops + 1): + all_idx.extend(loop_seg) + all_idx.extend(range(loop_end + 1, num_layers)) + + num_enc = len(all_idx) // 2 + self.encoder_indices = all_idx[:num_enc] + self.decoder_indices = all_idx[num_enc:] + self.num_skip = min(len(self.encoder_indices), len(self.decoder_indices)) + + self.skip_weights = nn.Parameter(torch.ones(self.num_skip, dim)) + self.skip_gates = nn.Parameter(torch.zeros(self.num_skip, dim)) + self.tok_emb = nn.Embedding(vocab_size, dim) + nn.init.normal_(self.tok_emb.weight, std=0.005) + + def forward(self, input_ids): + x = F.rms_norm(self.tok_emb(input_ids), (self.tok_emb.embedding_dim,)) + x0 = x + skips = [] + for i in self.encoder_indices: + x = self.blocks[i](x, x0) + skips.append(x) + for skip_idx, i in enumerate(self.decoder_indices): + if skip_idx < self.num_skip and skips: + sk = self.skip_weights[skip_idx].to(x.dtype) * skips.pop() + g = torch.sigmoid(self.skip_gates[skip_idx].to(x.dtype)) + x = torch.lerp(sk, x, g) + x = self.blocks[i](x, x0) + return F.linear(x, self.tok_emb.weight) + + +def run_test(name, loop_end, expected_virtual_layers, expected_skips): + print(f"\n{'='*60}") + print(f"Config: {name}") + + model = GPTRecurrent(loop_end=loop_end) + vl = len(model.encoder_indices) + len(model.decoder_indices) + assert vl == expected_virtual_layers, f"Expected {expected_virtual_layers} virtual layers, got {vl}" + assert model.num_skip == expected_skips, f"Expected {expected_skips} skips, got {model.num_skip}" + + print(f" encoder: {model.encoder_indices}") + print(f" decoder: {model.decoder_indices}") + print(f" virtual_layers={vl}, skips={model.num_skip}") + + # Forward pass + B, T = 2, 32 + ids = torch.randint(0, 256, (B, T)) + logits = model(ids) + assert logits.shape == (B, T, 256), f"Bad output shape: {logits.shape}" + assert not logits.isnan().any(), "NaN in forward pass" + print(f" forward pass: OK shape={logits.shape}") + + # Gradient flow + loss = F.cross_entropy(logits[:, :-1].reshape(-1, 256), ids[:, 1:].reshape(-1)) + loss.backward() + no_grad = [n for n, p in model.named_parameters() if p.requires_grad and p.grad is None] + has_nan = [n for n, p in model.named_parameters() if p.grad is not None and p.grad.isnan().any()] + assert not no_grad, f"Missing gradients: {no_grad}" + assert not has_nan, f"NaN gradients: {has_nan}" + print(f" gradient flow: OK loss={loss.item():.4f}") + + # Verify blocks in looped section receive gradients (key check for recurrence) + looped = set(model.encoder_indices + model.decoder_indices) + for block_idx in range(11): + if block_idx in looped: + for p in model.blocks[block_idx].parameters(): + assert p.grad is not None and not p.grad.isnan().any() + print(f" all looped blocks have clean gradients: OK") + + +if __name__ == "__main__": + run_test("SOTA (loop_end=5)", loop_end=5, expected_virtual_layers=17, expected_skips=8) + run_test("This PR (loop_end=6)", loop_end=6, expected_virtual_layers=19, expected_skips=9) + print(f"\n{'='*60}") + print("All architecture tests PASSED.") diff --git a/records/track_10min_16mb/2026-04-16_SP8192_4LayerRecurrence/train_gpt.py b/records/track_10min_16mb/2026-04-16_SP8192_4LayerRecurrence/train_gpt.py new file mode 100644 index 0000000000..10a56715e0 --- /dev/null +++ b/records/track_10min_16mb/2026-04-16_SP8192_4LayerRecurrence/train_gpt.py @@ -0,0 +1,2 @@ +import lzma as L,base64 as B +exec(L.decompress(B.b85decode(";JwH*dR+iCn@VT6Qap3bt~@<3h>ok~)Km^%c^ys%R{D_%yAk9-_tV7^coUOo3$w>`(`ci)t`2F7>r>Ltx>>S2CRw|7ov>Wn1e~_!RLQ=%V9g?)G3yPsu%SBy!lj1PaC-x%dDmCDOZ^r^!)+WWz}ejKXTJ#^U6Ra!};QocHHXQC+4UM!QQ!-N5Xd|%~a(9)bTYIO+>B~8~@lqmri%^qEkQUy074Rh6w7V_#^s9J-3BNA`G;qyR$LYcI?e+loZVWi~B$n=TKFp{%SeHYp{oNWh;U@Ahk8M2$OU%K8B$lb*dRQXd-GR_@*KAZdRdwSd#v=LSq1v@Puul=a7WXDmh1^kBj}Y2XlER!D2E{&{%lV(hz$#n5%+%sk&Q}>{y0xpRgiQQBJeVV0hy8UD3ntyo@(Pv+K7^zVRDt4bah(r8kfsZThb+H1)~K-lIr4`|V#-2R>G7pP*N!fwWd&Dq8C)y=NrG_U_Oz6Q?+@ok1?(VJ5?ZT~&}C4Ks38WRB>3i=I!}H-8qq=&yKJ;tbpwwn~lAseD^q1C*u5T;lKQtF;?zv@u0f36%6SXU~txi3v5iSPK*`fNE9531KaQDL`zTPF$MX4U(-3sY-&?>QJe)giBQzpor7H)AZ#4=Hn#`AoAL7tT){&bw(fgz|eQRt`#6-<>;m*+&$!nf|od6&lVKYYHuOoNgZU_L>E@!O%__mlt=);Hwdc43+CM?sh5y+my3XSVYMO8F1pXuq$fvTU<$mpDjr>J(TToE`V^_}Zn!5!L?z?WEjbgs4FEL|QI3d+uTVjjDP;5mS|KGdf(Q3`k2utZ{v36ih_0RnqB7`I-Uhbix!dTRAI5lSVD~$w`;0vXIn7%joh_Zwk3J&7PSm8$@)vXE{xay(xXF3nUp;r@UyCrfJF#}}^&+d$Vj*~UQwb}>mDndelaV#h4|oY84(e*bErzh%f4Dp487`*Wiu5!@0RVdy3yUNSgn=cLpDov{1D{d19U2|yTLINWhGsg6|Ly_y5T1HIB?MrIcor{1vjScKf!{mk!c7GxGF34g;ESg0m~DEDp(e$+qn5L9Nhgl=T`_L2LkJCk^xumD!_|tarei{(xV0Z~5N6V*qJ!C0i-9lXvMDs7az5dV^-FmV3>OG}WzT&vMJZOzAGqhiJ5r-N*|N+fgQNWEKpTgIHG9i{tX>>3_%tCbL6D;16%GQp%&~KTzW+$_H`ljJnTs9rd4^#mB8i+`AKyP6gv7Y7_lFr){1iT!tU%yioSZ_P)iUEb~WT?yww%#G)f;ve9a_S*aoJFM%tYL~0o}jOPbtBS7$Ci}>xa8r*Fn1;)&D$)*L`B!Sq7gEL0<#P(w6-X~qnVj+iF{AxRrG@B9x{cj*u9abZQ7n_6Nt*~DVgfyTt|omv-vwhpA;pdq$A{miGg^!$|xOym`PnJek0gA?{1xj7_n^8=6=^cN*Xdm155+V^&xwryhh$Srp@-JhmkFobWK?|C!GWE2dFQAFpElM^VAv>t9K_Rc8}#7$TI!XBbE5HCpT$t!R>yvFrY`C+4mV5c)S5;?<06?;x!S7zM1f7{P{lF)x1rid4@Dm>i-syl3`vo74pYa2qP%r*`r5*2*_8rz^oG>T=Mc;U4*4Q7{l+^-te<4Xw^~%pfO%4huhpWs^OO3;yvaot#cpojP@jbj3mLzG2k6SST-^o2C+z2_R0k2Xlym*m|NX+7*05*|BOE(gAkt7CfkE+M((zT4jP^S?_?otiv|Lzw7k9uYL@bp-v^wcbk?Xp4C^H4l5HeV{ikk$9qUE~6-)3ka;PkbF3^$JQkbQc0p+Y$cs-eTZ&;zTFjm3X(XkZKJ@uSBok)@%DYDBK{!L5V6H4$$>O#g~m43Z#na%}c%aYYC439ETVP#f9vy6>>PpUcxNm$A5Bk1@FQB+zK!T?1jVGM7slC%H&y5MPe45t3mPKE?|6`q3EVR$ISAz>{&+zns5yfvx0!BlIttF`~;q9d#AQSx8zc!=nZj_hy=hPT^CX}BQJ#-i1YTt{0@CM0P?DiS)75d+W=Rbo=dL8{50Idq1jSuJ(t;9Yx>(5P*QE!S$|ThTf8Nm@_0hwq`!mWzfGw3~uWrpIePPY%aqEdgz}lRcmoL&K$5<1rBsvEs8H9}4qDzZd~>`noAELDNlK0oFze%v%z#s`f&VSgcLh$5k!0vhO@BAlI=spvpUp=EDU4bt;J?MLjom;*^$Q@%?rLen@J;wJ9@boq`UMPr%Ns|NnbZxNcyd$Gwvp#<&LeY{Rmi3P$2`X|9NcG!j1gyoiAlxj*d!6utV6Wv6wA=|uoNO__sFuOnk^Z-a|0Wo#A3yG&+p~y<_=IC_rldx$cqCNtBgoD677TcZ2--;J-_*&bOT%~#nFs^67~o(>6(U9@Ayc!S_RRVQdV5h*2~{R}0gpmvKd0LmWtRUVx8qG(~!Oj_U2^H=+SW!dWtXg^Xy0bCIDJoj4bHn=HG|*NS+6I#U_v`NZ#b$jIxsB6a5?|o|KJ)Wz>L-9C{d3Gufa8&F|9C|SR_>spSPsSwWqBZgBZJE8_XS?Fqym3ox>k0FR%VdwXo%05xCO9WitiNo1!%MjjsGloK-F}c~6=*-@^fd?nhlm~5^n?n&O2j%b^s#JZB*Y_$VH_M2c`jOpA-VB@4>9@)ieX#(^XdUD^P&8*z*3C#J-Z-TwRvx1dD!=K&`i#Y&X;P`4F-gbj;Y%CrTM4<9NGL2ev@Rz7j9M6$a6n&87{WnPsRhP<{zT{BQt80<^mFt3kAVl~D<$rdkOxZXIV)WRi(?~xM<1eka_=&OSOJBZBHi(pf^z0c`5@^&-+9V7vYRDxQ))IT`x;mKwMJhl*S0@S^F3qox57srLF3n>!uB~Ikz)0K(%ja(Y8^y213M4i#hGPWU)(gxxiDqPWz>{Y{%(l667o#>;Oykz}nMV9hvt8OB&%XS;L+H2)nog^_MuCfj&TtU5T!EKKR!H#wNMbwD(6VYZ3WQw1rujwajz``F^Y8m9DOvN2ITX|N(KHm0N&9zs)L`lR<+XW>MQR5+n?Lj;5Qf^LKte130Up(1=ykzKa|t-Qe9Aia?WA&w{26!M2d4lY@D!M<$*fP5Cu`*j|T^ccXXLPa?|;Ju$Aw4-D%y78X{s`Nh`i3&TCU=aDoM2aGlM2MWh+O1Pf&hG}u7@twVUL1*e&t))g|`QzHB0?AdEDz?BW#}d{uLkNWGBd$AaBN10~#Bo{uLgV;ss`t$(K6sF&TiB77z^Nz4Xn1E30|-QZJ#VvQJ0B-sEApbaLGuX8^U=*%wjRfT7#-V9kH&;KJRDpDNEHVN^KL|n){n3sf&PIO4B^L7WrRkwx@-5yF5#lm!3$a&#N0Hc>(@r|j=q-FulC@>kWGgS=Fh}Ze3whQ)mt=95vi_IeIEB2?Q@r|4&TFBM6MJoGKr%>61vP??!=fG)E4UQ)J;Qcrb9T|TJ&9o@_8q|4Qp*fj3B1XNFhw>`m}<%fP>Ei583hSzxF6bqOHUvN}B*ulB{AWaoSjGSdzUDWMt!3DT*N2J{Bt_fkFQRy0O=wie|w3w{7mo9bK)A1AhREM~=~(!4-;&TPzDJic9pMfATjuA}0pdz8=nc-g{@hx5jvQ7$_C+=uf1!s~orCPz&N{nNef)yQKAHuJemh5%9k=WduZC@J8zJe5?hI6DC*u6q8YPfsK7iF#cG)(j@C6PmzNy)ETXW<&bg}wC0PA*EPU-3VT8v)r?i{#JRFvtvJ?st#AY$-w#QI26eow7%QF=nOk+@A<*|6`tLY99ob?PX-X{TJT7sKJk5*@P)P{EDil8@7*TL~kj-~OPm3(!Tw!9yQy?aJ!!ID=J#Or}K5$yZsm2b|HJB8FN67X~gT7|;kNla=9Ix7jXX!0)j9X&u%*4fFk>l2faMzT-M%NjNckaE}!6Cq4N>To1x4pFK#yM3=fyRP#8eEQ4qU0IL@k8^j;qTdZXmAG)zQ;@+V`tP#J7u;=ZCW|SZrq~OT2QX{l5$P#H07r^<%4OzM+4u98Kn++!=RmblH-O5$kM>k9d#+&z63DCZT%+$BV!TdHCI_#`*$+za0|njJZIsUQASy`cTbh^sNQ4~i__>tGrlok%e@-!r+_ee(y*K0G=-*Z1wBv3dbJ6P;?SfvgL$U1ma=_O%!88(zoM`Y9T83eEyu_}tV*kqJ>V@`LzP<1Ya#pf+M#=OXHqlNm?`Q+I!A_=5~BvYty+&k$nFp=f(g`1cmAXDpCPwgvn@TAbMblen$9+Wg?jXXS}fWEMDP2V5tWL@j`S`;ADGee!a&PU2So0c?$AVk8n*ST?Mf5kU!%T}=wM3n@)EM5x(#25iPuR=Nk&(De&sLr2pD6MFnIqAm#dadPHhaDVCRg6j1~1zi>hMw0^D9?B)>6V3w=;%w0m}o*Y@m94SEgo?->j;a75Sad&8TS6tUtzq&NZhDs+6xvF2pb-P0#V6{tCc6-QL3pop^vjTVz;+OYa=kk8=M(tAB1ht+N1^4^Xr<&MwnTSQSGvh#68MVd6%HCWr_sUMyO?uVlrsYwHJgMMO+l76{5S65Km##Pl0Vw+APLiv~C68Imr0Ktu+~Au=&gBM)z*%&$Q}+ggjr+JnsHZ~`iiZ`qz9)O57E|KQ`d2y*Bd)ORd3O(xBnsZh6RiNdc1Q;#5$BYEYjS@w#Ik^-_h_3<+~8?8$rYtC?G2-53H7fc4~1Q=~mZLQJoM3#2;c`+TL5GdcGbdvGeP#Z3?vS0O!{!QUSC3)OdD?En=65j*@{`Pj6Fy08UD5`H?`DE`d#CQ14%>^M~pECs!gs0&v_S+>OxP?y#)5aa~;trq`tSvuX%Gw}m6@=RN2v;W)yPGaS-A0=IKzOd*B43E|o3J?So3fqC?*nu-5J{rKu4%i*U-AFZdl6pQaE3yV>z{N@Jbx;>9$1+3Q0sl$XY;vpB*{fS?O%J9P(jXWyz>)o=N5=H_KZ#wUYXIuGf=Edj!qr-A-;Y4C64FZ8bgffwaxCYR$H6HuVok$MWU%+HJ5*|%483~Qa4fMaP4(oI;Km-ZOYl4MsO?jZ46#8OiwIpuhJ?4r2I|=L822$b=<%vg(iBfsP+zz}Oo!s9)S*TM}sRb;#-onRCN+6xrC&!yhcg12T>S@*8V0TWddF&Q>S}(gKLNxDL_m=MhrV%w;~8b8XBa59n|nAMG<4)$0~79)*iG7GL5mQ_VE3nVwqyr+e^{X{Q__*$&F8(YoL2-@7Zi~hq5Sm0R*p*(vH;ZzT~k1-^1^Q9OVT@2Ibhb;!z__>4?Q?hg+ab;)V~kFTBW}`WID)#N<87=3r1=o$CSrEtYhTdTG_BJ>&#X#q~~a(7il29hNLe4O7Q-OlYplK2Ci{`djA80EQfSHPYcxXSv{{$>FO*nhxH$$eCZIr>|L@PmlE^cDT*z$nMxoKQ@15mLu+qlinY>!gDQ-J_2*1(>)4nJqexdjlA3?k|6CEpp-q3>6vS4N6n_X=`|9!x~_?Cn{_x`47I_tFrvj1-;-Tt}K>d!ATFQvcyg7O#5T)S8k!f1nx)2!IlHGAlJCAq2nDQRcgR0RIW7MLX7c98Eyj1Fq;ZW9rKGvn&UV)ptN{nz`j^9GArf*TD6yVOjKJQ*}?2rD`EiKBs!4$&U~6dV!ayrJn+n3yW}RQE7ak?Is&J^zv?ox^S^!dlQZkW_`3wxR1^CW5FyA>Udr`O92Oan#!QFNmbGnCEzQ+f&;@AA-ufqh*?|+>aIQi+ZeOQ*h()dc0J<_tbYc)30?0(U|0cxIgo-5hsUhxyb6j&1mPFuPC~Lt>Ae%>y^s0$40Z&Xh`DI7(EBBgC5FbRp@hodxaf2)ez%q9n{#;j7S;J8tOyiLp8o5dSLS5VB*xSWp}STjz!fh~O$X}vn8pgl9T9rXb4P8KCTbfY!=w|Y{9YYiW<=|HcoerxN-&H(VH1TR&m7NAfM>*Y?@Z3kcZ~pUbgnzqTIEBNO%PK<2vb*w$tiD>yjKEaZ7}V0iBI|6=Trh@yG3izyi)UlQ_bSQaBYPd3%B<_C)BDup3HmdWrL;)o(n%*>PY^E!!bH@lwK5Ti%Ayc#4l|3IpUQQVPp%8^c|d)+9)pdnF>f7t__>QWozK^_*7XK84O~Yq)--<*3HVs(^q1EaPR(SVhufwXoHmp;k`EssC+4(qZ5?lH9?|f_iy$OVUrQ((Hh*Pv$b$K(5Gx)Mz2%Z=a1_9dEO;67&HL0SY@HRQ+ZUC!VQsP?uPdq=N8tVg~QALznbMA(D|`i|$NMA2@&D!WmN~4`;hB;;Ux%DXBrtly7f7Q$LW-+I)Kf(Q3SOosu@R^MZUy@V43ITc|-+Z1p2%Vo0(y|*oTvr2%$?cHxX>Z8Q#eo2WUC7L@E8PVJ$;~2d*7jcd{my}k-g~YdIe=Z2k5hwVsDXSt0V&vrbfI8Ga7d@nrO`KYBdwHV1spG-B&+g&o2Hian#8)k*|dy}NC5?{9b@1`8p{_)5y|r}Wh`+p&pSFV=X98nSq#pVpO_0B!i;JlZ3k08>US7-s0NL!MLG33E?@8vHd7Ce*$7PloRY6B1Xh}sQ!0`#IcVV3@Hpag=f1}kmf+wFxlS)8NG{=;DsApw9!YTM>C)T5}ow-Og46u&-iDvhob-1YDaMXVky@|c4#`-<5eomK~A8VT#y0ZB1vKo?>jP3Xpc!|lq3)cCZz0~D%Vl0S3D72fLIiM0wgx>>o)_Kc@_8W@d#i%B@_pjyD(lh7<;bB@`KXEqDjglzw!J>re6L}3lXqYI1!Y5uqdXA7RH4q&g2r-L&zI$2Z7g8+SLx8m21lE9`47cXDUlSU|~h&6E2S-<~`rG?ViRtrWgI8x{dtXl}$>m$1|g4Z8ybGgVOn`EKd7!ggLNK|s)hxY^eL4rmLj;Fp%#5GA_2N;C)>k8`Z9d|n?jCS_ycvt{9K64-0kl@`NVr&e!V8iUZ!(`8dvr?O`*m_dIgWl+*ZIiGEjC%Oc@Yr5Ky~g#|B-3=LE7^PmIO)&Tv1hU*IH#+x_^)yl2^MC`WOx7hBwqoTWUl=owt4U6M7M5zO@WG1A3sX&7eC+Y88oyq{cs+6Bnyz^l(_@OF~$#t&H_Nj@j2j`5427;zIo9X=+?pMDZI;Cp}=n*{V+VbrLM+B~FGg?F$LhQ<~Zd=$b|21t?EsR2?j@NnQualPtz7vgcuu=V|v15PNez9JXA(Rq%Ymo%Ho8xfCivw^S*IpD-e39m^(jur#_WOjh^8T7;s??#H}(#h?j$c&$*7g_K~j(iJNc(DMWy&VTiv(h6Yu6>Wyje{$R4c{o&(*hHak+=lSxyPgZKl}w79J5Zj*Kuobq}=h`g+3~OG8{lmxd{~ph>M##0`-hvu16#5bK%j9OUTn^bn*b}DBC-viX5=A^T<~0(G{y&;Tz%rW7+X3jD+WnbFScL$0V%=G+plfmDmzkXPjLP@Pd57z9JRufj;b~3H{F3rN7N@ES{>zT28$d83UeJ-KphnRS~%loY+2f!#j9e#+o~0_kGBaIHq@KZ#$7!q>eEeMA;$kfH}%)@l3_*yJ#Ae&92$iq%?6vbUypoF<_D!>nYobZCR_H%*ip-uU+y7j8S}>C-^WQc(7{Xe&_?%c2deV0C0=P|x0wlp|wagL%^gBo>fbc*VVwDWnWlFK9pnltzZG~rphHJmcfg)gG%8}7=`^0c(XA|4Ftl5PPQ%GG-AvPk_l##CO#Ut>Um=_>xdU7N>aQRTcR|)|7D~N)CNKuFo*vnAGZ$N9h#7Ivoo32k3gLMq~qDz)JwLBL8)2olKPy-Rz`uBza8r*fkvy2f=flbWSkyN-pI2ChTVO4py)r#^Y+6_SMV&}qD&XS8f=ZsL&cHTFRz~)dd82@1L0yEbL2PqOt(&a=SOUH^8cG*nt8PH~U3WKML`(u%uruSlp0IZ6Hr8_u3JDXW4GVPDZdX#Ss!_X1gf$0aX~`r?N99q;=}nAOC)d_q!e5yx9Eh_#MdPfS`xC{*U^}jl9hoO@9;zh}%xpoeTqyaliTIjXSp?lB}c-DSkI;By7r0toO^$72KxM*|6#!w@@H*ccZNLbVnOb3mKNEKkcvXjfjJnd@BX0D!7Hk3xX$0!^t+@~)ly>yKMRUrLfs#6a4$$))mgTsx;%F{3Jlz{gGv=nZzDG}FVxH}V1L(ahC`JHu&6p@q04)FQpJ2sr6e9DC#1?^bu_SmU_v;So77&mZJg-JFWSA@jj|gyK=k<=^W(o@=J#>u6U|L=CESSh5#m~`&Ss26b*aLEHMb9h#Zt;dS(Y;SfQP!kXcGim~Texl-U&m!GMPvldk>cuMne0imMMbF(2JJeR!ftPV-q-?3;S9BaY&5Se+A1cR5E==j*)h)Q2Xpt^32Dc$lp}zesAtrPj62|3IvK&0iWONktf&JqE(G(M1A1-bUnJKxz~bbOEN)?CqLWGX(zQVrG#K-Hp~!u;1k<@{Vqo_2C>wc>2wVhm?FDKYExo6X5*O={nw><{0LiiNG-Bl&r}q7m@I9|PaL$D|iAdY1@1q1zy5*Wwx$DDy07g`cgL>DAH#IIO4NA)@U2*3%8F6*I;Svs3Rf)Wv3O$)j;|be;7bqAC=~zM!NLaZ&mSJCBUB$J_?&mbMc@VNut>D2e5Fs<2^~&2~Wf3_Yh6pWi_V)$k^v8o`JowGvE9z%TO5Y6ld?Rl%h^(-^T5W3?5DZ>gwXmDx1{iRTj#^}>d$>E$=Pi7w^!#?bPA6Rc)LUn=V{JLtk1&_0qnjUFd5+=(A7`T#y!%(1K0uDySa&V>nNQYE;3p3nJIa#0fL!5GMiguj=ayjiXQ8RIPC;-`hqLy|KT@<}C{H{M5$8AsSKaA~8m3Qb@V{d6J4L6rq#AmH#mDsP2nB!anLI#%`<90d#-Xi(rT+;2na_j=Yp?Q&0<-8sEkB0Gt^ROF>P?T7wL^*?0jKQnAjCqQ0C`ub8KRkV)$*W~)gia$R8~u#7(PI9f)KmgmD4CpP|mMsanDj3Q7^DVkzdk*7b7j|9NJP4v|${ax^-#UlKATm5-xz)$myC;qU6;Y(pn5Fd(9>~B&u@1%?jcsu^3gQ)nL1e8rpI#n`CSme8J;OLNTCfW@k>pXzME7I~5Ri9J`#RPJsiXb?f*jnzEzfvKj_9bgdoB9i7I};Z3FEfZ+sSW@*;9oW@|!3w%@UOwO8DP~*~MKiOgC~tYa#|MdCiX*Fsuz6B~PFCB3esB&Sg)vWoj02{93(%XEr852kdtC^Vls;v*j^a`_$f{dr1vU|FYFw8`yavJVD16QdFhU+-0ueG~*`W)65E>8q>3{H`41^isrtOoNVd8fB{|wIX%DP{jbZpTw8f9E4#C%PVy7J({31OmwdJGJvH##YjvIx=}?#kZEG0{9cHZ%>Ai}dY%MRzbt_TeZ*A#>IzFNJ1er0F~L6^3gr76}%hY)jt#SsK5H$hqDmG#^qZEpBv2F7&aA@@ytmfi^PG(oLloP+Vumu0Io>t!CuI`xLi@j%=LuNr4yK?+pdC%}=Q+QWI4Tt5CTq<#lSe6k2UQqOA6wTueWZA`9sdr>$f|}0(z-$kvp7ygF8=q8TnCz%Y4Qp>jvs$}tk8XY>zKPYk=9Mrj(jP9)avi$AO8^rItRpgI^Uce(#Z$HrCqaAV4$o6dvbAwv+!O7hhcp1L|$|6HROH*5b-c|Nfi6y*^~c%7d^s;`sHDf28|SeyU$Ho;F(;oN;7%H=+?Q+$gnHo6L=;w_W9c?rpXd{_6H`Lp0C*odIahO4I6ot5-0TeG9lg6meR0P|gjDVFkn0N1`0@VCbh1;f?sU=jeV$JgG8*UNkgEWj(=Ou(H|2CU6yjlB#40n>3Hc9u-un9TcldU$B{hdvnN&-$O{*G+hed*(q1pz^LqDdc=KeKqB=I<}0L>i3yn%DIPo~k=Q%aH=~FbD#WrCo-xjsr<3z>C_$8EYse))qzz!V^UEFSH_>NaL6RE;+U$ph+sOcGJoz=l|H3KduiL-AgAPv;(}~xz_bJ(&6rsU8T=4~nf&=PUP@wSt^KhMrSW7zsQJ3^5JgvLWC#$fWde*SRzsmr>Sj@q9zmF4=2X6xzKqS0`6L&Cuoj97JD@Bfei%jwR(}R_nGU|GfLnFfuNCCv5gw!)EsqT6BO;juK`f+lNXbM^#rIU-&acO9o0>FC{s5f`hqM#3VmOV1z##Ihk^_{Kbo^hasmGq^7z#?fg*QY(SaU^aeEjXsB1lO|{=CbuPFr4sWE!$Olc?YuJtuv%cYCEij83}46pw%BxL>;uPhaQo9g{jGe_CcpNMK9Mmul%Lqej~dd5s0DU{1ijDLgo`Ptit-f^4!hi;Hv2t+saXuJ8rnPb`?BK~K-P$ym;awC1>f>gVGSIY_y8@6Iy#zuR_W_mP+XUQqbdgNn@l%fq;3sR>(ur-%V*)DYSza^hq(OON5C#BPnz!r360-iL~OFKk5Ztp1zcU&7(Md)c~4KMEjPVbr(j6YB5hJO?QTD9meeYUO$!P#~dSkS~ORu5fi1Ek*pU;q"),format=L.FORMAT_RAW,filters=[{"id":L.FILTER_LZMA2}])) diff --git a/records/track_10min_16mb/2026-04-16_SP8192_4LayerRecurrence/train_gpt_human.py b/records/track_10min_16mb/2026-04-16_SP8192_4LayerRecurrence/train_gpt_human.py new file mode 100644 index 0000000000..607db377df --- /dev/null +++ b/records/track_10min_16mb/2026-04-16_SP8192_4LayerRecurrence/train_gpt_human.py @@ -0,0 +1,470 @@ +import collections,copy,glob,io,lzma,math,os +from pathlib import Path +import random,re,subprocess,sys,time,uuid,numpy as np,sentencepiece as spm,torch,torch.distributed as dist,torch.nn.functional as F +from torch.nn.parallel import DistributedDataParallel as DDP +from torch import Tensor,nn +from flash_attn_interface import flash_attn_func as flash_attn_3_func +class Hyperparameters:data_dir=os.environ.get('DATA_DIR','./data/');seed=int(os.environ.get('SEED',1337));run_id=os.environ.get('RUN_ID',str(uuid.uuid4()));iterations=int(os.environ.get('ITERATIONS',20000));warmdown_frac=float(os.environ.get('WARMDOWN_FRAC',.72));warmup_steps=int(os.environ.get('WARMUP_STEPS',20));train_batch_tokens=int(os.environ.get('TRAIN_BATCH_TOKENS',786432));train_seq_len=int(os.environ.get('TRAIN_SEQ_LEN',2048));train_log_every=int(os.environ.get('TRAIN_LOG_EVERY',500));max_wallclock_seconds=float(os.environ.get('MAX_WALLCLOCK_SECONDS',6e2));val_batch_tokens=int(os.environ.get('VAL_BATCH_TOKENS',524288));eval_seq_len=int(os.environ.get('EVAL_SEQ_LEN',2048));val_loss_every=int(os.environ.get('VAL_LOSS_EVERY',4000));sliding_window_enabled=bool(int(os.environ.get('SLIDING_WINDOW_ENABLED','1')));vocab_size=int(os.environ.get('VOCAB_SIZE',8192));num_layers=int(os.environ.get('NUM_LAYERS',11));xsa_last_n=int(os.environ.get('XSA_LAST_N',11));model_dim=int(os.environ.get('MODEL_DIM',512));embedding_dim=int(os.environ.get('EMBEDDING_DIM',512));num_kv_heads=int(os.environ.get('NUM_KV_HEADS',4));num_heads=int(os.environ.get('NUM_HEADS',8));mlp_mult=float(os.environ.get('MLP_MULT',4.));skip_gates_enabled=bool(int(os.environ.get('SKIP_GATES_ENABLED','1')));tie_embeddings=bool(int(os.environ.get('TIE_EMBEDDINGS','1')));logit_softcap=float(os.environ.get('LOGIT_SOFTCAP',3e1));rope_base=float(os.environ.get('ROPE_BASE',1e4));rope_dims=int(os.environ.get('ROPE_DIMS',16));rope_train_seq_len=int(os.environ.get('ROPE_TRAIN_SEQ_LEN',2048));ln_scale=bool(int(os.environ.get('LN_SCALE','1')));qk_gain_init=float(os.environ.get('QK_GAIN_INIT',5.25));num_loops=int(os.environ.get('NUM_LOOPS',2));loop_start=int(os.environ.get('LOOP_START',3));loop_end=int(os.environ.get('LOOP_END',6));enable_looping_at=float(os.environ.get('ENABLE_LOOPING_AT',.35));parallel_residual_start=int(os.environ.get('PARALLEL_RESIDUAL_START',7));min_lr=float(os.environ.get('MIN_LR',.0));embed_lr=float(os.environ.get('EMBED_LR',.6));head_lr=float(os.environ.get('HEAD_LR',.008));tied_embed_lr=float(os.environ.get('TIED_EMBED_LR',.03));tied_embed_init_std=float(os.environ.get('TIED_EMBED_INIT_STD',.005));matrix_lr=float(os.environ.get('MATRIX_LR',.022));scalar_lr=float(os.environ.get('SCALAR_LR',.02));muon_momentum=float(os.environ.get('MUON_MOMENTUM',.99));muon_backend_steps=int(os.environ.get('MUON_BACKEND_STEPS',5));muon_momentum_warmup_start=float(os.environ.get('MUON_MOMENTUM_WARMUP_START',.92));muon_momentum_warmup_steps=int(os.environ.get('MUON_MOMENTUM_WARMUP_STEPS',1500));muon_row_normalize=bool(int(os.environ.get('MUON_ROW_NORMALIZE','1')));beta1=float(os.environ.get('BETA1',.9));beta2=float(os.environ.get('BETA2',.95));adam_eps=float(os.environ.get('ADAM_EPS',1e-08));grad_clip_norm=float(os.environ.get('GRAD_CLIP_NORM',.3));eval_stride=int(os.environ.get('EVAL_STRIDE',64));muon_beta2=float(os.environ.get('MUON_BETA2',.95));adam_wd=float(os.environ.get('ADAM_WD',.02));muon_wd=float(os.environ.get('MUON_WD',.095));embed_wd=float(os.environ.get('EMBED_WD',.085));ema_decay=float(os.environ.get('EMA_DECAY',.9965));ttt_enabled=bool(int(os.environ.get('TTT_ENABLED','0')));ttt_lr=float(os.environ.get('TTT_LR',.005));ttt_epochs=int(os.environ.get('TTT_EPOCHS',3));ttt_momentum=float(os.environ.get('TTT_MOMENTUM',.9));ttt_chunk_tokens=int(os.environ.get('TTT_CHUNK_TOKENS',32768));etlb_enabled=bool(int(os.environ.get('ETLB_ENABLED','0')));etlb_lr=float(os.environ.get('ETLB_LR',.05));etlb_steps=int(os.environ.get('ETLB_STEPS',5));etlb_clip=float(os.environ.get('ETLB_CLIP',3.));compressor=os.environ.get('COMPRESSOR','brotli');gptq_calibration_batches=int(os.environ.get('GPTQ_CALIBRATION_BATCHES',64));gptq_reserve_seconds=float(os.environ.get('GPTQ_RESERVE_SECONDS',12.));matrix_bits=int(os.environ.get('MATRIX_BITS',6));embed_bits=int(os.environ.get('EMBED_BITS',8));matrix_clip_sigmas=float(os.environ.get('MATRIX_CLIP_SIGMAS',12.85));embed_clip_sigmas=float(os.environ.get('EMBED_CLIP_SIGMAS',2e1));distributed='RANK'in os.environ and'WORLD_SIZE'in os.environ;rank=int(os.environ.get('RANK','0'));world_size=int(os.environ.get('WORLD_SIZE','1'));local_rank=int(os.environ.get('LOCAL_RANK','0'));is_main_process=rank==0;grad_accum_steps=8//world_size;datasets_dir=os.path.join(data_dir,'datasets',f"fineweb10B_sp{vocab_size}");train_files=os.path.join(datasets_dir,'fineweb_train_*.bin');val_files=os.path.join(datasets_dir,'fineweb_val_*.bin');tokenizer_path=os.path.join(data_dir,'tokenizers',f"fineweb_{vocab_size}_bpe.model");logfile=f"logs/{run_id}.txt";model_path='final_model.pt';quantized_model_path='final_model.int6.ptz' +_logger_hparams=None +def set_logging_hparams(h):global _logger_hparams;_logger_hparams=h +def log(msg,console=True): + if _logger_hparams is None:print(msg);return + if _logger_hparams.is_main_process: + if console:print(msg) + if _logger_hparams.logfile is not None: + with open(_logger_hparams.logfile,'a',encoding='utf-8')as f:print(msg,file=f) +class ValidationData: + def __init__(self,h,device): + self.sp=spm.SentencePieceProcessor(model_file=h.tokenizer_path) + if int(self.sp.vocab_size())!=h.vocab_size:raise ValueError(f"VOCAB_SIZE={h.vocab_size} does not match tokenizer vocab_size={int(self.sp.vocab_size())}") + self.val_tokens=load_validation_tokens(h.val_files,h.eval_seq_len);self.base_bytes_lut,self.has_leading_space_lut,self.is_boundary_token_lut=build_sentencepiece_luts(self.sp,h.vocab_size,device) +def build_sentencepiece_luts(sp,vocab_size,device): + sp_vocab_size=int(sp.vocab_size());assert sp.piece_to_id('▁')!=sp.unk_id(),"Tokenizer must have '▁' (space) as its own token for correct BPB byte counting";table_size=max(sp_vocab_size,vocab_size);base_bytes_np=np.zeros((table_size,),dtype=np.int16);has_leading_space_np=np.zeros((table_size,),dtype=np.bool_);is_boundary_token_np=np.ones((table_size,),dtype=np.bool_) + for token_id in range(sp_vocab_size): + if sp.is_control(token_id)or sp.is_unknown(token_id)or sp.is_unused(token_id):continue + is_boundary_token_np[token_id]=False + if sp.is_byte(token_id):base_bytes_np[token_id]=1;continue + piece=sp.id_to_piece(token_id) + if piece.startswith('▁'):has_leading_space_np[token_id]=True;piece=piece[1:] + base_bytes_np[token_id]=len(piece.encode('utf-8')) + return torch.tensor(base_bytes_np,dtype=torch.int16,device=device),torch.tensor(has_leading_space_np,dtype=torch.bool,device=device),torch.tensor(is_boundary_token_np,dtype=torch.bool,device=device) +def load_validation_tokens(pattern,seq_len): + files=[Path(p)for p in sorted(glob.glob(pattern))] + if not files:raise FileNotFoundError(f"No files found for pattern: {pattern}") + tokens=torch.cat([load_data_shard(file)for file in files]).contiguous();usable=(tokens.numel()-1)//seq_len*seq_len + if usable<=0:raise ValueError(f"Validation split is too short for TRAIN_SEQ_LEN={seq_len}") + return tokens[:usable+1] +def load_data_shard(file): + header_bytes=256*np.dtype('0 else 0;num_sequences=(self.num_tokens[si]-1-phase)//self.seq_len;sequence_order=self.rng.permutation(num_sequences);self.start_inds[si]=(phase+sequence_order*self.seq_len).tolist() + def next_batch(self,global_tokens,grad_accum_steps): + device_tokens=global_tokens//(self.world_size*grad_accum_steps);device_batch_size=device_tokens//self.seq_len;remaining=np.array([len(s)for s in self.start_inds],dtype=np.float64);x=torch.empty((device_batch_size,self.seq_len),dtype=torch.int64);y=torch.empty((device_batch_size,self.seq_len),dtype=torch.int64) + for bi in range(device_batch_size): + total=remaining.sum() + if total<=0: + for si in range(len(self.files)):self._reset_shard(si) + remaining=np.array([len(s)for s in self.start_inds],dtype=np.float64);total=remaining.sum() + probs=remaining/total;si=int(self.rng.choice(len(self.files),p=probs));start_ind=self.start_inds[si].pop();remaining[si]-=1;mm=_get_shard_memmap(self.files[si]);window=torch.as_tensor(np.array(mm[start_ind:start_ind+self.seq_len+1],dtype=np.int64));x[bi]=window[:-1];y[bi]=window[1:] + return x.to(self.device,non_blocking=True),y.to(self.device,non_blocking=True) +class RMSNorm(nn.Module): + def __init__(self,eps=None):super().__init__();self.eps=eps + def forward(self,x):return F.rms_norm(x,(x.size(-1),),eps=self.eps) +class CastedLinear(nn.Linear): + def forward(self,x):w=self.weight.to(x.dtype);bias=self.bias.to(x.dtype)if self.bias is not None else None;return F.linear(x,w,bias) +class Rotary(nn.Module): + def __init__(self,dim,base=1e4,train_seq_len=1024,rope_dims=0):super().__init__();self.dim=dim;self.base=base;self.train_seq_len=train_seq_len;self.rope_dims=rope_dims if rope_dims>0 else dim;inv_freq=1./base**(torch.arange(0,self.rope_dims,2,dtype=torch.float32)/self.rope_dims);self.register_buffer('inv_freq',inv_freq,persistent=False);self._seq_len_cached=0;self._cos_cached=None;self._sin_cached=None + def forward(self,seq_len,device,dtype): + if self._cos_cached is None or self._sin_cached is None or self._seq_len_cached!=seq_len or self._cos_cached.device!=device: + rd=self.rope_dims + if seq_len>self.train_seq_len:scale=seq_len/self.train_seq_len;new_base=self.base*scale**(rd/(rd-2));inv_freq=1./new_base**(torch.arange(0,rd,2,dtype=torch.float32,device=device)/rd) + else:inv_freq=self.inv_freq.to(device) + t=torch.arange(seq_len,device=device,dtype=inv_freq.dtype);freqs=torch.outer(t,inv_freq);self._cos_cached=freqs.cos()[None,:,None,:];self._sin_cached=freqs.sin()[None,:,None,:];self._seq_len_cached=seq_len + return self._cos_cached.to(dtype=dtype),self._sin_cached.to(dtype=dtype) +def apply_rotary_emb(x,cos,sin,rope_dims=0): + if rope_dims>0 and rope_dims0: + head_dim=h.model_dim//h.num_heads + for block in self.blocks:block.attn.rope_dims=h.rope_dims;block.attn.rotary=Rotary(head_dim,base=h.rope_base,train_seq_len=h.train_seq_len,rope_dims=h.rope_dims) + self.final_norm=RMSNorm();self.lm_head=None if h.tie_embeddings else CastedLinear(h.embedding_dim,h.vocab_size,bias=False) + if self.lm_head is not None:self.lm_head._zero_init=True + if h.xsa_last_n>0: + for i in range(max(0,h.num_layers-h.xsa_last_n),h.num_layers):self.blocks[i].attn.use_xsa=True + if h.parallel_residual_start>=0: + for i in range(h.parallel_residual_start,h.num_layers):self.blocks[i].parallel=True + self.looping_active=False + if h.num_loops>0: + loop_seg=list(range(h.loop_start,h.loop_end+1));all_indices=list(range(h.loop_start)) + for _ in range(h.num_loops+1):all_indices.extend(loop_seg) + all_indices.extend(range(h.loop_end+1,h.num_layers));num_enc=len(all_indices)//2;self.encoder_indices=all_indices[:num_enc];self.decoder_indices=all_indices[num_enc:] + else:self.encoder_indices=list(range(self.num_encoder_layers));self.decoder_indices=list(range(self.num_encoder_layers,h.num_layers)) + self.num_skip_weights=min(len(self.encoder_indices),len(self.decoder_indices));self.skip_weights=nn.Parameter(torch.ones(self.num_skip_weights,h.model_dim,dtype=torch.float32));self.skip_gates=nn.Parameter(torch.zeros(self.num_skip_weights,h.model_dim,dtype=torch.float32))if h.skip_gates_enabled else None;self._init_weights() + def _init_weights(self): + if self.tie_embeddings:nn.init.normal_(self.tok_emb.weight,mean=.0,std=self.tied_embed_init_std) + for(name,module)in self.named_modules(): + if isinstance(module,nn.Linear): + if getattr(module,'_zero_init',False):nn.init.zeros_(module.weight) + elif module.weight.ndim==2 and module.weight.shape[0]>=64 and module.weight.shape[1]>=64:nn.init.orthogonal_(module.weight,gain=1.) + def forward_logits(self,input_ids): + x=self.tok_emb(input_ids);x=F.rms_norm(x,(x.size(-1),)) + if self.embed_proj is not None:x=self.embed_proj(x) + x0=x;skips=[];enc_iter=self.encoder_indices if self.looping_active else range(self.num_encoder_layers);dec_iter=self.decoder_indices if self.looping_active else range(self.num_encoder_layers,self.num_encoder_layers+self.num_decoder_layers) + for i in enc_iter:x=self.blocks[i](x,x0);skips.append(x) + for(skip_idx,i)in enumerate(dec_iter): + if skip_idxG.size(1) + if transposed:X=X.T + for _ in range(steps):A=X@X.T;B=b*A+c*A@A;X=a*X+B@X + return X.T if transposed else X +class Muon(torch.optim.Optimizer): + def __init__(self,params,lr,momentum,backend_steps,nesterov=True,weight_decay=.0,row_normalize=False):super().__init__(params,dict(lr=lr,momentum=momentum,backend_steps=backend_steps,nesterov=nesterov,weight_decay=weight_decay,row_normalize=row_normalize)) + @torch.no_grad() + def step(self,closure=None): + loss=None + if closure is not None: + with torch.enable_grad():loss=closure() + distributed=dist.is_available()and dist.is_initialized();world_size=dist.get_world_size()if distributed else 1;rank=dist.get_rank()if distributed else 0 + for group in self.param_groups: + params=group['params'] + if not params:continue + lr=group['lr'];momentum=group['momentum'];backend_steps=group['backend_steps'];nesterov=group['nesterov'];total_params=sum(int(p.numel())for p in params);updates_flat=torch.zeros(total_params,device=params[0].device,dtype=torch.bfloat16);curr=0 + for(i,p)in enumerate(params): + if i%world_size==rank and p.grad is not None: + g=p.grad;state=self.state[p] + if'momentum_buffer'not in state:state['momentum_buffer']=torch.zeros_like(g) + buf=state['momentum_buffer'];buf.mul_(momentum).add_(g) + if nesterov:g=g.add(buf,alpha=momentum) + if group.get('row_normalize',False):row_norms=g.float().norm(dim=-1,keepdim=True).clamp_min(1e-07);g=g/row_norms.to(g.dtype) + g=zeropower_via_newtonschulz5(g,steps=backend_steps);g*=max(1,g.size(0)/g.size(1))**.5;updates_flat[curr:curr+p.numel()]=g.reshape(-1) + curr+=p.numel() + if distributed:dist.all_reduce(updates_flat,op=dist.ReduceOp.SUM) + wd=group.get('weight_decay',.0);curr=0 + for p in params: + if wd>.0:p.data.mul_(1.-lr*wd) + g=updates_flat[curr:curr+p.numel()].view_as(p).to(dtype=p.dtype);p.add_(g,alpha=-lr);curr+=p.numel() + return loss +CONTROL_TENSOR_NAME_PATTERNS=tuple(pattern for pattern in os.environ.get('CONTROL_TENSOR_NAME_PATTERNS','attn_scale,attn_scales,mlp_scale,mlp_scales,resid_mix,resid_mixes,q_gain,skip_weight,skip_weights,skip_gates').split(',')if pattern) +class Optimizers: + def __init__(self,h,base_model): + block_named_params=list(base_model.blocks.named_parameters());matrix_params=[p for(name,p)in block_named_params if p.ndim==2 and not any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS)];scalar_params=[p for(name,p)in block_named_params if p.ndim<2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS)] + if base_model.skip_weights.numel()>0:scalar_params.append(base_model.skip_weights) + if base_model.skip_gates is not None and base_model.skip_gates.numel()>0:scalar_params.append(base_model.skip_gates) + token_lr=h.tied_embed_lr if h.tie_embeddings else h.embed_lr;tok_params=[{'params':[base_model.tok_emb.weight],'lr':token_lr,'base_lr':token_lr}];self.optimizer_tok=torch.optim.AdamW(tok_params,betas=(h.beta1,h.beta2),eps=h.adam_eps,weight_decay=h.embed_wd,fused=True);self.optimizer_muon=Muon(matrix_params,lr=h.matrix_lr,momentum=h.muon_momentum,backend_steps=h.muon_backend_steps,weight_decay=h.muon_wd,row_normalize=h.muon_row_normalize) + for group in self.optimizer_muon.param_groups:group['base_lr']=h.matrix_lr + self.optimizer_scalar=torch.optim.AdamW([{'params':scalar_params,'lr':h.scalar_lr,'base_lr':h.scalar_lr}],betas=(h.beta1,h.beta2),eps=h.adam_eps,weight_decay=h.adam_wd,fused=True);self.optimizers=[self.optimizer_tok,self.optimizer_muon,self.optimizer_scalar] + if base_model.lm_head is not None:self.optimizer_head=torch.optim.Adam([{'params':[base_model.lm_head.weight],'lr':h.head_lr,'base_lr':h.head_lr}],betas=(h.beta1,h.beta2),eps=h.adam_eps,fused=True);self.optimizers.insert(1,self.optimizer_head) + else:self.optimizer_head=None + def __iter__(self):return iter(self.optimizers) + def zero_grad_all(self): + for opt in self.optimizers:opt.zero_grad(set_to_none=True) + def step(self): + for opt in self.optimizers:opt.step() + self.zero_grad_all() +def restore_fp32_params(model): + for module in model.modules(): + if isinstance(module,CastedLinear):module.float() + for(name,param)in model.named_parameters(): + if(param.ndim<2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS))and param.dtype!=torch.float32:param.data=param.data.float() +def collect_hessians(model,train_loader,h,device,n_calibration_batches=64): + hessians={};hooks=[] + def make_hook(name): + def hook_fn(module,inp,out): + x=inp[0].detach().float() + if x.ndim==3:x=x.reshape(-1,x.shape[-1]) + if name not in hessians:hessians[name]=torch.zeros(x.shape[1],x.shape[1],dtype=torch.float32,device=device) + hessians[name].addmm_(x.T,x) + return hook_fn + for(name,module)in model.named_modules(): + if isinstance(module,CastedLinear)and module.weight.numel()>65536: + cat=classify_param(name+'.weight') + if cat in('mlp','attn'):hooks.append(module.register_forward_hook(make_hook(name+'.weight'))) + if model.tie_embeddings: + hook_module=model.head_proj if model.head_proj is not None else model.final_norm + def make_output_hook(name): + def hook_fn(module,inp,out): + x=out.detach().float() + if x.ndim==3:x=x.reshape(-1,x.shape[-1]) + if name not in hessians:hessians[name]=torch.zeros(x.shape[1],x.shape[1],dtype=torch.float32,device=device) + hessians[name].addmm_(x.T,x) + return hook_fn + hooks.append(hook_module.register_forward_hook(make_output_hook('tok_emb.weight'))) + model.eval() + with torch.no_grad(): + for _ in range(n_calibration_batches):x,_=train_loader.next_batch(h.train_batch_tokens,h.grad_accum_steps);model.forward_logits(x) + for hook in hooks:hook.remove() + for name in hessians:hessians[name]=hessians[name].cpu()/n_calibration_batches + return hessians +def gptq_quantize_weight(w,H,clip_sigmas=3.,clip_range=63,block_size=128): + W_orig=w.float().clone();rows,cols=W_orig.shape;H=H.float().clone();dead=torch.diag(H)==0;H[dead,dead]=1;damp=.01*H.diag().mean();H.diagonal().add_(damp);perm=torch.argsort(H.diag(),descending=True);invperm=torch.argsort(perm);W_perm=W_orig[:,perm].clone();W_perm[:,dead[perm]]=0;H=H[perm][:,perm];Hinv=torch.cholesky_inverse(torch.linalg.cholesky(H));Hinv=torch.linalg.cholesky(Hinv,upper=True);row_std=W_orig.std(dim=1);s=(clip_sigmas*row_std/clip_range).clamp_min(1e-10).to(torch.float16);sf=s.float();Q=torch.zeros(rows,cols,dtype=torch.int8);W_work=W_perm.clone() + for i1 in range(0,cols,block_size): + i2=min(i1+block_size,cols);W_block=W_work[:,i1:i2].clone();Hinv_block=Hinv[i1:i2,i1:i2];Err=torch.zeros(rows,i2-i1) + for j in range(i2-i1):w_col=W_block[:,j];d=Hinv_block[j,j];q_col=torch.clamp(torch.round(w_col/sf),-clip_range,clip_range);Q[:,i1+j]=q_col.to(torch.int8);err=(w_col-q_col.float()*sf)/d;Err[:,j]=err;W_block[:,j:]-=err.unsqueeze(1)*Hinv_block[j,j:].unsqueeze(0) + if i20:out[name]=(q.float()*s.float().view(q.shape[0],*[1]*(q.ndim-1))).to(orig_dtype) + else:out[name]=(q.float()*float(s.item())).to(orig_dtype) + return out +_BSHF_MAGIC=b'BSHF' +def _byte_shuffle(data,stride=2): + if stride<=1 or len(data)0: + base_model.train();chunk_seqs=(chunk_end-chunk_start)//seq_len + if chunk_seqs>0: + cos_lr=h.ttt_lr*.5*(1.+math.cos(math.pi*ci/max(num_chunks-1,1))) + for pg in optimizer.param_groups:pg['lr']=cos_lr + my_seq_s=chunk_seqs*rank//world_size;my_seq_e=chunk_seqs*(rank+1)//world_size;my_chunk_seqs=my_seq_e-my_seq_s + for _ep in range(h.ttt_epochs): + for bs in range(0,my_chunk_seqs,batch_seqs): + be=min(bs+batch_seqs,my_chunk_seqs);actual_bs=my_seq_s+bs;start_tok=chunk_start+actual_bs*seq_len;end_tok=chunk_start+(my_seq_s+be)*seq_len+1 + if end_tok>val_data.val_tokens.numel():continue + local=val_data.val_tokens[start_tok:end_tok].to(device=device,dtype=torch.int64);x=local[:-1].reshape(-1,seq_len);y=local[1:].reshape(-1,seq_len);optimizer.zero_grad(set_to_none=True) + with torch.autocast(device_type='cuda',dtype=torch.bfloat16):loss=base_model(x,y) + loss.backward() + if world_size>1: + for p in ttt_params: + if p.grad is not None:dist.all_reduce(p.grad,op=dist.ReduceOp.AVG) + torch.nn.utils.clip_grad_norm_(ttt_params,1.);optimizer.step() + if dist.is_available()and dist.is_initialized():dist.all_reduce(loss_sum,op=dist.ReduceOp.SUM);dist.all_reduce(token_count,op=dist.ReduceOp.SUM);dist.all_reduce(byte_count,op=dist.ReduceOp.SUM) + for p in base_model.parameters():p.requires_grad_(True) + base_model.eval();return _loss_bpb(loss_sum,token_count,byte_count) +def timed_eval(label,fn,*args,**kwargs):torch.cuda.synchronize();t0=time.perf_counter();val_loss,val_bpb=fn(*args,**kwargs);torch.cuda.synchronize();elapsed_ms=1e3*(time.perf_counter()-t0);log(f"{label} val_loss:{val_loss:.8f} val_bpb:{val_bpb:.8f} eval_time:{elapsed_ms:.0f}ms");return val_loss,val_bpb +def train_model(h,device,val_data): + base_model=GPT(h).to(device).bfloat16();restore_fp32_params(base_model);compiled_model=torch.compile(base_model,dynamic=False,fullgraph=True) + if h.distributed:model=DDP(compiled_model,device_ids=[h.local_rank],broadcast_buffers=False) + else:model=compiled_model + log(f"model_params:{sum(p.numel()for p in base_model.parameters())}");optimizers=Optimizers(h,base_model);train_loader=ShuffledSequenceLoader(h,device);max_wallclock_ms=1e3*h.max_wallclock_seconds if h.max_wallclock_seconds>0 else None + if max_wallclock_ms is not None:max_wallclock_ms-=h.gptq_reserve_seconds*1e3;log(f"gptq:reserving {h.gptq_reserve_seconds:.0f}s, effective={max_wallclock_ms:.0f}ms") + def training_frac(step,elapsed_ms): + if max_wallclock_ms is None:return step/max(h.iterations,1) + return elapsed_ms/max(max_wallclock_ms,1e-09) + def lr_mul(frac): + if h.warmdown_frac<=0:return 1. + if frac>=1.-h.warmdown_frac:return max((1.-frac)/h.warmdown_frac,h.min_lr) + return 1. + def step_fn(step,lr_scale): + optimizers.zero_grad_all();train_loss=torch.zeros((),device=device) + for micro_step in range(h.grad_accum_steps): + if h.distributed:model.require_backward_grad_sync=micro_step==h.grad_accum_steps-1 + x,y=train_loader.next_batch(h.train_batch_tokens,h.grad_accum_steps) + with torch.autocast(device_type='cuda',dtype=torch.bfloat16,enabled=True):loss=model(x,y) + train_loss+=loss.detach();(loss/h.grad_accum_steps).backward() + train_loss/=h.grad_accum_steps;frac=min(step/h.muon_momentum_warmup_steps,1.)if h.muon_momentum_warmup_steps>0 else 1.;muon_momentum=(1-frac)*h.muon_momentum_warmup_start+frac*h.muon_momentum + for group in optimizers.optimizer_muon.param_groups:group['momentum']=muon_momentum + for opt in optimizers: + for group in opt.param_groups:group['lr']=group['base_lr']*lr_scale + if h.grad_clip_norm>0:torch.nn.utils.clip_grad_norm_(base_model.parameters(),h.grad_clip_norm) + optimizers.step();return train_loss + if h.warmup_steps>0: + initial_model_state={name:tensor.detach().cpu().clone()for(name,tensor)in base_model.state_dict().items()};initial_optimizer_states=[copy.deepcopy(opt.state_dict())for opt in optimizers];model.train() + for warmup_step in range(h.warmup_steps): + step_fn(warmup_step,1.) + if warmup_step<=5 or(warmup_step+1)%10==0 or warmup_step+1==h.warmup_steps:log(f"warmup_step: {warmup_step+1}/{h.warmup_steps}") + if h.num_loops>0: + base_model.looping_active=True;log(f"loop_warmup:enabled encoder:{base_model.encoder_indices} decoder:{base_model.decoder_indices}") + for warmup_step in range(h.warmup_steps): + step_fn(warmup_step,1.) + if warmup_step<=5 or(warmup_step+1)%10==0 or warmup_step+1==h.warmup_steps:log(f"loop_warmup_step: {warmup_step+1}/{h.warmup_steps}") + base_model.looping_active=False + base_model.load_state_dict(initial_model_state,strict=True) + for(opt,state)in zip(optimizers,initial_optimizer_states,strict=True):opt.load_state_dict(state) + optimizers.zero_grad_all() + if h.distributed:model.require_backward_grad_sync=True + train_loader=ShuffledSequenceLoader(h,device) + ema_state={name:t.detach().float().clone()for(name,t)in base_model.state_dict().items()};ema_decay=h.ema_decay;training_time_ms=.0;stop_after_step=None;torch.cuda.synchronize();t0=time.perf_counter();step=0 + while True: + last_step=step==h.iterations or stop_after_step is not None and step>=stop_after_step;should_validate=last_step or h.val_loss_every>0 and step%h.val_loss_every==0 + if should_validate:torch.cuda.synchronize();training_time_ms+=1e3*(time.perf_counter()-t0);val_loss,val_bpb=eval_val(h,device,val_data,model);log(f"{step}/{h.iterations} val_loss: {val_loss:.4f} val_bpb: {val_bpb:.4f}");torch.cuda.synchronize();t0=time.perf_counter() + if last_step: + if stop_after_step is not None and step0 and not base_model.looping_active and frac>=h.enable_looping_at:base_model.looping_active=True;log(f"layer_loop:enabled step:{step} frac:{frac:.3f} encoder:{base_model.encoder_indices} decoder:{base_model.decoder_indices}") + train_loss=step_fn(step,scale) + with torch.no_grad(): + for(name,t)in base_model.state_dict().items():ema_state[name].mul_(ema_decay).add_(t.detach().float(),alpha=1.-ema_decay) + step+=1;approx_training_time_ms=training_time_ms+1e3*(time.perf_counter()-t0);should_log_train=h.train_log_every>0 and(step<=5 or step%h.train_log_every==0 or stop_after_step is not None) + if should_log_train:tok_per_sec=step*h.train_batch_tokens/(approx_training_time_ms/1e3);log(f"{step}/{h.iterations} train_loss: {train_loss.item():.4f} train_time: {approx_training_time_ms/60000:.1f}m tok/s: {tok_per_sec:.0f}") + reached_cap=max_wallclock_ms is not None and approx_training_time_ms>=max_wallclock_ms + if h.distributed and max_wallclock_ms is not None:reached_cap_tensor=torch.tensor(int(reached_cap),device=device);dist.all_reduce(reached_cap_tensor,op=dist.ReduceOp.MAX);reached_cap=bool(reached_cap_tensor.item()) + if stop_after_step is None and reached_cap:stop_after_step=step + log(f"peak memory allocated: {torch.cuda.max_memory_allocated()//1024//1024} MiB reserved: {torch.cuda.max_memory_reserved()//1024//1024} MiB");log('ema:applying EMA weights');current_state=base_model.state_dict();avg_state={name:t.to(dtype=current_state[name].dtype)for(name,t)in ema_state.items()};base_model.load_state_dict(avg_state,strict=True);return base_model,compiled_model +def train_and_eval(h,device): + random.seed(h.seed);np.random.seed(h.seed);torch.manual_seed(h.seed);torch.cuda.manual_seed_all(h.seed);val_data=ValidationData(h,device);log(f"train_shards: {len(list(Path(h.datasets_dir).resolve().glob("fineweb_train_*.bin")))}");log(f"val_tokens: {val_data.val_tokens.numel()-1}");base_model,compiled_model=train_model(h,device,val_data);torch._dynamo.reset();timed_eval('pre-quantization post-ema',eval_val,h,device,val_data,compiled_model);serialize(h,base_model,Path(__file__).read_text(encoding='utf-8')) + if h.distributed:dist.barrier() + eval_model=deserialize(h,device) + if h.num_loops>0:eval_model.looping_active=True + compiled_model=torch.compile(eval_model,dynamic=False,fullgraph=True);timed_eval('quantized',eval_val,h,device,val_data,compiled_model) + if h.sliding_window_enabled:timed_eval('quantized_sliding_window',eval_val_sliding,h,device,val_data,eval_model) + if h.ttt_enabled and h.sliding_window_enabled: + del eval_model,compiled_model;torch._dynamo.reset();torch.cuda.empty_cache();ttt_model=deserialize(h,device) + if h.num_loops>0:ttt_model.looping_active=True + timed_eval('quantized_ttt',eval_val_ttt,h,device,val_data,ttt_model);del ttt_model + if h.etlb_enabled and h.sliding_window_enabled: + if'eval_model'not in dir(): + eval_model=deserialize(h,device) + if h.num_loops>0:eval_model.looping_active=True + timed_eval('quantized_sliding_etlb',eval_val_sliding_etlb,h,device,val_data,eval_model) +def main(): + world_size=int(os.environ.get('WORLD_SIZE','1'));local_rank=int(os.environ.get('LOCAL_RANK','0'));distributed='RANK'in os.environ and'WORLD_SIZE'in os.environ + if not torch.cuda.is_available():raise RuntimeError('CUDA is required') + if world_size<=0:raise ValueError(f"WORLD_SIZE must be positive, got {world_size}") + if 8%world_size!=0:raise ValueError(f"WORLD_SIZE={world_size} must divide 8 so grad_accum_steps stays integral") + device=torch.device('cuda',local_rank);torch.cuda.set_device(device) + if distributed:dist.init_process_group(backend='nccl',device_id=device);dist.barrier() + torch.backends.cuda.matmul.allow_tf32=True;torch.backends.cudnn.allow_tf32=True;torch.set_float32_matmul_precision('high');from torch.backends.cuda import enable_cudnn_sdp,enable_flash_sdp,enable_math_sdp,enable_mem_efficient_sdp;enable_cudnn_sdp(False);enable_flash_sdp(True);enable_mem_efficient_sdp(False);enable_math_sdp(False);torch._dynamo.config.optimize_ddp=False;h=Hyperparameters();set_logging_hparams(h) + if h.is_main_process: + os.makedirs('logs',exist_ok=True);log(100*'=',console=False);log('Hyperparameters:',console=True) + for(k,v)in sorted(vars(type(h)).items()): + if not k.startswith('_'):log(f" {k}: {v}",console=True) + log('='*100,console=False);log(f"Running Python {sys.version}",console=False);log(f"Running PyTorch {torch.__version__}",console=False);log(subprocess.run(['nvidia-smi'],stdout=subprocess.PIPE,stderr=subprocess.PIPE,text=True,check=False).stdout,console=False);log('='*100,console=False) + train_and_eval(h,device) + if distributed:dist.destroy_process_group() +if __name__=='__main__':main() \ No newline at end of file