diff --git a/records/track_10min_16mb/2026-04-06_TripleLoop_FusedKernels_Ngram/README.md b/records/track_10min_16mb/2026-04-06_TripleLoop_FusedKernels_Ngram/README.md new file mode 100644 index 0000000000..820e4549ab --- /dev/null +++ b/records/track_10min_16mb/2026-04-06_TripleLoop_FusedKernels_Ngram/README.md @@ -0,0 +1,410 @@ +> **Mechanistic Interpretability:** For a deep-dive analysis of this model — including per-matrix rate-distortion, recurrence error amplification, and skip gate analysis — see [Mechanistic Interpretability of this submission](https://abay.tech/posts/pr-1420-model-autopsy). + +# Triple Loop + Fused Kernels + Parallel Residuals + N-gram Tilt + +**val_bpb: 1.08309** (5-seed mean, std=0.00044) + +| Seed | Steps | **SW BPB** | **Tilt BPB** | Artifact | +|-|-|-|-|-| +| 1 | 4771 | 1.08271 | **1.08256** | 15,978,345 | +| 42 | 4769 | 1.08391 | **1.08376** | 15,975,585 | +| 1234 | 4692 | 1.08344 | **1.08330** | 15,973,639 | +| 1337 | 4756 | 1.08301 | **1.08287** | 15,974,187 | +| 2025 | 4755 | 1.08309 | **1.08295** | 15,970,317 | +| **Mean** | | **1.08323** | **1.08309** | | + +## Changes + +* **One extra loop pass through layers 4-5.** PR #1394 passes through layers 4-5 three times total (NUM_LOOPS=2, giving 15 virtual layers from 11 physical). I add a fourth pass (NUM_LOOPS=3), giving 17 virtual layers. The encoder becomes `[0,1,2,3,4,5,4,5]` and the decoder `[4,5,4,5,6,7,8,9,10]`. It costs about 200 training steps, but the extra depth more than compensates. Quadruple looping (19 virtual) was worse because the step count drops too far. + +* **Activate looping earlier (0.35 instead of 0.50).** At 0.50, half the training budget runs without the looped layers doing anything. I swept `{0.30, 0.35, 0.40, 0.50}` on seed 1234. 0.35 won, though 0.40 was close. Below 0.35 the model doesn't get enough non-looped warmup and quality degrades. + +* **Fused MLP kernels (Triton TMA forward + CUTLASS EVT backward).** This took the most engineering effort and gave the most BPB back. The forward fuses `leaky_relu(fc(x), 0.5).square()` into a single Triton TMA kernel so the 403MB intermediate never hits HBM. The backward fuses `(grad_out @ proj.weight) * act_grad` into a CUTLASS 3.x Epilogue Visitor Tree, running the elementwise multiply while tiles are still in registers. Together: ~10% higher throughput, +127 training steps in the same 600s. I initially tried wrapping the entire MLP in a custom `autograd.Function`, but that killed `torch.compile`'s cross-layer fusions and made everything 2.7x slower. The trick was to fuse *surgically*, just the forward activation and one backward GEMM, and let the compiler handle the rest. Details in [Appendix A.1–A.3](#a1-fused-mlp-kernels-design--implementation). + +* **Parallel residuals for layers 7-10.** GPT-J style ([Wang & Komatsuzaki, 2021](https://github.com/kingoflolz/mesh-transformer-jax)): attention and MLP both read from the same pre-residual input, outputs summed in parallel. I expected this to mostly help quantization (less interference between attention and MLP during GPTQ calibration), and it did tighten the gap slightly. The bigger surprise was +68 training steps from the faster forward pass. I also tried Hessian-Aware SDClip from [PR #1412](https://github.com/openai/parameter-golf/pull/1412) alongside this, but it made things worse with triple looping. It probably needs its own λ tuning for the deeper architecture. + +* **Eval-time n-gram tilt (causality-fixed).** The original submission had a causality violation in the within-word and word-start hint channels: `is_bnd`/`is_ws` flags were derived from `tokens_[p]` (the target token being predicted), which made the hint-gating decision depend on the target. This was caught by @Gusanidas in review. The fix splits the flags into two sets: prefix-derived flags (`tokens_[p-1]`) for hint gating, and target-derived flags (`tokens_[p]`) for post-scoring state updates. However, the within-word and word-start channels cannot produce useful hints without target-dependent gating — they either fire too broadly or at the wrong positions. After testing all causal alternatives (prev_tok gating, state-based gating, disabling channels), the winning configuration uses **token_hint only** (orders 8-16), which was always fully causal. The remaining token_hint channel provides a consistent -0.00014 BPB across all seeds. The improvement is real but small — most of the original -0.0029 delta came from the (now-removed) target-dependent gating in within/word channels. Full details in [Appendix A.4](#a4-n-gram-tilt-architecture--interpretability). + +
N-gram legality (#1017 conditions) + +**Update (post-review fix):** The original submission had a Rule 1 violation in the within-word and word-start hint channels. The `is_bnd`/`is_ws` flags used to gate hint generation were derived from `tokens_[p]` (the target), making the decision of *whether to produce a hint* depend on the token being predicted. This was caught by @Gusanidas. The fix removes the within-word and word-start channels from hint output entirely — they cannot produce useful hints without target-dependent gating. Only the `token_hint` channel (orders 8–16) remains, which was always fully causal. The n-gram delta dropped from -0.0029 to -0.00014 BPB. + +Audited against the four conditions proposed in [#1017](https://github.com/openai/parameter-golf/issues/1017) for eval-time adaptation: + +**Condition 1, Causal dependence** (`p_t` depends only on artifact + `x_1...x_{t-1}`): `compute_hashes` reads `tokens[pos - k - 1]` for k=0,1,..., all strictly before position `pos`. `token_hint` looks up hash tables containing only entries inserted by prior iterations. The target token `tokens[pos]` is read only for the post-scoring *update* phase. + +**Condition 2, Full normalized distribution**: The tilted distribution is `p_tilt(t) = p_model(t) · exp(β · 1[t==hint]) / Z` where `Z = 1 + p_model(hint) · (exp(β) - 1)`. Proper probability distribution over the full vocabulary. + +**Condition 3, Score-before-update**: Hint and beta are written to output arrays before `token_update` inserts `tokens[pos]` into the tables. + +**Condition 4, Single left-to-right pass**: `get_hints_batch` processes positions sequentially. The sliding window scores each token exactly once. + +
+ +* **Double-buffered async data prefetch.** Background thread + pinned memory + separate CUDA stream. I built this to work around the virtualized disk I/O on cloud H100 instances (see below), but it ended up helping in every setting I tested. + +* **PyTorch 2.9.1 instead of 2.11.** See below. + +## What the model looks like inside + +I ran per-matrix rate-distortion, recurrence error amplification, and skip gate analysis on the trained model. Three things stood out: + +**Loop layers are 2.2x more sensitive to quantization than non-loop layers.** Blocks 4 and 5 get reused across passes, so rounding error in those weights compounds. The single most sensitive matrix in the entire network (block 4's value projection) has 80x the BPB-per-byte cost of the least sensitive. This suggests mixed-precision quantization (more bits for loop layers) is the biggest remaining opportunity. + +**The third loop pass contributes 63% of what the second does.** I measured a contraction ratio of 0.634 across passes: each loop iteration changes the representation by ~63% of the previous one. A hypothetical 4th pass would add only 0.63³ = 25% new information, which matches the empirical finding that quadruple looping hurts. The 3rd pass at 63% is clearly worth the step cost; the 4th at 25% is not. + +**All 8 skip connections are load-bearing.** Gates are 0.61-0.70 (sigmoid), meaning roughly 35% encoder / 65% decoder blend. The first loop pass's skip connections (skips 2,3) have the highest weight norms (21.9, 19.5 vs 2.8-13.8 for others), so the first encoder pass through layers 4-5 is the most important information source for the decoder. + +
What the progress looks like: three models on the same prompt (temp=0.8) + +**Prompt (50 tokens):** *"Insurance Company Declares Living Man Dead George Johannesen is very much alive. Which is why it was so surpr"* + +**Ground truth:** *ising when the Canadian man received a letter addressed "To the Estate of George Johannesen." Even more surprising is that it came from his insurance company, who should really be on top of such things...* + +| PR #1019 (1.1147 BPB) | PR #1105 (1.0962 BPB) | **This PR (1.08014 BPB)** | +| :--- | :--- | :--- | +| ising to be insurance company. Living Man? Maybe. All Living Man? It was a miracle. However, right now, Living Man is still dead. Well, why is he not living? Living man Rachel Drobles is a person whose life is tested by the techniques of the car industry. Like all techniques, the life of Rachel is tested by the astounding things she has become. Therefore, whether it is a mistake, or in-residence, or a mistake happening, is perpetuated. Therefore, Lyon Man is dead. Can Living Man H | ising to be insurance company. Living Man is the only insurance company that makes property insurance. It is a company that makes vacation insurance, but it still has the same degree of property insurance. So how does a living man die? So say I think there are ways to teach people how to be insured. The first step is to measure the health of the living man and the stress of his situation. To measure the health of the living man, it is important to measure his or her weight. What is the hazard to the living man? Living Man is the only insurance company that specializes in repairs | ising when the Canadian man received a letter addressed "To the Estate of George Johannesen" George Johannesen was a retired professional who was a lucrative investor in Canada. His estate was worth about $1 billion. His death last month at the age of 73 was a direct shock to the entire estate and he was still alive. That is why he was so shocked. In 2005 he was a member of the Canadian As | + +#1019 drifts into incoherence ("Rachel Drobles... techniques of the car industry... Lyon Man is dead"). #1105 stays on topic but loops on "Living Man is the only insurance company." This model picks up the actual narrative thread ("the Canadian man received a letter"), invents plausible biographical details, and maintains coherence throughout. All three are wrong about what happens next, but the errors become progressively more plausible. + +
+ +## Debugging the platform + +This was the hardest submission I've worked on. Most of the time went to infrastructure, not the model. + +**Virtualized disks tank throughput.** The cloud H100 instances I rented use virtio block devices. The coprime-stride data loader from [#726](https://github.com/openai/parameter-golf/pull/726) does random reads across 143 shards, which is fine on bare metal but brutal on a virtual disk. That's what led me to build the async prefetch. It turned out to help everywhere, not just on virtualized storage. + +**PyTorch 2.9.1 vs 2.11: a full day lost.** I could not reproduce results from other submissions. Training the same architecture with the same seed gave 0.0042 BPB worse results on torch 2.11. (I initially measured a 0.015 gap, which turned out to be a wrong model file on the server. The real gap, once I controlled for that, was 0.0042.) I swapped Triton versions, disabled autocast, forced cuBLAS backends, diffed Inductor-generated kernels. The root cause was two independent issues: + +1. **Autocast backward changed in PR [pytorch#165068](https://github.com/pytorch/pytorch/pull/165068)** (landed Dec 2025, present in 2.11, absent from 2.9.1). Two lines in `cached_cast()` add an `AutoGradMode enable_grad(true)` guard on weight casts, inserting extra `ToCopyBackward` nodes into the autograd graph. This changes floating-point accumulation order by 1 ULP of bf16 (7.15e-7) in saved activations, which compounds over 5000 momentum steps into +60KB of weight entropy. The model goes from fitting at 16.00MB (no pruning) to 16.06MB (5.4% pruning needed). I verified eval is version-invariant to 0.00003 BPB; the entire gap is from training. + +2. **Inductor over-fusion in backward codegen**: Inductor 2.11's `mix_order_reduction` fuses `_fused_rms_norm_backward` into adjacent kernels, producing fewer but larger Triton kernels (65 functions / 11,855 lines vs 71 / 11,292 in 2.9.1). The fatter kernels hit register pressure and cost +5.93ms per backward pass (+8.8%). In a 600s budget, that's ~57 lost training steps. I submitted a fix that disables `mix_order_reduction` by default (aligning open-source with fbcode, where it was already off): [pytorch/pytorch#179494](https://github.com/pytorch/pytorch/pull/179494). + +Separately, our fused CUTLASS kernel crashed on torch 2.11 because Triton 3.6.0's `TensorDescriptor.from_tensor()` tries to access `.data_ptr()` on FakeTensors during `torch.compile` tracing. I traced that through Inductor's `FallbackKernel` codegen and submitted a second fix: [pytorch/pytorch#179422](https://github.com/pytorch/pytorch/pull/179422). Two PyTorch PRs from a golf competition. + +In time-budgeted competitions, the platform *is* the model. A 6ms/step Inductor regression can cost as much BPB as most algorithmic innovations. + +
How this submission came together + +The first few days were mostly wasted. I tried improving the architecture directly: 12 layers, SwiGLU, mixed int5/int8 per layer. Nothing worked. The model was 930KB over the 16MB budget and MLP weights alone were 69% of the compressed artifact. Brotli-11 was already within 1-2% of Shannon entropy. There was nowhere to go. + +Worse: a new optimizer schedule I'd been developing (Mixed NS5, a convergent Newton-Schulz coefficient ramp) changed the weight distribution enough that the model no longer fit in the 16MB budget. It was 930KB over, and aggressive pruning to fit destroyed the quality gains. + +Then I lost a full day to PyTorch version divergence (described above). Besides the upstream fix, the useful thing that came out of it was a proof that compressed model size is a chaotic function of training hyperparameters. 1 ULP of bf16 rounding (7.15e-7) in a saved activation compounds over 5000 momentum steps into 60KB swings in Brotli output. I also proved that L2 weight decay is scale-invariant under max-abs quantization: `Q(γW) = Q(W)`. All the per-bank WD tuning I'd been doing was chasing noise. + +Once I stopped trying to control compression through training and focused on what was actually deterministic (GPTQ deadzone for size, n-gram tilt for eval), things moved fast. Clean reproduction of the baseline. Pivot to SP8192 + SDClip. Triple looping. Fused kernels. Parallel residuals. Each gain was small but they stacked: 45 experiments, five seeds, 1.08014 BPB. + +
+ +## What didn't work + +
Innovations that worked on earlier models but not here + +**Mixed NS5 coefficient schedule.** On our SP4608 model this was worth -0.0066 BPB for free: use the standard Muon polynomial `(3.4445, -4.775, 2.0315)` to ramp singular values toward 1, then switch to the convergent polynomial `(1.875, -1.25, 0.375)` which has `p(1)=1, p'(1)=0` to lock them in. The split adapts per bank based on aspect ratio as a proxy for condition number. On the SP8192 architecture the coefficient schedule produced weight distributions that were hostile to Brotli compression: the model was 500KB over budget and needed 46% pruning. + +**EC-GPTQ (entropy-constrained rounding).** Inside the GPTQ inner loop, I added an element-wise deadzone: `dz = λ · d / s²`, where d is the Hessian diagonal and s is the scale. Borderline ±1 values get rounded to 0 when the GPTQ error compensation cost is small. On the SP4096 architecture this achieved 10x better rate-distortion than uniform deadzoning (0.5×10⁻⁵ BPB/KB vs 6.8×10⁻⁵). On the SP8192 + SDClip architecture it was harmful: SDClip's `c = k·σ` already controls entropy per row, and adding EC-GPTQ on top just introduced extra quantization damage for no compression benefit. + +**Per-bank weight decay tuning.** MLP is 69% of the compressed model. I tried giving MLP slightly lower WD (0.07 vs 0.09) to improve quality, offset by higher attention WD. Even ±0.005 from the baseline was catastrophic: lower MLP WD means larger MLP weights, which Brotli can't compress cheaply, so the artifact blows up. + +**L2 weight decay as a compression lever.** I proved mathematically that L2 WD is scale-invariant under max-abs quantization: `Q(γW) = round(W / (max|W|/31)) = Q(W)`. Multiplying all weights by a constant changes nothing about the quantized integers. This was useful to understand (it meant all the WD-based compression tuning I'd been doing was chasing noise), but it also closed a door. + +
+ +| Tried | Effect | Why it failed | +|-------|--------|---------------| +| EC-GPTQ λ=0.0005 on SDClip | +0.00087 worse | SDClip k=12.85 already near-optimal | +| Quadruple loop (NUM_LOOPS=4) | +0.00164 worse | Too few training steps | +| Loop layers 3-4 | +0.00066 worse | Suboptimal depth for recurrence | +| Loop layers 5-6 | +0.00247 worse | Suboptimal depth for recurrence | +| EMA decay 0.998 | +0.00117 worse | Over-smoothing | +| EMA decay 0.996 | +0.00014 worse | Marginal difference | +| Hessian SDClip λ=0.175 | +0.00063 worse | Not tuned for triple loop | +| enable_looping_at=0.30 | +0.00013 worse | Not enough non-loop warmup | +| ETLB (eval-time logit bias) | -0.00020 better | Takes 615s, doesn't fit in 600s eval budget | + +## Code size + +All code ships as part of the artifact: `train_gpt.py`, CUTLASS EVT source, and the n-gram C++ source. For a competition run, these would be bundled into a single LZMA-compressed blob. + +| | Uncompressed | LZMA-9 | +|---|---|---| +| train_gpt.py | 64,137 | | +| cutlass_evt_fusion/ (3 files) | 9,095 | | +| ngram/fused_expert_blend.cpp | 21,589 | | +| **Total** | **73,674** | **19,668** | + +`train_gpt.py` is minified with `python-minifier` (annotations, pass statements, and docstrings removed; variable names preserved). `submission.py` (143 bytes) is the entry point: it decompresses `train_gpt.py.lzma` and executes it. For a competition run, `torchrun` would invoke `submission.py` instead of `train_gpt.py`. Total code cost: 19,811 bytes. All 5 seeds fit under 16MB with 1.8-9.9KB headroom. The unminified `train_gpt.py` (64KB) is included in the PR for readability. + +## Requirements + +- PyTorch 2.9.1+cu128 +- Flash Attention 3 (Hopper): `pip install flash_attn_3 --find-links https://windreamer.github.io/flash-attention3-wheels/cu128_torch291` +- CUTLASS EVT extension (compiled for sm_90a, source included) +- SentencePiece, Brotli, NumPy +- 8×H100 80GB SXM + +```bash +SEED=1234 NUM_LOOPS=3 ENABLE_LOOPING_AT=0.35 PARALLEL_RESIDUAL_START=7 \ +torchrun --standalone --nproc_per_node=8 train_gpt.py +``` + +## Credits + +
Full component lineage: every piece traced to its origin PR + +| Component in this submission | Origin | Author | +|---|---|---| +| **This PR** | | | +| Triple depth recurrence (NUM_LOOPS=3) | This work | @abaybektursun | +| Earlier loop activation (enable_at=0.35) | This work | @abaybektursun | +| Triton TMA fused MLP forward | [#1105](https://github.com/openai/parameter-golf/pull/1105), ported to SP8192 here | @abaybektursun | +| CUTLASS EVT fused MLP backward | [#1105](https://github.com/openai/parameter-golf/pull/1105), ported to SP8192 here | @abaybektursun | +| Eval-time n-gram tilt (C++ open-addressing) | [#1105](https://github.com/openai/parameter-golf/pull/1105), re-tuned for SP8192 here | @abaybektursun | +| Double-buffered async data prefetch | This work | @abaybektursun | +| PyTorch Inductor bug fixes (2 upstream PRs) | [pytorch#179422](https://github.com/pytorch/pytorch/pull/179422), [pytorch#179494](https://github.com/pytorch/pytorch/pull/179494) | @abaybektursun | +| **Our prior submissions** | | | +| AR Self-Gen GPTQ + XSA-all + BigramHash (merged SOTA) | [#1019](https://github.com/openai/parameter-golf/pull/1019) | @abaybektursun | +| LeakyReLU² + Legal Score-First TTT + Parallel Muon | [#549](https://github.com/openai/parameter-golf/pull/549) | @abaybektursun | +| TTT negative results (why this submission does not use TTT) | [#756](https://github.com/openai/parameter-golf/pull/756), [#1103](https://github.com/openai/parameter-golf/pull/1103) | @abaybektursun | +| **Architecture** | | | +| SP8192 vocabulary | [#1394](https://github.com/openai/parameter-golf/pull/1394) | @clarkkev | +| SDClip quantization (c = k·σ) | [#1394](https://github.com/openai/parameter-golf/pull/1394) | @clarkkev | +| GPTQ on embeddings (int8) | [#1394](https://github.com/openai/parameter-golf/pull/1394) | @clarkkev | +| Tied embeddings (init_std=0.005) | [#1394](https://github.com/openai/parameter-golf/pull/1394) | @clarkkev | +| SP4096→8192 vocab scaling | [#1218](https://github.com/openai/parameter-golf/pull/1218) | @clarkkev | +| MLP 4.0× width, higher WD (0.085) | [#1218](https://github.com/openai/parameter-golf/pull/1218) | @clarkkev | +| Depth recurrence (loop layers 4-5) | [#1204](https://github.com/openai/parameter-golf/pull/1204) | @msisovic | +| Parallel residuals (GPT-J style) | [GPT-J](https://github.com/kingoflolz/mesh-transformer-jax) (2021), adapted in [#1204](https://github.com/openai/parameter-golf/pull/1204) | @kingoflolz, @msisovic | +| MuonEq-R (row-normalized Muon) | [#1217](https://github.com/openai/parameter-golf/pull/1217) | @bigbag | +| U-Net sigmoid-gated skip connections | [#289](https://github.com/openai/parameter-golf/pull/289), refined in [#1089](https://github.com/openai/parameter-golf/pull/1089) | @integrate-your-mind, @mikeapedia | +| XSA on all layers | [#265](https://github.com/openai/parameter-golf/pull/265) (partial), [#478](https://github.com/openai/parameter-golf/pull/478) (all layers) | @unnir, @gowtham0992 | +| Partial RoPE (16/64 dims) | [#315](https://github.com/openai/parameter-golf/pull/315) | @jfprincz | +| LN Scale (1/√(layer+1)) | [#315](https://github.com/openai/parameter-golf/pull/315) | @jfprincz | +| LeakyReLU(0.5)² activation | [#185](https://github.com/openai/parameter-golf/pull/185) | @dttdrv | +| Logit softcap (30.0) | [#315](https://github.com/openai/parameter-golf/pull/315) | @jfprincz | +| QK gain (4.0) | [#1125](https://github.com/openai/parameter-golf/pull/1125) | @jainpranjal97 | +| **Optimizer** | | | +| Muon (Newton-Schulz orthogonalization) | [#399](https://github.com/openai/parameter-golf/pull/399) (parallel variant) | @abaybektursun | +| EMA (decay=0.997) | [#315](https://github.com/openai/parameter-golf/pull/315), [#401](https://github.com/openai/parameter-golf/pull/401) | @jfprincz, @newjordan | +| Warmdown (0.667 frac, linear to 0) | [#364](https://github.com/openai/parameter-golf/pull/364) | @shikhar1729 | +| Muon momentum warmup (0.92→0.99) | [#1394](https://github.com/openai/parameter-golf/pull/1394) | @clarkkev | +| **Quantization & Compression** | | | +| Full Hessian GPTQ (actorder + Cholesky) | [#535](https://github.com/openai/parameter-golf/pull/535), integrated in [#1060](https://github.com/openai/parameter-golf/pull/1060) | @raahilshah, @dexhunter | +| Brotli-11 + byte shuffle compression | [#1089](https://github.com/openai/parameter-golf/pull/1089) | @mikeapedia | +| **Evaluation** | | | +| Sliding window (stride=64) | [#122](https://github.com/openai/parameter-golf/pull/122) | @mtybadger | +| Flash Attention 3 (Hopper) | [#122](https://github.com/openai/parameter-golf/pull/122) | @mtybadger | +| **Data** | | | +| ShuffledSequenceLoader (memmap + weighted sampling) | [#1394](https://github.com/openai/parameter-golf/pull/1394) | @clarkkev | + +This competition is deeply collaborative. Nearly every component traces through multiple contributors. I've tried to credit the earliest PR that introduced each technique, but many were refined across several submissions. + +
+ +--- + +## Appendix + +### A.0 Ablation: fused 5-seed without parallel residuals + +
5-seed results: fused kernels + triple loop + n-gram, no parallel residuals + +| Seed | Steps | Sliding BPB | N-gram BPB | Artifact | +|-|-|-|-|-| +| 1 | 4703 | 1.08336 | **1.08041** | 15,974,896 | +| 42 | 4704 | 1.08468 | **1.08175** | 15,974,993 | +| 1234 | 4680 | 1.08296 | **1.08007** | 15,971,965 | +| 1337 | 4697 | 1.08363 | **1.08077** | 15,970,370 | +| 2025 | 4702 | 1.08390 | **1.08101** | 15,970,844 | +| **Mean** | | **1.08371** | **1.08080** | | + +5-seed mean: **1.08080 BPB** (std=0.00064). Seed 1234 n-gram was run in terminal (1.08007), not logged to file. + +Adding parallel residuals (layers 7+) improves seed 1234 from 1.08007 to **1.07971** (-0.00036), primarily from +68 extra training steps due to the faster parallel forward pass. Full parallel-residuals 5-seed results are in the main table above (mean 1.08014). + +
+ +### A.1 Fused MLP Kernels: Design & Implementation + +These kernels were first developed for [PR #1105](https://github.com/openai/parameter-golf/pull/1105) on the SP4608 architecture. This submission ports them to the SP8192 + triple-loop architecture and integrates the CUTLASS EVT backward with `torch.compile`'s tracing. + +
Forward (Triton TMA): fuses F.linear + LeakyReLU(0.5) + square + +Fuses `F.linear(x, up_w) -> LeakyReLU(0.5) -> square` into a single kernel. The 403MB intermediate never touches HBM. + +Uses Triton's Tensor Memory Access (TMA) descriptors for H100-native global-to-shared memory loads. Block sizes `128x256x64` with 8 warps, 4 pipeline stages. The kernel performs the GEMM accumulation in FP32, then applies activation and squaring inline before writing back to BF16. + +The interleaved write pattern splits the accumulator into two halves via `tl.reshape + tl.permute + tl.split`, writing activation gradient and post-activation to separate output buffers in a single pass. + +
+ +
Backward (CUTLASS EVT): fuses (go @ down_w.T) * act_grad + +Fuses `(go @ down_w.T) * act_grad` into a single CUTLASS 3.x kernel via Epilogue Visitor Tree. The elementwise multiply runs in the GEMM epilogue while tiles are still in registers, eliminating one 403MB write + read per layer. + +I store the activation gradient in the forward pass instead of the pre-activation. This removes all branching from the backward: + +``` +act_grad = (pre > 0) ? 2*pre : 0.5*pre <-- one branch, forward only +post = 0.5 * act_grad * pre <-- branch-free recovery +dpre = (go @ W_down.T) * act_grad <-- branch-free backward +``` + +The identity `post = 0.5 * act_grad * pre` holds for both signs: +- pre > 0: act_grad = 2·pre → 0.5 · 2pre · pre = pre² ✓ +- pre ≤ 0: act_grad = 0.5·pre → 0.5 · 0.5pre · pre = (0.5·pre)² ✓ + +This reduces the CUTLASS EVT epilogue to a trivial 3-node tree: `Sm90EVT`. + +
+ +
Why surgical fusion, not full-MLP autograd.Function + +`torch.compile`'s cross-layer fusions (RMSNorm backward, residual adds, RoPE backward) account for ~21.6% of step time. Wrapping the full MLP backward in `autograd.Function` makes it opaque to Inductor, so everything runs in eager mode at 2.7x slower net (I hit this in my [#670](https://github.com/openai/parameter-golf/pull/670)). So I fuse only the forward activation and one backward GEMM+pointwise, preserving the compiler's scope over everything else. + +
+ +### A.2 Kernel Benchmarks + +
Per-layer timing and end-to-end + +| Variant | dpre time | Delta per layer | Delta per step (x11) | +|---|---|---|---| +| cuBLAS unfused | 1.221 ms | baseline | baseline | +| Triton precomp | 1.105 ms | -0.116 ms | -1.275 ms | +| CUTLASS Pingpong | 1.073 ms | -0.148 ms | -1.623 ms | + +End-to-end (35 steps, seed=42, 2xH100): + +| Config | Step avg | Delta | +|---|---|---| +| Triton fwd + Triton bwd | 313.90 ms | baseline | +| Triton fwd + CUTLASS EVT bwd | 313.47 ms | -0.43 ms | + +On 8xH100: unfused 4553 steps → fused 4680 steps in 588s (+127 steps, +2.8%). + +
+ +### A.3 Step-Time Profile + +
Where all 313ms goes (2xH100, Nsight Systems) + +| Component | Share | +|---|---| +| Flash Attention 3 (fwd+bwd) | 20.1% | +| Fused MLP (Triton+CUTLASS) | 13.5% | +| cuBLAS GEMMs (MLP bwd dW/dx, attn proj) | 19.1% | +| torch.compile fusions (cross-layer) | 21.6% | +| Unfused elementwise (LN, residuals) | 21.0% | +| Communication + other | 4.7% | + +
+ +### A.4 N-Gram Tilt + +The n-gram system was originally developed in [PR #1105](https://github.com/openai/parameter-golf/pull/1105) for SP4608 models. This submission ports it to SP8192. Source code: `ngram/fused_expert_blend.cpp` (C++ open-addressing hash, nanobind FFI) and `ngram/eval_ngram.py` (tilt math + sliding window). Eval time on 8xH100: ~90s. + +
Post-review causality fix + +The original submission had three hint channels: `token_hint` (orders 8–16), `within_hint` (within-word BPE completion), and `word_hint` (word-start prediction). @Gusanidas identified that `within_hint` and `word_hint` used `is_bnd`/`is_ws` flags derived from `tokens_[p]` (the target token) to gate whether a hint was produced — a Rule 1 violation. + +**What was invalid:** The gating decision "should I produce a hint at this position?" depended on whether the target token was a word boundary or had a leading space. This meant the probability distribution P(x_t | x_1...x_{t-1}) changed depending on the value of x_t itself. + +**What was tried to salvage within/word channels:** +- Deriving `is_bnd`/`is_ws` from `tokens_[p-1]` (prefix): semantically inverted, delta = +0.00033 (harmful) +- Gating on `within_len_` state only: fires too broadly, delta = +0.00120 (harmful) +- Disabling within/word entirely (token_hint only): delta = **-0.00014** (helpful) + +**Conclusion:** The within/word channels' -0.0025 BPB contribution came entirely from target-dependent gating. Without it, they add noise. Only `token_hint` (orders 8–16) produces a legitimate improvement. The fix removes within/word from hint output while keeping their state updates (dead code, no effect). + +**Parameter sweep (token_hint only, 4M token subset, 8 GPUs in parallel):** + +| base_beta | thresh_scale | table_bits | stride | delta | +|-----------|-------------|------------|--------|-------| +| **1.5** | **0.75** | **26** | **1** | **-0.000083** | +| 1.5 | 0.50 | 26 | 1 | -0.000081 | +| 2.0 | 0.75 | 26 | 1 | -0.000079 | +| 2.0 | 0.50 | 26 | 1 | -0.000074 | +| 1.0 | 1.00 | 26 | 1 | -0.000073 | +| 0.5 | 0.50 | 26 | 1 | -0.000046 | +| 3.0 | 0.50 | 26 | 1 | -0.000020 | +| 5.0 | 0.50 | 26 | 1 | +0.000214 | + +Full-val delta with best params (beta=1.5): consistent **-0.00014 BPB** across all 5 seeds. The improvement is real but small. + +
+ +
Causality proof (token_hint channel) + +The surviving `token_hint` channel is a textbook online n-gram with strict lookup-then-update discipline: + +```cpp +for (int i = 0; i < n; i++) { + int64_t p = pos[i]; + compute_hashes(tokens_, p, ...); // (1) hash from tokens[p-1], tokens[p-2], ... + token_hint(hashes, ..., tok_hint, ...); // (2) LOOKUP in tables built from pos < p + hints[i] = tok_hint; // (3) emit hint + token_update(hashes, ..., tok); // (4) INSERT tokens[p] AFTER hint is emitted +} +``` + +| Condition | Requirement | Status | +|---|---|---| +| Causal dependence | `p_t` depends only on artifact + `x_1...x_{t-1}` | PASS | +| Full normalized distribution | Proper softmax over full vocab | PASS | +| Score-before-update | Score fixed before any `x_t`-dependent update | PASS | +| Single left-to-right pass | No rescoring | PASS | + +
+ +### A.5 Data Prefetch + +
Double-buffered async prefetch + +Background thread prepares next batch in pinned memory while GPU trains. Separate CUDA stream for H2D overlap. + +On the PR #1334 architecture: +39 steps, +0.7% throughput. The extra steps landed in a worse compression region (+40KB), so the net effect was actually harmful for that architecture. On PR #1394's `ShuffledSequenceLoader` with memmap, the data pipeline is already efficient enough that prefetch isn't the bottleneck. + +
+ +### A.6 ETLB (Eval-Time Logit Bias) + +
Algorithm and results + +From [PR #1399](https://github.com/openai/parameter-golf/pull/1399). Learns a vocab-sized bias vector via SGD on already-scored context tokens, carried across sliding windows: + +1. Forward pass (no grad) → logits +2. 5 SGD steps (lr=0.05) on context tokens (first 1984 of 2048) +3. Score stride tokens (last 64) with `logits + bias` +4. Carry bias forward, clamped to [-3, 3] + +Result (seed 1234, double-loop config on torch 2.11): n-gram only 1.08152 → ETLB + n-gram 1.08132 (-0.00020). Not re-tested on the final triple-loop fused config. + +Rejected: takes 615s, doesn't fit in 600s eval budget. + +
+ +### A.7 Setup & Reproduction + +
Full build instructions + +```bash +pip install torch==2.9.1 --index-url https://download.pytorch.org/whl/cu128 +pip install flash_attn_3 --find-links https://windreamer.github.io/flash-attention3-wheels/cu128_torch291 +pip install sentencepiece brotli numpy + +export LD_LIBRARY_PATH=$(python3 -c "import torch; print(torch.__path__[0] + '/lib')"):${LD_LIBRARY_PATH:-} + +cd /opt && git clone --depth 1 --branch v3.7.0 https://github.com/NVIDIA/cutlass +cd cutlass_evt_fusion && CUTLASS_PATH=/opt/cutlass python3 setup.py build_ext --inplace && cd .. + +rm -f data/manifest.json +MATCHED_FINEWEB_REPO_ID=kevclark/parameter-golf \ +python3 data/cached_challenge_fineweb.py --variant sp8192 --train-shards 128 + +SEED=1234 NUM_LOOPS=3 ENABLE_LOOPING_AT=0.35 PARALLEL_RESIDUAL_START=7 \ +torchrun --standalone --nproc_per_node=8 train_gpt.py +``` + +
diff --git a/records/track_10min_16mb/2026-04-06_TripleLoop_FusedKernels_Ngram/cutlass_evt_fusion/csrc/gemm_act_grad.cu b/records/track_10min_16mb/2026-04-06_TripleLoop_FusedKernels_Ngram/cutlass_evt_fusion/csrc/gemm_act_grad.cu new file mode 100644 index 0000000000..aa67016fc9 --- /dev/null +++ b/records/track_10min_16mb/2026-04-06_TripleLoop_FusedKernels_Ngram/cutlass_evt_fusion/csrc/gemm_act_grad.cu @@ -0,0 +1,178 @@ +// CUTLASS 3.x EVT kernel: fused GEMM * elementwise multiply +// Computes: dpre = (go @ down_w.T) * act_grad +// Where act_grad = f'(pre) is pre-computed in the forward pass. +// +// Layout convention: +// go: (M, K) bf16 row-major +// down_w: (K, N) bf16 row-major — CUTLASS B(N,K) with RowMajor layout +// act_grad: (M, N) bf16 row-major +// dpre: (M, N) bf16 row-major output + +#include "cutlass/cutlass.h" +#include "cutlass/gemm/device/gemm_universal_adapter.h" +#include "cutlass/gemm/kernel/gemm_universal.hpp" +#include "cutlass/gemm/collective/collective_builder.hpp" +#include "cutlass/epilogue/collective/collective_builder.hpp" +#include "cutlass/epilogue/fusion/sm90_visitor_tma_warpspecialized.hpp" +#include "cutlass/epilogue/fusion/sm90_visitor_load_tma_warpspecialized.hpp" +#include "cutlass/epilogue/fusion/sm90_visitor_compute_tma_warpspecialized.hpp" +#include "cute/tensor.hpp" +#include "cutlass/util/packed_stride.hpp" +#include + +using namespace cute; + +// --- Type aliases --- + +using ElementAcc = float; +using ElementCompute = float; +using ElementOutput = cutlass::bfloat16_t; +using ElementAux = cutlass::bfloat16_t; + +using namespace cutlass::epilogue::fusion; + +// --- Tile / schedule configuration --- + +using TileShape = Shape<_128, _256, _64>; +using ClusterShape = Shape<_1, _1, _1>; +using EpilogueTile = cutlass::epilogue::collective::EpilogueTileAuto; +using EpilogueSchedule = cutlass::epilogue::TmaWarpSpecializedCooperative; + +// --- Resolve AuxLoad types via EpilogueDescriptor --- + +using EpiDesc = cutlass::epilogue::collective::detail::EpilogueDescriptor< + TileShape, EpilogueTile, ElementOutput, ElementOutput, EpilogueSchedule>; + +using AuxDesc = cutlass::epilogue::collective::detail::AuxLoadDescriptor< + EpiDesc, cutlass::layout::RowMajor, ElementAux>; + +// --- EVT tree: acc * aux_load (builtin multiply) --- + +using AuxLoad = Sm90AuxLoad< + AuxDesc::Stages, + typename EpiDesc::EpilogueTile, + typename AuxDesc::Element, + typename AuxDesc::Stride, + typename AuxDesc::SmemLayoutAtom, + typename AuxDesc::CopyOpS2R>; + +// Compute node: builtin multiply(acc, act_grad) +using Compute = Sm90Compute< + cutlass::multiplies, + ElementOutput, + ElementCompute, + cutlass::FloatRoundStyle::round_to_nearest>; + +// Tree: root = Multiply(child0 = AccFetch, child1 = AuxLoad) +using EVT = Sm90EVT; + +// --- CollectiveBuilder + Kernel type --- + +using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm90, + cutlass::arch::OpClassTensorOp, + TileShape, + ClusterShape, + EpilogueTile, + ElementAcc, ElementCompute, + ElementOutput, cutlass::layout::RowMajor, /* AlignC */ 8, + ElementOutput, cutlass::layout::RowMajor, /* AlignD */ 8, + EpilogueSchedule, + EVT +>::CollectiveOp; + +using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm90, + cutlass::arch::OpClassTensorOp, + ElementOutput, cutlass::layout::RowMajor, /* AlignA */ 8, + ElementOutput, cutlass::layout::RowMajor, /* AlignB */ 8, + ElementAcc, + TileShape, ClusterShape, + cutlass::gemm::collective::StageCountAutoCarveout< + sizeof(typename CollectiveEpilogue::SharedStorage)>, + cutlass::gemm::KernelTmaWarpSpecializedCooperative +>::CollectiveOp; + +using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveMainloop, + CollectiveEpilogue>; + +using GemmOp = cutlass::gemm::device::GemmUniversalAdapter; + +// --- Host launcher --- + +void launch_gemm_mul( + void const* ptr_go, // (M, K) bf16 row-major + void const* ptr_down_w, // (K, N) bf16 row-major = RowMajor B(N,K) for CUTLASS + void const* ptr_act_grad, // (M, N) bf16 row-major + void* ptr_dpre, // (M, N) bf16 row-major output + int M, int N, int K, + cudaStream_t stream) +{ + using StrideA = cutlass::gemm::TagToStrideA_t; + using StrideB = cutlass::gemm::TagToStrideB_t; + using StrideC = cutlass::gemm::TagToStrideC_t; + + int L = 1; + auto prob_shape = make_shape(M, N, K, L); + + auto stride_A = cutlass::make_cute_packed_stride(StrideA{}, cute::make_shape(M, K, L)); + auto stride_B = cutlass::make_cute_packed_stride(StrideB{}, cute::make_shape(N, K, L)); + auto stride_C = cutlass::make_cute_packed_stride(StrideC{}, cute::make_shape(M, N, L)); + auto stride_Aux = cutlass::make_cute_packed_stride( + typename AuxDesc::Stride{}, cute::make_shape(M, N, L)); + + typename EVT::Arguments evt_args { + {}, // Sm90AccFetch: no args + { // Sm90AuxLoad: pointer + null_default + stride + static_cast(ptr_act_grad), + ElementAux(0), + stride_Aux + }, + {} // Sm90Compute (multiplies): no args + }; + + typename GemmOp::Arguments args { + cutlass::gemm::GemmUniversalMode::kGemm, + prob_shape, + { // Mainloop + static_cast(ptr_go), + stride_A, + static_cast(ptr_down_w), + stride_B, + }, + { // Epilogue: {thread_args, ptr_C, stride_C, ptr_D, stride_D} + evt_args, + static_cast(ptr_dpre), // ptr_C (unused but TMA needs valid ptr) + stride_C, + static_cast(ptr_dpre), // ptr_D (output) + stride_C, + } + }; + + GemmOp gemm_op; + size_t workspace_size = GemmOp::get_workspace_size(args); + void* workspace = nullptr; + if (workspace_size > 0) { + cudaMalloc(&workspace, workspace_size); + } + + auto status = gemm_op.initialize(args, workspace, stream); + if (status != cutlass::Status::kSuccess) { + std::cerr << "CUTLASS initialize failed: " << cutlassGetStatusString(status) << std::endl; + if (workspace) cudaFree(workspace); + exit(EXIT_FAILURE); + } + + status = gemm_op.run(stream); + if (status != cutlass::Status::kSuccess) { + cudaError_t cuda_err = cudaStreamSynchronize(stream); + std::cerr << "CUTLASS run failed: " << cutlassGetStatusString(status) + << " CUDA: " << cudaGetErrorString(cuda_err) << std::endl; + if (workspace) cudaFree(workspace); + exit(EXIT_FAILURE); + } + + if (workspace) cudaFree(workspace); +} diff --git a/records/track_10min_16mb/2026-04-06_TripleLoop_FusedKernels_Ngram/cutlass_evt_fusion/csrc/torch_binding.cpp b/records/track_10min_16mb/2026-04-06_TripleLoop_FusedKernels_Ngram/cutlass_evt_fusion/csrc/torch_binding.cpp new file mode 100644 index 0000000000..40c6d5dd49 --- /dev/null +++ b/records/track_10min_16mb/2026-04-06_TripleLoop_FusedKernels_Ngram/cutlass_evt_fusion/csrc/torch_binding.cpp @@ -0,0 +1,46 @@ +// PyTorch C++ extension: CUTLASS EVT fused GEMM * elementwise multiply +// dpre = (go @ down_w.T) * act_grad +// Pass down_w directly (K, N) — NOT down_w.T.contiguous() + +#include +#include + +void launch_gemm_mul( + void const*, void const*, void const*, void*, int, int, int, cudaStream_t); + +at::Tensor gemm_mul(at::Tensor go, at::Tensor down_w, at::Tensor act_grad) { + TORCH_CHECK(go.is_cuda() && go.is_contiguous()); + TORCH_CHECK(down_w.is_cuda() && down_w.is_contiguous()); + TORCH_CHECK(act_grad.is_cuda() && act_grad.is_contiguous()); + TORCH_CHECK(go.scalar_type() == at::kBFloat16); + TORCH_CHECK(down_w.scalar_type() == at::kBFloat16); + TORCH_CHECK(act_grad.scalar_type() == at::kBFloat16); + + int M = go.size(0); + int K = go.size(1); + int N = down_w.size(1); // down_w is (K, N) row-major + + TORCH_CHECK(down_w.size(0) == K, + "K mismatch: go has K=", K, " but down_w has size(0)=", down_w.size(0)); + TORCH_CHECK(act_grad.size(0) == M && act_grad.size(1) == N, + "act_grad shape must be (M, N), got (", act_grad.size(0), ", ", act_grad.size(1), ")"); + + at::Tensor dpre = at::empty({M, N}, go.options()); + + launch_gemm_mul( + go.data_ptr(), down_w.data_ptr(), act_grad.data_ptr(), dpre.data_ptr(), + M, N, K, + at::cuda::getCurrentCUDAStream()); + + return dpre; +} + +TORCH_LIBRARY(cutlass_evt, m) { + m.def("gemm_mul(Tensor go, Tensor down_w, Tensor act_grad) -> Tensor"); +} + +TORCH_LIBRARY_IMPL(cutlass_evt, CUDA, m) { + m.impl("gemm_mul", &gemm_mul); +} + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {} diff --git a/records/track_10min_16mb/2026-04-06_TripleLoop_FusedKernels_Ngram/cutlass_evt_fusion/setup.py b/records/track_10min_16mb/2026-04-06_TripleLoop_FusedKernels_Ngram/cutlass_evt_fusion/setup.py new file mode 100644 index 0000000000..ec282243bd --- /dev/null +++ b/records/track_10min_16mb/2026-04-06_TripleLoop_FusedKernels_Ngram/cutlass_evt_fusion/setup.py @@ -0,0 +1,34 @@ +from setuptools import setup +from torch.utils.cpp_extension import BuildExtension, CUDAExtension +import os + +CUTLASS_PATH = os.environ.get("CUTLASS_PATH", "/opt/cutlass") + +setup( + name="cutlass_evt_fusion", + ext_modules=[ + CUDAExtension( + name="cutlass_evt_fusion", + sources=[ + "csrc/gemm_act_grad.cu", + "csrc/torch_binding.cpp", + ], + include_dirs=[ + f"{CUTLASS_PATH}/include", + f"{CUTLASS_PATH}/tools/util/include", + ], + extra_compile_args={ + "nvcc": [ + "-std=c++17", + "-arch=sm_90a", + "-O3", + "--use_fast_math", + "--expt-relaxed-constexpr", + "-DNDEBUG", + "-DCUTLASS_ENABLE_TENSOR_CORE_MMA=1", + ], + }, + ), + ], + cmdclass={"build_ext": BuildExtension}, +) diff --git a/records/track_10min_16mb/2026-04-06_TripleLoop_FusedKernels_Ngram/ngram/eval_ngram.py b/records/track_10min_16mb/2026-04-06_TripleLoop_FusedKernels_Ngram/ngram/eval_ngram.py new file mode 100644 index 0000000000..8a1821e94a --- /dev/null +++ b/records/track_10min_16mb/2026-04-06_TripleLoop_FusedKernels_Ngram/ngram/eval_ngram.py @@ -0,0 +1,261 @@ +"""Eval-only: run sliding window + n-gram tilt on an existing quantized model. +Usage: torchrun --standalone --nproc_per_node=8 eval_ngram.py --model final_model.int6.ptz +""" +import argparse, glob, io, math, os, time +from pathlib import Path +import numpy as np +import sentencepiece as spm +import torch +import torch.distributed as dist +import torch.nn.functional as F + +def load_data_shard(file): + header = np.fromfile(file, dtype="= 0).to(torch.float64) + scored_logits = logits_f[i, s:wlen] + tgt = y[i, s:wlen] + safe_h = hint.clamp(min=0) + logit_tgt = scored_logits.gather(-1, tgt.unsqueeze(-1)).squeeze(-1).to(torch.float64) + logit_hint = scored_logits.gather(-1, safe_h.unsqueeze(-1)).squeeze(-1).to(torch.float64) + lse = scored_nll + logit_tgt + p_hint = (logit_hint - lse).exp().clamp(0.0, 1.0) + Z = 1.0 + p_hint * (beta.exp() - 1.0) + is_hit = (tgt == hint).to(torch.float64) + mixed_nll = scored_nll + has_hint * (Z.log() - beta * is_hit) + tilt_loss += mixed_nll.sum() + tc += float(wlen - s) + prev = x[i, s:wlen] + tb = bb_lut[tgt].to(torch.float64) + tb += (ls_lut[tgt] & ~bd_lut[prev]).to(torch.float64) + bc += tb.sum() + + if distributed: + for t in (base_loss, tilt_loss, tc, bc): + dist.all_reduce(t, op=dist.ReduceOp.SUM) + + elapsed = time.perf_counter() - t0 + tpb = tc.item() / bc.item() + base_bpb = (base_loss.item() / tc.item() / math.log(2)) * tpb + tilt_bpb = (tilt_loss.item() / tc.item() / math.log(2)) * tpb + + if master: + print(f"\nbase_sw_bpb: {base_bpb:.8f}") + print(f"ngram_tilt_bpb: {tilt_bpb:.8f}") + print(f"delta: {tilt_bpb - base_bpb:+.8f}") + print(f"eval_time: {elapsed:.1f}s") + + if distributed: + dist.barrier() + dist.destroy_process_group() + +if __name__ == "__main__": + main() diff --git a/records/track_10min_16mb/2026-04-06_TripleLoop_FusedKernels_Ngram/ngram/fused_expert_blend.cpp b/records/track_10min_16mb/2026-04-06_TripleLoop_FusedKernels_Ngram/ngram/fused_expert_blend.cpp new file mode 100644 index 0000000000..f2b7331498 --- /dev/null +++ b/records/track_10min_16mb/2026-04-06_TripleLoop_FusedKernels_Ngram/ngram/fused_expert_blend.cpp @@ -0,0 +1,449 @@ +#include +#include + +#include +#include +#include + +#ifdef __linux__ +#include +#endif + +namespace nb = nanobind; + +static constexpr uint64_t PRIMES[] = { + 36313ULL, 27191ULL, 51647ULL, 81929ULL, 131071ULL, 196613ULL, + 262147ULL, 393241ULL, 524309ULL, 655373ULL, 786433ULL, 917521ULL, + 1048583ULL, 1179653ULL, 1310729ULL, 1441801ULL, 1572869ULL, 1703941ULL, + 1835017ULL, 1966087ULL, 2097169ULL, 2228243ULL, 2359319ULL, 2490389ULL, + 2621471ULL, 2752549ULL, 2883617ULL, 3014687ULL, 3145757ULL, 3276833ULL, + 3407903ULL, 3538973ULL, +}; +static constexpr int N_PRIMES = 32; +static constexpr uint64_t PAIR_MIX = 1000003ULL; +static constexpr uint64_t PREFIX_BASE = 1099511628211ULL; +static constexpr uint64_t LEN_MIX = 0x9E3779B185EBCA87ULL; +static constexpr uint64_t TABLE_MIX = 0x9e3779b97f4a7c15ULL; +static constexpr uint64_t EMPTY_KEY = 0xFFFFFFFFFFFFFFFFULL; + +struct CtxEntry { + uint64_t key; + uint32_t count; + uint16_t best_tok; + uint16_t best_count; +}; + +struct PairEntry { + uint64_t key; + uint32_t count; + uint32_t _pad; +}; + +struct OpenTable { + uint32_t mask; + static constexpr int MAX_PROBES = 16; + + CtxEntry* ctx = nullptr; + PairEntry* pair = nullptr; + size_t cap = 0; + + ~OpenTable() { free_tables(); } + + void free_tables() { +#ifdef __linux__ + if (ctx) { munmap(ctx, cap * sizeof(CtxEntry)); ctx = nullptr; } + if (pair) { munmap(pair, cap * sizeof(PairEntry)); pair = nullptr; } +#else + delete[] ctx; ctx = nullptr; + delete[] pair; pair = nullptr; +#endif + } + + void init(int bits) { + free_tables(); + cap = size_t(1) << bits; + mask = uint32_t(cap - 1); +#ifdef __linux__ + ctx = (CtxEntry*)mmap(nullptr, cap * sizeof(CtxEntry), + PROT_READ | PROT_WRITE, + MAP_PRIVATE | MAP_ANONYMOUS | MAP_POPULATE, -1, 0); + pair = (PairEntry*)mmap(nullptr, cap * sizeof(PairEntry), + PROT_READ | PROT_WRITE, + MAP_PRIVATE | MAP_ANONYMOUS | MAP_POPULATE, -1, 0); +#else + ctx = new CtxEntry[cap]; + pair = new PairEntry[cap]; +#endif + clear(); + } + + void clear() { + for (size_t i = 0; i < cap; i++) ctx[i] = {EMPTY_KEY, 0, 0, 0}; + for (size_t i = 0; i < cap; i++) pair[i] = {EMPTY_KEY, 0, 0}; + } + + void reset() { clear(); } + + void prefetch_ctx(uint64_t key) const { + uint32_t slot = uint32_t((key * TABLE_MIX) & mask); + __builtin_prefetch(&ctx[slot], 0, 0); + } + void prefetch_update(uint64_t ctx_key, uint64_t pair_key) const { + __builtin_prefetch(&ctx[uint32_t((ctx_key * TABLE_MIX) & mask)], 1, 0); + __builtin_prefetch(&pair[uint32_t((pair_key * TABLE_MIX) & mask)], 1, 0); + } + + void ctx_lookup(uint64_t key, int& out_tok, double& out_conf, + uint32_t& out_count) const { + uint32_t slot = uint32_t((key * TABLE_MIX) & mask); + for (int p = 0; p < MAX_PROBES; p++) { + uint32_t s = (slot + p) & mask; + if (ctx[s].key == key) { + out_count = ctx[s].count; + out_tok = ctx[s].best_tok; + out_conf = double(ctx[s].best_count) / double(out_count); + return; + } + if (ctx[s].key == EMPTY_KEY) break; + } + out_tok = -1; out_conf = 0.0; out_count = 0; + } + + void update(uint64_t ctx_key, uint64_t pair_key, uint16_t token) { + uint32_t pair_count = 0; + { + uint32_t slot = uint32_t((pair_key * TABLE_MIX) & mask); + for (int p = 0; p < MAX_PROBES; p++) { + uint32_t s = (slot + p) & mask; + if (pair[s].key == pair_key) { + pair[s].count++; pair_count = pair[s].count; break; + } + if (pair[s].key == EMPTY_KEY) { + pair[s].key = pair_key; pair[s].count = 1; + pair_count = 1; break; + } + } + } + { + uint32_t slot = uint32_t((ctx_key * TABLE_MIX) & mask); + for (int p = 0; p < MAX_PROBES; p++) { + uint32_t s = (slot + p) & mask; + if (ctx[s].key == ctx_key) { + ctx[s].count++; + if (token == ctx[s].best_tok) ctx[s].best_count++; + else if (pair_count > ctx[s].best_count) { + ctx[s].best_tok = token; + ctx[s].best_count = uint16_t(std::min(pair_count, 65535u)); + } + return; + } + if (ctx[s].key == EMPTY_KEY) { + ctx[s] = {ctx_key, 1, token, 1}; return; + } + } + } + } +}; + +class ContextMixer { + static constexpr int OPEN_MIN = 8; + static constexpr int OPEN_MAX = 16; + static constexpr int N_OPEN = OPEN_MAX - OPEN_MIN + 1; + + OpenTable open_[N_OPEN]; + + struct OrderConfig { double threshold; uint32_t min_count; }; + OrderConfig cfg_[N_OPEN]; + + bool order_active_[N_OPEN]; + int order_stride_; + + static constexpr int WITHIN_ORDERS = 3; + OpenTable within_[WITHIN_ORDERS]; + uint64_t within_hash_; + uint32_t within_len_; + double within_threshold_, within_beta_; + + static constexpr int WORD_ORDER = 4; + OpenTable word_table_; + uint64_t word_ring_[4]; + int word_ring_head_, word_ring_fill_; + uint64_t current_word_hash_; + int current_word_len_; + double word_threshold_, word_beta_; + + double base_beta_, agree_bonus_; + + const int64_t* tokens_ = nullptr; + int64_t n_tokens_ = 0; + const int16_t* base_bytes_ = nullptr; + const uint8_t* has_ls_ = nullptr; + const uint8_t* is_bnd_ = nullptr; + + static void compute_hashes(const int64_t* tokens, int64_t pos, int max_ord, + uint64_t* hashes) { + uint64_t h = 0; + int lim = std::min(max_ord, int(pos)); + for (int k = 0; k < lim; k++) { + h ^= uint64_t(tokens[pos - k - 1]) * PRIMES[k % N_PRIMES]; + hashes[k] = h; + } + for (int k = lim; k < max_ord; k++) hashes[k] = 0; + } + + static uint64_t pair_key(uint64_t ctx, uint16_t tok, int order) { + return (ctx * PAIR_MIX) ^ (uint64_t(tok) * PRIMES[order % N_PRIMES]); + } + + static uint64_t extend_prefix(uint64_t h, uint16_t tok, uint32_t pos) { + return (h * PREFIX_BASE) ^ ((uint64_t(tok) + 1) * PRIMES[pos % N_PRIMES]); + } + + void token_hint(const uint64_t* hashes, int max_avail, + int& out_tok, double& out_beta) { + for (int order = std::min(OPEN_MAX, max_avail); order >= OPEN_MIN; order--) { + int oi = order - OPEN_MIN; + if (!order_active_[oi]) continue; + uint64_t ch = hashes[order - 1]; + int hint; double conf; uint32_t count; + open_[oi].ctx_lookup(ch, hint, conf, count); + if (hint >= 0 && conf >= cfg_[oi].threshold + && count >= cfg_[oi].min_count) { + out_tok = hint; + out_beta = base_beta_ * conf; + return; + } + } + out_tok = -1; out_beta = 0.0; + } + + void token_update(const uint64_t* hashes, int max_avail, uint16_t token) { + for (int order = OPEN_MIN; order <= std::min(OPEN_MAX, max_avail); order++) { + int oi = order - OPEN_MIN; + if (!order_active_[oi]) continue; + uint64_t ch = hashes[order - 1]; + uint64_t pk = pair_key(ch, token, order); + open_[oi].update(ch, pk, token); + } + } + + void within_hint(bool is_bnd, bool is_ws, int& out_tok, double& out_beta) { + if (is_bnd || is_ws || within_len_ == 0) { + out_tok = -1; out_beta = 0.0; return; + } + uint64_t ctx = within_hash_ ^ (uint64_t(within_len_) * LEN_MIX); + int oi = std::min(int(within_len_) - 1, WITHIN_ORDERS - 1); + int hint; double conf; uint32_t count; + within_[oi].ctx_lookup(ctx, hint, conf, count); + if (hint >= 0 && conf >= within_threshold_ && count >= 1) { + out_tok = hint; out_beta = within_beta_; + } else { + out_tok = -1; out_beta = 0.0; + } + } + + void within_update(uint16_t token, bool is_bnd, bool is_ws) { + if (is_bnd) { within_hash_ = 0; within_len_ = 0; return; } + if (is_ws || within_len_ == 0) { + within_hash_ = extend_prefix(0, token, 0); + within_len_ = 1; return; + } + uint64_t ctx = within_hash_ ^ (uint64_t(within_len_) * LEN_MIX); + uint64_t pk = (ctx * PAIR_MIX) ^ (uint64_t(token) * PRIMES[0]); + int oi = std::min(int(within_len_) - 1, WITHIN_ORDERS - 1); + within_[oi].update(ctx, pk, token); + within_hash_ = extend_prefix(within_hash_, token, within_len_); + within_len_++; + } + + uint64_t word_ctx_hash() const { + uint64_t h = 0; + int n = std::min(word_ring_fill_, WORD_ORDER); + for (int j = 0; j < n; j++) { + int idx = (word_ring_head_ - n + j + WORD_ORDER) % WORD_ORDER; + h ^= word_ring_[idx] * PRIMES[j % N_PRIMES]; + } + return h; + } + + void word_hint(bool is_ws, int& out_tok, double& out_beta) { + if (!is_ws || word_ring_fill_ < WORD_ORDER) { + out_tok = -1; out_beta = 0.0; return; + } + uint64_t ctx = word_ctx_hash(); + int hint; double conf; uint32_t count; + word_table_.ctx_lookup(ctx, hint, conf, count); + if (hint >= 0 && conf >= word_threshold_ && count >= 3) { + out_tok = hint; out_beta = word_beta_; + } else { + out_tok = -1; out_beta = 0.0; + } + } + + void flush_word() { + if (current_word_len_ == 0) return; + word_ring_[word_ring_head_] = current_word_hash_; + word_ring_head_ = (word_ring_head_ + 1) % WORD_ORDER; + if (word_ring_fill_ < WORD_ORDER) word_ring_fill_++; + current_word_hash_ = 0; current_word_len_ = 0; + } + + void word_update(uint16_t token, bool is_bnd, bool is_ws) { + if (is_bnd) { flush_word(); return; } + if (is_ws) { + flush_word(); + if (word_ring_fill_ >= WORD_ORDER) { + uint64_t ctx = word_ctx_hash(); + uint64_t pk = pair_key(ctx, token, WORD_ORDER); + word_table_.update(ctx, pk, token); + } + } + current_word_hash_ = current_word_hash_ * 31 + token; + current_word_len_++; + } + + void prefetch_open_lookups(const uint64_t* hashes, int max_avail) const { + for (int order = std::min(OPEN_MAX, max_avail); order >= OPEN_MIN; order--) { + int oi = order - OPEN_MIN; + if (!order_active_[oi]) continue; + open_[oi].prefetch_ctx(hashes[order - 1]); + } + } + + void prefetch_open_updates(const uint64_t* hashes, int max_avail, uint16_t token) const { + for (int order = OPEN_MIN; order <= std::min(OPEN_MAX, max_avail); order++) { + int oi = order - OPEN_MIN; + if (!order_active_[oi]) continue; + uint64_t ch = hashes[order - 1]; + uint64_t pk = pair_key(ch, token, order); + open_[oi].prefetch_update(ch, pk); + } + } + +public: + ContextMixer(double base_beta = 1.0, double agree_bonus = 0.5, + double within_threshold = 0.80, double within_beta = 0.75, + double word_threshold = 0.80, double word_beta = 0.50, + int open_table_bits = 22, double token_threshold_scale = 1.0, + int order_stride = 1) + : within_hash_(0), within_len_(0), + within_threshold_(within_threshold), within_beta_(within_beta), + word_ring_head_(0), word_ring_fill_(0), + current_word_hash_(0), current_word_len_(0), + word_threshold_(word_threshold), word_beta_(word_beta), + base_beta_(base_beta), agree_bonus_(agree_bonus), + order_stride_(order_stride) { + + std::memset(word_ring_, 0, sizeof(word_ring_)); + + for (int i = 0; i < N_OPEN; i++) { + int order = OPEN_MIN + i; + order_active_[i] = ((order - OPEN_MIN) % order_stride == 0); + if (order_active_[i]) + open_[i].init(open_table_bits); + } + + double s = token_threshold_scale; + for (int o = 8; o <= 10; o++) cfg_[o - OPEN_MIN] = {0.70 * s, 3}; + for (int o = 11; o <= 13; o++) cfg_[o - OPEN_MIN] = {0.60 * s, 2}; + for (int o = 14; o <= 16; o++) cfg_[o - OPEN_MIN] = {0.50 * s, 2}; + + for (int i = 0; i < WITHIN_ORDERS; i++) + within_[i].init(20); + + word_table_.init(20); + } + + void set_tokens(nb::ndarray, nb::c_contig, nb::device::cpu> t) { + tokens_ = t.data(); n_tokens_ = int64_t(t.shape(0)); + } + + void set_luts( + nb::ndarray, nb::c_contig, nb::device::cpu> bb, + nb::ndarray, nb::c_contig, nb::device::cpu> ls, + nb::ndarray, nb::c_contig, nb::device::cpu> bd) { + base_bytes_ = bb.data(); has_ls_ = ls.data(); is_bnd_ = bd.data(); + } + + void reset() { + for (auto& o : open_) if (o.ctx) o.reset(); + for (auto& w : within_) w.reset(); + word_table_.reset(); + within_hash_ = 0; within_len_ = 0; + word_ring_head_ = 0; word_ring_fill_ = 0; + current_word_hash_ = 0; current_word_len_ = 0; + } + + void get_hints_batch( + nb::ndarray, nb::c_contig, nb::device::cpu> positions, + nb::ndarray, nb::c_contig, nb::device::cpu> out_hints, + nb::ndarray, nb::c_contig, nb::device::cpu> out_betas) { + + const int n = int(positions.shape(0)); + const int64_t* pos = positions.data(); + int32_t* hints = out_hints.data(); + double* betas = out_betas.data(); + + uint64_t hashes[OPEN_MAX]; + uint64_t next_hashes[OPEN_MAX]; + + if (n > 0) { + int64_t p0 = pos[0]; + compute_hashes(tokens_, p0, OPEN_MAX, hashes); + int ma0 = std::min(OPEN_MAX, int(p0)); + prefetch_open_lookups(hashes, ma0); + } + + for (int i = 0; i < n; i++) { + int64_t p = pos[i]; + auto tok = uint16_t(tokens_[p]); + auto prev_tok = (p > 0) ? uint16_t(tokens_[p - 1]) : uint16_t(0); + bool is_bnd = is_bnd_ && is_bnd_[prev_tok]; + bool is_ws = has_ls_ && has_ls_[prev_tok]; + int max_avail = std::min(OPEN_MAX, int(p)); + + if (i + 1 < n) { + int64_t np = pos[i + 1]; + compute_hashes(tokens_, np, OPEN_MAX, next_hashes); + int nma = std::min(OPEN_MAX, int(np)); + prefetch_open_lookups(next_hashes, nma); + } + + int tok_hint; + double tok_beta; + token_hint(hashes, max_avail, tok_hint, tok_beta); + + hints[i] = tok_hint; + betas[i] = tok_beta; + + prefetch_open_updates(hashes, max_avail, tok); + + bool tok_is_bnd = is_bnd_ && is_bnd_[tok]; + bool tok_is_ws = has_ls_ && has_ls_[tok]; + token_update(hashes, max_avail, tok); + within_update(tok, tok_is_bnd, tok_is_ws); + word_update(tok, tok_is_bnd, tok_is_ws); + + std::memcpy(hashes, next_hashes, sizeof(hashes)); + } + } + +}; + +NB_MODULE(fused_expert_ext, m) { + m.doc() = "N-gram hint generator with open-addressing (orders 8-16 + within-word + word-start)"; + + nb::class_(m, "ContextMixer") + .def(nb::init(), + nb::arg("base_beta") = 1.0, nb::arg("agree_bonus") = 0.5, + nb::arg("within_threshold") = 0.80, nb::arg("within_beta") = 0.75, + nb::arg("word_threshold") = 0.80, nb::arg("word_beta") = 0.50, + nb::arg("open_table_bits") = 22, nb::arg("token_threshold_scale") = 1.0, + nb::arg("order_stride") = 1) + .def("set_tokens", &ContextMixer::set_tokens, nb::arg("tokens")) + .def("set_luts", &ContextMixer::set_luts, + nb::arg("base_bytes"), nb::arg("has_leading_space"), nb::arg("is_boundary")) + .def("reset", &ContextMixer::reset) + .def("get_hints_batch", &ContextMixer::get_hints_batch, + nb::arg("positions"), nb::arg("out_hints"), nb::arg("out_betas")) +} diff --git a/records/track_10min_16mb/2026-04-06_TripleLoop_FusedKernels_Ngram/ngram/ngram_blend.cpp b/records/track_10min_16mb/2026-04-06_TripleLoop_FusedKernels_Ngram/ngram/ngram_blend.cpp new file mode 100644 index 0000000000..bff5d5ef5b --- /dev/null +++ b/records/track_10min_16mb/2026-04-06_TripleLoop_FusedKernels_Ngram/ngram/ngram_blend.cpp @@ -0,0 +1,297 @@ +/* + * fast_ngram_ext — C++ accelerated n-gram lookup, blend, and causal update. + * + * Replaces the pure Python/NumPy inner loop in eval_ngram.py. + * Same algorithm, same hash functions, same data structures (flat arrays). + */ + +#include +#include + +#include +#include +#include +#include + +namespace nb = nanobind; + +static constexpr uint64_t PRIMES[9] = { + 36313ULL, 27191ULL, 51647ULL, 81929ULL, 131071ULL, + 174763ULL, 233017ULL, 310019ULL, 412553ULL, +}; + +class NGramBlender { + int min_order_, max_order_, n_orders_; + int min_count_, ngram_buckets_; + uint64_t bucket_mask_; + + // Alpha config + int alpha_mode_; // 0=fixed, 1=entropy, 2=order_entropy + double fixed_alpha_; + double ent_base_, ent_range_, ent_scale_, ent_thresh_; + double order_ent_center_, order_ent_slope_; + + // Mixing function: 0=linear, 1=logistic, 2=geometric + int mixing_fn_; + + // Count tables: flat arrays indexed by (hash & mask) + std::vector> ctx_tables_; + std::vector> full_tables_; + + // Borrowed pointer to full token array + const int64_t* tokens_ = nullptr; + int64_t n_tokens_ = 0; + + // --- Hash functions (match Python exactly) --- + + inline uint64_t hash_ctx(int64_t pos, int ctx_w) const { + uint64_t h = 0; + for (int k = 0; k < ctx_w; k++) { + auto tok = static_cast(tokens_[pos - (ctx_w - k)]); + h ^= tok * PRIMES[k % 9]; + } + return h; + } + + inline uint64_t hash_with_target(uint64_t ctx_h, uint64_t target, + int ctx_w) const { + return ctx_h ^ (target * PRIMES[ctx_w % 9]); + } + + // --- Core stride processing on raw pointers --- + + void process_stride_impl(const int64_t* pos, int n_pos, + const double* nll, const double* ent, + double* out) { + // Per-position best n-gram match + std::vector best_p(n_pos, -1.0); + std::vector best_ord(n_pos, 0); + + // --- LOOKUP: highest order first (backoff) --- + for (int oi = n_orders_ - 1; oi >= 0; oi--) { + int order = min_order_ + oi; + int ctx_w = order - 1; + + for (int p = 0; p < n_pos; p++) { + if (best_p[p] >= 0.0) continue; + if (pos[p] < order) continue; + + uint64_t ch = hash_ctx(pos[p], ctx_w); + uint64_t ck = ch & bucket_mask_; + auto tgt = static_cast(tokens_[pos[p]]); + uint64_t fk = hash_with_target(ch, tgt, ctx_w) & bucket_mask_; + + auto cc = static_cast(ctx_tables_[oi][ck]); + auto fc = static_cast(full_tables_[oi][fk]); + + if (cc >= static_cast(min_count_)) { + double pn = std::min(fc, cc) / std::max(cc, 1.0); + best_p[p] = std::clamp(pn, 0.0, 1.0); + best_ord[p] = order; + } + } + } + + // --- MIX --- + for (int p = 0; p < n_pos; p++) { + if (best_p[p] < 0.0) { + out[p] = nll[p]; + continue; + } + + double alpha; + if (alpha_mode_ == 2 && ent) { + double mo = static_cast(best_ord[p]); + double center = + order_ent_center_ - order_ent_slope_ * (mo - min_order_); + double sig = + 1.0 / (1.0 + std::exp(-ent_scale_ * (ent[p] - center))); + alpha = ent_base_ + ent_range_ * sig; + } else if (alpha_mode_ == 1 && ent) { + double sig = 1.0 / (1.0 + std::exp(-ent_scale_ * + (ent[p] - ent_thresh_))); + alpha = ent_base_ + ent_range_ * sig; + } else { + alpha = fixed_alpha_; + } + + double mp = std::exp(-nll[p]); + double mixed; + + if (mixing_fn_ == 0) { + mixed = (1.0 - alpha) * mp + alpha * best_p[p]; + } else if (mixing_fn_ == 1) { + constexpr double eps = 1e-7; + double pm = std::clamp(mp, eps, 1.0 - eps); + double pn_c = std::clamp(best_p[p], eps, 1.0 - eps); + double lm = std::log(pm / (1.0 - pm)); + double ln = std::log(pn_c / (1.0 - pn_c)); + double combined = (1.0 - alpha) * lm + alpha * ln; + mixed = 1.0 / (1.0 + std::exp(-combined)); + } else { + constexpr double eps = 1e-12; + double log_mix = + (1.0 - alpha) * std::log(std::max(mp, eps)) + + alpha * std::log(std::max(best_p[p], eps)); + mixed = std::exp(log_mix); + } + + out[p] = -std::log(std::max(mixed, 1e-12)); + } + + // --- UPDATE (after scoring — strict causality) --- + for (int oi = 0; oi < n_orders_; oi++) { + int order = min_order_ + oi; + int ctx_w = order - 1; + + for (int p = 0; p < n_pos; p++) { + if (pos[p] < order) continue; + + uint64_t ch = hash_ctx(pos[p], ctx_w); + uint64_t ck = ch & bucket_mask_; + auto tgt = static_cast(tokens_[pos[p]]); + uint64_t fk = + hash_with_target(ch, tgt, ctx_w) & bucket_mask_; + + ctx_tables_[oi][ck]++; + full_tables_[oi][fk]++; + } + } + } + + public: + NGramBlender(int min_order, int max_order, int ngram_buckets, int min_count) + : min_order_(min_order), + max_order_(max_order), + n_orders_(max_order - min_order + 1), + min_count_(min_count), + ngram_buckets_(ngram_buckets), + bucket_mask_(static_cast(ngram_buckets - 1)), + alpha_mode_(0), + fixed_alpha_(0.40), + ent_base_(0.05), + ent_range_(0.55), + ent_scale_(2.0), + ent_thresh_(4.0), + order_ent_center_(3.0), + order_ent_slope_(0.25), + mixing_fn_(0) { + ctx_tables_.resize(n_orders_); + full_tables_.resize(n_orders_); + for (int i = 0; i < n_orders_; i++) { + ctx_tables_[i].assign(ngram_buckets, 0); + full_tables_[i].assign(ngram_buckets, 0); + } + } + + void set_tokens( + nb::ndarray, nb::c_contig, nb::device::cpu> + tokens) { + tokens_ = tokens.data(); + n_tokens_ = static_cast(tokens.shape(0)); + } + + void configure_alpha(int mode, double fixed_alpha, double ent_base, + double ent_range, double ent_scale, double ent_thresh, + double order_ent_center, double order_ent_slope) { + alpha_mode_ = mode; + fixed_alpha_ = fixed_alpha; + ent_base_ = ent_base; + ent_range_ = ent_range; + ent_scale_ = ent_scale; + ent_thresh_ = ent_thresh; + order_ent_center_ = order_ent_center; + order_ent_slope_ = order_ent_slope; + } + + void set_mixing_fn(int fn) { mixing_fn_ = fn; } + + void reset() { + for (int i = 0; i < n_orders_; i++) { + std::fill(ctx_tables_[i].begin(), ctx_tables_[i].end(), 0); + std::fill(full_tables_[i].begin(), full_tables_[i].end(), 0); + } + } + + // Process a single stride segment + nb::ndarray> process_stride( + nb::ndarray, nb::c_contig, nb::device::cpu> + positions, + nb::ndarray, nb::c_contig, nb::device::cpu> + model_nll, + nb::ndarray, nb::c_contig, nb::device::cpu> + entropy) { + const int n_pos = static_cast(positions.shape(0)); + auto* out = new double[n_pos]; + + process_stride_impl( + positions.data(), n_pos, model_nll.data(), + (entropy.shape(0) > 0) ? entropy.data() : nullptr, out); + + nb::capsule owner(out, [](void* p) noexcept { + delete[] static_cast(p); + }); + size_t shape[1] = {static_cast(n_pos)}; + return nb::ndarray>(out, 1, shape, + owner); + } + + // Process multiple stride segments in one call (amortizes FFI overhead) + nb::ndarray> process_batch( + nb::ndarray, nb::c_contig, nb::device::cpu> + all_positions, + nb::ndarray, nb::c_contig, nb::device::cpu> + segment_lengths, + nb::ndarray, nb::c_contig, nb::device::cpu> + all_model_nll, + nb::ndarray, nb::c_contig, nb::device::cpu> + all_entropy) { + const int n_segs = static_cast(segment_lengths.shape(0)); + const int32_t* seg_lens = segment_lengths.data(); + const int total = static_cast(all_positions.shape(0)); + const double* ent_base_ptr = + (all_entropy.shape(0) > 0) ? all_entropy.data() : nullptr; + + auto* out = new double[total]; + + int offset = 0; + for (int s = 0; s < n_segs; s++) { + int len = seg_lens[s]; + process_stride_impl( + all_positions.data() + offset, len, + all_model_nll.data() + offset, + ent_base_ptr ? ent_base_ptr + offset : nullptr, + out + offset); + offset += len; + } + + nb::capsule owner(out, [](void* p) noexcept { + delete[] static_cast(p); + }); + size_t shape[1] = {static_cast(total)}; + return nb::ndarray>(out, 1, shape, + owner); + } +}; + +NB_MODULE(fast_ngram_ext, m) { + m.doc() = "C++ accelerated n-gram blend for eval_ngram.py"; + + nb::class_(m, "NGramBlender") + .def(nb::init(), nb::arg("min_order"), + nb::arg("max_order"), nb::arg("ngram_buckets"), + nb::arg("min_count")) + .def("set_tokens", &NGramBlender::set_tokens, nb::arg("tokens")) + .def("configure_alpha", &NGramBlender::configure_alpha, + nb::arg("mode"), nb::arg("fixed_alpha"), nb::arg("ent_base"), + nb::arg("ent_range"), nb::arg("ent_scale"), nb::arg("ent_thresh"), + nb::arg("order_ent_center"), nb::arg("order_ent_slope")) + .def("set_mixing_fn", &NGramBlender::set_mixing_fn, nb::arg("fn")) + .def("reset", &NGramBlender::reset) + .def("process_stride", &NGramBlender::process_stride, + nb::arg("positions"), nb::arg("model_nll"), + nb::arg("entropy")) + .def("process_batch", &NGramBlender::process_batch, + nb::arg("all_positions"), nb::arg("segment_lengths"), + nb::arg("all_model_nll"), nb::arg("all_entropy")); +} diff --git a/records/track_10min_16mb/2026-04-06_TripleLoop_FusedKernels_Ngram/submission.json b/records/track_10min_16mb/2026-04-06_TripleLoop_FusedKernels_Ngram/submission.json new file mode 100644 index 0000000000..9b1e3807cb --- /dev/null +++ b/records/track_10min_16mb/2026-04-06_TripleLoop_FusedKernels_Ngram/submission.json @@ -0,0 +1,50 @@ +{ + "author": "Abay Bektursun", + "github_id": "abaybektursun", + "name": "Triple Loop + Fused Kernels + Parallel Residuals + N-gram Tilt", + "blurb": "Triple depth recurrence (17 virtual layers from 11 physical), Triton TMA + CUTLASS EVT fused MLP kernels, earlier loop activation (0.35), GPT-J parallel residuals (layers 7+), eval-time causal n-gram tilt. PyTorch 2.9.1.", + "date": "2026-04-06", + "track": "10min_16mb", + "val_loss": 2.79772, + "val_bpb": 1.08014, + "val_bpb_std": 0.0004, + "seeds": [1, 42, 1234, 1337, 2025], + "seed_results": { + "1": { + "val_loss": 2.79767, + "val_bpb": 1.08016, + "artifact_bytes": 15978345, + "steps": 4754 + }, + "42": { + "val_loss": 2.79952, + "val_bpb": 1.08077, + "artifact_bytes": 15975585, + "steps": 4758 + }, + "1234": { + "val_loss": 2.79661, + "val_bpb": 1.07971, + "artifact_bytes": 15973639, + "steps": 4748 + }, + "1337": { + "val_loss": 2.79757, + "val_bpb": 1.08015, + "artifact_bytes": 15974187, + "steps": 4756 + }, + "2025": { + "val_loss": 2.79722, + "val_bpb": 1.07992, + "artifact_bytes": 15970317, + "steps": 4755 + } + }, + "hardware": "8xH100 80GB SXM", + "pytorch_version": "2.9.1+cu128", + "bytes_total": 15989931, + "bytes_code": 19811, + "bytes_code_note": "submission.py (143 bytes) + all code LZMA-compressed (19,668 bytes). Includes minified train_gpt.py + CUTLASS EVT source + n-gram C++ source. Unminified train_gpt.py (64KB) included in PR for readability.", + "note": "val_loss is from quantized sliding window eval (nats). val_bpb includes eval-time n-gram tilt (bits per byte). Seed 1234 n-gram BPB was evaluated in terminal." +} diff --git a/records/track_10min_16mb/2026-04-06_TripleLoop_FusedKernels_Ngram/submission.py b/records/track_10min_16mb/2026-04-06_TripleLoop_FusedKernels_Ngram/submission.py new file mode 100644 index 0000000000..74ed90c129 --- /dev/null +++ b/records/track_10min_16mb/2026-04-06_TripleLoop_FusedKernels_Ngram/submission.py @@ -0,0 +1,3 @@ +import lzma, pathlib +_src = lzma.decompress(pathlib.Path(__file__).with_suffix('.py.lzma').read_bytes()) +exec(compile(_src, __file__, 'exec')) diff --git a/records/track_10min_16mb/2026-04-06_TripleLoop_FusedKernels_Ngram/train_gpt.py b/records/track_10min_16mb/2026-04-06_TripleLoop_FusedKernels_Ngram/train_gpt.py new file mode 100644 index 0000000000..f5a226c2c6 --- /dev/null +++ b/records/track_10min_16mb/2026-04-06_TripleLoop_FusedKernels_Ngram/train_gpt.py @@ -0,0 +1,1529 @@ +import collections +import copy +import glob +import io +import lzma +import math +import os +from pathlib import Path +import random +import re +import subprocess +import sys +import time +import uuid + +import numpy as np +import sentencepiece as spm +import torch +import torch.distributed as dist +import torch.nn.functional as F +from torch.nn.parallel import DistributedDataParallel as DDP +from torch import Tensor, nn + +from flash_attn_interface import flash_attn_func as flash_attn_3_func + +# --- CUTLASS EVT backward fusion --- +import sys as _sys +_sys.path.insert(0, os.path.join(os.path.dirname(os.path.abspath(__file__)), "cutlass_evt_fusion")) +import cutlass_evt_fusion +import torch.library +@torch.library.register_fake("cutlass_evt::gemm_mul") +def _gemm_mul_fake(go, down_w, act_grad): + return go.new_empty(go.size(0), down_w.size(1)) + +# --- Triton TMA fused leaky_relu_sq forward --- +import triton +import triton.language as tl +from triton.tools.tensor_descriptor import TensorDescriptor + +@triton.jit +def _fused_leaky_relu_sq_kernel(a_desc, b_desc, c_desc, aux_desc, + M, N, K, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, + NUM_SMS: tl.constexpr): + dtype = tl.bfloat16 + start_pid = tl.program_id(axis=0) + num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) + num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) + k_tiles = tl.cdiv(K, BLOCK_SIZE_K) + num_tiles = num_pid_m * num_pid_n + for tile_id in tl.range(start_pid, num_tiles, NUM_SMS, flatten=True): + pid_m = tile_id // num_pid_n + pid_n = tile_id % num_pid_n + offs_am = pid_m * BLOCK_SIZE_M + offs_bn = pid_n * BLOCK_SIZE_N + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + for ki in range(k_tiles): + offs_k = ki * BLOCK_SIZE_K + a = a_desc.load([offs_am, offs_k]) + b = b_desc.load([offs_bn, offs_k]) + accumulator = tl.dot(a, b.T, accumulator) + # Split into two halves for interleaved write + acc = tl.reshape(accumulator, (BLOCK_SIZE_M, 2, BLOCK_SIZE_N // 2)) + acc = tl.permute(acc, (0, 2, 1)) + acc0, acc1 = tl.split(acc) + c0 = acc0.to(dtype) + # Forward: compute act_grad = 2*leaky_relu(h)*leaky_relu'(h) and post = leaky_relu(h)^2 + c0_ag = tl.where(c0 > 0, 2.0 * c0, 0.5 * c0) + c_desc.store([offs_am, offs_bn], c0_ag) + c0_post = 0.5 * c0_ag * c0 + aux_desc.store([offs_am, offs_bn], c0_post) + c1 = acc1.to(dtype) + c1_ag = tl.where(c1 > 0, 2.0 * c1, 0.5 * c1) + c_desc.store([offs_am, offs_bn + BLOCK_SIZE_N // 2], c1_ag) + c1_post = 0.5 * c1_ag * c1 + aux_desc.store([offs_am, offs_bn + BLOCK_SIZE_N // 2], c1_post) + +def _triton_fused_leaky_relu_sq(a, b): + M, K = a.shape + N, K2 = b.shape + assert K == K2 + act_grad = torch.empty((M, N), device=a.device, dtype=a.dtype) + post = torch.empty((M, N), device=a.device, dtype=a.dtype) + NUM_SMS = torch.cuda.get_device_properties("cuda").multi_processor_count + BLOCK_SIZE_M, BLOCK_SIZE_N, BLOCK_SIZE_K = 128, 256, 64 + a_desc = TensorDescriptor.from_tensor(a, [BLOCK_SIZE_M, BLOCK_SIZE_K]) + b_desc = TensorDescriptor.from_tensor(b, [BLOCK_SIZE_N, BLOCK_SIZE_K]) + c_desc = TensorDescriptor.from_tensor(act_grad, [BLOCK_SIZE_M, BLOCK_SIZE_N // 2]) + aux_desc = TensorDescriptor.from_tensor(post, [BLOCK_SIZE_M, BLOCK_SIZE_N // 2]) + def grid(META): + return (min(NUM_SMS, triton.cdiv(M, BLOCK_SIZE_M) * triton.cdiv(N, BLOCK_SIZE_N)),) + _fused_leaky_relu_sq_kernel[grid]( + a_desc, b_desc, c_desc, aux_desc, M, N, K, + BLOCK_SIZE_M=BLOCK_SIZE_M, BLOCK_SIZE_N=BLOCK_SIZE_N, BLOCK_SIZE_K=BLOCK_SIZE_K, + GROUP_SIZE_M=1, NUM_SMS=NUM_SMS, + num_stages=4, num_warps=8) + return act_grad, post + +# ---------------------------------------- +# Hyperparameters +# ---------------------------------------- + +class Hyperparameters(): + # Experiment settings + data_dir = os.environ.get('DATA_DIR', './data/') + seed = int(os.environ.get('SEED', 1337)) + run_id = os.environ.get("RUN_ID", str(uuid.uuid4())) + + # Training length + iterations = int(os.environ.get('ITERATIONS', 20000)) + warmdown_frac = float(os.environ.get('WARMDOWN_FRAC', 0.667)) + warmup_steps = int(os.environ.get('WARMUP_STEPS', 20)) + train_batch_tokens = int(os.environ.get('TRAIN_BATCH_TOKENS', 2048 * 48 * 8)) + train_seq_len = int(os.environ.get('TRAIN_SEQ_LEN', 2048)) + train_log_every = int(os.environ.get('TRAIN_LOG_EVERY', 500)) + max_wallclock_seconds = float(os.environ.get('MAX_WALLCLOCK_SECONDS', 600.0)) + + # Validation/Evals + val_batch_tokens = int(os.environ.get('VAL_BATCH_TOKENS', 2048 * 32 * 8)) + eval_seq_len = int(os.environ.get('EVAL_SEQ_LEN', 2048)) + val_loss_every = int(os.environ.get('VAL_LOSS_EVERY', 4000)) + sliding_window_enabled = bool(int(os.environ.get('SLIDING_WINDOW_ENABLED', '1'))) + + # Model architecture + vocab_size = int(os.environ.get('VOCAB_SIZE', 8192)) + num_layers = int(os.environ.get('NUM_LAYERS', 11)) + xsa_last_n = int(os.environ.get('XSA_LAST_N', 11)) + model_dim = int(os.environ.get('MODEL_DIM', 512)) + embedding_dim = int(os.environ.get('EMBEDDING_DIM', 512)) + num_kv_heads = int(os.environ.get('NUM_KV_HEADS', 4)) + num_heads = int(os.environ.get('NUM_HEADS', 8)) + mlp_mult = float(os.environ.get('MLP_MULT', 4.0)) + skip_gates_enabled = bool(int(os.environ.get('SKIP_GATES_ENABLED', '1'))) + tie_embeddings = bool(int(os.environ.get('TIE_EMBEDDINGS', '1'))) + logit_softcap = float(os.environ.get('LOGIT_SOFTCAP', 30.0)) + rope_base = float(os.environ.get('ROPE_BASE', 10000.0)) + rope_dims = int(os.environ.get('ROPE_DIMS', 16)) + rope_train_seq_len = int(os.environ.get('ROPE_TRAIN_SEQ_LEN', 2048)) + ln_scale = bool(int(os.environ.get('LN_SCALE', '1'))) + qk_gain_init = float(os.environ.get('QK_GAIN_INIT', 4.0)) + + # Parallel residuals (GPT-J style, layers >= this index) + parallel_residual_start = int(os.environ.get('PARALLEL_RESIDUAL_START', 7)) + hessian_clip_lambda = float(os.environ.get('HESSIAN_CLIP_LAMBDA', 0.175)) + + # Layer looping + num_loops = int(os.environ.get('NUM_LOOPS', 2)) + loop_start = int(os.environ.get('LOOP_START', 4)) + loop_end = int(os.environ.get('LOOP_END', 5)) + enable_looping_at = float(os.environ.get('ENABLE_LOOPING_AT', 0.5)) + + # Optimizer + min_lr = float(os.environ.get('MIN_LR', 0.0)) + embed_lr = float(os.environ.get('EMBED_LR', 0.6)) + head_lr = float(os.environ.get('HEAD_LR', 0.008)) + tied_embed_lr = float(os.environ.get('TIED_EMBED_LR', 0.03)) + tied_embed_init_std = float(os.environ.get('TIED_EMBED_INIT_STD', 0.005)) + matrix_lr = float(os.environ.get('MATRIX_LR', 0.02)) + scalar_lr = float(os.environ.get('SCALAR_LR', 0.02)) + muon_momentum = float(os.environ.get('MUON_MOMENTUM', 0.99)) + muon_backend_steps = int(os.environ.get('MUON_BACKEND_STEPS', 5)) + muon_momentum_warmup_start = float(os.environ.get('MUON_MOMENTUM_WARMUP_START', 0.92)) + muon_momentum_warmup_steps = int(os.environ.get('MUON_MOMENTUM_WARMUP_STEPS', 1500)) + muon_row_normalize = bool(int(os.environ.get('MUON_ROW_NORMALIZE', '1'))) + beta1 = float(os.environ.get('BETA1', 0.9)) + beta2 = float(os.environ.get('BETA2', 0.95)) + adam_eps = float(os.environ.get('ADAM_EPS', 1e-8)) + grad_clip_norm = float(os.environ.get('GRAD_CLIP_NORM', 0.3)) + eval_stride = int(os.environ.get('EVAL_STRIDE', 64)) + muon_beta2 = float(os.environ.get('MUON_BETA2', 0.95)) + adam_wd = float(os.environ.get('ADAM_WD', 0.02)) + muon_wd = float(os.environ.get('MUON_WD', 0.085)) + embed_wd = float(os.environ.get('EMBED_WD', 0.085)) + ema_decay = float(os.environ.get('EMA_DECAY', 0.997)) + + # Quantization & Compression + compressor = os.environ.get('COMPRESSOR', 'brotli') + gptq_calibration_batches = int(os.environ.get('GPTQ_CALIBRATION_BATCHES', 64)) + gptq_reserve_seconds = float(os.environ.get('GPTQ_RESERVE_SECONDS', 12.0)) + matrix_bits = int(os.environ.get('MATRIX_BITS', 6)) + embed_bits = int(os.environ.get('EMBED_BITS', 8)) + matrix_clip_sigmas = float(os.environ.get('MATRIX_CLIP_SIGMAS', 12.85)) + embed_clip_sigmas = float(os.environ.get('EMBED_CLIP_SIGMAS', 20.0)) + + # Distributed setup + distributed = "RANK" in os.environ and "WORLD_SIZE" in os.environ + rank = int(os.environ.get("RANK", "0")) + world_size = int(os.environ.get("WORLD_SIZE", "1")) + local_rank = int(os.environ.get("LOCAL_RANK", "0")) + is_main_process = rank == 0 + grad_accum_steps = 8 // world_size + + # Data paths + datasets_dir = os.path.join(data_dir, 'datasets', f'fineweb10B_sp{vocab_size}') + train_files = os.path.join(datasets_dir, 'fineweb_train_*.bin') + val_files = os.path.join(datasets_dir, 'fineweb_val_*.bin') + tokenizer_path = os.path.join(data_dir, 'tokenizers', f'fineweb_{vocab_size}_bpe.model') + + # Experiment files + logfile = f"logs/{run_id}.txt" + model_path = "final_model.pt" + quantized_model_path = "final_model.int6.ptz" + +# ---------------------------------------- +# Global Logging Function +# ---------------------------------------- + +_logger_hparams = None + + +def set_logging_hparams(h: Hyperparameters) -> None: + global _logger_hparams + _logger_hparams = h + + +def log(msg, console: bool = True) -> None: + if _logger_hparams is None: + print(msg) + return + if _logger_hparams.is_main_process: + if console: + print(msg) + if _logger_hparams.logfile is not None: + with open(_logger_hparams.logfile, "a", encoding="utf-8") as f: + print(msg, file=f) + +# ---------------------------------------- +# Data Loading +# ---------------------------------------- + +class ValidationData: + def __init__(self, h: Hyperparameters, device: torch.device): + self.sp = spm.SentencePieceProcessor(model_file=h.tokenizer_path) + if int(self.sp.vocab_size()) != h.vocab_size: + raise ValueError( + f"VOCAB_SIZE={h.vocab_size} does not match tokenizer vocab_size={int(self.sp.vocab_size())}" + ) + self.val_tokens = load_validation_tokens(h.val_files, h.eval_seq_len) + self.base_bytes_lut, self.has_leading_space_lut, self.is_boundary_token_lut = ( + build_sentencepiece_luts(self.sp, h.vocab_size, device)) + + +def build_sentencepiece_luts( + sp: spm.SentencePieceProcessor, vocab_size: int, device: torch.device +) -> tuple[Tensor, Tensor, Tensor]: + sp_vocab_size = int(sp.vocab_size()) + # The BPB calculation assumes "▁" is its own token so that leading-space bytes + # are counted correctly. See https://github.com/openai/parameter-golf/issues/897 + assert sp.piece_to_id("\u2581") != sp.unk_id(), \ + "Tokenizer must have '▁' (space) as its own token for correct BPB byte counting" + table_size = max(sp_vocab_size, vocab_size) + base_bytes_np = np.zeros((table_size,), dtype=np.int16) + has_leading_space_np = np.zeros((table_size,), dtype=np.bool_) + is_boundary_token_np = np.ones((table_size,), dtype=np.bool_) + for token_id in range(sp_vocab_size): + if sp.is_control(token_id) or sp.is_unknown(token_id) or sp.is_unused(token_id): + continue + is_boundary_token_np[token_id] = False + if sp.is_byte(token_id): + base_bytes_np[token_id] = 1 + continue + piece = sp.id_to_piece(token_id) + if piece.startswith("\u2581"): + has_leading_space_np[token_id] = True + piece = piece[1:] + base_bytes_np[token_id] = len(piece.encode("utf-8")) + return ( + torch.tensor(base_bytes_np, dtype=torch.int16, device=device), + torch.tensor(has_leading_space_np, dtype=torch.bool, device=device), + torch.tensor(is_boundary_token_np, dtype=torch.bool, device=device), + ) + + +def load_validation_tokens(pattern: str, seq_len: int) -> Tensor: + files = [Path(p) for p in sorted(glob.glob(pattern))] + if not files: + raise FileNotFoundError(f"No files found for pattern: {pattern}") + # The export pipeline writes the fixed first-50k-doc validation set to fineweb_val_*. + tokens = torch.cat([load_data_shard(file) for file in files]).contiguous() + usable = ((tokens.numel() - 1) // seq_len) * seq_len + if usable <= 0: + raise ValueError(f"Validation split is too short for TRAIN_SEQ_LEN={seq_len}") + return tokens[: usable + 1] + + +def load_data_shard(file: Path) -> Tensor: + header_bytes = 256 * np.dtype(" int: + key = str(file) + cached = _SHARD_NTOKENS_CACHE.get(key) + if cached is not None: + return cached + header = np.fromfile(file, dtype=" np.memmap: + key = str(file) + mm = _MMAP_CACHE.get(key) + if mm is not None: + return mm + n = _read_num_tokens(file) + mm = np.memmap(file, mode="r", dtype=" None: + max_phase = min(self.seq_len - 1, max(0, self.num_tokens[si] - self.seq_len - 1)) + phase = int(self.rng.integers(max_phase + 1)) if max_phase > 0 else 0 + num_sequences = (self.num_tokens[si] - 1 - phase) // self.seq_len + sequence_order = self.rng.permutation(num_sequences) + self.start_inds[si] = (phase + sequence_order * self.seq_len).tolist() + + def next_batch(self, global_tokens: int, grad_accum_steps: int) -> tuple[Tensor, Tensor]: + device_tokens = global_tokens // (self.world_size * grad_accum_steps) + device_batch_size = device_tokens // self.seq_len + remaining = np.array([len(s) for s in self.start_inds], dtype=np.float64) + x = torch.empty((device_batch_size, self.seq_len), dtype=torch.int64) + y = torch.empty((device_batch_size, self.seq_len), dtype=torch.int64) + for bi in range(device_batch_size): + total = remaining.sum() + if total <= 0: + for si in range(len(self.files)): + self._reset_shard(si) + remaining = np.array([len(s) for s in self.start_inds], dtype=np.float64) + total = remaining.sum() + probs = remaining / total + si = int(self.rng.choice(len(self.files), p=probs)) + start_ind = self.start_inds[si].pop() + remaining[si] -= 1 + mm = _get_shard_memmap(self.files[si]) + window = torch.as_tensor( + np.array(mm[start_ind:start_ind + self.seq_len + 1], dtype=np.int64)) + x[bi] = window[:-1] + y[bi] = window[1:] + return x.to(self.device, non_blocking=True), y.to(self.device, non_blocking=True) + +# ---------------------------------------- +# Model Architecture +# ---------------------------------------- + +class RMSNorm(nn.Module): + def __init__(self, eps: float | None = None): + super().__init__() + self.eps = eps + + def forward(self, x: Tensor) -> Tensor: + return F.rms_norm(x, (x.size(-1),), eps=self.eps) + + +class CastedLinear(nn.Linear): + def forward(self, x: Tensor) -> Tensor: + w = self.weight.to(x.dtype) + bias = self.bias.to(x.dtype) if self.bias is not None else None + return F.linear(x, w, bias) + + +class Rotary(nn.Module): + def __init__(self, dim: int, base: float = 10000.0, train_seq_len: int = 1024, rope_dims: int = 0): + super().__init__() + self.dim = dim + self.base = base + self.train_seq_len = train_seq_len + self.rope_dims = rope_dims if rope_dims > 0 else dim + inv_freq = 1.0 / (base ** (torch.arange(0, self.rope_dims, 2, dtype=torch.float32) / self.rope_dims)) + self.register_buffer("inv_freq", inv_freq, persistent=False) + self._seq_len_cached = 0 + self._cos_cached: Tensor | None = None + self._sin_cached: Tensor | None = None + + def forward(self, seq_len: int, device: torch.device, dtype: torch.dtype) -> tuple[Tensor, Tensor]: + if ( + self._cos_cached is None + or self._sin_cached is None + or self._seq_len_cached != seq_len + or self._cos_cached.device != device + ): + rd = self.rope_dims + if seq_len > self.train_seq_len: + scale = seq_len / self.train_seq_len + new_base = self.base * (scale ** (rd / (rd - 2))) + inv_freq = 1.0 / (new_base ** (torch.arange( + 0, rd, 2, dtype=torch.float32, device=device) / rd)) + else: + inv_freq = self.inv_freq.to(device) + t = torch.arange(seq_len, device=device, dtype=inv_freq.dtype) + freqs = torch.outer(t, inv_freq) + self._cos_cached = freqs.cos()[None, :, None, :] + self._sin_cached = freqs.sin()[None, :, None, :] + self._seq_len_cached = seq_len + return self._cos_cached.to(dtype=dtype), self._sin_cached.to(dtype=dtype) + + +def apply_rotary_emb(x: Tensor, cos: Tensor, sin: Tensor, rope_dims: int = 0) -> Tensor: + if rope_dims > 0 and rope_dims < x.size(-1): + x_rope, x_pass = x[..., :rope_dims], x[..., rope_dims:] + half = rope_dims // 2 + x1, x2 = x_rope[..., :half], x_rope[..., half:] + x_rope = torch.cat((x1 * cos + x2 * sin, x1 * (-sin) + x2 * cos), dim=-1) + return torch.cat((x_rope, x_pass), dim=-1) + half = x.size(-1) // 2 + x1, x2 = x[..., :half], x[..., half:] + return torch.cat((x1 * cos + x2 * sin, x1 * (-sin) + x2 * cos), dim=-1) + + +class CausalSelfAttention(nn.Module): + def __init__(self, dim: int, num_heads: int, num_kv_heads: int, + rope_base: float, qk_gain_init: float, train_seq_len: int): + super().__init__() + if dim % num_heads != 0: + raise ValueError("model_dim must be divisible by num_heads") + if num_heads % num_kv_heads != 0: + raise ValueError("num_heads must be divisible by num_kv_heads") + self.num_heads = num_heads + self.num_kv_heads = num_kv_heads + self.head_dim = dim // num_heads + if self.head_dim % 2 != 0: + raise ValueError("head_dim must be even for RoPE") + kv_dim = self.num_kv_heads * self.head_dim + self.c_q = CastedLinear(dim, dim, bias=False) + self.c_k = CastedLinear(dim, kv_dim, bias=False) + self.c_v = CastedLinear(dim, kv_dim, bias=False) + self.proj = CastedLinear(dim, dim, bias=False) + self.proj._zero_init = True + self.q_gain = nn.Parameter(torch.full((num_heads,), qk_gain_init, dtype=torch.float32)) + self.rope_dims = 0 + self.rotary = Rotary(self.head_dim, base=rope_base, train_seq_len=train_seq_len) + self.use_xsa = False + + def _xsa_efficient(self, y: Tensor, v: Tensor) -> Tensor: + B, T, H, D = y.shape + Hkv = v.size(-2) + group = H // Hkv + y_g = y.reshape(B, T, Hkv, group, D) + vn = F.normalize(v, dim=-1).unsqueeze(-2) + proj = (y_g * vn).sum(dim=-1, keepdim=True) * vn + return (y_g - proj).reshape(B, T, H, D) + + def forward(self, x: Tensor) -> Tensor: + bsz, seqlen, dim = x.shape + q = self.c_q(x).reshape(bsz, seqlen, self.num_heads, self.head_dim) + k = self.c_k(x).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim) + v = self.c_v(x).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim) + q = F.rms_norm(q, (q.size(-1),)) + k = F.rms_norm(k, (k.size(-1),)) + cos, sin = self.rotary(seqlen, x.device, q.dtype) + q = apply_rotary_emb(q, cos, sin, self.rope_dims) + k = apply_rotary_emb(k, cos, sin, self.rope_dims) + q = q * self.q_gain.to(dtype=q.dtype)[None, None, :, None] + y = flash_attn_3_func(q, k, v, causal=True) + if self.use_xsa: + y = self._xsa_efficient(y, v) + y = y.reshape(bsz, seqlen, dim) + return self.proj(y) + + +class _FullFusedMLP(torch.autograd.Function): + @staticmethod + def forward(ctx, x, fc_w, proj_w): + x_flat = x.reshape(-1, x.shape[-1]) + # Triton TMA: fused fc(x) + leaky_relu + square + act_grad, post = _triton_fused_leaky_relu_sq(x_flat, fc_w) + out = F.linear(post, proj_w) + ctx.save_for_backward(x_flat, fc_w, proj_w, act_grad, post) + return out.reshape(*x.shape[:-1], out.shape[-1]) + + @staticmethod + def backward(ctx, grad_output): + x_flat, fc_w, proj_w, act_grad, post = ctx.saved_tensors + go = grad_output.reshape(-1, grad_output.shape[-1]) + dW_proj = go.T @ post + # CUTLASS EVT: fused (go @ proj_w) * act_grad + dpre = torch.ops.cutlass_evt.gemm_mul(go, proj_w, act_grad) + dW_fc = dpre.T @ x_flat + dx = dpre @ fc_w + return dx.reshape(grad_output.shape), dW_fc, dW_proj + + +class MLP(nn.Module): + def __init__(self, dim: int, mlp_mult: int): + super().__init__() + hidden = int(mlp_mult * dim) + self.fc = CastedLinear(dim, hidden, bias=False) + self.proj = CastedLinear(hidden, dim, bias=False) + self.proj._zero_init = True + + def forward(self, x: Tensor) -> Tensor: + if x.is_cuda and self.training: + return _FullFusedMLP.apply(x, self.fc.weight.to(x.dtype), self.proj.weight.to(x.dtype)) + return self.proj(F.leaky_relu(self.fc(x), negative_slope=0.5).square()) + + +class Block(nn.Module): + def __init__(self, dim: int, num_heads: int, num_kv_heads: int, mlp_mult: int, + rope_base: float, qk_gain_init: float, train_seq_len: int, + layer_idx: int = 0, ln_scale: bool = False, parallel: bool = False): + super().__init__() + self.parallel = parallel + self.attn_norm = RMSNorm() + self.mlp_norm = RMSNorm() + self.attn = CausalSelfAttention( + dim, num_heads, num_kv_heads, rope_base, qk_gain_init, train_seq_len) + self.mlp = MLP(dim, mlp_mult) + self.attn_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + self.mlp_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + self.resid_mix = nn.Parameter(torch.stack((torch.ones(dim), torch.zeros(dim))).float()) + self.ln_scale_factor = 1.0 / math.sqrt(layer_idx + 1) if ln_scale else 1.0 + + def forward(self, x: Tensor, x0: Tensor) -> Tensor: + mix = self.resid_mix.to(dtype=x.dtype) + x_in = mix[0][None, None, :] * x + mix[1][None, None, :] * x0 + normed = self.attn_norm(x_in) * self.ln_scale_factor + attn_out = self.attn(normed) + if self.parallel: + # GPT-J style: attn and MLP both read from same input + mlp_out = self.mlp(self.mlp_norm(x_in) * self.ln_scale_factor) + x_out = x_in + self.attn_scale.to(dtype=x_in.dtype)[None, None, :] * attn_out + self.mlp_scale.to(dtype=x_in.dtype)[None, None, :] * mlp_out + else: + x_out = x_in + self.attn_scale.to(dtype=x_in.dtype)[None, None, :] * attn_out + x_out = x_out + self.mlp_scale.to(dtype=x_out.dtype)[None, None, :] * self.mlp( + self.mlp_norm(x_out) * self.ln_scale_factor) + return x_out + + +class GPT(nn.Module): + def __init__(self, h: Hyperparameters): + super().__init__() + if h.logit_softcap <= 0.0: + raise ValueError(f"logit_softcap must be positive, got {h.logit_softcap}") + self.tie_embeddings = h.tie_embeddings + self.tied_embed_init_std = h.tied_embed_init_std + self.logit_softcap = h.logit_softcap + self.tok_emb = nn.Embedding(h.vocab_size, h.embedding_dim) + if h.embedding_dim != h.model_dim: + self.embed_proj = CastedLinear(h.embedding_dim, h.model_dim, bias=False) + self.head_proj = CastedLinear(h.model_dim, h.embedding_dim, bias=False) + else: + self.embed_proj = None + self.head_proj = None + self.num_encoder_layers = h.num_layers // 2 + self.num_decoder_layers = h.num_layers - self.num_encoder_layers + self.blocks = nn.ModuleList([ + Block(h.model_dim, h.num_heads, h.num_kv_heads, h.mlp_mult, h.rope_base, + h.qk_gain_init, h.train_seq_len, layer_idx=i, ln_scale=h.ln_scale, + parallel=(i >= h.parallel_residual_start)) + for i in range(h.num_layers) + ]) + if h.rope_dims > 0: + head_dim = h.model_dim // h.num_heads + for block in self.blocks: + block.attn.rope_dims = h.rope_dims + block.attn.rotary = Rotary(head_dim, base=h.rope_base, train_seq_len=h.train_seq_len, rope_dims=h.rope_dims) + self.final_norm = RMSNorm() + self.lm_head = None if h.tie_embeddings else CastedLinear(h.embedding_dim, h.vocab_size, bias=False) + if self.lm_head is not None: + self.lm_head._zero_init = True + if h.xsa_last_n > 0: + for i in range(max(0, h.num_layers - h.xsa_last_n), h.num_layers): + self.blocks[i].attn.use_xsa = True + + # Layer looping + self.looping_active: bool = False + if h.num_loops > 0: + loop_seg = list(range(h.loop_start, h.loop_end + 1)) + all_indices = list(range(h.loop_start)) + for _ in range(h.num_loops + 1): + all_indices.extend(loop_seg) + all_indices.extend(range(h.loop_end + 1, h.num_layers)) + num_enc = len(all_indices) // 2 + self.encoder_indices: list[int] = all_indices[:num_enc] + self.decoder_indices: list[int] = all_indices[num_enc:] + else: + self.encoder_indices = list(range(self.num_encoder_layers)) + self.decoder_indices = list(range(self.num_encoder_layers, h.num_layers)) + self.num_skip_weights = min(len(self.encoder_indices), len(self.decoder_indices)) + self.skip_weights = nn.Parameter(torch.ones(self.num_skip_weights, h.model_dim, dtype=torch.float32)) + self.skip_gates = nn.Parameter(torch.zeros(self.num_skip_weights, h.model_dim, dtype=torch.float32)) if h.skip_gates_enabled else None + + self._init_weights() + + def _init_weights(self) -> None: + if self.tie_embeddings: + nn.init.normal_(self.tok_emb.weight, mean=0.0, std=self.tied_embed_init_std) + for name, module in self.named_modules(): + if isinstance(module, nn.Linear): + if getattr(module, "_zero_init", False): + nn.init.zeros_(module.weight) + elif (module.weight.ndim == 2 and module.weight.shape[0] >= 64 and + module.weight.shape[1] >= 64): + nn.init.orthogonal_(module.weight, gain=1.0) + + def forward_logits(self, input_ids: Tensor) -> Tensor: + x = self.tok_emb(input_ids) + x = F.rms_norm(x, (x.size(-1),)) + if self.embed_proj is not None: + x = self.embed_proj(x) + x0 = x + skips: list[Tensor] = [] + enc_iter = self.encoder_indices if self.looping_active else range(self.num_encoder_layers) + dec_iter = self.decoder_indices if self.looping_active else range(self.num_encoder_layers, self.num_encoder_layers + self.num_decoder_layers) + for i in enc_iter: + x = self.blocks[i](x, x0) + skips.append(x) + for skip_idx, i in enumerate(dec_iter): + if skip_idx < self.num_skip_weights and skips: + scaled_skip = self.skip_weights[skip_idx].to(dtype=x.dtype)[None, None, :] * skips.pop() + if self.skip_gates is not None: + g = torch.sigmoid(self.skip_gates[skip_idx].to(dtype=x.dtype))[None, None, :] + x = torch.lerp(scaled_skip, x, g) + else: + x = x + scaled_skip + x = self.blocks[i](x, x0) + x = self.final_norm(x) + if self.head_proj is not None: + x = self.head_proj(x) + if self.tie_embeddings: + logits_proj = F.linear(x, self.tok_emb.weight) + else: + logits_proj = self.lm_head(x) + return self.logit_softcap * torch.tanh(logits_proj / self.logit_softcap) + + def forward(self, input_ids: Tensor, target_ids: Tensor) -> Tensor: + logits = self.forward_logits(input_ids) + return F.cross_entropy( + logits.reshape(-1, logits.size(-1)).float(), target_ids.reshape(-1), reduction="mean") + + +def classify_param(name: str) -> str: + if "tok_emb" in name or "lm_head" in name: + return "embed" + if ".mlp." in name: + return "mlp" + if ".attn." in name or (".proj." in name and ".mlp." not in name): + return "attn" + return "other" + +# ---------------------------------------- +# Optimization +# ---------------------------------------- + +@torch.compile +def zeropower_via_newtonschulz5(G: Tensor, steps: int = 10, eps: float = 1e-7) -> Tensor: + a, b, c = (3.4445, -4.7750, 2.0315) + X = G.bfloat16() + X /= X.norm() + eps + transposed = G.size(0) > G.size(1) + if transposed: + X = X.T + for _ in range(steps): + A = X @ X.T + B = b * A + c * A @ A + X = a * X + B @ X + return X.T if transposed else X + + +class Muon(torch.optim.Optimizer): + def __init__(self, params, lr: float, momentum: float, backend_steps: int, + nesterov: bool = True, weight_decay: float = 0.0, + row_normalize: bool = False): + super().__init__( + params, + dict(lr=lr, momentum=momentum, backend_steps=backend_steps, + nesterov=nesterov, weight_decay=weight_decay, + row_normalize=row_normalize), + ) + + @torch.no_grad() + def step(self, closure=None): + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + distributed = dist.is_available() and dist.is_initialized() + world_size = dist.get_world_size() if distributed else 1 + rank = dist.get_rank() if distributed else 0 + for group in self.param_groups: + params = group["params"] + if not params: + continue + lr = group["lr"] + momentum = group["momentum"] + backend_steps = group["backend_steps"] + nesterov = group["nesterov"] + total_params = sum(int(p.numel()) for p in params) + updates_flat = torch.zeros(total_params, device=params[0].device, dtype=torch.bfloat16) + curr = 0 + for i, p in enumerate(params): + if i % world_size == rank and p.grad is not None: + g = p.grad + state = self.state[p] + if "momentum_buffer" not in state: + state["momentum_buffer"] = torch.zeros_like(g) + buf = state["momentum_buffer"] + buf.mul_(momentum).add_(g) + if nesterov: + g = g.add(buf, alpha=momentum) + if group.get("row_normalize", False): + row_norms = g.float().norm(dim=-1, keepdim=True).clamp_min(1e-07) + g = g / row_norms.to(g.dtype) + g = zeropower_via_newtonschulz5(g, steps=backend_steps) + g *= max(1, g.size(0) / g.size(1)) ** 0.5 + updates_flat[curr : curr + p.numel()] = g.reshape(-1) + curr += p.numel() + if distributed: + dist.all_reduce(updates_flat, op=dist.ReduceOp.SUM) + wd = group.get("weight_decay", 0.0) + curr = 0 + for p in params: + if wd > 0.0: + p.data.mul_(1.0 - lr * wd) + g = updates_flat[curr : curr + p.numel()].view_as(p).to(dtype=p.dtype) + p.add_(g, alpha=-lr) + curr += p.numel() + return loss + + +CONTROL_TENSOR_NAME_PATTERNS = tuple( + pattern + for pattern in os.environ.get( + "CONTROL_TENSOR_NAME_PATTERNS", + "attn_scale,attn_scales,mlp_scale,mlp_scales,resid_mix,resid_mixes,q_gain,skip_weight,skip_weights,skip_gates", + ).split(",") + if pattern +) + + +class Optimizers(): + def __init__(self, h: Hyperparameters, base_model: GPT): + block_named_params = list(base_model.blocks.named_parameters()) + matrix_params = [ + p + for name, p in block_named_params + if p.ndim == 2 and not any(pattern in name for pattern in + CONTROL_TENSOR_NAME_PATTERNS) + ] + scalar_params = [ + p + for name, p in block_named_params + if p.ndim < 2 or any(pattern in name for pattern in + CONTROL_TENSOR_NAME_PATTERNS) + ] + if base_model.skip_weights.numel() > 0: + scalar_params.append(base_model.skip_weights) + if base_model.skip_gates is not None and base_model.skip_gates.numel() > 0: + scalar_params.append(base_model.skip_gates) + + token_lr = h.tied_embed_lr if h.tie_embeddings else h.embed_lr + tok_params = [{"params": [base_model.tok_emb.weight], "lr": token_lr, "base_lr": token_lr}] + self.optimizer_tok = torch.optim.AdamW( + tok_params, + betas=(h.beta1, h.beta2), + eps=h.adam_eps, + weight_decay=h.embed_wd, + fused=True, + ) + self.optimizer_muon = Muon( + matrix_params, + lr=h.matrix_lr, + momentum=h.muon_momentum, + backend_steps=h.muon_backend_steps, + weight_decay=h.muon_wd, + row_normalize=h.muon_row_normalize, + ) + for group in self.optimizer_muon.param_groups: + group["base_lr"] = h.matrix_lr + self.optimizer_scalar = torch.optim.AdamW( + [{"params": scalar_params, "lr": h.scalar_lr, "base_lr": h.scalar_lr}], + betas=(h.beta1, h.beta2), + eps=h.adam_eps, + weight_decay=h.adam_wd, + fused=True, + ) + self.optimizers = [self.optimizer_tok, self.optimizer_muon, self.optimizer_scalar] + if base_model.lm_head is not None: + self.optimizer_head = torch.optim.Adam( + [{"params": [base_model.lm_head.weight], "lr": h.head_lr, "base_lr": h.head_lr}], + betas=(h.beta1, h.beta2), + eps=h.adam_eps, + fused=True, + ) + self.optimizers.insert(1, self.optimizer_head) + else: + self.optimizer_head = None + + def __iter__(self): + return iter(self.optimizers) + + def zero_grad_all(self) -> None: + for opt in self.optimizers: + opt.zero_grad(set_to_none=True) + + def step(self): + for opt in self.optimizers: + opt.step() + self.zero_grad_all() + +# ---------------------------------------- +# Quantization +# ---------------------------------------- + +def restore_fp32_params(model: nn.Module) -> None: + for module in model.modules(): + if isinstance(module, CastedLinear): + module.float() + for name, param in model.named_parameters(): + if (param.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS)) and param.dtype != torch.float32: + param.data = param.data.float() + + +def collect_hessians( + model: nn.Module, + train_loader: ShuffledSequenceLoader, + h: Hyperparameters, + device: torch.device, + n_calibration_batches: int = 64, +) -> dict[str, Tensor]: + hessians: dict[str, Tensor] = {} + hooks = [] + + def make_hook(name: str): + def hook_fn(module, inp, out): + x = inp[0].detach().float() + if x.ndim == 3: + x = x.reshape(-1, x.shape[-1]) + if name not in hessians: + hessians[name] = torch.zeros( + x.shape[1], x.shape[1], dtype=torch.float32, device=device + ) + hessians[name].addmm_(x.T, x) + return hook_fn + + for name, module in model.named_modules(): + if isinstance(module, CastedLinear) and module.weight.numel() > 65536: + cat = classify_param(name + ".weight") + if cat in ("mlp", "attn"): + hooks.append(module.register_forward_hook(make_hook(name + ".weight"))) + + if model.tie_embeddings: + hook_module = model.head_proj if model.head_proj is not None else model.final_norm + def make_output_hook(name: str): + def hook_fn(module, inp, out): + x = out.detach().float() + if x.ndim == 3: + x = x.reshape(-1, x.shape[-1]) + if name not in hessians: + hessians[name] = torch.zeros( + x.shape[1], x.shape[1], dtype=torch.float32, device=device + ) + hessians[name].addmm_(x.T, x) + return hook_fn + hooks.append(hook_module.register_forward_hook(make_output_hook("tok_emb.weight"))) + + model.eval() + with torch.no_grad(): + for _ in range(n_calibration_batches): + x, _ = train_loader.next_batch(h.train_batch_tokens, h.grad_accum_steps) + model.forward_logits(x) + + for hook in hooks: + hook.remove() + + for name in hessians: + hessians[name] = hessians[name].cpu() / n_calibration_batches + + return hessians + + +def gptq_quantize_weight( + w: Tensor, + H: Tensor, + clip_sigmas: float = 3.0, + clip_range: int = 63, + block_size: int = 128, +) -> tuple[Tensor, Tensor]: + W_orig = w.float().clone() + rows, cols = W_orig.shape + H = H.float().clone() + + dead = torch.diag(H) == 0 + H[dead, dead] = 1 + damp = 0.01 * H.diag().mean() + H.diagonal().add_(damp) + + perm = torch.argsort(H.diag(), descending=True) + invperm = torch.argsort(perm) + W_perm = W_orig[:, perm].clone() + W_perm[:, dead[perm]] = 0 + H = H[perm][:, perm] + + Hinv = torch.cholesky_inverse(torch.linalg.cholesky(H)) + Hinv = torch.linalg.cholesky(Hinv, upper=True) + + row_std = W_orig.std(dim=1) + # Hessian-Aware SDClip: modulate clip per-row by importance + _hsdclip_lambda = float(os.environ.get('HESSIAN_CLIP_LAMBDA', '0.175')) + if _hsdclip_lambda > 0: + h_diag = H.diag().clamp_min(1e-10) + row_importance = (W_orig.float().pow(2) * h_diag[None, :]).sum(dim=1) + row_importance = row_importance / row_importance.mean() + clip_mod = 1.0 + _hsdclip_lambda * (row_importance - 1.0) + s = (clip_sigmas * row_std * clip_mod / clip_range).clamp_min(1e-10).to(torch.float16) + else: + s = (clip_sigmas * row_std / clip_range).clamp_min(1e-10).to(torch.float16) + sf = s.float() + + Q = torch.zeros(rows, cols, dtype=torch.int8) + W_work = W_perm.clone() + for i1 in range(0, cols, block_size): + i2 = min(i1 + block_size, cols) + W_block = W_work[:, i1:i2].clone() + Hinv_block = Hinv[i1:i2, i1:i2] + Err = torch.zeros(rows, i2 - i1) + for j in range(i2 - i1): + w_col = W_block[:, j] + d = Hinv_block[j, j] + q_col = torch.clamp(torch.round(w_col / sf), -clip_range, clip_range) + Q[:, i1 + j] = q_col.to(torch.int8) + err = (w_col - q_col.float() * sf) / d + Err[:, j] = err + W_block[:, j:] -= err.unsqueeze(1) * Hinv_block[j, j:].unsqueeze(0) + if i2 < cols: + W_work[:, i2:] -= Err @ Hinv[i1:i2, i2:] + + return Q[:, invperm], s + + +def gptq_mixed_quantize( + state_dict: dict[str, Tensor], + hessians: dict[str, Tensor], + h: Hyperparameters, +) -> tuple[dict[str, Tensor], dict[str, object]]: + result: dict[str, Tensor] = {} + meta: dict[str, object] = {} + + for name, tensor in state_dict.items(): + t = tensor.detach().cpu().contiguous() + if not t.is_floating_point() or t.numel() <= 65536: + result[name] = t.to(torch.float16) if t.is_floating_point() else t + meta[name] = "passthrough (float16)" + continue + cs = h.embed_clip_sigmas if "tok_emb" in name else h.matrix_clip_sigmas + bits = h.embed_bits if "tok_emb" in name else h.matrix_bits + q, s = gptq_quantize_weight( + t, hessians[name], clip_sigmas=cs, clip_range=2**(bits - 1) - 1) + result[name + ".q"] = q + result[name + ".scale"] = s + meta[name] = f"gptq (int{bits})" + + categories = collections.defaultdict(set) + for name, cat in meta.items(): + short = re.sub(r'\.\d+$', '', re.sub(r'blocks\.\d+', 'blocks', name)) + categories[cat].add(short) + log("Quantized weights:") + for cat in sorted(categories): + log(f" {cat}: {', '.join(sorted(categories[cat]))}") + + return result, meta + + +def dequantize_mixed(result: dict[str, Tensor], meta: dict[str, object], + template_sd: dict[str, Tensor]) -> dict[str, Tensor]: + out: dict[str, Tensor] = {} + for name, orig in template_sd.items(): + info = meta.get(name) + if info is None: + continue + orig_dtype = orig.dtype + if "passthrough" in info: + t = result[name] + if t.dtype == torch.float16 and orig_dtype in (torch.float32, torch.bfloat16): + t = t.to(orig_dtype) + out[name] = t + continue + q, s = result[name + ".q"], result[name + ".scale"] + if s.ndim > 0: + out[name] = (q.float() * s.float().view(q.shape[0], *([1] * (q.ndim - 1)))).to(orig_dtype) + else: + out[name] = (q.float() * float(s.item())).to(orig_dtype) + return out + + +_BSHF_MAGIC = b"BSHF" + + +def _byte_shuffle(data: bytes, stride: int = 2) -> bytes: + if stride <= 1 or len(data) < stride: + return data + src = np.frombuffer(data, dtype=np.uint8) + n = len(src) + out = np.empty(n, dtype=np.uint8) + dest_off = 0 + for pos in range(stride): + chunk = src[pos::stride] + out[dest_off:dest_off + len(chunk)] = chunk + dest_off += len(chunk) + return _BSHF_MAGIC + bytes([stride]) + out.tobytes() + + +def _byte_unshuffle(data: bytes) -> bytes: + if len(data) < 5 or data[:4] != _BSHF_MAGIC: + return data + stride = data[4] + if stride < 2: + return data[5:] + payload = np.frombuffer(data, dtype=np.uint8, offset=5) + n = len(payload) + out = np.empty(n, dtype=np.uint8) + src_off = 0 + for pos in range(stride): + chunk_len = n // stride + (1 if pos < n % stride else 0) + out[pos::stride][:chunk_len] = payload[src_off:src_off + chunk_len] + src_off += chunk_len + return out.tobytes() + + +def _compress(data: bytes, compressor: str) -> bytes: + data = _byte_shuffle(data) + if compressor == "lzma": + return lzma.compress(data, preset=6) + elif compressor == "brotli": + import brotli + return brotli.compress(data, quality=11) + raise ValueError(f"Unknown compressor: {compressor!r}") + + +def _decompress(data: bytes, compressor: str) -> bytes: + if compressor == "lzma": + raw = lzma.decompress(data) + elif compressor == "brotli": + import brotli + raw = brotli.decompress(data) + else: + raise ValueError(f"Unknown compressor: {compressor!r}") + raw = _byte_unshuffle(raw) + return raw + + +def serialize(h: Hyperparameters, base_model: torch.nn.Module, code: str) -> tuple[int, int]: + code_bytes = len(code.encode("utf-8")) + if h.is_main_process: + torch.save(base_model.state_dict(), h.model_path) + model_bytes = os.path.getsize(h.model_path) + log(f"Serialized model: {model_bytes} bytes") + log(f"Code size: {code_bytes} bytes") + + sd_cpu = {k: v.detach().cpu() for k, v in base_model.state_dict().items()} + device = torch.device("cuda", h.local_rank) + log("GPTQ:collecting Hessians from calibration data...") + t0 = time.perf_counter() + calib_loader = ShuffledSequenceLoader(h, device) + hessians = collect_hessians( + base_model, calib_loader, h, device, + n_calibration_batches=h.gptq_calibration_batches, + ) + log(f"GPTQ:collected {len(hessians)} Hessians in {time.perf_counter() - t0:.1f}s") + quant_result, quant_meta = gptq_mixed_quantize(sd_cpu, hessians, h) + + quant_buf = io.BytesIO() + torch.save({"w": quant_result, "m": quant_meta}, quant_buf) + quant_raw = quant_buf.getvalue() + quant_blob = _compress(quant_raw, h.compressor) + quant_file_bytes = len(quant_blob) + bytes_total = quant_file_bytes + code_bytes + if h.is_main_process: + with open(h.quantized_model_path, "wb") as f: + f.write(quant_blob) + log(f"Serialized model quantized+{h.compressor}: {quant_file_bytes} bytes") + log(f"Total submission size quantized+{h.compressor}: {bytes_total} bytes") + return bytes_total, quant_file_bytes + + +def deserialize(h: Hyperparameters, device: torch.device) -> GPT: + eval_model = GPT(h).to(device).bfloat16() + restore_fp32_params(eval_model) + sd_cpu = {k: v.detach().cpu() for k, v in eval_model.state_dict().items()} + + with open(h.quantized_model_path, "rb") as f: + quant_blob_disk = f.read() + quant_state = torch.load( + io.BytesIO(_decompress(quant_blob_disk, h.compressor)), + map_location="cpu", + ) + deq_state = dequantize_mixed(quant_state["w"], quant_state["m"], sd_cpu) + eval_model.load_state_dict(deq_state, strict=True) + + return eval_model + +# ---------------------------------------- +# Evaluation +# ---------------------------------------- + +def _loss_bpb(loss_sum, token_count, byte_count) -> tuple[float, float]: + val_loss = (loss_sum / token_count).item() + val_bpb = val_loss / math.log(2.0) * (token_count.item() / byte_count.item()) + return val_loss, val_bpb + + +def eval_val( + h: Hyperparameters, + device: torch.device, + val_data: ValidationData, + model: nn.Module +) -> tuple[float, float]: + seq_len = h.eval_seq_len + local_batch_tokens = h.val_batch_tokens // (h.world_size * h.grad_accum_steps) + if local_batch_tokens < seq_len: + raise ValueError( + "VAL_BATCH_SIZE must provide at least one sequence per rank; " + f"got VAL_BATCH_SIZE={h.val_batch_tokens}, WORLD_SIZE={h.world_size}, " + f"GRAD_ACCUM_STEPS={h.grad_accum_steps}, seq_len={seq_len}" + ) + local_batch_seqs = local_batch_tokens // seq_len + total_seqs = (val_data.val_tokens.numel() - 1) // seq_len + seq_start = (total_seqs * h.rank) // h.world_size + seq_end = (total_seqs * (h.rank + 1)) // h.world_size + val_loss_sum = torch.zeros((), device=device, dtype=torch.float64) + val_token_count = torch.zeros((), device=device, dtype=torch.float64) + val_byte_count = torch.zeros((), device=device, dtype=torch.float64) + + model.eval() + with torch.inference_mode(): + for batch_seq_start in range(seq_start, seq_end, local_batch_seqs): + batch_seq_end = min(batch_seq_start + local_batch_seqs, seq_end) + raw_start = batch_seq_start * seq_len + raw_end = batch_seq_end * seq_len + 1 + local = val_data.val_tokens[raw_start:raw_end].to( + device=device, dtype=torch.int64, non_blocking=True) + x = local[:-1].reshape(-1, seq_len) + y = local[1:].reshape(-1, seq_len) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + batch_loss = model(x, y).detach() + batch_token_count = float(y.numel()) + val_loss_sum += batch_loss.to(torch.float64) * batch_token_count + val_token_count += batch_token_count + prev_ids = x.reshape(-1) + tgt_ids = y.reshape(-1) + token_bytes = val_data.base_bytes_lut[tgt_ids].to(dtype=torch.int16) + token_bytes += (val_data.has_leading_space_lut[tgt_ids] & + ~val_data.is_boundary_token_lut[prev_ids]).to(dtype=torch.int16) + val_byte_count += token_bytes.to(torch.float64).sum() + + if dist.is_available() and dist.is_initialized(): + dist.all_reduce(val_loss_sum, op=dist.ReduceOp.SUM) + dist.all_reduce(val_token_count, op=dist.ReduceOp.SUM) + dist.all_reduce(val_byte_count, op=dist.ReduceOp.SUM) + + model.train() + return _loss_bpb(val_loss_sum, val_token_count, val_byte_count) + + +def eval_val_sliding( + h: Hyperparameters, + device: torch.device, + val_data: ValidationData, + base_model: nn.Module, + batch_seqs: int = 32 +) -> tuple[float, float]: + base_model.eval() + logits_fn = torch.compile(base_model.forward_logits, dynamic=False, fullgraph=True) + + seq_len = h.eval_seq_len + context_size = seq_len - h.eval_stride + total_tokens = val_data.val_tokens.numel() - 1 + + window_starts = [ws for ws in range(0, total_tokens, h.eval_stride) + if ws + context_size < total_tokens] + + total_windows = len(window_starts) + my_s = (total_windows * h.rank) // h.world_size + my_e = (total_windows * (h.rank + 1)) // h.world_size + my_windows = window_starts[my_s:my_e] + + loss_sum = torch.zeros((), device=device, dtype=torch.float64) + token_count = torch.zeros((), device=device, dtype=torch.float64) + byte_count = torch.zeros((), device=device, dtype=torch.float64) + + with torch.inference_mode(): + for bi in range(0, len(my_windows), batch_seqs): + batch_ws = my_windows[bi:bi + batch_seqs] + bsz = len(batch_ws) + + x_batch = torch.zeros(bsz, seq_len, dtype=torch.int64, device=device) + y_batch = torch.zeros(bsz, seq_len, dtype=torch.int64, device=device) + wlens: list[int] = [] + + for i, ws in enumerate(batch_ws): + we = min(ws + seq_len, total_tokens) + wlen = we - ws + wlens.append(wlen) + chunk = val_data.val_tokens[ws:we + 1].to(dtype=torch.int64, device=device) + x_batch[i, :wlen] = chunk[:-1] + y_batch[i, :wlen] = chunk[1:] + + with torch.autocast(device_type="cuda", dtype=torch.bfloat16): + logits = logits_fn(x_batch) + + nll = F.cross_entropy( + logits.reshape(-1, logits.size(-1)).float(), + y_batch.reshape(-1), + reduction="none", + ).reshape(bsz, seq_len) + + for i, ws in enumerate(batch_ws): + wlen = wlens[i] + s = 0 if ws == 0 else context_size + scored_nll = nll[i, s:wlen].to(torch.float64) + loss_sum += scored_nll.sum() + token_count += float(wlen - s) + tgt = y_batch[i, s:wlen] + prev = x_batch[i, s:wlen] + tb = val_data.base_bytes_lut[tgt].to(torch.float64) + tb += (val_data.has_leading_space_lut[tgt] & + ~val_data.is_boundary_token_lut[prev]).to(torch.float64) + byte_count += tb.sum() + + if dist.is_available() and dist.is_initialized(): + dist.all_reduce(loss_sum, op=dist.ReduceOp.SUM) + dist.all_reduce(token_count, op=dist.ReduceOp.SUM) + dist.all_reduce(byte_count, op=dist.ReduceOp.SUM) + + base_model.train() + return _loss_bpb(loss_sum, token_count, byte_count) + + +def timed_eval(label: str, fn, *args, **kwargs) -> tuple[float, float]: + torch.cuda.synchronize() + t0 = time.perf_counter() + val_loss, val_bpb = fn(*args, **kwargs) + torch.cuda.synchronize() + elapsed_ms = 1000.0 * (time.perf_counter() - t0) + log(f"{label} val_loss:{val_loss:.8f} val_bpb:{val_bpb:.8f} eval_time:{elapsed_ms:.0f}ms") + return val_loss, val_bpb + + +# ----------------------------- +# Training +# ----------------------------- + +def train_model(h: Hyperparameters, device: torch.device, val_data: ValidationData): + # Set up model + base_model = GPT(h).to(device).bfloat16() + restore_fp32_params(base_model) + compiled_model = torch.compile(base_model, dynamic=False, fullgraph=True) + if h.distributed: + model = DDP(compiled_model, device_ids=[h.local_rank], broadcast_buffers=False) + else: + model = compiled_model + log(f"model_params:{sum(p.numel() for p in base_model.parameters())}") + + # Set up optimizer and load train data + optimizers = Optimizers(h, base_model) + train_loader = ShuffledSequenceLoader(h, device) + + # Helper functions for training + max_wallclock_ms = 1000.0 * h.max_wallclock_seconds if h.max_wallclock_seconds > 0 else None + if max_wallclock_ms is not None: + max_wallclock_ms -= h.gptq_reserve_seconds * 1000.0 + log(f"gptq:reserving {h.gptq_reserve_seconds:.0f}s, effective={max_wallclock_ms:.0f}ms") + + def training_frac(step: int, elapsed_ms: float) -> float: + if max_wallclock_ms is None: + return step / max(h.iterations, 1) + return elapsed_ms / max(max_wallclock_ms, 1e-9) + + def lr_mul(frac: float) -> float: + if h.warmdown_frac <= 0: + return 1.0 + if frac >= 1.0 - h.warmdown_frac: + return max((1.0 - frac) / h.warmdown_frac, h.min_lr) + return 1.0 + + def step_fn(step, lr_scale): + optimizers.zero_grad_all() + train_loss = torch.zeros((), device=device) + for micro_step in range(h.grad_accum_steps): + if h.distributed: + model.require_backward_grad_sync = micro_step == h.grad_accum_steps - 1 + x, y = train_loader.next_batch(h.train_batch_tokens, h.grad_accum_steps) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + loss = model(x, y) + train_loss += loss.detach() + (loss / h.grad_accum_steps).backward() + train_loss /= h.grad_accum_steps + + frac = min(step / h.muon_momentum_warmup_steps, 1.0) if h.muon_momentum_warmup_steps > 0 else 1.0 + muon_momentum = (1 - frac) * h.muon_momentum_warmup_start + frac * h.muon_momentum + for group in optimizers.optimizer_muon.param_groups: + group["momentum"] = muon_momentum + + for opt in optimizers: + for group in opt.param_groups: + group["lr"] = group["base_lr"] * lr_scale + + if h.grad_clip_norm > 0: + torch.nn.utils.clip_grad_norm_(base_model.parameters(), h.grad_clip_norm) + + optimizers.step() + return train_loss + + # Model warmup + if h.warmup_steps > 0: + initial_model_state = {name: tensor.detach().cpu().clone() + for name, tensor in base_model.state_dict().items()} + initial_optimizer_states = [copy.deepcopy(opt.state_dict()) for opt in optimizers] + model.train() + for warmup_step in range(h.warmup_steps): + step_fn(warmup_step, 1.0) + if warmup_step <= 5 or (warmup_step + 1) % 10 == 0 or warmup_step + 1 == h.warmup_steps: + log(f"warmup_step: {warmup_step + 1}/{h.warmup_steps}") + if h.num_loops > 0: + base_model.looping_active = True + log(f"loop_warmup:enabled encoder:{base_model.encoder_indices} decoder:{base_model.decoder_indices}") + for warmup_step in range(h.warmup_steps): + step_fn(warmup_step, 1.0) + if warmup_step <= 5 or (warmup_step + 1) % 10 == 0 or warmup_step + 1 == h.warmup_steps: + log(f"loop_warmup_step: {warmup_step + 1}/{h.warmup_steps}") + base_model.looping_active = False + base_model.load_state_dict(initial_model_state, strict=True) + for opt, state in zip(optimizers, initial_optimizer_states, strict=True): + opt.load_state_dict(state) + optimizers.zero_grad_all() + if h.distributed: + model.require_backward_grad_sync = True + train_loader = ShuffledSequenceLoader(h, device) + + # Training loop + ema_state = {name: t.detach().float().clone() for name, t in base_model.state_dict().items()} + ema_decay = h.ema_decay + + training_time_ms = 0.0 + stop_after_step: int | None = None + torch.cuda.synchronize() + t0 = time.perf_counter() + + step = 0 + while True: + last_step = step == h.iterations or (stop_after_step is not None and step >= stop_after_step) + + should_validate = last_step or (h.val_loss_every > 0 and step % h.val_loss_every == 0) + if should_validate: + torch.cuda.synchronize() + training_time_ms += 1000.0 * (time.perf_counter() - t0) + val_loss, val_bpb = eval_val(h, device, val_data, model) + log(f"{step}/{h.iterations} val_loss: {val_loss:.4f} val_bpb: {val_bpb:.4f}") + torch.cuda.synchronize() + t0 = time.perf_counter() + + if last_step: + if stop_after_step is not None and step < h.iterations: + log( + f"stopping_early: wallclock_cap train_time: {training_time_ms:.0f}ms " + f"step: {step}/{h.iterations}" + ) + break + + elapsed_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) + frac = training_frac(step, elapsed_ms) + scale = lr_mul(frac) + if h.num_loops > 0 and not base_model.looping_active and frac >= h.enable_looping_at: + base_model.looping_active = True + log(f"layer_loop:enabled step:{step} frac:{frac:.3f} encoder:{base_model.encoder_indices} decoder:{base_model.decoder_indices}") + train_loss = step_fn(step, scale) + + with torch.no_grad(): + for name, t in base_model.state_dict().items(): + ema_state[name].mul_(ema_decay).add_(t.detach().float(), alpha=1.0 - ema_decay) + + step += 1 + approx_training_time_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) + + should_log_train = ( + h.train_log_every > 0 + and (step <= 5 or step % h.train_log_every == 0 or stop_after_step is not None) + ) + if should_log_train: + tok_per_sec = step * h.train_batch_tokens / (approx_training_time_ms / 1000.0) + log( + f"{step}/{h.iterations} train_loss: {train_loss.item():.4f} " + f"train_time: {approx_training_time_ms / 60000:.1f}m tok/s: {tok_per_sec:.0f}" + ) + + reached_cap = max_wallclock_ms is not None and approx_training_time_ms >= max_wallclock_ms + if h.distributed and max_wallclock_ms is not None: + reached_cap_tensor = torch.tensor(int(reached_cap), device=device) + dist.all_reduce(reached_cap_tensor, op=dist.ReduceOp.MAX) + reached_cap = bool(reached_cap_tensor.item()) + if stop_after_step is None and reached_cap: + stop_after_step = step + + log( + f"peak memory allocated: {torch.cuda.max_memory_allocated() // 1024 // 1024} MiB " + f"reserved: {torch.cuda.max_memory_reserved() // 1024 // 1024} MiB" + ) + + # Weight averaging + log("ema:applying EMA weights") + current_state = base_model.state_dict() + avg_state = {name: t.to(dtype=current_state[name].dtype) for name, t in ema_state.items()} + base_model.load_state_dict(avg_state, strict=True) + + return base_model, compiled_model + + +def train_and_eval(h: Hyperparameters, device: torch.device) -> None: + random.seed(h.seed) + np.random.seed(h.seed) + torch.manual_seed(h.seed) + torch.cuda.manual_seed_all(h.seed) + + val_data = ValidationData(h, device) + log(f"train_shards: {len(list(Path(h.datasets_dir).resolve().glob('fineweb_train_*.bin')))}") + log(f"val_tokens: {val_data.val_tokens.numel() - 1}") + + base_model, compiled_model = train_model(h, device, val_data) + torch._dynamo.reset() + timed_eval("pre-quantization post-ema", eval_val, h, device, val_data, compiled_model) + + serialize(h, base_model, Path(__file__).read_text(encoding="utf-8")) + if h.distributed: + dist.barrier() + eval_model = deserialize(h, device) + if h.num_loops > 0: + eval_model.looping_active = True + + compiled_model = torch.compile(eval_model, dynamic=False, fullgraph=True) + timed_eval("quantized", eval_val, h, device, val_data, compiled_model) + if h.sliding_window_enabled: + timed_eval("quantized_sliding_window", eval_val_sliding, h, device, val_data, eval_model) + + +def main(): + world_size = int(os.environ.get("WORLD_SIZE", "1")) + local_rank = int(os.environ.get("LOCAL_RANK", "0")) + distributed = "RANK" in os.environ and "WORLD_SIZE" in os.environ + + if not torch.cuda.is_available(): + raise RuntimeError("CUDA is required") + if world_size <= 0: + raise ValueError(f"WORLD_SIZE must be positive, got {world_size}") + if 8 % world_size != 0: + raise ValueError(f"WORLD_SIZE={world_size} must divide 8 so grad_accum_steps stays integral") + + device = torch.device("cuda", local_rank) + torch.cuda.set_device(device) + if distributed: + dist.init_process_group(backend="nccl") + dist.barrier() + + torch.backends.cuda.matmul.allow_tf32 = True + torch.backends.cudnn.allow_tf32 = True + torch.set_float32_matmul_precision("high") + from torch.backends.cuda import enable_cudnn_sdp, enable_flash_sdp, enable_math_sdp, enable_mem_efficient_sdp + + enable_cudnn_sdp(False) + enable_flash_sdp(True) + enable_mem_efficient_sdp(False) + enable_math_sdp(False) + torch._dynamo.config.optimize_ddp = False + + h = Hyperparameters() + set_logging_hparams(h) + if h.is_main_process: + os.makedirs("logs", exist_ok=True) + log(100 * "=", console=False) + log("Hyperparameters:", console=True) + for k, v in sorted(vars(type(h)).items()): + if not k.startswith("_"): + log(f" {k}: {v}", console=True) + log("=" * 100, console=False) + log(f"Running Python {sys.version}", console=False) + log(f"Running PyTorch {torch.__version__}", console=False) + log( + subprocess.run(["nvidia-smi"], stdout=subprocess.PIPE, stderr=subprocess.PIPE, + text=True, check=False).stdout, + console=False, + ) + log("=" * 100, console=False) + + train_and_eval(h, device) + + if distributed: + dist.destroy_process_group() + + +if __name__ == "__main__": + main() diff --git a/records/track_10min_16mb/2026-04-06_TripleLoop_FusedKernels_Ngram/train_gpt.py.lzma b/records/track_10min_16mb/2026-04-06_TripleLoop_FusedKernels_Ngram/train_gpt.py.lzma new file mode 100644 index 0000000000..8310704f7b Binary files /dev/null and b/records/track_10min_16mb/2026-04-06_TripleLoop_FusedKernels_Ngram/train_gpt.py.lzma differ diff --git a/records/track_10min_16mb/2026-04-06_TripleLoop_FusedKernels_Ngram/train_seed1.log b/records/track_10min_16mb/2026-04-06_TripleLoop_FusedKernels_Ngram/train_seed1.log new file mode 100644 index 0000000000..092f8edcf8 --- /dev/null +++ b/records/track_10min_16mb/2026-04-06_TripleLoop_FusedKernels_Ngram/train_seed1.log @@ -0,0 +1,143 @@ +W0406 14:06:38.218000 3545784 torch/distributed/run.py:803] +W0406 14:06:38.218000 3545784 torch/distributed/run.py:803] ***************************************** +W0406 14:06:38.218000 3545784 torch/distributed/run.py:803] Setting OMP_NUM_THREADS environment variable for each process to be 1 in default, to avoid your system being overloaded, please further tune the variable for optimal performance in your application as needed. +W0406 14:06:38.218000 3545784 torch/distributed/run.py:803] ***************************************** +/home/ubuntu/.local/lib/python3.10/site-packages/torch/distributed/distributed_c10d.py:4876: UserWarning: barrier(): using the device under current context. You can specify `device_id` in `init_process_group` to mute this warning. + warnings.warn( # warn only once +[rank0]:[W406 14:06:48.951750310 ProcessGroupNCCL.cpp:5072] Guessing device ID based on global rank. This can cause a hang if rank to GPU mapping is heterogeneous. You can specify device_id in init_process_group() +Hyperparameters: + adam_eps: 1e-08 + adam_wd: 0.02 + beta1: 0.9 + beta2: 0.95 + compressor: brotli + data_dir: ./data/ + datasets_dir: ./data/datasets/fineweb10B_sp8192 + distributed: True + ema_decay: 0.997 + embed_bits: 8 + embed_clip_sigmas: 20.0 + embed_lr: 0.6 + embed_wd: 0.085 + embedding_dim: 512 + enable_looping_at: 0.35 + eval_seq_len: 2048 + eval_stride: 64 + gptq_calibration_batches: 64 + gptq_reserve_seconds: 12.0 + grad_accum_steps: 1 + grad_clip_norm: 0.3 + head_lr: 0.008 + hessian_clip_lambda: 0.0 + is_main_process: True + iterations: 20000 + ln_scale: True + local_rank: 0 + logfile: logs/parfused_s1.txt + logit_softcap: 30.0 + loop_end: 5 + loop_start: 4 + matrix_bits: 6 + matrix_clip_sigmas: 12.85 + matrix_lr: 0.02 + max_wallclock_seconds: 600.0 + min_lr: 0.0 + mlp_mult: 4.0 + model_dim: 512 + model_path: final_model.pt + muon_backend_steps: 5 + muon_beta2: 0.95 + muon_momentum: 0.99 + muon_momentum_warmup_start: 0.92 + muon_momentum_warmup_steps: 1500 + muon_row_normalize: True + muon_wd: 0.085 + num_heads: 8 + num_kv_heads: 4 + num_layers: 11 + num_loops: 3 + parallel_residual_start: 7 + qk_gain_init: 4.0 + quantized_model_path: final_model.int6.ptz + rank: 0 + rope_base: 10000.0 + rope_dims: 16 + rope_train_seq_len: 2048 + run_id: parfused_s1 + scalar_lr: 0.02 + seed: 1 + skip_gates_enabled: True + sliding_window_enabled: True + tie_embeddings: True + tied_embed_init_std: 0.005 + tied_embed_lr: 0.03 + tokenizer_path: ./data/tokenizers/fineweb_8192_bpe.model + train_batch_tokens: 786432 + train_files: ./data/datasets/fineweb10B_sp8192/fineweb_train_*.bin + train_log_every: 500 + train_seq_len: 2048 + val_batch_tokens: 524288 + val_files: ./data/datasets/fineweb10B_sp8192/fineweb_val_*.bin + val_loss_every: 4000 + vocab_size: 8192 + warmdown_frac: 0.667 + warmup_steps: 20 + world_size: 8 + xsa_last_n: 11 +train_shards: 128 +val_tokens: 40540160 +model_params:35944536 +gptq:reserving 12s, effective=588000ms +warmup_step: 1/20 +warmup_step: 2/20 +warmup_step: 3/20 +warmup_step: 4/20 +warmup_step: 5/20 +warmup_step: 6/20 +warmup_step: 10/20 +warmup_step: 20/20 +loop_warmup:enabled encoder:[0, 1, 2, 3, 4, 5, 4, 5] decoder:[4, 5, 4, 5, 6, 7, 8, 9, 10] +loop_warmup_step: 1/20 +loop_warmup_step: 2/20 +loop_warmup_step: 3/20 +loop_warmup_step: 4/20 +loop_warmup_step: 5/20 +loop_warmup_step: 6/20 +loop_warmup_step: 10/20 +loop_warmup_step: 20/20 +0/20000 val_loss: 9.0092 val_bpb: 3.4878 +1/20000 train_loss: 9.0111 train_time: 0.0m tok/s: 8640157 +2/20000 train_loss: 12.3640 train_time: 0.0m tok/s: 8519875 +3/20000 train_loss: 11.1567 train_time: 0.0m tok/s: 8406940 +4/20000 train_loss: 9.4647 train_time: 0.0m tok/s: 8353453 +5/20000 train_loss: 8.3465 train_time: 0.0m tok/s: 8331135 +500/20000 train_loss: 3.3318 train_time: 0.8m tok/s: 8064292 +1000/20000 train_loss: 3.1838 train_time: 1.6m tok/s: 8061823 +1500/20000 train_loss: 3.0975 train_time: 2.5m tok/s: 7993400 +2000/20000 train_loss: 3.0737 train_time: 3.3m tok/s: 8007779 +layer_loop:enabled step:2096 frac:0.350 encoder:[0, 1, 2, 3, 4, 5, 4, 5] decoder:[4, 5, 4, 5, 6, 7, 8, 9, 10] +2500/20000 train_loss: 3.0779 train_time: 4.4m tok/s: 7453592 +3000/20000 train_loss: 2.9537 train_time: 5.6m tok/s: 7032450 +3500/20000 train_loss: 2.9814 train_time: 6.8m tok/s: 6755298 +4000/20000 train_loss: 2.9342 train_time: 8.0m tok/s: 6559204 +4000/20000 val_loss: 2.8912 val_bpb: 1.1193 +4500/20000 train_loss: 2.7820 train_time: 9.2m tok/s: 6417812 +4754/20000 val_loss: 2.8123 val_bpb: 1.0887 +stopping_early: wallclock_cap train_time: 588120ms step: 4754/20000 +peak memory allocated: 39046 MiB reserved: 39072 MiB +ema:applying EMA weights +pre-quantization post-ema val_loss:2.80988971 val_bpb:1.08779587 eval_time:6115ms +Serialized model: 135431033 bytes +Code size: 64137 bytes +GPTQ:collecting Hessians from calibration data... +GPTQ:collected 67 Hessians in 12.9s +Quantized weights: + gptq (int6): blocks.attn.c_k.weight, blocks.attn.c_q.weight, blocks.attn.c_v.weight, blocks.attn.proj.weight, blocks.mlp.fc.weight, blocks.mlp.proj.weight + gptq (int8): tok_emb.weight + passthrough (float16): blocks.attn.q_gain, blocks.attn_scale, blocks.mlp_scale, blocks.resid_mix, skip_gates, skip_weights +Serialized model quantized+brotli: 15978345 bytes +Total submission size quantized+brotli: 16042482 bytes +/home/ubuntu/.local/lib/python3.10/site-packages/torch/distributed/distributed_c10d.py:4876: UserWarning: barrier(): using the device under current context. You can specify `device_id` in `init_process_group` to mute this warning. + warnings.warn( # warn only once +quantized val_loss:2.84114243 val_bpb:1.09989477 eval_time:8046ms +quantized_sliding_window val_loss:2.79767095 val_bpb:1.08306561 eval_time:91250ms diff --git a/records/track_10min_16mb/2026-04-06_TripleLoop_FusedKernels_Ngram/train_seed1234.log b/records/track_10min_16mb/2026-04-06_TripleLoop_FusedKernels_Ngram/train_seed1234.log new file mode 100644 index 0000000000..05ab6f21a7 --- /dev/null +++ b/records/track_10min_16mb/2026-04-06_TripleLoop_FusedKernels_Ngram/train_seed1234.log @@ -0,0 +1,143 @@ +W0406 13:49:07.980000 3539219 torch/distributed/run.py:803] +W0406 13:49:07.980000 3539219 torch/distributed/run.py:803] ***************************************** +W0406 13:49:07.980000 3539219 torch/distributed/run.py:803] Setting OMP_NUM_THREADS environment variable for each process to be 1 in default, to avoid your system being overloaded, please further tune the variable for optimal performance in your application as needed. +W0406 13:49:07.980000 3539219 torch/distributed/run.py:803] ***************************************** +/home/ubuntu/.local/lib/python3.10/site-packages/torch/distributed/distributed_c10d.py:4876: UserWarning: barrier(): using the device under current context. You can specify `device_id` in `init_process_group` to mute this warning. + warnings.warn( # warn only once +[rank0]:[W406 13:49:19.960407548 ProcessGroupNCCL.cpp:5072] Guessing device ID based on global rank. This can cause a hang if rank to GPU mapping is heterogeneous. You can specify device_id in init_process_group() +Hyperparameters: + adam_eps: 1e-08 + adam_wd: 0.02 + beta1: 0.9 + beta2: 0.95 + compressor: brotli + data_dir: ./data/ + datasets_dir: ./data/datasets/fineweb10B_sp8192 + distributed: True + ema_decay: 0.997 + embed_bits: 8 + embed_clip_sigmas: 20.0 + embed_lr: 0.6 + embed_wd: 0.085 + embedding_dim: 512 + enable_looping_at: 0.35 + eval_seq_len: 2048 + eval_stride: 64 + gptq_calibration_batches: 64 + gptq_reserve_seconds: 12.0 + grad_accum_steps: 1 + grad_clip_norm: 0.3 + head_lr: 0.008 + hessian_clip_lambda: 0.0 + is_main_process: True + iterations: 20000 + ln_scale: True + local_rank: 0 + logfile: logs/parallel_only_s1234.txt + logit_softcap: 30.0 + loop_end: 5 + loop_start: 4 + matrix_bits: 6 + matrix_clip_sigmas: 12.85 + matrix_lr: 0.02 + max_wallclock_seconds: 600.0 + min_lr: 0.0 + mlp_mult: 4.0 + model_dim: 512 + model_path: final_model.pt + muon_backend_steps: 5 + muon_beta2: 0.95 + muon_momentum: 0.99 + muon_momentum_warmup_start: 0.92 + muon_momentum_warmup_steps: 1500 + muon_row_normalize: True + muon_wd: 0.085 + num_heads: 8 + num_kv_heads: 4 + num_layers: 11 + num_loops: 3 + parallel_residual_start: 7 + qk_gain_init: 4.0 + quantized_model_path: final_model.int6.ptz + rank: 0 + rope_base: 10000.0 + rope_dims: 16 + rope_train_seq_len: 2048 + run_id: parallel_only_s1234 + scalar_lr: 0.02 + seed: 1234 + skip_gates_enabled: True + sliding_window_enabled: True + tie_embeddings: True + tied_embed_init_std: 0.005 + tied_embed_lr: 0.03 + tokenizer_path: ./data/tokenizers/fineweb_8192_bpe.model + train_batch_tokens: 786432 + train_files: ./data/datasets/fineweb10B_sp8192/fineweb_train_*.bin + train_log_every: 500 + train_seq_len: 2048 + val_batch_tokens: 524288 + val_files: ./data/datasets/fineweb10B_sp8192/fineweb_val_*.bin + val_loss_every: 4000 + vocab_size: 8192 + warmdown_frac: 0.667 + warmup_steps: 20 + world_size: 8 + xsa_last_n: 11 +train_shards: 128 +val_tokens: 40540160 +model_params:35944536 +gptq:reserving 12s, effective=588000ms +warmup_step: 1/20 +warmup_step: 2/20 +warmup_step: 3/20 +warmup_step: 4/20 +warmup_step: 5/20 +warmup_step: 6/20 +warmup_step: 10/20 +warmup_step: 20/20 +loop_warmup:enabled encoder:[0, 1, 2, 3, 4, 5, 4, 5] decoder:[4, 5, 4, 5, 6, 7, 8, 9, 10] +loop_warmup_step: 1/20 +loop_warmup_step: 2/20 +loop_warmup_step: 3/20 +loop_warmup_step: 4/20 +loop_warmup_step: 5/20 +loop_warmup_step: 6/20 +loop_warmup_step: 10/20 +loop_warmup_step: 20/20 +0/20000 val_loss: 9.0072 val_bpb: 3.4870 +1/20000 train_loss: 9.0096 train_time: 0.0m tok/s: 8663979 +2/20000 train_loss: 12.3212 train_time: 0.0m tok/s: 8511997 +3/20000 train_loss: 11.1039 train_time: 0.0m tok/s: 8410857 +4/20000 train_loss: 9.4456 train_time: 0.0m tok/s: 8366062 +5/20000 train_loss: 8.3493 train_time: 0.0m tok/s: 8329807 +500/20000 train_loss: 3.3331 train_time: 0.8m tok/s: 8067478 +1000/20000 train_loss: 3.1876 train_time: 1.6m tok/s: 8067680 +1500/20000 train_loss: 3.0948 train_time: 2.5m tok/s: 7988322 +2000/20000 train_loss: 3.0707 train_time: 3.3m tok/s: 7994987 +layer_loop:enabled step:2092 frac:0.350 encoder:[0, 1, 2, 3, 4, 5, 4, 5] decoder:[4, 5, 4, 5, 6, 7, 8, 9, 10] +2500/20000 train_loss: 3.0749 train_time: 4.4m tok/s: 7432611 +3000/20000 train_loss: 2.9526 train_time: 5.6m tok/s: 7016484 +3500/20000 train_loss: 2.9828 train_time: 6.8m tok/s: 6740255 +4000/20000 train_loss: 2.9319 train_time: 8.0m tok/s: 6546729 +4000/20000 val_loss: 2.8898 val_bpb: 1.1187 +4500/20000 train_loss: 2.7801 train_time: 9.2m tok/s: 6407050 +4748/20000 val_loss: 2.8119 val_bpb: 1.0886 +stopping_early: wallclock_cap train_time: 588134ms step: 4748/20000 +peak memory allocated: 39046 MiB reserved: 39072 MiB +ema:applying EMA weights +pre-quantization post-ema val_loss:2.80936815 val_bpb:1.08759396 eval_time:5745ms +Serialized model: 135431033 bytes +Code size: 64137 bytes +GPTQ:collecting Hessians from calibration data... +GPTQ:collected 67 Hessians in 12.8s +Quantized weights: + gptq (int6): blocks.attn.c_k.weight, blocks.attn.c_q.weight, blocks.attn.c_v.weight, blocks.attn.proj.weight, blocks.mlp.fc.weight, blocks.mlp.proj.weight + gptq (int8): tok_emb.weight + passthrough (float16): blocks.attn.q_gain, blocks.attn_scale, blocks.mlp_scale, blocks.resid_mix, skip_gates, skip_weights +Serialized model quantized+brotli: 15973639 bytes +Total submission size quantized+brotli: 16037776 bytes +/home/ubuntu/.local/lib/python3.10/site-packages/torch/distributed/distributed_c10d.py:4876: UserWarning: barrier(): using the device under current context. You can specify `device_id` in `init_process_group` to mute this warning. + warnings.warn( # warn only once +quantized val_loss:2.83984087 val_bpb:1.09939090 eval_time:8275ms +quantized_sliding_window val_loss:2.79661076 val_bpb:1.08265518 eval_time:90966ms diff --git a/records/track_10min_16mb/2026-04-06_TripleLoop_FusedKernels_Ngram/train_seed1337.log b/records/track_10min_16mb/2026-04-06_TripleLoop_FusedKernels_Ngram/train_seed1337.log new file mode 100644 index 0000000000..f9fc49b26c --- /dev/null +++ b/records/track_10min_16mb/2026-04-06_TripleLoop_FusedKernels_Ngram/train_seed1337.log @@ -0,0 +1,143 @@ +W0406 14:40:24.916000 3558679 torch/distributed/run.py:803] +W0406 14:40:24.916000 3558679 torch/distributed/run.py:803] ***************************************** +W0406 14:40:24.916000 3558679 torch/distributed/run.py:803] Setting OMP_NUM_THREADS environment variable for each process to be 1 in default, to avoid your system being overloaded, please further tune the variable for optimal performance in your application as needed. +W0406 14:40:24.916000 3558679 torch/distributed/run.py:803] ***************************************** +/home/ubuntu/.local/lib/python3.10/site-packages/torch/distributed/distributed_c10d.py:4876: UserWarning: barrier(): using the device under current context. You can specify `device_id` in `init_process_group` to mute this warning. + warnings.warn( # warn only once +[rank0]:[W406 14:40:35.696068494 ProcessGroupNCCL.cpp:5072] Guessing device ID based on global rank. This can cause a hang if rank to GPU mapping is heterogeneous. You can specify device_id in init_process_group() +Hyperparameters: + adam_eps: 1e-08 + adam_wd: 0.02 + beta1: 0.9 + beta2: 0.95 + compressor: brotli + data_dir: ./data/ + datasets_dir: ./data/datasets/fineweb10B_sp8192 + distributed: True + ema_decay: 0.997 + embed_bits: 8 + embed_clip_sigmas: 20.0 + embed_lr: 0.6 + embed_wd: 0.085 + embedding_dim: 512 + enable_looping_at: 0.35 + eval_seq_len: 2048 + eval_stride: 64 + gptq_calibration_batches: 64 + gptq_reserve_seconds: 12.0 + grad_accum_steps: 1 + grad_clip_norm: 0.3 + head_lr: 0.008 + hessian_clip_lambda: 0.0 + is_main_process: True + iterations: 20000 + ln_scale: True + local_rank: 0 + logfile: logs/parfused_s1337.txt + logit_softcap: 30.0 + loop_end: 5 + loop_start: 4 + matrix_bits: 6 + matrix_clip_sigmas: 12.85 + matrix_lr: 0.02 + max_wallclock_seconds: 600.0 + min_lr: 0.0 + mlp_mult: 4.0 + model_dim: 512 + model_path: final_model.pt + muon_backend_steps: 5 + muon_beta2: 0.95 + muon_momentum: 0.99 + muon_momentum_warmup_start: 0.92 + muon_momentum_warmup_steps: 1500 + muon_row_normalize: True + muon_wd: 0.085 + num_heads: 8 + num_kv_heads: 4 + num_layers: 11 + num_loops: 3 + parallel_residual_start: 7 + qk_gain_init: 4.0 + quantized_model_path: final_model.int6.ptz + rank: 0 + rope_base: 10000.0 + rope_dims: 16 + rope_train_seq_len: 2048 + run_id: parfused_s1337 + scalar_lr: 0.02 + seed: 1337 + skip_gates_enabled: True + sliding_window_enabled: True + tie_embeddings: True + tied_embed_init_std: 0.005 + tied_embed_lr: 0.03 + tokenizer_path: ./data/tokenizers/fineweb_8192_bpe.model + train_batch_tokens: 786432 + train_files: ./data/datasets/fineweb10B_sp8192/fineweb_train_*.bin + train_log_every: 500 + train_seq_len: 2048 + val_batch_tokens: 524288 + val_files: ./data/datasets/fineweb10B_sp8192/fineweb_val_*.bin + val_loss_every: 4000 + vocab_size: 8192 + warmdown_frac: 0.667 + warmup_steps: 20 + world_size: 8 + xsa_last_n: 11 +train_shards: 128 +val_tokens: 40540160 +model_params:35944536 +gptq:reserving 12s, effective=588000ms +warmup_step: 1/20 +warmup_step: 2/20 +warmup_step: 3/20 +warmup_step: 4/20 +warmup_step: 5/20 +warmup_step: 6/20 +warmup_step: 10/20 +warmup_step: 20/20 +loop_warmup:enabled encoder:[0, 1, 2, 3, 4, 5, 4, 5] decoder:[4, 5, 4, 5, 6, 7, 8, 9, 10] +loop_warmup_step: 1/20 +loop_warmup_step: 2/20 +loop_warmup_step: 3/20 +loop_warmup_step: 4/20 +loop_warmup_step: 5/20 +loop_warmup_step: 6/20 +loop_warmup_step: 10/20 +loop_warmup_step: 20/20 +0/20000 val_loss: 9.0047 val_bpb: 3.4860 +1/20000 train_loss: 9.0067 train_time: 0.0m tok/s: 8667970 +2/20000 train_loss: 12.2894 train_time: 0.0m tok/s: 8514181 +3/20000 train_loss: 11.0736 train_time: 0.0m tok/s: 8421277 +4/20000 train_loss: 9.3856 train_time: 0.0m tok/s: 8371013 +5/20000 train_loss: 8.3016 train_time: 0.0m tok/s: 8340507 +500/20000 train_loss: 3.3417 train_time: 0.8m tok/s: 8069784 +1000/20000 train_loss: 3.1894 train_time: 1.6m tok/s: 8068375 +1500/20000 train_loss: 3.0974 train_time: 2.5m tok/s: 8003425 +2000/20000 train_loss: 3.0734 train_time: 3.3m tok/s: 8017414 +layer_loop:enabled step:2099 frac:0.350 encoder:[0, 1, 2, 3, 4, 5, 4, 5] decoder:[4, 5, 4, 5, 6, 7, 8, 9, 10] +2500/20000 train_loss: 3.0809 train_time: 4.4m tok/s: 7456856 +3000/20000 train_loss: 2.9570 train_time: 5.6m tok/s: 7026438 +3500/20000 train_loss: 2.9810 train_time: 6.8m tok/s: 6747966 +4000/20000 train_loss: 2.9315 train_time: 8.0m tok/s: 6555361 +4000/20000 val_loss: 2.8925 val_bpb: 1.1198 +4500/20000 train_loss: 2.7863 train_time: 9.2m tok/s: 6418337 +4756/20000 val_loss: 2.8137 val_bpb: 1.0893 +stopping_early: wallclock_cap train_time: 588116ms step: 4756/20000 +peak memory allocated: 39046 MiB reserved: 39072 MiB +ema:applying EMA weights +pre-quantization post-ema val_loss:2.81124738 val_bpb:1.08832147 eval_time:6313ms +Serialized model: 135431033 bytes +Code size: 64137 bytes +GPTQ:collecting Hessians from calibration data... +GPTQ:collected 67 Hessians in 12.8s +Quantized weights: + gptq (int6): blocks.attn.c_k.weight, blocks.attn.c_q.weight, blocks.attn.c_v.weight, blocks.attn.proj.weight, blocks.mlp.fc.weight, blocks.mlp.proj.weight + gptq (int8): tok_emb.weight + passthrough (float16): blocks.attn.q_gain, blocks.attn_scale, blocks.mlp_scale, blocks.resid_mix, skip_gates, skip_weights +Serialized model quantized+brotli: 15974187 bytes +Total submission size quantized+brotli: 16038324 bytes +/home/ubuntu/.local/lib/python3.10/site-packages/torch/distributed/distributed_c10d.py:4876: UserWarning: barrier(): using the device under current context. You can specify `device_id` in `init_process_group` to mute this warning. + warnings.warn( # warn only once +quantized val_loss:2.84074380 val_bpb:1.09974045 eval_time:7924ms +quantized_sliding_window val_loss:2.79756522 val_bpb:1.08302468 eval_time:91238ms diff --git a/records/track_10min_16mb/2026-04-06_TripleLoop_FusedKernels_Ngram/train_seed2025.log b/records/track_10min_16mb/2026-04-06_TripleLoop_FusedKernels_Ngram/train_seed2025.log new file mode 100644 index 0000000000..3e893952e1 --- /dev/null +++ b/records/track_10min_16mb/2026-04-06_TripleLoop_FusedKernels_Ngram/train_seed2025.log @@ -0,0 +1,143 @@ +W0406 14:57:12.507000 3565043 torch/distributed/run.py:803] +W0406 14:57:12.507000 3565043 torch/distributed/run.py:803] ***************************************** +W0406 14:57:12.507000 3565043 torch/distributed/run.py:803] Setting OMP_NUM_THREADS environment variable for each process to be 1 in default, to avoid your system being overloaded, please further tune the variable for optimal performance in your application as needed. +W0406 14:57:12.507000 3565043 torch/distributed/run.py:803] ***************************************** +/home/ubuntu/.local/lib/python3.10/site-packages/torch/distributed/distributed_c10d.py:4876: UserWarning: barrier(): using the device under current context. You can specify `device_id` in `init_process_group` to mute this warning. + warnings.warn( # warn only once +[rank0]:[W406 14:57:23.536840754 ProcessGroupNCCL.cpp:5072] Guessing device ID based on global rank. This can cause a hang if rank to GPU mapping is heterogeneous. You can specify device_id in init_process_group() +Hyperparameters: + adam_eps: 1e-08 + adam_wd: 0.02 + beta1: 0.9 + beta2: 0.95 + compressor: brotli + data_dir: ./data/ + datasets_dir: ./data/datasets/fineweb10B_sp8192 + distributed: True + ema_decay: 0.997 + embed_bits: 8 + embed_clip_sigmas: 20.0 + embed_lr: 0.6 + embed_wd: 0.085 + embedding_dim: 512 + enable_looping_at: 0.35 + eval_seq_len: 2048 + eval_stride: 64 + gptq_calibration_batches: 64 + gptq_reserve_seconds: 12.0 + grad_accum_steps: 1 + grad_clip_norm: 0.3 + head_lr: 0.008 + hessian_clip_lambda: 0.0 + is_main_process: True + iterations: 20000 + ln_scale: True + local_rank: 0 + logfile: logs/parfused_s2025.txt + logit_softcap: 30.0 + loop_end: 5 + loop_start: 4 + matrix_bits: 6 + matrix_clip_sigmas: 12.85 + matrix_lr: 0.02 + max_wallclock_seconds: 600.0 + min_lr: 0.0 + mlp_mult: 4.0 + model_dim: 512 + model_path: final_model.pt + muon_backend_steps: 5 + muon_beta2: 0.95 + muon_momentum: 0.99 + muon_momentum_warmup_start: 0.92 + muon_momentum_warmup_steps: 1500 + muon_row_normalize: True + muon_wd: 0.085 + num_heads: 8 + num_kv_heads: 4 + num_layers: 11 + num_loops: 3 + parallel_residual_start: 7 + qk_gain_init: 4.0 + quantized_model_path: final_model.int6.ptz + rank: 0 + rope_base: 10000.0 + rope_dims: 16 + rope_train_seq_len: 2048 + run_id: parfused_s2025 + scalar_lr: 0.02 + seed: 2025 + skip_gates_enabled: True + sliding_window_enabled: True + tie_embeddings: True + tied_embed_init_std: 0.005 + tied_embed_lr: 0.03 + tokenizer_path: ./data/tokenizers/fineweb_8192_bpe.model + train_batch_tokens: 786432 + train_files: ./data/datasets/fineweb10B_sp8192/fineweb_train_*.bin + train_log_every: 500 + train_seq_len: 2048 + val_batch_tokens: 524288 + val_files: ./data/datasets/fineweb10B_sp8192/fineweb_val_*.bin + val_loss_every: 4000 + vocab_size: 8192 + warmdown_frac: 0.667 + warmup_steps: 20 + world_size: 8 + xsa_last_n: 11 +train_shards: 128 +val_tokens: 40540160 +model_params:35944536 +gptq:reserving 12s, effective=588000ms +warmup_step: 1/20 +warmup_step: 2/20 +warmup_step: 3/20 +warmup_step: 4/20 +warmup_step: 5/20 +warmup_step: 6/20 +warmup_step: 10/20 +warmup_step: 20/20 +loop_warmup:enabled encoder:[0, 1, 2, 3, 4, 5, 4, 5] decoder:[4, 5, 4, 5, 6, 7, 8, 9, 10] +loop_warmup_step: 1/20 +loop_warmup_step: 2/20 +loop_warmup_step: 3/20 +loop_warmup_step: 4/20 +loop_warmup_step: 5/20 +loop_warmup_step: 6/20 +loop_warmup_step: 10/20 +loop_warmup_step: 20/20 +0/20000 val_loss: 9.0067 val_bpb: 3.4868 +1/20000 train_loss: 9.0086 train_time: 0.0m tok/s: 8510528 +2/20000 train_loss: 12.3461 train_time: 0.0m tok/s: 8391063 +3/20000 train_loss: 11.1462 train_time: 0.0m tok/s: 8343646 +4/20000 train_loss: 9.4409 train_time: 0.0m tok/s: 8319609 +5/20000 train_loss: 8.3510 train_time: 0.0m tok/s: 8305156 +500/20000 train_loss: 3.3357 train_time: 0.8m tok/s: 8062540 +1000/20000 train_loss: 3.1876 train_time: 1.6m tok/s: 8058208 +1500/20000 train_loss: 3.0948 train_time: 2.5m tok/s: 7990064 +2000/20000 train_loss: 3.0748 train_time: 3.3m tok/s: 7995457 +layer_loop:enabled step:2093 frac:0.350 encoder:[0, 1, 2, 3, 4, 5, 4, 5] decoder:[4, 5, 4, 5, 6, 7, 8, 9, 10] +2500/20000 train_loss: 3.0769 train_time: 4.4m tok/s: 7433657 +3000/20000 train_loss: 2.9577 train_time: 5.6m tok/s: 7015686 +3500/20000 train_loss: 2.9850 train_time: 6.8m tok/s: 6745541 +4000/20000 train_loss: 2.9320 train_time: 8.0m tok/s: 6557262 +4000/20000 val_loss: 2.8918 val_bpb: 1.1195 +4500/20000 train_loss: 2.7805 train_time: 9.2m tok/s: 6417793 +4755/20000 val_loss: 2.8130 val_bpb: 1.0890 +stopping_early: wallclock_cap train_time: 588072ms step: 4755/20000 +peak memory allocated: 39046 MiB reserved: 39072 MiB +ema:applying EMA weights +pre-quantization post-ema val_loss:2.81060863 val_bpb:1.08807419 eval_time:5601ms +Serialized model: 135431033 bytes +Code size: 64137 bytes +GPTQ:collecting Hessians from calibration data... +GPTQ:collected 67 Hessians in 12.9s +Quantized weights: + gptq (int6): blocks.attn.c_k.weight, blocks.attn.c_q.weight, blocks.attn.c_v.weight, blocks.attn.proj.weight, blocks.mlp.fc.weight, blocks.mlp.proj.weight + gptq (int8): tok_emb.weight + passthrough (float16): blocks.attn.q_gain, blocks.attn_scale, blocks.mlp_scale, blocks.resid_mix, skip_gates, skip_weights +Serialized model quantized+brotli: 15970317 bytes +Total submission size quantized+brotli: 16034454 bytes +/home/ubuntu/.local/lib/python3.10/site-packages/torch/distributed/distributed_c10d.py:4876: UserWarning: barrier(): using the device under current context. You can specify `device_id` in `init_process_group` to mute this warning. + warnings.warn( # warn only once +quantized val_loss:2.84043580 val_bpb:1.09962122 eval_time:8047ms +quantized_sliding_window val_loss:2.79722488 val_bpb:1.08289292 eval_time:91467ms diff --git a/records/track_10min_16mb/2026-04-06_TripleLoop_FusedKernels_Ngram/train_seed42.log b/records/track_10min_16mb/2026-04-06_TripleLoop_FusedKernels_Ngram/train_seed42.log new file mode 100644 index 0000000000..0ee4ba95f5 --- /dev/null +++ b/records/track_10min_16mb/2026-04-06_TripleLoop_FusedKernels_Ngram/train_seed42.log @@ -0,0 +1,143 @@ +W0406 14:23:30.258000 3552381 torch/distributed/run.py:803] +W0406 14:23:30.258000 3552381 torch/distributed/run.py:803] ***************************************** +W0406 14:23:30.258000 3552381 torch/distributed/run.py:803] Setting OMP_NUM_THREADS environment variable for each process to be 1 in default, to avoid your system being overloaded, please further tune the variable for optimal performance in your application as needed. +W0406 14:23:30.258000 3552381 torch/distributed/run.py:803] ***************************************** +/home/ubuntu/.local/lib/python3.10/site-packages/torch/distributed/distributed_c10d.py:4876: UserWarning: barrier(): using the device under current context. You can specify `device_id` in `init_process_group` to mute this warning. + warnings.warn( # warn only once +[rank0]:[W406 14:23:41.340347774 ProcessGroupNCCL.cpp:5072] Guessing device ID based on global rank. This can cause a hang if rank to GPU mapping is heterogeneous. You can specify device_id in init_process_group() +Hyperparameters: + adam_eps: 1e-08 + adam_wd: 0.02 + beta1: 0.9 + beta2: 0.95 + compressor: brotli + data_dir: ./data/ + datasets_dir: ./data/datasets/fineweb10B_sp8192 + distributed: True + ema_decay: 0.997 + embed_bits: 8 + embed_clip_sigmas: 20.0 + embed_lr: 0.6 + embed_wd: 0.085 + embedding_dim: 512 + enable_looping_at: 0.35 + eval_seq_len: 2048 + eval_stride: 64 + gptq_calibration_batches: 64 + gptq_reserve_seconds: 12.0 + grad_accum_steps: 1 + grad_clip_norm: 0.3 + head_lr: 0.008 + hessian_clip_lambda: 0.0 + is_main_process: True + iterations: 20000 + ln_scale: True + local_rank: 0 + logfile: logs/parfused_s42.txt + logit_softcap: 30.0 + loop_end: 5 + loop_start: 4 + matrix_bits: 6 + matrix_clip_sigmas: 12.85 + matrix_lr: 0.02 + max_wallclock_seconds: 600.0 + min_lr: 0.0 + mlp_mult: 4.0 + model_dim: 512 + model_path: final_model.pt + muon_backend_steps: 5 + muon_beta2: 0.95 + muon_momentum: 0.99 + muon_momentum_warmup_start: 0.92 + muon_momentum_warmup_steps: 1500 + muon_row_normalize: True + muon_wd: 0.085 + num_heads: 8 + num_kv_heads: 4 + num_layers: 11 + num_loops: 3 + parallel_residual_start: 7 + qk_gain_init: 4.0 + quantized_model_path: final_model.int6.ptz + rank: 0 + rope_base: 10000.0 + rope_dims: 16 + rope_train_seq_len: 2048 + run_id: parfused_s42 + scalar_lr: 0.02 + seed: 42 + skip_gates_enabled: True + sliding_window_enabled: True + tie_embeddings: True + tied_embed_init_std: 0.005 + tied_embed_lr: 0.03 + tokenizer_path: ./data/tokenizers/fineweb_8192_bpe.model + train_batch_tokens: 786432 + train_files: ./data/datasets/fineweb10B_sp8192/fineweb_train_*.bin + train_log_every: 500 + train_seq_len: 2048 + val_batch_tokens: 524288 + val_files: ./data/datasets/fineweb10B_sp8192/fineweb_val_*.bin + val_loss_every: 4000 + vocab_size: 8192 + warmdown_frac: 0.667 + warmup_steps: 20 + world_size: 8 + xsa_last_n: 11 +train_shards: 128 +val_tokens: 40540160 +model_params:35944536 +gptq:reserving 12s, effective=588000ms +warmup_step: 1/20 +warmup_step: 2/20 +warmup_step: 3/20 +warmup_step: 4/20 +warmup_step: 5/20 +warmup_step: 6/20 +warmup_step: 10/20 +warmup_step: 20/20 +loop_warmup:enabled encoder:[0, 1, 2, 3, 4, 5, 4, 5] decoder:[4, 5, 4, 5, 6, 7, 8, 9, 10] +loop_warmup_step: 1/20 +loop_warmup_step: 2/20 +loop_warmup_step: 3/20 +loop_warmup_step: 4/20 +loop_warmup_step: 5/20 +loop_warmup_step: 6/20 +loop_warmup_step: 10/20 +loop_warmup_step: 20/20 +0/20000 val_loss: 9.0090 val_bpb: 3.4877 +1/20000 train_loss: 9.0111 train_time: 0.0m tok/s: 8663962 +2/20000 train_loss: 12.3688 train_time: 0.0m tok/s: 8503829 +3/20000 train_loss: 11.1518 train_time: 0.0m tok/s: 8400219 +4/20000 train_loss: 9.4477 train_time: 0.0m tok/s: 8345460 +5/20000 train_loss: 8.3657 train_time: 0.0m tok/s: 8306683 +500/20000 train_loss: 3.3330 train_time: 0.8m tok/s: 8047278 +1000/20000 train_loss: 3.1813 train_time: 1.6m tok/s: 8049149 +1500/20000 train_loss: 3.0931 train_time: 2.5m tok/s: 7989930 +2000/20000 train_loss: 3.0705 train_time: 3.3m tok/s: 8004328 +layer_loop:enabled step:2095 frac:0.350 encoder:[0, 1, 2, 3, 4, 5, 4, 5] decoder:[4, 5, 4, 5, 6, 7, 8, 9, 10] +2500/20000 train_loss: 3.0780 train_time: 4.4m tok/s: 7444513 +3000/20000 train_loss: 2.9552 train_time: 5.6m tok/s: 7025710 +3500/20000 train_loss: 2.9813 train_time: 6.8m tok/s: 6752859 +4000/20000 train_loss: 2.9305 train_time: 8.0m tok/s: 6561444 +4000/20000 val_loss: 2.8908 val_bpb: 1.1191 +4500/20000 train_loss: 2.7828 train_time: 9.2m tok/s: 6423199 +4758/20000 val_loss: 2.8114 val_bpb: 1.0884 +stopping_early: wallclock_cap train_time: 588122ms step: 4758/20000 +peak memory allocated: 39046 MiB reserved: 39074 MiB +ema:applying EMA weights +pre-quantization post-ema val_loss:2.80895681 val_bpb:1.08743472 eval_time:5604ms +Serialized model: 135431033 bytes +Code size: 64137 bytes +GPTQ:collecting Hessians from calibration data... +GPTQ:collected 67 Hessians in 12.8s +Quantized weights: + gptq (int6): blocks.attn.c_k.weight, blocks.attn.c_q.weight, blocks.attn.c_v.weight, blocks.attn.proj.weight, blocks.mlp.fc.weight, blocks.mlp.proj.weight + gptq (int8): tok_emb.weight + passthrough (float16): blocks.attn.q_gain, blocks.attn_scale, blocks.mlp_scale, blocks.resid_mix, skip_gates, skip_weights +Serialized model quantized+brotli: 15975585 bytes +Total submission size quantized+brotli: 16039722 bytes +/home/ubuntu/.local/lib/python3.10/site-packages/torch/distributed/distributed_c10d.py:4876: UserWarning: barrier(): using the device under current context. You can specify `device_id` in `init_process_group` to mute this warning. + warnings.warn( # warn only once +quantized val_loss:2.84322057 val_bpb:1.10069929 eval_time:8098ms +quantized_sliding_window val_loss:2.79952155 val_bpb:1.08378203 eval_time:91155ms