Record: Triple Loop + Fused Kernels + Parallel Residuals + N-gram Tilt; val_bpb 1.08309 (5-seed mean)#1420
Conversation
14879d0 to
d581795
Compare
…m tilt, SP8192 primary path - PR openai#771 confirmed CLOSED/REJECTED (train-then-score AdamW TTT) - PR openai#727 confirmed CLOSED (illegal n-gram hash cache) - Merged SOTA unchanged at 1.1147 - New primary target: PR openai#1420 (abaybektursun, 1.08014): SP8192 + Triple Loop (3×, 17 virtual layers) + N-gram Tilt (legal, properly normalized, -0.0029 bpb) + Fused Kernels (+127 steps) - PR openai#1413 (1.08279): confirms legal score-first TTT adds -0.003 bpb - ETLB (-0.0019 bpb) noted as unruled — await @valerio-oai - Strategy updated to v10.0: SP8192 + Triple Loop replaces SP4096 + 2× https://claude.ai/code/session_01TbdBLJPXpbK5wGHpLAQ9x4
635dd75 to
accb40b
Compare
Post-Quantization Compression: Eight Negative Results@clarkkev established in PR #1394 that compressed model size is governed by Shannon entropy, not hardware bitwidth: This note documents eight attempts to improve the compression pipeline beyond SDClip + GPTQ + Brotli. I raided the toolkits of crystallography (E8 lattice sphere packing), particle physics ( E8 Lattice Vector QuantizationThe E8 lattice achieves optimal sphere packing in 8 dimensions (Viazovska, 2016), with normalized second moment 14% below the cubic lattice. I implemented D8 nearest-point rounding (the integer sublattice: all coordinates with even sum) and measured MSE on Gaussian-distributed weights. D8 increased MSE by 8.37%. The constraint removes half the codewords from the integer grid without adding new ones. The VQ advantage requires dense index-based encoding, not per-coordinate int8 storage. Abandoned. Entropy EqualizationInterpretability analysis revealed 80x variation in per-matrix quantization sensitivity. I derived the optimal bit allocation via Lagrange multipliers on the rate-distortion model, yielding Controlled A/B (same Hessians, 5 seeds): -0.004 BPB. End-to-end training: +0.002 BPB. The A/B test isolated the clip-allocation effect by holding GPTQ randomness constant. In practice, GPTQ stochasticity (~0.002 BPB from calibration sampling and floating-point non-determinism) exceeds the signal. The improvement is real but unmeasurable. Sign-Flip GaugeMLP hidden neurons admit a Scale Discretization53K per-row float16 scales contribute ~100KB of mantissa entropy. I snapped them to a log-lattice before GPTQ, expecting the solver to absorb the <0.8% perturbation. Artifact grew by 31KB. The discretization destroyed the smooth mantissa gradient that Brotli was already exploiting: Shannon entropy decreased, but Kolmogorov complexity increased. ZigZag EncodingTwo's Complement maps Matrix TranspositionColumn-major storage compressed 13KB smaller than row-major on the same quantized tensors because input-feature correlations dominate output-neuron correlations. Combined with stratigraphic dict ordering (grouping same-type matrices for inter-layer LZ77 matches): -16KB offline. End-to-end: +37KB. GPTQ output varies ~30-40KB across runs, overwhelming the signal. Permutation SortMLP hidden-dimension permutation symmetry ( The Noise FloorEvery experiment followed the same pattern: positive in controlled settings, neutral or negative end-to-end. The root cause is a GPTQ noise floor of ~0.002 BPB and ~30-40KB in artifact size, arising from Hessian estimation variance and floating-point non-determinism. Any compression-side optimization below this floor is unmeasurable in practice. Brotli quality=11 is empirically at the byte-level compression frontier for this data. Six distinct byte-manipulation strategies (ZigZag, transposition, bit masking, scale discretization, dict reordering, permutation sort) all failed to improve on it |
|
Experimental Attempts with Negative Results 1. Isospectral Conjugation (Failed: OOM Error)
2. Skip-Gate Variance Normalization (Failed: Redundant)
|
|
I have also been playing with E8 lattice VQ over the past few days - despite not managing to break the frontier it is the best of the VQ methods I've tried, and certainly beat various learned/shared codebook strategies. |
|
@Eppie Curious to know what you think about the ngram tilt |
|
@mtybadger Cool! We need some more fun ideas that will help beat this sota, let me know if you get any more ideas/results |
|
@abaybektursun ngram tilt looks cool! From what I can see, it appears to be fully online, basically, it's a different approach to mixing the prediction from the trained model and the various order-N contexts, with some fixed confidence / count thresholds / priorities. Perhaps "mixing" is the wrong word, since it is focused on narrowing the model's probability distribution by assigning extra probability to the token predicted by the ngram model (as far as I can understand it). Also very cool to see the fused CUDA kernels. Great work! |
…am Tilt — val_bpb 1.07800 (3-seed mean) 3-lever stack on top of PR openai#1394 sp8192 baseline: - Parallel Residuals on layers 7-10 (PR openai#1412 by @Robby955) - 3-layer depth recurrence (LOOP_START=3 LOOP_END=5, extends PR openai#1394's 2-layer recurrence) - Eval-time causal n-gram tilt (PR openai#1420 by @abaybektursun, lineage PR openai#1145 by @AnirudhRahul) Plus our existing PR openai#1413 stack: QK_GAIN_INIT=5, score-first legal TTT (LR=0.005, epochs=3). Results (3-seed mean, 8xH100 SXM): - val_bpb 1.07800 (std 0.00053) - val_loss 2.78457 nats per token - Beats PR openai#1394 (1.08563) by 0.01971 nats per token - Beats PR openai#1420 (1.08014) by 0.00553 nats per token - Beats own PR openai#1413 (1.08279) by 0.01237 nats per token All four issue openai#1017 conditions verified for the n-gram tilt path: prefix-only hash construction, full-vocab renormalized one-token tilt, score-before-update ordering inside the C++ kernel, single left-to-right pass. C++ n-gram kernel ported from PR openai#1420 with the nanobind dependency removed (extern "C" shim + ctypes loader, single g++ -shared invocation at runtime). 5-seed re-verification via the shipped mini wrapper is in progress; this PR will be updated with the final 5-seed mean once s1337 and s2025 land.
…val_bpb 1.07983 3-seed mean val_bpb 1.07983 (std 0.00050) on the PR openai#1394 sp8192 stack. Changes from PR openai#1394 + PR openai#1413 baseline: - Muon momentum = 0.97 (vs 0.99 default), warmup 0.92→0.97 unchanged - Causal token n-gram tilt (base_beta=2.0, agree_bonus=0.1) on top of legal score-first TTT; within-word and word-start experts explicitly disabled (within_beta=0, word_beta=0) because they cannot be made fully causal. - 3-seed verification (seeds 0/42/1234) Seeds: - seed 0 → 1.07928 bpb / 2.78790 nats / 15,993,346 bytes - seed 42 → 1.07997 bpb / 2.78967 nats / 15,992,995 bytes - seed 1234 → 1.08025 bpb / 2.79039 nats / 15,994,604 bytes - mean → 1.07983 bpb / 2.78932 nats / 15,993,648 bytes Delta vs current merged SOTA PR openai#1493 (1.0810): 0.00117 bpb / 0.00302 nats per token Credits: @clarkkev (base PR openai#1394 sp8192 stack), @abaybektursun (n-gram tilt kernel PR openai#1420, causal fix applied), prior legal-TTT precedent PR openai#549 / PR openai#461. Platform: 8xH100 80GB SXM, PyTorch 2.9.1+cu128. Training 588s, eval <437s per seed, both under the 600s budget. Artifact under 16 MB on all 3 seeds.
Triple Loop + Fused Kernels + Parallel Residuals + N-gram Tilt
val_bpb: 1.08309 (5-seed mean, std=0.00044)
Changes
One extra loop pass through layers 4-5. PR Record: SP8192 + GPTQ Embeddings + Depth Recurrence + MuonEq-R + SDClip — val_bpb 1.08563 (5 seed mean) #1394 passes through layers 4-5 three times total (NUM_LOOPS=2, giving 15 virtual layers from 11 physical). I add a fourth pass (NUM_LOOPS=3), giving 17 virtual layers. The encoder becomes
[0,1,2,3,4,5,4,5]and the decoder[4,5,4,5,6,7,8,9,10]. It costs about 200 training steps, but the extra depth more than compensates. Quadruple looping (19 virtual) was worse because the step count drops too far.Activate looping earlier (0.35 instead of 0.50). At 0.50, half the training budget runs without the looped layers doing anything. I swept
{0.30, 0.35, 0.40, 0.50}on seed 1234. 0.35 won, though 0.40 was close. Below 0.35 the model doesn't get enough non-looped warmup and quality degrades.Fused MLP kernels (Triton TMA forward + CUTLASS EVT backward). This took the most engineering effort and gave the most BPB back. The forward fuses
leaky_relu(fc(x), 0.5).square()into a single Triton TMA kernel so the 403MB intermediate never hits HBM. The backward fuses(grad_out @ proj.weight) * act_gradinto a CUTLASS 3.x Epilogue Visitor Tree, running the elementwise multiply while tiles are still in registers. Together: ~10% higher throughput, +127 training steps in the same 600s. I initially tried wrapping the entire MLP in a customautograd.Function, but that killedtorch.compile's cross-layer fusions and made everything 2.7x slower. The trick was to fuse surgically, just the forward activation and one backward GEMM, and let the compiler handle the rest. Details in Appendix A.1–A.3.Parallel residuals for layers 7-10. GPT-J style (Wang & Komatsuzaki, 2021): attention and MLP both read from the same pre-residual input, outputs summed in parallel. I expected this to mostly help quantization (less interference between attention and MLP during GPTQ calibration), and it did tighten the gap slightly. The bigger surprise was +68 training steps from the faster forward pass. I also tried Hessian-Aware SDClip from PR #1412 alongside this, but it made things worse with triple looping. It probably needs its own λ tuning for the deeper architecture.
Eval-time n-gram tilt (causality-fixed). The original submission had a causality violation in the within-word and word-start hint channels:
is_bnd/is_wsflags were derived fromtokens_[p](the target token being predicted), which made the hint-gating decision depend on the target. This was caught by @Gusanidas in review. The fix splits the flags into two sets: prefix-derived flags (tokens_[p-1]) for hint gating, and target-derived flags (tokens_[p]) for post-scoring state updates. However, the within-word and word-start channels cannot produce useful hints without target-dependent gating — they either fire too broadly or at the wrong positions. After testing all causal alternatives (prev_tok gating, state-based gating, disabling channels), the winning configuration uses token_hint only (orders 8-16), which was always fully causal. The remaining token_hint channel provides a consistent -0.00014 BPB across all seeds. The improvement is real but small — most of the original -0.0029 delta came from the (now-removed) target-dependent gating in within/word channels. Full details in Appendix A.4.N-gram legality (#1017 conditions)
Update (post-review fix): The original submission had a Rule 1 violation in the within-word and word-start hint channels. The
is_bnd/is_wsflags used to gate hint generation were derived fromtokens_[p](the target), making the decision of whether to produce a hint depend on the token being predicted. This was caught by @Gusanidas. The fix removes the within-word and word-start channels from hint output entirely — they cannot produce useful hints without target-dependent gating. Only thetoken_hintchannel (orders 8–16) remains, which was always fully causal. The n-gram delta dropped from -0.0029 to -0.00014 BPB.Audited against the four conditions proposed in #1017 for eval-time adaptation:
Condition 1, Causal dependence (
p_tdepends only on artifact +x_1...x_{t-1}):compute_hashesreadstokens[pos - k - 1]for k=0,1,..., all strictly before positionpos.token_hintlooks up hash tables containing only entries inserted by prior iterations. The target tokentokens[pos]is read only for the post-scoring update phase.Condition 2, Full normalized distribution: The tilted distribution is
p_tilt(t) = p_model(t) · exp(β · 1[t==hint]) / ZwhereZ = 1 + p_model(hint) · (exp(β) - 1). Proper probability distribution over the full vocabulary.Condition 3, Score-before-update: Hint and beta are written to output arrays before
token_updateinsertstokens[pos]into the tables.Condition 4, Single left-to-right pass:
get_hints_batchprocesses positions sequentially. The sliding window scores each token exactly once.Double-buffered async data prefetch. Background thread + pinned memory + separate CUDA stream. I built this to work around the virtualized disk I/O on cloud H100 instances (see below), but it ended up helping in every setting I tested.
PyTorch 2.9.1 instead of 2.11. See below.
What the model looks like inside
I ran per-matrix rate-distortion, recurrence error amplification, and skip gate analysis on the trained model. Three things stood out:
Loop layers are 2.2x more sensitive to quantization than non-loop layers. Blocks 4 and 5 get reused across passes, so rounding error in those weights compounds. The single most sensitive matrix in the entire network (block 4's value projection) has 80x the BPB-per-byte cost of the least sensitive. This suggests mixed-precision quantization (more bits for loop layers) is the biggest remaining opportunity.
The third loop pass contributes 63% of what the second does. I measured a contraction ratio of 0.634 across passes: each loop iteration changes the representation by ~63% of the previous one. A hypothetical 4th pass would add only 0.63³ = 25% new information, which matches the empirical finding that quadruple looping hurts. The 3rd pass at 63% is clearly worth the step cost; the 4th at 25% is not.
All 8 skip connections are load-bearing. Gates are 0.61-0.70 (sigmoid), meaning roughly 35% encoder / 65% decoder blend. The first loop pass's skip connections (skips 2,3) have the highest weight norms (21.9, 19.5 vs 2.8-13.8 for others), so the first encoder pass through layers 4-5 is the most important information source for the decoder.
What the progress looks like: three models on the same prompt (temp=0.8)
Prompt (50 tokens): "Insurance Company Declares Living Man Dead George Johannesen is very much alive. Which is why it was so surpr"
Ground truth: ising when the Canadian man received a letter addressed "To the Estate of George Johannesen." Even more surprising is that it came from his insurance company, who should really be on top of such things...
#1019 drifts into incoherence ("Rachel Drobles... techniques of the car industry... Lyon Man is dead"). #1105 stays on topic but loops on "Living Man is the only insurance company." This model picks up the actual narrative thread ("the Canadian man received a letter"), invents plausible biographical details, and maintains coherence throughout. All three are wrong about what happens next, but the errors become progressively more plausible.
Debugging the platform
This was the hardest submission I've worked on. Most of the time went to infrastructure, not the model.
Virtualized disks tank throughput. The cloud H100 instances I rented use virtio block devices. The coprime-stride data loader from #726 does random reads across 143 shards, which is fine on bare metal but brutal on a virtual disk. That's what led me to build the async prefetch. It turned out to help everywhere, not just on virtualized storage.
PyTorch 2.9.1 vs 2.11: a full day lost. I could not reproduce results from other submissions. Training the same architecture with the same seed gave 0.0042 BPB worse results on torch 2.11. (I initially measured a 0.015 gap, which turned out to be a wrong model file on the server. The real gap, once I controlled for that, was 0.0042.) I swapped Triton versions, disabled autocast, forced cuBLAS backends, diffed Inductor-generated kernels. The root cause was two independent issues:
Autocast backward changed in PR pytorch#165068 (landed Dec 2025, present in 2.11, absent from 2.9.1). Two lines in
cached_cast()add anAutoGradMode enable_grad(true)guard on weight casts, inserting extraToCopyBackwardnodes into the autograd graph. This changes floating-point accumulation order by 1 ULP of bf16 (7.15e-7) in saved activations, which compounds over 5000 momentum steps into +60KB of weight entropy. The model goes from fitting at 16.00MB (no pruning) to 16.06MB (5.4% pruning needed). I verified eval is version-invariant to 0.00003 BPB; the entire gap is from training.Inductor over-fusion in backward codegen: Inductor 2.11's
mix_order_reductionfuses_fused_rms_norm_backwardinto adjacent kernels, producing fewer but larger Triton kernels (65 functions / 11,855 lines vs 71 / 11,292 in 2.9.1). The fatter kernels hit register pressure and cost +5.93ms per backward pass (+8.8%). In a 600s budget, that's ~57 lost training steps. I submitted a fix that disablesmix_order_reductionby default (aligning open-source with fbcode, where it was already off): pytorch/pytorch#179494.Separately, our fused CUTLASS kernel crashed on torch 2.11 because Triton 3.6.0's
TensorDescriptor.from_tensor()tries to access.data_ptr()on FakeTensors duringtorch.compiletracing. I traced that through Inductor'sFallbackKernelcodegen and submitted a second fix: pytorch/pytorch#179422. Two PyTorch PRs from a golf competition.In time-budgeted competitions, the platform is the model. A 6ms/step Inductor regression can cost as much BPB as most algorithmic innovations.
How this submission came together
The first few days were mostly wasted. I tried improving the architecture directly: 12 layers, SwiGLU, mixed int5/int8 per layer. Nothing worked. The model was 930KB over the 16MB budget and MLP weights alone were 69% of the compressed artifact. Brotli-11 was already within 1-2% of Shannon entropy. There was nowhere to go.
Worse: a new optimizer schedule I'd been developing (Mixed NS5, a convergent Newton-Schulz coefficient ramp) changed the weight distribution enough that the model no longer fit in the 16MB budget. It was 930KB over, and aggressive pruning to fit destroyed the quality gains.
Then I lost a full day to PyTorch version divergence (described above). Besides the upstream fix, the useful thing that came out of it was a proof that compressed model size is a chaotic function of training hyperparameters. 1 ULP of bf16 rounding (7.15e-7) in a saved activation compounds over 5000 momentum steps into 60KB swings in Brotli output. I also proved that L2 weight decay is scale-invariant under max-abs quantization:
Q(γW) = Q(W). All the per-bank WD tuning I'd been doing was chasing noise.Once I stopped trying to control compression through training and focused on what was actually deterministic (GPTQ deadzone for size, n-gram tilt for eval), things moved fast. Clean reproduction of the baseline. Pivot to SP8192 + SDClip. Triple looping. Fused kernels. Parallel residuals. Each gain was small but they stacked: 45 experiments, five seeds, 1.08014 BPB.
What didn't work
Innovations that worked on earlier models but not here
Mixed NS5 coefficient schedule. On our SP4608 model this was worth -0.0066 BPB for free: use the standard Muon polynomial
(3.4445, -4.775, 2.0315)to ramp singular values toward 1, then switch to the convergent polynomial(1.875, -1.25, 0.375)which hasp(1)=1, p'(1)=0to lock them in. The split adapts per bank based on aspect ratio as a proxy for condition number. On the SP8192 architecture the coefficient schedule produced weight distributions that were hostile to Brotli compression: the model was 500KB over budget and needed 46% pruning.EC-GPTQ (entropy-constrained rounding). Inside the GPTQ inner loop, I added an element-wise deadzone:
dz = λ · d / s², where d is the Hessian diagonal and s is the scale. Borderline ±1 values get rounded to 0 when the GPTQ error compensation cost is small. On the SP4096 architecture this achieved 10x better rate-distortion than uniform deadzoning (0.5×10⁻⁵ BPB/KB vs 6.8×10⁻⁵). On the SP8192 + SDClip architecture it was harmful: SDClip'sc = k·σalready controls entropy per row, and adding EC-GPTQ on top just introduced extra quantization damage for no compression benefit.Per-bank weight decay tuning. MLP is 69% of the compressed model. I tried giving MLP slightly lower WD (0.07 vs 0.09) to improve quality, offset by higher attention WD. Even ±0.005 from the baseline was catastrophic: lower MLP WD means larger MLP weights, which Brotli can't compress cheaply, so the artifact blows up.
L2 weight decay as a compression lever. I proved mathematically that L2 WD is scale-invariant under max-abs quantization:
Q(γW) = round(W / (max|W|/31)) = Q(W). Multiplying all weights by a constant changes nothing about the quantized integers. This was useful to understand (it meant all the WD-based compression tuning I'd been doing was chasing noise), but it also closed a door.Code size
All code ships as part of the artifact:
train_gpt.py, CUTLASS EVT source, and the n-gram C++ source. For a competition run, these would be bundled into a single LZMA-compressed blob.train_gpt.pyis minified withpython-minifier(annotations, pass statements, and docstrings removed; variable names preserved).submission.py(143 bytes) is the entry point: it decompressestrain_gpt.py.lzmaand executes it. For a competition run,torchrunwould invokesubmission.pyinstead oftrain_gpt.py. Total code cost: 19,811 bytes. All 5 seeds fit under 16MB with 1.8-9.9KB headroom. The unminifiedtrain_gpt.py(64KB) is included in the PR for readability.Requirements
pip install flash_attn_3 --find-links https://windreamer.github.io/flash-attention3-wheels/cu128_torch291Credits
Full component lineage: every piece traced to its origin PR
This competition is deeply collaborative. Nearly every component traces through multiple contributors. I've tried to credit the earliest PR that introduced each technique, but many were refined across several submissions.
Appendix
A.0 Ablation: fused 5-seed without parallel residuals
5-seed results: fused kernels + triple loop + n-gram, no parallel residuals
5-seed mean: 1.08080 BPB (std=0.00064). Seed 1234 n-gram was run in terminal (1.08007), not logged to file.
Adding parallel residuals (layers 7+) improves seed 1234 from 1.08007 to 1.07971 (-0.00036), primarily from +68 extra training steps due to the faster parallel forward pass. Full parallel-residuals 5-seed results are in the main table above (mean 1.08014).
A.1 Fused MLP Kernels: Design & Implementation
These kernels were first developed for PR #1105 on the SP4608 architecture. This submission ports them to the SP8192 + triple-loop architecture and integrates the CUTLASS EVT backward with
torch.compile's tracing.Forward (Triton TMA): fuses F.linear + LeakyReLU(0.5) + square
Fuses
F.linear(x, up_w) -> LeakyReLU(0.5) -> squareinto a single kernel. The 403MB intermediate never touches HBM.Uses Triton's Tensor Memory Access (TMA) descriptors for H100-native global-to-shared memory loads. Block sizes
128x256x64with 8 warps, 4 pipeline stages. The kernel performs the GEMM accumulation in FP32, then applies activation and squaring inline before writing back to BF16.The interleaved write pattern splits the accumulator into two halves via
tl.reshape + tl.permute + tl.split, writing activation gradient and post-activation to separate output buffers in a single pass.Backward (CUTLASS EVT): fuses (go @ down_w.T) * act_grad
Fuses
(go @ down_w.T) * act_gradinto a single CUTLASS 3.x kernel via Epilogue Visitor Tree. The elementwise multiply runs in the GEMM epilogue while tiles are still in registers, eliminating one 403MB write + read per layer.I store the activation gradient in the forward pass instead of the pre-activation. This removes all branching from the backward:
The identity
post = 0.5 * act_grad * preholds for both signs:This reduces the CUTLASS EVT epilogue to a trivial 3-node tree:
Sm90EVT<multiplies, AccFetch, AuxLoad>.Why surgical fusion, not full-MLP autograd.Function
torch.compile's cross-layer fusions (RMSNorm backward, residual adds, RoPE backward) account for ~21.6% of step time. Wrapping the full MLP backward inautograd.Functionmakes it opaque to Inductor, so everything runs in eager mode at 2.7x slower net (I hit this in my #670). So I fuse only the forward activation and one backward GEMM+pointwise, preserving the compiler's scope over everything else.A.2 Kernel Benchmarks
Per-layer timing and end-to-end
End-to-end (35 steps, seed=42, 2xH100):
On 8xH100: unfused 4553 steps → fused 4680 steps in 588s (+127 steps, +2.8%).
A.3 Step-Time Profile
Where all 313ms goes (2xH100, Nsight Systems)
A.4 N-Gram Tilt
The n-gram system was originally developed in PR #1105 for SP4608 models. This submission ports it to SP8192. Source code:
ngram/fused_expert_blend.cpp(C++ open-addressing hash, nanobind FFI) andngram/eval_ngram.py(tilt math + sliding window). Eval time on 8xH100: ~90s.Post-review causality fix
The original submission had three hint channels:
token_hint(orders 8–16),within_hint(within-word BPE completion), andword_hint(word-start prediction). @Gusanidas identified thatwithin_hintandword_hintusedis_bnd/is_wsflags derived fromtokens_[p](the target token) to gate whether a hint was produced — a Rule 1 violation.What was invalid: The gating decision "should I produce a hint at this position?" depended on whether the target token was a word boundary or had a leading space. This meant the probability distribution P(x_t | x_1...x_{t-1}) changed depending on the value of x_t itself.
What was tried to salvage within/word channels:
is_bnd/is_wsfromtokens_[p-1](prefix): semantically inverted, delta = +0.00033 (harmful)within_len_state only: fires too broadly, delta = +0.00120 (harmful)Conclusion: The within/word channels' -0.0025 BPB contribution came entirely from target-dependent gating. Without it, they add noise. Only
token_hint(orders 8–16) produces a legitimate improvement. The fix removes within/word from hint output while keeping their state updates (dead code, no effect).Parameter sweep (token_hint only, 4M token subset, 8 GPUs in parallel):
Full-val delta with best params (beta=1.5): consistent -0.00014 BPB across all 5 seeds. The improvement is real but small.
Causality proof (token_hint channel)
The surviving
token_hintchannel is a textbook online n-gram with strict lookup-then-update discipline:p_tdepends only on artifact +x_1...x_{t-1}x_t-dependent updateA.5 Data Prefetch
Double-buffered async prefetch
Background thread prepares next batch in pinned memory while GPU trains. Separate CUDA stream for H2D overlap.
On the PR #1334 architecture: +39 steps, +0.7% throughput. The extra steps landed in a worse compression region (+40KB), so the net effect was actually harmful for that architecture. On PR #1394's
ShuffledSequenceLoaderwith memmap, the data pipeline is already efficient enough that prefetch isn't the bottleneck.A.6 ETLB (Eval-Time Logit Bias)
Algorithm and results
From PR #1399. Learns a vocab-sized bias vector via SGD on already-scored context tokens, carried across sliding windows:
logits + biasResult (seed 1234, double-loop config on torch 2.11): n-gram only 1.08152 → ETLB + n-gram 1.08132 (-0.00020). Not re-tested on the final triple-loop fused config.
Rejected: takes 615s, doesn't fit in 600s eval budget.
A.7 Setup & Reproduction
Full build instructions