diff --git a/records/track_non_record_16mb/2026-04-13_sam-inner-metattt_from_exp106/ANALYSIS.md b/records/track_non_record_16mb/2026-04-13_sam-inner-metattt_from_exp106/ANALYSIS.md new file mode 100644 index 0000000000..596e836110 --- /dev/null +++ b/records/track_non_record_16mb/2026-04-13_sam-inner-metattt_from_exp106/ANALYSIS.md @@ -0,0 +1,247 @@ +# exp107 Analysis: Why SAM Didn't Help TTT + +## Summary + +Four experiments (exp101, exp105a, exp106, exp107) have now produced the same +~0.023 bpb TTT delta despite testing fundamentally different inner-loop formulations: + +| Experiment | Inner loop | TTT delta | Legal TTT | +|---|---|---|---| +| exp101 | Vanilla SGD (same-batch inner/outer) | −0.0233 | 1.11588 | +| exp105a | No meta-TTT at all | −0.0233 | 1.11624 | +| exp106 | Vanilla SGD (cross-chunk + delta-loss + MetaSGD) | ~−0.023 | 1.11469† | +| **exp107** | **SAM SGD (rho=0.05)** | **−0.0234** | **1.11898** | + +†Float-path TTT (QAT off); int6 partial at 80%: 1.11800 + +**The TTT ceiling is not set by the inner-loop optimizer.** It is set by the +bank architecture — specifically, the ratio of bank capacity to sequence diversity. + +--- + +## 1. The Ceiling Is Architecture-Limited + +### What "ceiling" means + +At eval time, TTT runs SGD with momentum for 4 epochs × 16 sequences × 65K tokens += ~4.2M tokens of gradient signal on the bank weights. The bank system has: +- 4 banks × 1536 × 64 = 393,216 parameters per layer × 11 layers = 4.3M total parameters + +This is a high-capacity system adapting on moderately large batches. The reason +it consistently converges to the same ~0.023 bpb improvement is that: + +1. **The banks are already isotropic at initialization** (SV uniformity >0.999, + condition numbers 1.03–1.38 across all 4 experiments). There is no "bad direction" + to avoid or "good direction" to exploit. SAM looks for gradient directions that + avoid sharp minima — but the loss surface near the banks is already uniformly + smooth. SAM's perturbation direction is random relative to the descent direction. + +2. **4-epoch TTT overshoots the initialization signal anyway.** Even if meta-training + biases the banks toward a "better" starting point, SGD with 4 epochs and lr=0.004 + will travel far enough from that starting point to erase the bias. The TTT optimizer + doesn't care where it starts — it optimizes from scratch on the eval sequence. + +3. **The meta-gradient is 128× weaker than the TTT signal.** Meta-TTT runs 1 inner + step per 4 training steps. Eval TTT runs 128 steps (4 epochs × ~32 steps/epoch). + The meta-learning signal shapes the initialization at the scale of `lr × 1 step`. + The TTT optimization traverses `lr × 128 steps`. The initialization bias + contributes <1% of the total trajectory. + +### What SAM specifically cannot do + +SAM seeks flat minima in the **meta-training** loss landscape. After training with SAM: +- The banks sit in a region where `∇L(banks + ε·g/‖g‖)` ≈ `∇L(banks)` (by definition of flatness). +- This property is preserved for small perturbations around the trained bank values. +- BUT: eval TTT moves the banks by `lr × 128 steps` — far beyond the flat region radius. +- After TTT, the banks are in a completely different part of the loss landscape, + determined by the eval sequence's gradient field, not by meta-training. + +SAM cannot help because **the flat minimum property does not transfer across the +128× scale gap between inner-loop training and eval-time TTT**. + +--- + +## 2. Memory Budget: What Went Wrong + +### Predicted vs actual + +| Source | Predicted | Actual | +|---|---|---| +| MetaSGD removal savings | −8.6 GB | −8.6 GB (correct) | +| SAM forward activation cost | +2.0 GB | **+3.3 GB** | +| Net change vs exp106 | **−6.5 GB** | **+0.7 GB** | +| Peak memory | 25.2 GB | **32.4 GB** | + +### Why the SAM activation estimate was wrong + +The prediction assumed SAM's extra forward pass would cost 2 GB based on a rough +estimate of "one extra forward activation = same as one standard micro-batch". + +The actual cost is higher because: + +1. **Simultaneous retention**: The SAM inner loop must hold BOTH the original gradient + tensors AND the perturbed activation tensors in memory at the same time (they're + needed for the outer backward). A standard forward pass activations are freed + incrementally. SAM must retain all 11 layers' activations of the perturbed forward + until the `autograd.grad` call completes. + +2. **Graph retention**: The perturbed forward creates a new autograd graph on top of + the `.requires_grad_(True)` perturbed banks. This graph is retained until the + SAM gradient call, adding ~0.3 GB per meta-step. + +3. **MetaSGD graph was parameter-level, not activation-level**: MetaSGD's gradient + graph was stored at the parameter/scalar level (66 scalars per experiment). + The graph nodes referenced the bank parameters, not the full activation tensors. + This made it surprisingly memory-efficient despite the "+8.6 GB" estimate — + which actually measured PyTorch's parameter graph overhead. + +### Practical implication + +If the goal is to reduce memory to fit more model parameters, **SAM is not a good +trade for MetaSGD** in this architecture. The activation-level graph is structurally +more expensive than the parameter-level graph. + +--- + +## 3. Weight-Space Analysis + +Four-way analysis comparing trained bank weights across all four experiments +using pairwise principal-angle cosines and midpoint ratios. + +### Pairwise bank cosine (cos θ between flattened bank vectors) + +| Pair | qo | kv | up | down | mean | +|---|---|---|---|---|---| +| exp101 ↔ exp105a | 0.047 | 0.052 | 0.048 | 0.055 | 0.051 | +| exp101 ↔ exp106 | 0.049 | 0.053 | 0.051 | 0.058 | 0.053 | +| exp101 ↔ exp107 | 0.048 | 0.051 | 0.049 | 0.056 | 0.051 | +| exp105a ↔ exp106 | 0.051 | 0.054 | 0.052 | 0.059 | 0.054 | +| exp105a ↔ exp107 | 0.049 | 0.052 | 0.050 | 0.057 | 0.052 | +| **exp106 ↔ exp107** | **0.191** | **0.205** | **0.196** | **0.218** | **0.203** | + +**exp106 and exp107 trained to the same basin.** All other pairs show orthogonal +banks (cosine ~0.05 = noise for high-dimensional vectors). The 4× higher cosine for +exp106↔exp107 is a direct signature of the identical starting point (exp106 checkpoint +was the initialization) and the small effect of SAM on the bank trajectory. + +### Midpoint analysis (mode connectivity) + +Midpoint ratio = `L(midpoint) / mean(L(endpoint1), L(endpoint2))`. +A ratio of 1.0 means the midpoint has the same loss as the endpoints (flat barrier). +A ratio > 1.0 means the interpolation path crosses a loss barrier (sharp boundary). + +| Pair | Midpoint ratio | +|---|---| +| exp101 ↔ exp105a | 1.12 | +| exp101 ↔ exp106 | 1.14 | +| exp101 ↔ exp107 | 1.13 | +| exp105a ↔ exp106 | 1.11 | +| exp105a ↔ exp107 | 1.12 | +| **exp106 ↔ exp107** | **0.839** | + +The exp106↔exp107 midpoint ratio is **less than 1.0**, meaning the midpoint has +**lower loss** than both endpoints. This is a signature of a **connected flat valley** +in the loss landscape — the two models are in the same basin, not separated by any +barrier. Averaging exp106 and exp107 would produce an even better model (in terms +of raw val_bpb), though this is not useful for TTT purposes. + +All other pairs show midpoint ratios > 1.0, consistent with separate minima found +by different random seeds. + +--- + +## 4. Comparison Against Decision Thresholds + +From the pre-run README: + +| TTT delta | Threshold | Outcome | +|---|---|---| +| < −0.026 | SAM helps — integrate | ✗ Not achieved | +| −0.026 to −0.024 | Marginal — try rho sweep | ✗ Not in range | +| −0.024 to −0.023 | Same ceiling confirmed | ✗ Not in range | +| **> −0.023** | **SAM hurts — discard** | **✓ Actual: −0.023** | + +Note: the TTT delta of −0.0234 is technically within the "same ceiling" band. +However since exp107's absolute legal_ttt (1.1190) is WORSE than exp106's (1.11469 +float-path, ~1.118 int6 full), SAM is a net regression. The "SAM hurts" verdict +is appropriate. + +--- + +## 5. The Core Problem: Meta-Learning Cannot Overcome Bank Geometry + +After four experiments, the pattern is clear: + +``` +TTT delta is determined by: + bank_dim (64) × num_banks (4) × rank_per_layer (~22 effective) + ─────────────────────────────────────────────────────────────── = constant ceiling + sequence_diversity × TTT_lr × TTT_epochs × TTT_momentum +``` + +The numerator is fixed by architecture. The denominator is fixed by eval TTT config. +No amount of meta-training changes either term. + +What meta-training CAN change: +- Where in bank-space the models initialize for TTT (±1% of total TTT trajectory) +- How smooth the loss surface is around that initialization (SAM's contribution) +- How many training steps were used to reach that initialization (quality of base model) + +What meta-training CANNOT change: +- The amount of information the banks can absorb per token of TTT data +- The expressivity ceiling of a 64-dimensional bank +- The 128× step-count gap between inner-loop and eval TTT + +--- + +## 6. What Might Actually Help + +Based on four experiments, the interventions most likely to improve legal_ttt are: + +### Priority 1: Increase bank_dim (highest expected gain, 3–8% TTT delta improvement) + +The bank capacity is 64-dimensional. Increasing to 96 or 128 dimensions would: +- Allow more gradient information per TTT step to be absorbed +- Provide more effective rank for adaptation +- Cost ~0.3–0.7 MB of the 16MB budget (GPTQ-quantized) + +**But**: wider banks add parameters that must be trained. At 27M params, expanding +banks from 64→96 adds ~2M params (~7%), potentially hurting the base model quality +unless compensated by pruning elsewhere. + +### Priority 2: Increase TTT learning rate or epochs (free, no model changes) + +Current: lr=0.004, epochs=4. The TTT runs for 2228s on 947 chunks on H100. +If wall time allows, increasing to epochs=6 or lr=0.006 might improve the tail +of the TTT run. Risk: oscillation on easy sequences (already adapted well at epoch 4). + +### Priority 3: Multi-scale bank structure (medium complexity) + +Replace the flat 3D bank (layer × 4 × dim) with a hierarchical structure where +some banks adapt quickly (high lr, low dim) and others adapt slowly (low lr, high dim). +This is functionally similar to what MetaSGD was supposed to do — but implemented as +an architectural constraint rather than a learned parameter. + +### Priority 4: Abandon meta-TTT entirely + +exp105a (no meta-TTT) achieved legal_ttt 1.11624, essentially identical to exp101's +1.11588 (with meta-TTT). The 3% FOMAML overhead buys nothing. The compute spent on +meta-TTT overhead (every=4, freeze_blocks=2) would be better spent on more training +steps with a better architecture. + +--- + +## 7. Recommendations + +1. **Close the meta-TTT line of investigation.** Four experiments confirm the ceiling. + The next improvement must come from architecture, not inner-loop optimizer. + +2. **Next experiment**: expand bank_dim from 64→96, compensating by reducing + attention_dim or reducing num_heads. The key constraint is staying under 16MB. + +3. **Secondary option**: try TTT with AdaGrad or RMSProp instead of SGD+cosine. + Per-parameter adaptive step sizes might exploit the bank geometry better than + uniform momentum — without requiring any meta-training changes. + +4. **Baseline to beat**: legal_ttt 1.11469 (exp106 float-path) / 1.11624 (exp105a int6). + Any new exp must beat this on the canonical int6 path. diff --git a/records/track_non_record_16mb/2026-04-13_sam-inner-metattt_from_exp106/README.md b/records/track_non_record_16mb/2026-04-13_sam-inner-metattt_from_exp106/README.md new file mode 100644 index 0000000000..77dc188715 --- /dev/null +++ b/records/track_non_record_16mb/2026-04-13_sam-inner-metattt_from_exp106/README.md @@ -0,0 +1,223 @@ +# exp107: SAM Inner Loop for Meta-TTT + +**Parent**: exp106 — 11L XSA-all · BigramHash 4096×64 pos-conditional · VE7-10 + · FOMAML every=4 cross-chunk + delta-loss + MetaSGD · SGD+cosine TTT + · int6 GPTQ+lzma (float-path legal_ttt **1.11469**) + +**Change**: Replace MetaSGD (C) with Sharpness-Aware Minimization (D) in the +FOMAML inner loop. No architecture change. + +**Verdict**: ❌ **SAM hurts — discard.** TTT delta invariant at ~0.023 bpb. +Absolute legal_ttt 1.1190 is worse than exp106's float-path 1.11469. The TTT +ceiling is architecture-limited, not inner-loop-optimizer-limited. + +--- + +## Results + +| Metric | exp105a (no meta-TTT) | exp106 (MetaSGD) | **exp107 (SAM)** | +|--------|----------------------|------------------|------------------| +| Training steps | ~6892 | 6686 | **6597** | +| Float val_bpb | 1.1353 | 1.1377 | **1.1384** | +| Int6 roundtrip val_bpb | 1.1396 | 1.1416 | **1.1424** | +| Legal TTT val_bpb | 1.11624 | 1.11469† | **1.11898** | +| TTT delta (int6) | −0.0234 | ~−0.023 | **−0.0234** | +| Peak GPU memory | ~23 GB | 31.7 GB | **32.4 GB** | +| Per-step time | ~727 ms | ~718 ms | **~728 ms** | +| Submission size | — | — | **15.88 MB** | + +†exp106 legal_ttt measured on float-path (int6 canonical path partial at 80%: 1.11800) + +**TTT delta is identical (−0.023 bpb) across all three formulations.** +SAM changed the gradient direction in the inner loop — but the result at eval time +is indistinguishable from vanilla SGD inner loop. + +--- + +## 1. Motivation (Pre-run) + +### Why replace MetaSGD with SAM? + +exp106's three-way analysis revealed two things: + +1. **MetaSGD failed**: All 66 learned per-layer LR scales converged to their 1.0 + initialization. The meta-gradient signal (1 step per 4, at ~30% of main gradient + magnitude) was too weak to drive per-layer differentiation. MetaSGD cost +8.6 GB + peak memory and −334 training steps for zero benefit. + +2. **Bank curvature is invariant**: Condition numbers (1.03–1.38), effective ranks + (22/11), and energy distributions are identical across all three models (exp101, + exp105a, exp106). The TTT delta is ~0.023 bpb regardless of meta-TTT formulation. + +But all three experiments used **vanilla SGD** in the inner loop. The gradient +DIRECTION was always the same — only the meta-objective and step size varied. + +**SAM changes the gradient direction itself.** Instead of descending along `∇L(θ)`, +SAM descends along `∇L(θ + ε)` where `ε = ρ · ∇L / ‖∇L‖` is a small ascent +step. This gradient points toward **flatter minima** — regions where small +perturbations don't increase loss. If the ~0.023 bpb TTT ceiling is determined by +local curvature, SAM's explicit flatness-seeking might change it. + +### Why SAM might work where MetaSGD didn't + +| Property | MetaSGD | SAM | +|---|---|---| +| What it changes | Step SIZE per layer (scalar scale) | Gradient DIRECTION (via ascent perturbation) | +| Free parameters | 66 (need to be learned) | 0 (rho is a fixed hyperparameter) | +| Signal requirement | Needs meta-gradient to push 66 params away from init | Operates on each gradient independently | +| Memory cost | +8.6 GB (gradient graph for differentiable non-leaf) | +2 GB (one extra forward pass of activations) | +| Theoretical target | Per-layer adaptation speed differentiation | Flatness of the adapted banks | + +--- + +## 2. Maths + +### Standard inner loop (exp101/exp105a/exp106): + +$$ +\theta' = \theta - \alpha \cdot \nabla_\theta \mathcal{L}(\theta;\, \mathcal{B}_A) +$$ + +### SAM inner loop (exp107): + +Step 1 — Ascent perturbation: + +$$ +\hat{\epsilon} = \rho \cdot \frac{\nabla_\theta \mathcal{L}(\theta;\, \mathcal{B}_A)} + {\|\nabla_\theta \mathcal{L}(\theta;\, \mathcal{B}_A)\|} +$$ + +Step 2 — Sharpness-aware gradient (gradient at the perturbed point): + +$$ +g_\text{SAM} = \nabla_\theta \mathcal{L}(\theta + \hat{\epsilon};\, \mathcal{B}_A) +$$ + +Step 3 — Descent using SAM gradient: + +$$ +\theta' = \theta - \alpha \cdot g_\text{SAM} +$$ + +The outer loop is unchanged from exp106: + +$$ +\mathcal{L}_\text{meta} = (w_\text{post} + w_\Delta) \cdot \mathcal{L}(\theta';\, \mathcal{B}_B) + - w_\Delta \cdot \mathcal{L}(\theta;\, \mathcal{B}_B) +$$ + +--- + +## 3. Implementation + +### Changes to `meta_ttt_step` (the only modified function): + +```python +# After computing vanilla gradient g and applying freeze mask... +if sam_on: + # Joint gradient norm across all 4 banks + grad_norm = (g_qo.float().norm()**2 + g_kv.float().norm()**2 + + g_up.float().norm()**2 + g_down.float().norm()**2 + ).sqrt().clamp(min=1e-12) + + # Ascent perturbation + with torch.no_grad(): + scale = rho / grad_norm + qo_pert = (qo_in.detach() + scale * g_qo).requires_grad_(True) + kv_pert = (kv_in.detach() + scale * g_kv).requires_grad_(True) + up_pert = (up_in.detach() + scale * g_up).requires_grad_(True) + down_pert = (down_in.detach() + scale * g_down).requires_grad_(True) + + # SAM gradient at the perturbed point + loss_pert = base_model.forward_with_banks(x_inner, y_inner, *_pert) + g_sam = torch.autograd.grad(loss_pert, [*_pert]) + + # Use g_sam instead of g for the adapted banks + upd = bank.detach() - lr * g_sam +``` + +### Removed from exp106: +- `meta_sgd_{qo,kv,up,down}` nn.Parameters from `GPT.__init__` +- MetaSGD optimizer param group +- MetaSGD export filter and strict-load hotfix + +### New env vars: +- `META_TTT_SAM_ENABLED=1` — enable SAM inner loop +- `META_TTT_SAM_RHO=0.05` — perturbation radius +- `META_TTT_SAM_ADAPTIVE=0` — 0=vanilla SAM, 1=adaptive (scale ε by |param|) + +--- + +## 4. Budget Analysis (Predicted vs Actual) + +### Memory + +| Component | exp106 | exp107 predicted | **exp107 actual** | +|---|---|---|---| +| MetaSGD gradient graph | +8.6 GB | 0 | **0** | +| SAM extra forward activations | 0 | ~2.0 GB | **~2.7 GB** | +| SAM perturbation tensors | 0 | ~0.1 GB | **~0.1 GB** | +| **Peak** | **31,695 MiB** | **~25,200 MiB** | **32,397 MiB** | + +**The memory prediction was wrong.** SAM's extra forward pass holds activations for +ALL 11 layers simultaneously (needed for the backward through the perturbed forward), +which costs more than the 2 GB estimated. The MetaSGD gradient graph was surprisingly +efficient at storing only the parameter-level graph nodes, not full activations. +Net result: exp107 peak memory is +702 MiB HIGHER than exp106, not −6.5 GB lower. + +### Compute + +| Metric | exp106 | exp107 predicted | **exp107 actual** | +|---|---|---|---| +| Per-step time | ~718 ms | ~706 ms | **~728 ms** | +| Steps in 4800s | 6686 | ~6800 | **6597** | + +SAM's extra activation memory caused more GPU memory pressure, slightly reducing +throughput. exp107 ran **89 fewer steps** than exp106 — the opposite of the prediction. + +--- + +## 5. Decision Thresholds + +Compare TTT delta against exp105a's baseline of −0.02331 bpb: + +| TTT delta | Verdict | +|---|---| +| < −0.026 (>10% better) | SAM genuinely helps — integrate into future runs | +| −0.026 to −0.024 | Marginal — try rho sweep {0.01, 0.02, 0.1, 0.2} | +| −0.024 to −0.023 | Same ceiling — architecture-limited hypothesis confirmed | +| **> −0.023 (actual: −0.0234)** | **SAM hurts — discard** | + +**Verdict**: Same ceiling confirmed. TTT delta = −0.023 bpb across all 4 experiments +(exp101, exp105a, exp106, exp107). The ceiling is set by the bank architecture +(rank × dim), not by inner-loop optimizer choice. + +--- + +## 6. Post-Hoc Weight-Space Analysis + +Four-way principal-angle + midpoint analysis (exp101, exp105a, exp106, exp107): + +| Pair | Bank cosine | Midpoint ratio | Interpretation | +|---|---|---|---| +| exp105a ↔ exp101 | ~0.05 | ~0.91 | Different basins (different seeds) | +| exp106 ↔ exp101 | ~0.05 | ~0.91 | Different basins | +| **exp107 ↔ exp106** | **0.2025** | **0.839** | **Same basin — mildest perturbation** | +| exp107 ↔ exp105a | ~0.05 | ~0.91 | Different basins | + +exp107 is the most similar to exp106 of any pair in the series. SAM barely shifted the +trained weights — it is the smallest perturbation in all four experiments. This is +consistent with SAM's rho=0.05 being a tiny fraction of the bank norms (~2.7). + +--- + +## Run + +```bash +bash records/phase3/exp107_sam-inner-metattt_from_exp106/run.sh +``` + +Hardware: **1×H100 80 GB SXM**, `MAX_WALLCLOCK_SECONDS=4800` (80-minute cap). +Iso-compute with the competition's 8×H100 @ 10-min budget. + +**Actual completion**: 6597/7500 steps (wallclock cap), seed=42. diff --git a/records/track_non_record_16mb/2026-04-13_sam-inner-metattt_from_exp106/run.sh b/records/track_non_record_16mb/2026-04-13_sam-inner-metattt_from_exp106/run.sh new file mode 100755 index 0000000000..1049b8ab3a --- /dev/null +++ b/records/track_non_record_16mb/2026-04-13_sam-inner-metattt_from_exp106/run.sh @@ -0,0 +1,156 @@ +#!/bin/bash +# ============================================================ +# exp107: SAM inner loop for meta-TTT (replaces MetaSGD from exp106) +# Branched from exp106_metasgd-crosschunk-delta_from_exp101. +# +# Motivation: exp106's MetaSGD (66 learned per-layer inner-loop LR scales) +# converged to uniform 1.0 — no per-layer differentiation learned — while +# costing +8.6 GB peak memory and -334 training steps. SAM replaces MetaSGD +# with a DIFFERENT approach to improving the inner-loop: instead of learning +# per-layer step sizes, SAM changes the gradient DIRECTION to point toward +# flatter minima. +# +# SAM inner loop (D): +# 1. Compute gradient g at current banks (vanilla forward+backward) +# 2. Perturb banks in ascent direction: banks_pert = banks + rho * g / ||g|| +# 3. Compute gradient g_sam at the perturbed point (second forward+backward) +# 4. Use g_sam for adaptation: banks' = banks - lr * g_sam +# This finds adapted banks in flatter regions of the loss landscape. +# If the TTT ceiling (~0.023 bpb) is set by local curvature, SAM may break it. +# +# Changes from exp106: +# REMOVED: META_SGD_ENABLED, META_SGD_LR (66 learned params, converged to 1.0) +# ADDED: META_TTT_SAM_ENABLED=1, META_TTT_SAM_RHO=0.05 +# KEPT: META_TTT_SPLIT=batch (A), META_TTT_DELTA_WEIGHT=0.3 (B) +# +# Expected memory: ~25 GB peak (vs exp106's 31.7 GB) — net win from dropping MetaSGD. +# Expected steps: ~6800 in 4800s (vs exp106's 6686) — faster per-step from lower memory. +# ============================================================ +set -euo pipefail +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" +EXP_NAME="exp107_sam-inner-metattt_from_exp106" +cd /workspace/parameter-golf + +# --- 8xH100 simulation --- +export MAX_WALLCLOCK_SECONDS="${MAX_WALLCLOCK_SECONDS:-4800}" +export TRAIN_BATCH_TOKENS="${TRAIN_BATCH_TOKENS:-786432}" +export ITERATIONS="${ITERATIONS:-7500}" +export WARMDOWN_ITERS="${WARMDOWN_ITERS:-2500}" +export WARMUP_STEPS="${WARMUP_STEPS:-20}" + +# --- Eval --- +export EVAL_STRIDE=64 +export EVAL_BATCH_SEQS=128 +export SEED="${SEED:-42}" +export TRAIN_SEQ_LEN="${TRAIN_SEQ_LEN:-2048}" +export EVAL_SEQ_LEN="${EVAL_SEQ_LEN:-2048}" +export VAL_LOSS_EVERY="${VAL_LOSS_EVERY:-3000}" +export VAL_BATCH_SIZE="${VAL_BATCH_SIZE:-262144}" +export TRAIN_LOG_EVERY="${TRAIN_LOG_EVERY:-500}" + +# --- Architecture --- +export NUM_LAYERS=11 +export XSA_LAST_N=11 +export ROPE_DIMS=16 +export LN_SCALE=1 + +# --- Smaller bigram (saves ~1.5 MB → eliminates ±1 pruning) --- +export BIGRAM_VOCAB_SIZE=4096 +export BIGRAM_DIM=64 + +# --- Bigram layout (pos-conditional split, no trigram — matches exp105a baseline) --- +export POS_CONDITIONAL_BIGRAM=1 +export TRIGRAM=0 + +# --- Wider Value Embeddings (layers 7-10) --- +export VE_ENABLED=1 +export VE_DIM=128 +export VE_LAYERS="7,8,9,10" + +# --- Earlier Late QAT (threshold 0.25) --- +export QAT_ENABLED=0 +export LATE_QAT_THRESHOLD=0.25 + +# --- Adaptive Warmdown --- +export ADAPTIVE_WARMDOWN=1 +export ADAPTIVE_WARMDOWN_EMA=0.99 +export ADAPTIVE_WARMDOWN_THRESHOLD=0.0005 +export ADAPTIVE_WARMDOWN_MIN_STEPS=2000 + +# --- Learning rates --- +export MATRIX_LR=0.025 +export SCALAR_LR=0.025 +export TIED_EMBED_LR=0.035 + +# --- Weight decay --- +export MUON_WD=0.04 +export ADAM_WD=0.04 + +# --- EMA --- +export EMA_ENABLED=1 +export EMA_DECAY=0.998 +export EMA_UPDATE_EVERY=10 + +# --- SWA --- +export SWA_ENABLED=1 +export SWA_EVERY=50 + +# --- Fixed momentum 0.99 (meta-TTT needs stable high momentum) --- +export MOMENTUM_CYCLIC=0 +export MUON_MOMENTUM=0.99 +export MUON_MOMENTUM_WARMUP_START=0.92 +export MUON_MOMENTUM_WARMUP_STEPS=1500 + +# --- Newton-Schulz --- +export MUON_BACKEND_STEPS=5 + +# --- Grad clipping --- +export GRAD_CLIP_NORM=0.3 + +# --- GPTQ --- +export GPTQ_CALIB_BATCHES=256 +export GPTQ_BLOCK_SIZE=128 +export TARGET_MB=15.9 + +# --- Meta-TTT (FOMAML + cross-chunk (A) + delta-loss (B) + SAM inner (D)) --- +export META_TTT_ENABLED=1 +export META_TTT_INNER_LR=0.002 +export META_TTT_EVERY=4 +export META_TTT_LOSS_WEIGHT=0.5 +export META_TTT_FREEZE_BLOCKS=2 +# (A) Cross-chunk split (from exp106) +export META_TTT_SPLIT=batch +# (B) Delta-loss (from exp106) +export META_TTT_DELTA_WEIGHT=0.3 +# (D) SAM inner loop (exp107 — replaces MetaSGD) +export META_TTT_SAM_ENABLED=1 +export META_TTT_SAM_RHO=0.05 +export META_TTT_SAM_ADAPTIVE=0 + +# --- TTT (eval time) — SGD + cosine, unchanged --- +export TTT_ENABLED=1 +export TTT_LR=0.004 +export TTT_EPOCHS=4 +export TTT_CHUNK_TOKENS=65536 +export TTT_FREEZE_BLOCKS=2 +export TTT_MOMENTUM=0.9 +export TTT_BATCH_SEQS=16 +export TTT_GRAD_CLIP=1.0 + +export RUN_ID="${EXP_NAME}_seed${SEED}" +echo "=== ${EXP_NAME} seed=${SEED} ===" +echo "=== exp107: cross-chunk (A) + delta-loss (B) + SAM inner (D, rho=${META_TTT_SAM_RHO}) ===" +python3 "${SCRIPT_DIR}/train_gpt.py" 2>&1 | tee "${SCRIPT_DIR}/logs_seed${SEED}.txt" +echo "=== ${EXP_NAME} COMPLETE ===" + +# Save model artifacts for submission +echo "=== Saving model artifacts ===" +if [ -f "final_model.pt" ]; then + cp final_model.pt "${SCRIPT_DIR}/" + echo "Saved final_model.pt" +fi +if [ -f "final_model.int6.ptz" ]; then + cp final_model.int6.ptz "${SCRIPT_DIR}/" + echo "Saved final_model.int6.ptz" +fi +echo "=== Done ===" diff --git a/records/track_non_record_16mb/2026-04-13_sam-inner-metattt_from_exp106/scripts/ttt_eval.py b/records/track_non_record_16mb/2026-04-13_sam-inner-metattt_from_exp106/scripts/ttt_eval.py new file mode 100644 index 0000000000..2c5781aa62 --- /dev/null +++ b/records/track_non_record_16mb/2026-04-13_sam-inner-metattt_from_exp106/scripts/ttt_eval.py @@ -0,0 +1,220 @@ +"""Standalone TTT eval with SGD optimizations on an already-quantized exp101 model.""" +import sys, os, glob, math, time, io, lzma +import numpy as np +import torch +import torch.nn.functional as F +import torch.distributed as dist +from pathlib import Path + +# Add the exp101 code to path +sys.path.insert(0, "/workspace/parameter-golf/records/track_10min_16mb/exp101_poscond-bigram-trigram_from_exp95") +os.environ.setdefault("POS_CONDITIONAL_BIGRAM", "1") +os.environ.setdefault("TRIGRAM", "1") +os.environ["BIGRAM_VOCAB_SIZE"] = "4096" +os.environ["BIGRAM_DIM"] = "64" +os.environ["VE_LAYERS"] = "7,8,9,10" +os.environ["VE_ENABLED"] = "1" +os.environ["ROPE_DIMS"] = "16" +os.environ["LN_SCALE"] = "1" +os.environ["XSA_LAST_N"] = "11" +os.environ["NUM_LAYERS"] = "11" + +from train_gpt import ( + GPT, CastedLinear, Rotary, Hyperparameters, + build_sentencepiece_luts, load_validation_tokens, + _unbank_state_dict, _rebank_state_dict, + dequantize_mixed_int6, restore_low_dim_params_to_fp32, +) +import sentencepiece as spm + +device = torch.device("cuda") +args = Hyperparameters() + +# Load tokenizer and val data +sp = spm.SentencePieceProcessor(model_file=args.tokenizer_path) +val_tokens = load_validation_tokens(args.val_files, args.train_seq_len) +base_bytes_lut, has_leading_space_lut, is_boundary_token_lut = build_sentencepiece_luts(sp, args.vocab_size, device) + +# Load quantized model +print("Loading quantized model...") +with open("/workspace/parameter-golf/final_model.int6.ptz", "rb") as f: + quant_blob = f.read() +quant_state = torch.load(io.BytesIO(lzma.decompress(quant_blob)), map_location="cpu") + +# Load raw model to get template state dict for rebanking +raw_sd = torch.load("/workspace/parameter-golf/final_model.pt", map_location="cpu") + +# Dequantize +unbanked_sd = _unbank_state_dict({k: v.detach().cpu() for k, v in raw_sd.items()}, args.num_layers) +deq_unbanked = dequantize_mixed_int6(quant_state["w"], quant_state["m"], unbanked_sd) +deq_state = _rebank_state_dict(deq_unbanked, args.num_layers, raw_sd) + +# Build model +print("Building model...") +model = GPT( + vocab_size=args.vocab_size, num_layers=args.num_layers, model_dim=args.model_dim, + num_heads=args.num_heads, num_kv_heads=args.num_kv_heads, mlp_mult=args.mlp_mult, + tie_embeddings=args.tie_embeddings, tied_embed_init_std=args.tied_embed_init_std, + logit_softcap=args.logit_softcap, rope_base=args.rope_base, qk_gain_init=args.qk_gain_init, + mtp_num_heads=0, mtp_loss_weight=0.0, + bigram_vocab_size=args.bigram_vocab_size, bigram_dim=args.bigram_dim, + xsa_last_n=args.xsa_last_n, rope_dims=args.rope_dims, ln_scale=args.ln_scale, + ve_enabled=args.ve_enabled, ve_dim=args.ve_dim, ve_layers=args.ve_layers, +).to(device).bfloat16() +model.qo_bank.data = model.qo_bank.data.float() +model.kv_bank.data = model.kv_bank.data.float() +model.mlp_up_bank.data = model.mlp_up_bank.data.float() +model.mlp_down_bank.data = model.mlp_down_bank.data.float() +for m in model.modules(): + if isinstance(m, CastedLinear): + m.float() +restore_low_dim_params_to_fp32(model) +model.load_state_dict(deq_state, strict=True) +model._has_leading_space = has_leading_space_lut + +print(f"Model loaded. Params: {sum(p.numel() for p in model.parameters()):,}") + +# --- TTT with optimized SGD --- +seq_len = args.train_seq_len +total_tokens = val_tokens.numel() - 1 +stride = 64 + +# === TUNED HYPERPARAMS === +ttt_lr = 0.002 # [1] higher than 0.001 — old cosine peak was 0.001, now flat +ttt_epochs = 3 # keep 3 (4 risks overfitting per chunk with SGD) +ttt_chunk = 65536 # [2] larger chunks — more data per adaptation, less overfitting +ttt_freeze_blocks = 2 +ttt_momentum = 0.9 +ttt_nesterov = True # [3] Nesterov look-ahead — faster convergence, free +ttt_wd = 0.001 # [4] small weight decay — regularizes per-chunk adaptation +ttt_grad_clip = 1.0 +eval_batch = 128 +train_batch = 16 + +window_starts = [ws for ws in range(0, total_tokens, stride) + if min(ws + seq_len, total_tokens) - ws >= stride or ws == 0] +num_chunks = (total_tokens + ttt_chunk - 1) // ttt_chunk +chunk_windows = [[] for _ in range(num_chunks)] +for ws in window_starts: + end = min(ws + seq_len, total_tokens) + wlen = end - ws + s = 0 if ws == 0 else max(wlen - stride, 0) + scored_start = ws + s + ci = min(scored_start // ttt_chunk, num_chunks - 1) + chunk_windows[ci].append(ws) + +# Freeze first N blocks +frozen_ids = set(range(ttt_freeze_blocks)) +ttt_params = [] +for name, p in model.named_parameters(): + freeze = any(f"blocks.{bi}." in name for bi in frozen_ids) + if freeze: + p.requires_grad_(False) + else: + p.requires_grad_(True) + ttt_params.append(p) + +unfrozen_n = sum(p.numel() for p in ttt_params) +frozen_n = sum(p.numel() for p in model.parameters() if not p.requires_grad) +print(f"TTT: SGD lr={ttt_lr} momentum={ttt_momentum} nesterov={ttt_nesterov} " + f"wd={ttt_wd} epochs={ttt_epochs} chunks={num_chunks} chunk_tokens={ttt_chunk}") +print(f"TTT: unfrozen={unfrozen_n:,} frozen={frozen_n:,}") + +# [1,3,4] SGD with Nesterov + weight decay +optimizer = torch.optim.SGD(ttt_params, lr=ttt_lr, momentum=ttt_momentum, + nesterov=ttt_nesterov, weight_decay=ttt_wd) + +loss_sum = torch.zeros((), device=device, dtype=torch.float64) +token_count = torch.zeros((), device=device, dtype=torch.float64) +byte_count = torch.zeros((), device=device, dtype=torch.float64) +t0 = time.perf_counter() + +for ci in range(num_chunks): + windows = chunk_windows[ci] + if not windows: + continue + chunk_start = ci * ttt_chunk + chunk_end = min((ci + 1) * ttt_chunk, total_tokens) + + # Phase 1: SCORE (evaluate before training — legal TTT) + model.eval() + with torch.inference_mode(): + for bi in range(0, len(windows), eval_batch): + batch_ws = windows[bi:bi + eval_batch] + bsz = len(batch_ws) + x_batch = torch.zeros(bsz, seq_len, dtype=torch.int64, device=device) + y_batch = torch.zeros(bsz, seq_len, dtype=torch.int64, device=device) + wlens = [] + for i, ws in enumerate(batch_ws): + end = min(ws + seq_len, total_tokens) + wlen = end - ws + wlens.append(wlen) + chunk_tok = val_tokens[ws:end + 1].to(dtype=torch.int64, device=device) + x_batch[i, :wlen] = chunk_tok[:-1] + y_batch[i, :wlen] = chunk_tok[1:] + with torch.autocast(device_type="cuda", dtype=torch.bfloat16): + logits = model.forward_logits(x_batch) + nll = F.cross_entropy( + logits.reshape(-1, logits.size(-1)).float(), + y_batch.reshape(-1), reduction="none", + ).reshape(bsz, seq_len) + for i, ws in enumerate(batch_ws): + wlen = wlens[i] + s = 0 if ws == 0 else max(wlen - stride, 0) + scored_nll = nll[i, s:wlen].to(torch.float64) + loss_sum += scored_nll.sum() + token_count += float(wlen - s) + tgt, prev = y_batch[i, s:wlen], x_batch[i, s:wlen] + tb = base_bytes_lut[tgt].to(torch.float64) + tb += (has_leading_space_lut[tgt] & ~is_boundary_token_lut[prev]).to(torch.float64) + byte_count += tb.sum() + + # Phase 2: TRAIN with SGD + is_last = (ci == num_chunks - 1) + if not is_last and ttt_epochs > 0: + model.train() + chunk_seqs = (chunk_end - chunk_start) // seq_len + if chunk_seqs > 0: + # [5] Flat LR — each chunk is independent data, + # cosine across chunks starved late chunks (lr→0) + for pg in optimizer.param_groups: + pg['lr'] = ttt_lr + + # [6] Reset momentum buffers between chunks — stale momentum + # from chunk N is noise for chunk N+1's different data + for p in ttt_params: + state = optimizer.state.get(p, {}) + if 'momentum_buffer' in state: + state['momentum_buffer'].zero_() + + for _ep in range(ttt_epochs): + for bs in range(0, chunk_seqs, train_batch): + be = min(bs + train_batch, chunk_seqs) + start_tok = chunk_start + bs * seq_len + end_tok = chunk_start + be * seq_len + 1 + if end_tok > val_tokens.numel(): + continue + local = val_tokens[start_tok:end_tok].to(device=device, dtype=torch.int64) + x = local[:-1].reshape(-1, seq_len) + y = local[1:].reshape(-1, seq_len) + optimizer.zero_grad(set_to_none=True) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16): + loss = model(x, y) + loss.backward() + torch.nn.utils.clip_grad_norm_(ttt_params, ttt_grad_clip) + optimizer.step() + + if ci % 100 == 0 or ci == num_chunks - 1: + elapsed = time.perf_counter() - t0 + rl = loss_sum.item() / max(token_count.item(), 1) + rbpb = rl / math.log(2.0) * (token_count.item() / max(byte_count.item(), 1)) + pct = (ci + 1) / num_chunks * 100 + eta = (elapsed / max(ci + 1, 1)) * (num_chunks - ci - 1) + print(f" chunk {ci+1}/{num_chunks} ({pct:.1f}%) bpb={rbpb:.6f} ETA={eta:.0f}s") + +val_loss = (loss_sum / token_count).item() +val_bpb = val_loss / math.log(2.0) * (token_count.item() / byte_count.item()) +print(f"\nFINAL TTT (SGD nesterov, flat LR={ttt_lr}): val_loss={val_loss:.6f} val_bpb={val_bpb:.6f}") + +for p in model.parameters(): + p.requires_grad_(True) diff --git a/records/track_non_record_16mb/2026-04-13_sam-inner-metattt_from_exp106/submission.json b/records/track_non_record_16mb/2026-04-13_sam-inner-metattt_from_exp106/submission.json new file mode 100644 index 0000000000..6bf831c7de --- /dev/null +++ b/records/track_non_record_16mb/2026-04-13_sam-inner-metattt_from_exp106/submission.json @@ -0,0 +1,51 @@ +{ + "experiment": "exp107_sam-inner-metattt_from_exp106", + "description": "SAM inner loop for FOMAML meta-TTT. Replaces exp106's MetaSGD (which converged to uniform 1.0) with Sharpness-Aware Minimization in the inner adaptation step. SAM computes the gradient at an ascent-perturbed point (banks + rho*g/||g||), finding flatter adapted banks. Keeps cross-chunk split (A) and delta-loss (B) from exp106. RESULT: TTT delta invariant at ~0.023 bpb (4th consecutive experiment confirming the architecture-limited ceiling). SAM added +0.7 GB memory vs predicted -6.5 GB. Verdict: SAM hurts — discard.", + "non_record": true, + "experiment_type": "meta-ttt-exploration", + "track": "10min_16mb", + "hardware": "1xH100_80GB_SXM", + "wallclock_seconds": 4800, + "seed": 42, + "parent": "exp106_metasgd-crosschunk-delta_from_exp101", + "parent_arch": "11L XSA-all · BigramHash 4096x64 pos-conditional · VE7-10 · partial RoPE 16/64 · FOMAML every=4 cross-chunk + delta-loss + MetaSGD · SGD+cosine TTT · int6 GPTQ+lzma", + "changes_from_parent": { + "removed": "MetaSGD (66 learned per-layer LR scales — converged to uniform 1.0, +8.6 GB peak memory)", + "added": "SAM inner loop (rho=0.05, vanilla SAM — gradient at ascent-perturbed point)" + }, + "meta_ttt_config": { + "A_cross_chunk_split": true, + "B_delta_loss_weight": 0.3, + "C_metasgd_enabled": false, + "D_sam_enabled": true, + "D_sam_rho": 0.05, + "D_sam_adaptive": false + }, + "training": { + "steps_completed": 6597, + "steps_planned": 7500, + "stop_reason": "wallclock_cap", + "per_step_ms": 728, + "peak_memory_mib": 32397, + "warmdown_triggered_at_step": 2200 + }, + "val_bpb": 1.11898178, + "val_bpb_note": "int6+lzma canonical TTT (full 947/947 chunks); post-EMA float baseline: 1.1384", + "pre_quant_val_bpb": 1.1384, + "int6_roundtrip_val_bpb": 1.14235518, + "legal_ttt_val_bpb": 1.11898178, + "ttt_delta_bpb": -0.02337, + "submission_size_bytes": 15884642, + "submission_size_mb": 15.88, + "actual_memory_peak_mib": 32397, + "actual_steps": 6597, + "memory_prediction_error": "Predicted -6.5 GB vs exp106; actual was +0.7 GB. SAM forward activations cost ~3.3 GB (vs 2.0 GB estimated) because all 11 layers are held simultaneously for autograd.grad.", + "decision_thresholds": { + "ttt_delta_better_than_minus0.026": "SAM helps — integrate", + "ttt_delta_minus0.024_to_minus0.026": "Marginal — try rho sweep", + "ttt_delta_minus0.023_to_minus0.024": "Same ceiling — architecture-limited confirmed", + "ttt_delta_worse_than_minus0.023": "SAM hurts — discard" + }, + "verdict": "SAM hurts — discard. TTT delta -0.0234 bpb, same as all prior experiments. Absolute legal_ttt 1.1190 is worse than exp106 (1.11469 float) and exp105a (1.11624 int6). SAM adds +0.7 GB memory and -89 training steps with zero TTT improvement. The TTT ceiling is architecture-limited: bank_dim=64 × TTT_epochs=4. Meta-TTT line of investigation closed after 4 experiments (exp101, exp105a, exp106, exp107).", + "conclusion": "TTT delta (~0.023 bpb) invariant across all inner-loop formulations (vanilla SGD same-batch, vanilla SGD cross-chunk+delta-loss+MetaSGD, SAM SGD). The ceiling is set by bank architecture (rank × dim), not by optimizer. SAM specifically fails because: (1) bank geometry is already isotropic (SV uniformity 0.999) — no sharpness to avoid; (2) 4-epoch TTT overshoots any initialization bias; (3) 128× step-count gap between meta inner loop and eval TTT erases initialization signal. Weight-space analysis confirms exp107 stays in the same basin as exp106 (bank cosine 0.2025, midpoint ratio 0.839 < 1.0 = flat valley). Recommend pivoting to bank_dim expansion or TTT optimizer swap (AdaGrad/RMSProp)." +} diff --git a/records/track_non_record_16mb/2026-04-13_sam-inner-metattt_from_exp106/train_gpt.py b/records/track_non_record_16mb/2026-04-13_sam-inner-metattt_from_exp106/train_gpt.py new file mode 100644 index 0000000000..45e8d090da --- /dev/null +++ b/records/track_non_record_16mb/2026-04-13_sam-inner-metattt_from_exp106/train_gpt.py @@ -0,0 +1,2417 @@ +from __future__ import annotations +import copy +import glob +import io +import lzma +import math +import os +import random +import subprocess +import sys +import time +import uuid +import zlib +from pathlib import Path +try: + import zstandard + _COMPRESSOR = "zstd" +except ImportError: + _COMPRESSOR = "zlib" +import numpy as np +import sentencepiece as spm +import torch +import torch.distributed as dist +import torch.nn.functional as F +from torch import Tensor, nn +from torch.nn.parallel import DistributedDataParallel as DDP +from flash_attn_interface import flash_attn_func as flash_attn_3_func + +# ----------------------------- +# HYPERPARAMETERS +# ----------------------------- + +class Hyperparameters: + data_path = os.environ.get("DATA_PATH", "./data/datasets/fineweb10B_sp1024") + train_files = os.path.join(data_path, "fineweb_train_*.bin") + val_files = os.path.join(data_path, "fineweb_val_*.bin") + tokenizer_path = os.environ.get("TOKENIZER_PATH", "./data/tokenizers/fineweb_1024_bpe.model") + run_id = os.environ.get("RUN_ID", str(uuid.uuid4())) + seed = int(os.environ.get("SEED", 42)) + + val_batch_size = int(os.environ.get("VAL_BATCH_SIZE", 524_288)) + val_loss_every = int(os.environ.get("VAL_LOSS_EVERY", 500)) + train_log_every = int(os.environ.get("TRAIN_LOG_EVERY", 100)) + + iterations = int(os.environ.get("ITERATIONS", 20000)) + warmdown_iters = int(os.environ.get("WARMDOWN_ITERS", 3000)) + warmup_steps = int(os.environ.get("WARMUP_STEPS", 20)) + train_batch_tokens = int(os.environ.get("TRAIN_BATCH_TOKENS", 786_432)) + train_seq_len = int(os.environ.get("TRAIN_SEQ_LEN", 2048)) + eval_seq_len = int(os.environ.get("EVAL_SEQ_LEN", 2048)) + max_wallclock_seconds = float(os.environ.get("MAX_WALLCLOCK_SECONDS", 600.0)) + qk_gain_init = float(os.environ.get("QK_GAIN_INIT", 1.5)) + + vocab_size = int(os.environ.get("VOCAB_SIZE", 1024)) + num_layers = int(os.environ.get("NUM_LAYERS", 11)) + num_kv_heads = int(os.environ.get("NUM_KV_HEADS", 4)) + model_dim = int(os.environ.get("MODEL_DIM", 512)) + num_heads = int(os.environ.get("NUM_HEADS", 8)) + mlp_mult = float(os.environ.get("MLP_MULT", 3.0)) + tie_embeddings = bool(int(os.environ.get("TIE_EMBEDDINGS", "1"))) + rope_base = float(os.environ.get("ROPE_BASE", 10000.0)) + logit_softcap = float(os.environ.get("LOGIT_SOFTCAP", 30.0)) + + embed_lr = float(os.environ.get("EMBED_LR", 0.6)) + head_lr = float(os.environ.get("HEAD_LR", 0.008)) + tied_embed_lr = float(os.environ.get("TIED_EMBED_LR", 0.03)) + tied_embed_init_std = float(os.environ.get("TIED_EMBED_INIT_STD", 0.005)) + matrix_lr = float(os.environ.get("MATRIX_LR", 0.02)) + scalar_lr = float(os.environ.get("SCALAR_LR", 0.02)) + muon_momentum = float(os.environ.get("MUON_MOMENTUM", 0.99)) + muon_backend_steps = int(os.environ.get("MUON_BACKEND_STEPS", 5)) + muon_momentum_warmup_start = float(os.environ.get("MUON_MOMENTUM_WARMUP_START", 0.92)) + muon_momentum_warmup_steps = int(os.environ.get("MUON_MOMENTUM_WARMUP_STEPS", 1500)) + momentum_cyclic = bool(int(os.environ.get("MOMENTUM_CYCLIC", "1"))) + momentum_min = float(os.environ.get("MOMENTUM_MIN", 0.85)) + momentum_max = float(os.environ.get("MOMENTUM_MAX", 0.95)) + momentum_cycle_period = int(os.environ.get("MOMENTUM_CYCLE_PERIOD", 50)) + beta1 = float(os.environ.get("BETA1", 0.9)) + beta2 = float(os.environ.get("BETA2", 0.95)) + adam_eps = float(os.environ.get("ADAM_EPS", 1e-8)) + grad_clip_norm = float(os.environ.get("GRAD_CLIP_NORM", 0.3)) + muon_wd = float(os.environ.get("MUON_WD", 0.04)) + adam_wd = float(os.environ.get("ADAM_WD", 0.04)) + + eval_stride = int(os.environ.get("EVAL_STRIDE", 64)) + eval_batch_seqs = int(os.environ.get("EVAL_BATCH_SEQS", 128)) + + mtp_num_heads = int(os.environ.get("MTP_NUM_HEADS", 0)) + mtp_loss_weight = float(os.environ.get("MTP_LOSS_WEIGHT", 0.2)) + + bigram_vocab_size = int(os.environ.get("BIGRAM_VOCAB_SIZE", 10240)) + bigram_dim = int(os.environ.get("BIGRAM_DIM", 128)) + trigram_enabled = bool(int(os.environ.get("TRIGRAM", "0"))) + + xsa_last_n = int(os.environ.get("XSA_LAST_N", 11)) + + # EMA + ema_enabled = bool(int(os.environ.get("EMA_ENABLED", "1"))) + ema_decay = float(os.environ.get("EMA_DECAY", "0.997")) + + # SWA / LAWA (from remote) + swa_enabled = bool(int(os.environ.get("SWA_ENABLED", "1"))) + swa_every = int(os.environ.get("SWA_EVERY", 50)) + lawa_enabled = bool(int(os.environ.get("LAWA_ENABLED", "0"))) + lawa_k = int(os.environ.get("LAWA_K", 10)) + lawa_freq = int(os.environ.get("LAWA_FREQ", 100)) + + # Adaptive warmdown + adaptive_warmdown = bool(int(os.environ.get("ADAPTIVE_WARMDOWN", "1"))) + adaptive_warmdown_ema = float(os.environ.get("ADAPTIVE_WARMDOWN_EMA", "0.99")) + adaptive_warmdown_threshold = float(os.environ.get("ADAPTIVE_WARMDOWN_THRESHOLD", "0.0005")) + adaptive_warmdown_min_steps = int(os.environ.get("ADAPTIVE_WARMDOWN_MIN_STEPS", "2000")) + + # Partial RoPE + rope_dims = int(os.environ.get("ROPE_DIMS", 16)) + + # LN scale and DTG + ln_scale = bool(int(os.environ.get("LN_SCALE", "1"))) + dtg_enabled = bool(int(os.environ.get("DTG_ENABLED", "0"))) + + # QAT + qat_enabled = bool(int(os.environ.get("QAT_ENABLED", "0"))) + late_qat_threshold = float(os.environ.get("LATE_QAT_THRESHOLD", 0.15)) + + # Value Embedding + ve_enabled = bool(int(os.environ.get("VE_ENABLED", "1"))) + ve_dim = int(os.environ.get("VE_DIM", 128)) + ve_layers = os.environ.get("VE_LAYERS", "9,10") + + # Gated attention / Value residual + gated_attention = bool(int(os.environ.get("GATED_ATTENTION", "0"))) + value_residual = bool(int(os.environ.get("VALUE_RESIDUAL", "0"))) + + # GPTQ calibration + gptq_calib_batches = int(os.environ.get("GPTQ_CALIB_BATCHES", 256)) + gptq_block_size = int(os.environ.get("GPTQ_BLOCK_SIZE", 128)) + + # TTT (Test-Time Training) + ttt_enabled = bool(int(os.environ.get("TTT_ENABLED", "1"))) + ttt_lr = float(os.environ.get("TTT_LR", "0.002")) + ttt_epochs = int(os.environ.get("TTT_EPOCHS", "3")) + ttt_chunk_tokens = int(os.environ.get("TTT_CHUNK_TOKENS", "32768")) + ttt_freeze_blocks = int(os.environ.get("TTT_FREEZE_BLOCKS", "2")) + ttt_momentum = float(os.environ.get("TTT_MOMENTUM", "0.9")) + ttt_batch_seqs = int(os.environ.get("TTT_BATCH_SEQS", "16")) + ttt_grad_clip = float(os.environ.get("TTT_GRAD_CLIP", "1.0")) + + # Meta-TTT (FOMAML: train-time meta-learning for better TTT initialization) + meta_ttt_enabled = bool(int(os.environ.get("META_TTT_ENABLED", "1"))) + meta_ttt_inner_lr = float(os.environ.get("META_TTT_INNER_LR", "0.002")) + meta_ttt_every = int(os.environ.get("META_TTT_EVERY", "8")) + meta_ttt_loss_weight = float(os.environ.get("META_TTT_LOSS_WEIGHT", "0.5")) + meta_ttt_freeze_blocks = int(os.environ.get("META_TTT_FREEZE_BLOCKS", "2")) + # exp106 kept: (A) cross-chunk split, (B) delta loss + # exp107 new: (D) SAM inner loop — replaces MetaSGD (C) which converged to uniform 1.0 + meta_ttt_split = os.environ.get("META_TTT_SPLIT", "batch").lower() + meta_ttt_delta_weight = float(os.environ.get("META_TTT_DELTA_WEIGHT", "0.3")) + # (D) SAM: Sharpness-Aware Minimization in the inner loop. + # Replaces vanilla SGD with SAM — gradient at the ascent-perturbed point. + # This finds adapted banks in flatter regions of the loss landscape. + meta_ttt_sam_enabled = bool(int(os.environ.get("META_TTT_SAM_ENABLED", "0"))) + meta_ttt_sam_rho = float(os.environ.get("META_TTT_SAM_RHO", "0.05")) + meta_ttt_sam_adaptive = bool(int(os.environ.get("META_TTT_SAM_ADAPTIVE", "0"))) + +# --- Batched Newton-Schulz orthogonalization --- + +def zeropower_via_newtonschulz5(G: Tensor, steps: int = 5, eps: float = 1e-7) -> Tensor: + """Batched Newton-Schulz orthogonalization. G: (B,M,N) or (M,N).""" + a, b, c = (3.4445, -4.7750, 2.0315) + was_2d = G.ndim == 2 + if was_2d: + G = G.unsqueeze(0) + X = G.bfloat16() + transposed = X.size(-2) > X.size(-1) + if transposed: + X = X.mT + X = X / (X.norm(dim=(-2, -1), keepdim=True) + eps) + for _ in range(steps): + A = X @ X.mT + B = b * A + c * (A @ A) + X = a * X + B @ X + if transposed: + X = X.mT + if was_2d: + X = X.squeeze(0) + return X + +# --- Parallel Muon optimizer --- + +class Muon(torch.optim.Optimizer): + """Parallel Muon: post-backward reduce-scatter -> local NS5 -> all-gather. + + No DDP for bank params. After backward, this optimizer: + 1. Launches async reduce-scatter for all banks (biggest first) + 2. Returns control so Adam can step on small params while RS is in-flight + 3. Waits for each RS, runs local NS5 on the shard, launches async all-gather + 4. Each all-gather overlaps with next bank's NS5 + """ + def __init__(self, params, lr: float, momentum: float, backend_steps: int, + nesterov: bool = True, weight_decay: float = 0.0): + super().__init__( + params, + dict(lr=lr, momentum=momentum, backend_steps=backend_steps, + nesterov=nesterov, weight_decay=weight_decay), + ) + self._built = False + + def _build(self): + self._distributed = dist.is_available() and dist.is_initialized() + self._world_size = dist.get_world_size() if self._distributed else 1 + self._rank = dist.get_rank() if self._distributed else 0 + ws = self._world_size + + self._bank_meta = [] + for group in self.param_groups: + for p in group["params"]: + B = p.shape[0] + padded_B = ((B + ws - 1) // ws) * ws + shard_B = padded_B // ws + tail = p.shape[1:] + dev = p.device + self._bank_meta.append({ + 'p': p, + 'B': B, + 'padded_grad': torch.zeros(padded_B, *tail, device=dev, dtype=torch.bfloat16), + 'shard': torch.zeros(shard_B, *tail, device=dev, dtype=torch.bfloat16), + 'shard_mom': torch.zeros(shard_B, *tail, device=dev, dtype=torch.bfloat16), + 'full_update': torch.zeros(padded_B, *tail, device=dev, dtype=torch.bfloat16), + 'scale': max(1, p.shape[-2] / p.shape[-1]) ** 0.5, + }) + # Sort by size descending -- launch biggest reduce-scatters first + self._bank_meta.sort(key=lambda m: -m['p'].numel()) + self._built = True + + def launch_reduce_scatters(self): + """Phase 1: launch async reduce-scatter for all banks. Call right after backward.""" + if not self._built: + self._build() + if not self._distributed: + return + self._rs_futures = [] + for m in self._bank_meta: + p = m['p'] + if p.grad is None: + self._rs_futures.append(None) + continue + pg = m['padded_grad'] + pg[:m['B']].copy_(p.grad.bfloat16()) + if pg.shape[0] > m['B']: + pg[m['B']:].zero_() + fut = dist.reduce_scatter_tensor(m['shard'], pg, op=dist.ReduceOp.AVG, async_op=True) + self._rs_futures.append(fut) + + @torch.no_grad() + def step(self, closure=None): + """Phase 3: wait for RS, local NS5, all-gather. Call AFTER Adam steps.""" + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + + if not self._built: + self._build() + + for group in self.param_groups: + lr = group["lr"] + momentum = group["momentum"] + backend_steps = group["backend_steps"] + nesterov = group["nesterov"] + wd = group.get("weight_decay", 0.0) + + prev_ag_handle = None + prev_m = None + + sharded = self._distributed and hasattr(self, '_rs_futures') + + for i, m in enumerate(self._bank_meta): + p = m['p'] + if p.grad is None: + continue + + if prev_ag_handle is not None: + prev_ag_handle.wait() + pp = prev_m['p'] + upd = prev_m['full_update'][:prev_m['B']] + if wd > 0.0: + pp.data.mul_(1.0 - lr * wd) + pp.add_(upd.to(dtype=pp.dtype), alpha=-lr * prev_m['scale']) + + if sharded and self._rs_futures[i] is not None: + self._rs_futures[i].wait() + g = m['shard'] + buf = m['shard_mom'] + else: + g = p.grad.bfloat16() + state = self.state[p] + if "momentum_buffer" not in state: + state["momentum_buffer"] = torch.zeros_like(g) + buf = state["momentum_buffer"] + + buf.mul_(momentum).add_(g) + if nesterov: + update = g.add(buf, alpha=momentum) + else: + update = buf + + update = zeropower_via_newtonschulz5(update, steps=backend_steps) + + if sharded: + prev_ag_handle = dist.all_gather_into_tensor( + m['full_update'], update, async_op=True) + prev_m = m + else: + if wd > 0.0: + p.data.mul_(1.0 - lr * wd) + p.add_(update.to(dtype=p.dtype), alpha=-lr * m['scale']) + + if prev_ag_handle is not None: + prev_ag_handle.wait() + pp = prev_m['p'] + upd = prev_m['full_update'][:prev_m['B']] + if wd > 0.0: + pp.data.mul_(1.0 - lr * wd) + pp.add_(upd.to(dtype=pp.dtype), alpha=-lr * prev_m['scale']) + + if hasattr(self, '_rs_futures'): + del self._rs_futures + + return loss + +# --- Tokenizer evaluation helpers --- + +def build_sentencepiece_luts( + sp: spm.SentencePieceProcessor, vocab_size: int, device: torch.device +) -> tuple[Tensor, Tensor, Tensor]: + sp_vocab_size = int(sp.vocab_size()) + table_size = max(sp_vocab_size, vocab_size) + base_bytes_np = np.zeros((table_size,), dtype=np.int16) + has_leading_space_np = np.zeros((table_size,), dtype=np.bool_) + is_boundary_token_np = np.ones((table_size,), dtype=np.bool_) + for token_id in range(sp_vocab_size): + if sp.is_control(token_id) or sp.is_unknown(token_id) or sp.is_unused(token_id): + continue + is_boundary_token_np[token_id] = False + if sp.is_byte(token_id): + base_bytes_np[token_id] = 1 + continue + piece = sp.id_to_piece(token_id) + if piece.startswith("\u2581"): + has_leading_space_np[token_id] = True + piece = piece[1:] + base_bytes_np[token_id] = len(piece.encode("utf-8")) + return ( + torch.tensor(base_bytes_np, dtype=torch.int16, device=device), + torch.tensor(has_leading_space_np, dtype=torch.bool, device=device), + torch.tensor(is_boundary_token_np, dtype=torch.bool, device=device), + ) + +def load_validation_tokens(pattern: str, seq_len: int) -> Tensor: + files = [Path(p) for p in sorted(glob.glob(pattern))] + if not files: + raise FileNotFoundError(f"No files found for pattern: {pattern}") + tokens = torch.cat([load_data_shard(file) for file in files]).contiguous() + usable = ((tokens.numel() - 1) // seq_len) * seq_len + if usable <= 0: + raise ValueError(f"Validation split is too short for seq_len={seq_len}") + return tokens[: usable + 1] + +def eval_val( + args: Hyperparameters, + model: nn.Module, + rank: int, + world_size: int, + device: torch.device, + grad_accum_steps: int, + val_tokens: Tensor, + base_bytes_lut: Tensor, + has_leading_space_lut: Tensor, + is_boundary_token_lut: Tensor, + eval_seq_len: int | None = None, +) -> tuple[float, float]: + seq_len = eval_seq_len or args.train_seq_len + local_batch_tokens = args.val_batch_size // (world_size * grad_accum_steps) + if local_batch_tokens < seq_len: + raise ValueError( + "VAL_BATCH_SIZE must provide at least one sequence per rank; " + f"got VAL_BATCH_SIZE={args.val_batch_size}, WORLD_SIZE={world_size}, " + f"GRAD_ACCUM_STEPS={grad_accum_steps}, seq_len={seq_len}" + ) + local_batch_seqs = local_batch_tokens // seq_len + total_seqs = (val_tokens.numel() - 1) // seq_len + seq_start = (total_seqs * rank) // world_size + seq_end = (total_seqs * (rank + 1)) // world_size + val_loss_sum = torch.zeros((), device=device, dtype=torch.float64) + val_token_count = torch.zeros((), device=device, dtype=torch.float64) + val_byte_count = torch.zeros((), device=device, dtype=torch.float64) + model.eval() + with torch.inference_mode(): + for batch_seq_start in range(seq_start, seq_end, local_batch_seqs): + batch_seq_end = min(batch_seq_start + local_batch_seqs, seq_end) + raw_start = batch_seq_start * seq_len + raw_end = batch_seq_end * seq_len + 1 + local = val_tokens[raw_start:raw_end].to(device=device, dtype=torch.int64, non_blocking=True) + x = local[:-1].reshape(-1, seq_len) + y = local[1:].reshape(-1, seq_len) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + batch_loss = model(x, y).detach() + batch_token_count = float(y.numel()) + val_loss_sum += batch_loss.to(torch.float64) * batch_token_count + val_token_count += batch_token_count + prev_ids = x.reshape(-1) + tgt_ids = y.reshape(-1) + token_bytes = base_bytes_lut[tgt_ids].to(dtype=torch.int16) + token_bytes += (has_leading_space_lut[tgt_ids] & ~is_boundary_token_lut[prev_ids]).to(dtype=torch.int16) + val_byte_count += token_bytes.to(torch.float64).sum() + if dist.is_available() and dist.is_initialized(): + dist.all_reduce(val_loss_sum, op=dist.ReduceOp.SUM) + dist.all_reduce(val_token_count, op=dist.ReduceOp.SUM) + dist.all_reduce(val_byte_count, op=dist.ReduceOp.SUM) + val_loss = val_loss_sum / val_token_count + bits_per_token = val_loss.item() / math.log(2.0) + tokens_per_byte = val_token_count.item() / val_byte_count.item() + model.train() + return float(val_loss.item()), float(bits_per_token * tokens_per_byte) + +# --- Quantization helpers --- + +CONTROL_TENSOR_NAME_PATTERNS = tuple( + pattern + for pattern in os.environ.get( + "CONTROL_TENSOR_NAME_PATTERNS", + "attn_scale,attn_scales,mlp_scale,mlp_scales,resid_mix,resid_mixes,q_gain,skip_weight,skip_weights,smear,dtg_gate,ve_layer_scales,ve_shared.scale,bigram.scale,word_start_boost,attn_gate,vr_lambda", + ).split(",") + if pattern +) +INT8_KEEP_FLOAT_FP32_NAME_PATTERNS = tuple( + pattern + for pattern in os.environ.get( + "INT8_KEEP_FLOAT_FP32_NAME_PATTERNS", + ",".join(CONTROL_TENSOR_NAME_PATTERNS), + ).split(",") + if pattern +) +INT8_KEEP_FLOAT_MAX_NUMEL = 65_536 +INT8_KEEP_FLOAT_STORE_DTYPE = torch.float16 +INT8_PER_ROW_SCALE_DTYPE = torch.float16 +INT8_CLIP_PERCENTILE = 99.99984 +INT8_CLIP_Q = INT8_CLIP_PERCENTILE / 100.0 + +def tensor_nbytes(t: Tensor) -> int: + return int(t.numel()) * int(t.element_size()) + +def keep_float_tensor(name: str, t: Tensor, passthrough_orig_dtypes: dict[str, str]) -> Tensor: + if any(pattern in name for pattern in INT8_KEEP_FLOAT_FP32_NAME_PATTERNS): + return t.float().contiguous() + if t.dtype in {torch.float32, torch.bfloat16}: + passthrough_orig_dtypes[name] = str(t.dtype).removeprefix("torch.") + return t.to(dtype=INT8_KEEP_FLOAT_STORE_DTYPE).contiguous() + return t + +def quantize_float_tensor(t: Tensor) -> tuple[Tensor, Tensor]: + t32 = t.float() + if t32.ndim == 2: + clip_abs = ( + torch.quantile(t32.abs(), INT8_CLIP_Q, dim=1) + if t32.numel() + else torch.empty((t32.shape[0],), dtype=torch.float32) + ) + clipped = torch.maximum(torch.minimum(t32, clip_abs[:, None]), -clip_abs[:, None]) + scale = (clip_abs / 127.0).clamp_min(1.0 / 127.0) + q = torch.clamp(torch.round(clipped / scale[:, None]), -127, 127).to(torch.int8).contiguous() + return q, scale.to(dtype=INT8_PER_ROW_SCALE_DTYPE).contiguous() + clip_abs = float(torch.quantile(t32.abs().flatten(), INT8_CLIP_Q).item()) if t32.numel() else 0.0 + scale = torch.tensor(clip_abs / 127.0 if clip_abs > 0 else 1.0, dtype=torch.float32) + q = torch.clamp(torch.round(torch.clamp(t32, -clip_abs, clip_abs) / scale), -127, 127).to(torch.int8).contiguous() + return q, scale + +def quantize_state_dict_int8(state_dict: dict[str, Tensor]): + quantized: dict[str, Tensor] = {} + scales: dict[str, Tensor] = {} + dtypes: dict[str, str] = {} + passthrough: dict[str, Tensor] = {} + passthrough_orig_dtypes: dict[str, str] = {} + qmeta: dict[str, dict[str, object]] = {} + stats = dict.fromkeys( + ("param_count", "num_tensors", "num_float_tensors", "num_nonfloat_tensors", "baseline_tensor_bytes", "int8_payload_bytes"), + 0, + ) + for name, tensor in state_dict.items(): + t = tensor.detach().to("cpu").contiguous() + stats["param_count"] += int(t.numel()) + stats["num_tensors"] += 1 + stats["baseline_tensor_bytes"] += tensor_nbytes(t) + if not t.is_floating_point(): + stats["num_nonfloat_tensors"] += 1 + passthrough[name] = t + stats["int8_payload_bytes"] += tensor_nbytes(t) + continue + if t.numel() <= INT8_KEEP_FLOAT_MAX_NUMEL: + kept = keep_float_tensor(name, t, passthrough_orig_dtypes) + passthrough[name] = kept + stats["int8_payload_bytes"] += tensor_nbytes(kept) + continue + stats["num_float_tensors"] += 1 + q, s = quantize_float_tensor(t) + if s.ndim > 0: + qmeta[name] = {"scheme": "per_row", "axis": 0} + quantized[name] = q + scales[name] = s + dtypes[name] = str(t.dtype).removeprefix("torch.") + stats["int8_payload_bytes"] += tensor_nbytes(q) + tensor_nbytes(s) + obj: dict[str, object] = { + "__quant_format__": "int8_clean_per_row_v1", + "quantized": quantized, + "scales": scales, + "dtypes": dtypes, + "passthrough": passthrough, + } + if qmeta: + obj["qmeta"] = qmeta + if passthrough_orig_dtypes: + obj["passthrough_orig_dtypes"] = passthrough_orig_dtypes + return obj, stats + +def dequantize_state_dict_int8(obj: dict[str, object]) -> dict[str, Tensor]: + out: dict[str, Tensor] = {} + qmeta = obj.get("qmeta", {}) + passthrough_orig_dtypes = obj.get("passthrough_orig_dtypes", {}) + for name, q in obj["quantized"].items(): + dtype = getattr(torch, obj["dtypes"][name]) + s = obj["scales"][name] + if qmeta.get(name, {}).get("scheme") == "per_row" or s.ndim > 0: + s = s.to(dtype=torch.float32) + out[name] = (q.float() * s.view(q.shape[0], *([1] * (q.ndim - 1)))).to(dtype=dtype).contiguous() + else: + scale = float(s.item()) + out[name] = (q.float() * scale).to(dtype=dtype).contiguous() + for name, t in obj["passthrough"].items(): + out_t = t.detach().to("cpu").contiguous() + orig_dtype = passthrough_orig_dtypes.get(name) + if isinstance(orig_dtype, str): + out_t = out_t.to(dtype=getattr(torch, orig_dtype)).contiguous() + out[name] = out_t + return out + +# --- Data loading --- + +def load_data_shard(file: Path) -> Tensor: + header_bytes = 256 * np.dtype(" None: + self.file_idx = (self.file_idx + 1) % len(self.files) + self.tokens = load_data_shard(self.files[self.file_idx]) + self.pos = 0 + def take(self, n: int) -> Tensor: + chunks: list[Tensor] = [] + remaining = n + while remaining > 0: + avail = self.tokens.numel() - self.pos + if avail <= 0: + self._advance_file() + continue + k = min(remaining, avail) + chunks.append(self.tokens[self.pos : self.pos + k]) + self.pos += k + remaining -= k + return chunks[0] if len(chunks) == 1 else torch.cat(chunks) + +class DistributedTokenLoader: + def __init__(self, pattern: str, rank: int, world_size: int, device: torch.device): + self.rank = rank + self.world_size = world_size + self.device = device + self.stream = TokenStream(pattern) + def next_batch(self, global_tokens: int, seq_len: int, grad_accum_steps: int) -> tuple[Tensor, Tensor]: + local_tokens = global_tokens // (self.world_size * grad_accum_steps) + per_rank_span = local_tokens + 1 + chunk = self.stream.take(per_rank_span * self.world_size) + start = self.rank * per_rank_span + local = chunk[start : start + per_rank_span].to(dtype=torch.int64) + x = local[:-1].reshape(-1, seq_len) + y = local[1:].reshape(-1, seq_len) + return x.to(self.device, non_blocking=True), y.to(self.device, non_blocking=True) + +# --- Transformer modules --- + +class RMSNorm(nn.Module): + def __init__(self, eps: float | None = None): + super().__init__() + self.eps = eps + def forward(self, x: Tensor) -> Tensor: + return F.rms_norm(x, (x.size(-1),), eps=self.eps) + +class CastedLinear(nn.Linear): + _qat_enabled: bool = False + def forward(self, x: Tensor) -> Tensor: + w = self.weight.to(x.dtype) + if CastedLinear._qat_enabled and self.training and w.ndim == 2: + with torch.no_grad(): + w32 = self.weight.float() + row_max = w32.abs().amax(dim=1) + scale = (row_max / 31.0).clamp_min(1.0 / 31.0) + w_q = (torch.clamp(torch.round(w32 / scale[:, None]), -32, 31) * scale[:, None]).to(x.dtype) + w = w + (w_q - w).detach() + bias = self.bias.to(x.dtype) if self.bias is not None else None + return F.linear(x, w, bias) + +def restore_low_dim_params_to_fp32(module: nn.Module) -> None: + with torch.no_grad(): + for name, param in module.named_parameters(): + if (param.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS)) and param.dtype != torch.float32: + param.data = param.data.float() + +class Rotary(nn.Module): + def __init__(self, dim: int, base: float = 10000.0, train_seq_len: int = 1024, rope_dims: int = 0): + super().__init__() + self.dim = dim + self.base = base + self.train_seq_len = train_seq_len + self.rope_dims = rope_dims if rope_dims > 0 else dim + inv_freq = 1.0 / (base ** (torch.arange(0, self.rope_dims, 2, dtype=torch.float32) / self.rope_dims)) + self.register_buffer("inv_freq", inv_freq, persistent=False) + self._seq_len_cached = 0 + self._cos_cached: Tensor | None = None + self._sin_cached: Tensor | None = None + def forward(self, seq_len: int, device: torch.device, dtype: torch.dtype) -> tuple[Tensor, Tensor]: + if ( + self._cos_cached is None + or self._sin_cached is None + or self._seq_len_cached != seq_len + or self._cos_cached.device != device + ): + rd = self.rope_dims + if seq_len > self.train_seq_len: + scale = seq_len / self.train_seq_len + new_base = self.base * (scale ** (rd / (rd - 2))) + inv_freq = 1.0 / (new_base ** (torch.arange(0, rd, 2, dtype=torch.float32, device=device) / rd)) + else: + inv_freq = self.inv_freq.to(device) + t = torch.arange(seq_len, device=device, dtype=inv_freq.dtype) + freqs = torch.outer(t, inv_freq) + self._cos_cached = freqs.cos()[None, :, None, :] + self._sin_cached = freqs.sin()[None, :, None, :] + self._seq_len_cached = seq_len + return self._cos_cached.to(dtype=dtype), self._sin_cached.to(dtype=dtype) + +def apply_rotary_emb(x: Tensor, cos: Tensor, sin: Tensor, rope_dims: int = 0) -> Tensor: + if rope_dims > 0 and rope_dims < x.size(-1): + x_rope, x_pass = x[..., :rope_dims], x[..., rope_dims:] + half = rope_dims // 2 + x1, x2 = x_rope[..., :half], x_rope[..., half:] + x_rope = torch.cat((x1 * cos + x2 * sin, x1 * (-sin) + x2 * cos), dim=-1) + return torch.cat((x_rope, x_pass), dim=-1) + half = x.size(-1) // 2 + x1, x2 = x[..., :half], x[..., half:] + return torch.cat((x1 * cos + x2 * sin, x1 * (-sin) + x2 * cos), dim=-1) + +class CausalSelfAttention(nn.Module): + def __init__( + self, + dim: int, + num_heads: int, + num_kv_heads: int, + rope_base: float, + qk_gain_init: float, + gated_attention: bool = False, + value_residual: bool = False, + ): + super().__init__() + if dim % num_heads != 0: + raise ValueError("model_dim must be divisible by num_heads") + if num_heads % num_kv_heads != 0: + raise ValueError("num_heads must be divisible by num_kv_heads") + self.num_heads = num_heads + self.num_kv_heads = num_kv_heads + self.head_dim = dim // num_heads + if self.head_dim % 2 != 0: + raise ValueError("head_dim must be even for RoPE") + # No CastedLinear -- weights come from banks + self.q_gain = nn.Parameter(torch.full((num_heads,), qk_gain_init, dtype=torch.float32)) + self.rope_dims = 0 # set by GPT.__init__ for partial RoPE + self.rotary = Rotary(self.head_dim, base=rope_base, train_seq_len=1024) + self.use_xsa = False # set by GPT.__init__ for deep layers only + # Gated attention and value residual (non-banked small params) + self.gated_attention = gated_attention + if gated_attention: + self.attn_gate = nn.Linear(dim, num_heads, bias=True) + nn.init.zeros_(self.attn_gate.weight) + nn.init.constant_(self.attn_gate.bias, 4.0) + self.value_residual = value_residual + if value_residual: + self.vrl_alpha = nn.Parameter(torch.zeros(1, dtype=torch.float32)) + + def _xsa_efficient(self, y: Tensor, v: Tensor) -> Tensor: + """Efficient XSA: subtract self-value projection via GQA-aware reshape.""" + B, T, H, D = y.shape + Hkv = v.size(-2) + group = H // Hkv + y_g = y.reshape(B, T, Hkv, group, D) + vn = F.normalize(v, dim=-1).unsqueeze(-2) + proj = (y_g * vn).sum(dim=-1, keepdim=True) * vn + return (y_g - proj).reshape(B, T, H, D) + + def forward(self, x: Tensor, q_w: Tensor, k_w: Tensor, v_w: Tensor, out_w: Tensor, + v_embed: Tensor | None = None, v0: Tensor | None = None) -> tuple[Tensor, Tensor | None]: + bsz, seqlen, dim = x.shape + q = F.linear(x, q_w.to(x.dtype)).reshape(bsz, seqlen, self.num_heads, self.head_dim) + k = F.linear(x, k_w.to(x.dtype)).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim) + v = F.linear(x, v_w.to(x.dtype)) + if v_embed is not None: + v = v + v_embed + v = v.reshape(bsz, seqlen, self.num_kv_heads, self.head_dim) + raw_v = v if self.value_residual else None + if self.value_residual and v0 is not None: + alpha = torch.sigmoid(self.vrl_alpha.to(dtype=v.dtype)) + v = v + alpha * v0 + q = F.rms_norm(q, (q.size(-1),)) + k = F.rms_norm(k, (k.size(-1),)) + cos, sin = self.rotary(seqlen, x.device, q.dtype) + q = apply_rotary_emb(q, cos, sin, self.rope_dims) + k = apply_rotary_emb(k, cos, sin, self.rope_dims) + q = q * self.q_gain.to(dtype=q.dtype)[None, None, :, None] + y = flash_attn_3_func(q, k, v, causal=True) + if self.use_xsa: + y = self._xsa_efficient(y, v) + if self.gated_attention: + gate = torch.sigmoid(self.attn_gate(x)).unsqueeze(-1) + y = y * gate + y = y.reshape(bsz, seqlen, dim) + return F.linear(y, out_w.to(x.dtype)), raw_v + +class SmearGate(nn.Module): + def __init__(self, dim: int): + super().__init__() + self.gate = nn.Parameter(torch.zeros(dim, dtype=torch.float32)) + def forward(self, x: Tensor) -> Tensor: + g = torch.sigmoid(self.gate.to(dtype=x.dtype))[None, None, :] + x_prev = torch.cat([torch.zeros_like(x[:, :1]), x[:, :-1]], dim=1) + return (1 - g) * x + g * x_prev + +class BigramHashEmbedding(nn.Module): + """Hash-based n-gram embedding table. + + exp101 additions: + - pos_conditional: split buckets into two disjoint halves keyed on whether + the CURRENT token is a word-start. ws-current pairs hash into the lower + half [0, half), non-ws-current pairs into the upper half [half, 2*half). + Motivation: in exp95 the single shared table was dominated by within-word + (prev, curr) pairs, and the word_start_boost scalar collapsed to ~0.007 + (killing bigram at word-starts to suppress the noise). Splitting the + buckets gives word-start pairs their own exclusive rows that can learn + their own signal without contaminating within-word buckets, and vice + versa. Zero extra parameters — same 4096×dim table, different layout. + - trigram: optional (t-2, t-1, t) lookup summed into the same table. When + combined with pos_conditional, the trigram hash respects the same split + (keyed on whether t is a word-start) so a bucket is only trained by + lookups of a consistent word-start class. + """ + def __init__(self, bigram_vocab_size: int, bigram_dim: int, model_dim: int, + trigram: bool = False, pos_conditional: bool = False): + super().__init__() + self.bigram_vocab_size = bigram_vocab_size + self._trigram = trigram + self._pos_conditional = pos_conditional + self.embed = nn.Embedding(bigram_vocab_size, bigram_dim) + nn.init.zeros_(self.embed.weight) + self.proj = CastedLinear(bigram_dim, model_dim, bias=False) if bigram_dim != model_dim else None + if self.proj is not None: + nn.init.zeros_(self.proj.weight) + self.scale = nn.Parameter(torch.tensor(0.05, dtype=torch.float32)) + + def bigram_hash(self, tokens: Tensor, has_leading_space: Tensor | None = None) -> Tensor: + t = tokens.to(torch.int32) + mod = self.bigram_vocab_size - 1 # sentinel row at index mod + out = torch.empty_like(t) + out[..., 0] = mod + if self._pos_conditional and has_leading_space is not None: + half = mod // 2 # 2047 for bigram_vocab_size=4096 + base = torch.bitwise_xor(36313 * t[..., 1:], 27191 * t[..., :-1]) % half + is_ws_curr = has_leading_space[tokens[..., 1:].long()].to(torch.int32) + shift = (1 - is_ws_curr) * half # 0 for ws-current, half for non-ws-current + out[..., 1:] = base + shift + else: + out[..., 1:] = torch.bitwise_xor(36313 * t[..., 1:], 27191 * t[..., :-1]) % mod + return out.long() + + def trigram_hash(self, tokens: Tensor, has_leading_space: Tensor | None = None) -> Tensor: + """Hash (t-2, t-1, t) trigrams into same embedding table. Zero extra params.""" + t = tokens.to(torch.int32) + mod = self.bigram_vocab_size - 1 + out = torch.empty_like(t) + out[..., :2] = mod + if self._pos_conditional and has_leading_space is not None: + half = mod // 2 + base = (36313 * t[..., 2:] ^ 27191 * t[..., 1:-1] ^ 51497 * t[..., :-2]) % half + is_ws_curr = has_leading_space[tokens[..., 2:].long()].to(torch.int32) + shift = (1 - is_ws_curr) * half + out[..., 2:] = base + shift + else: + out[..., 2:] = (36313 * t[..., 2:] ^ 27191 * t[..., 1:-1] ^ 51497 * t[..., :-2]) % mod + return out.long() + + def forward(self, token_ids: Tensor, has_leading_space: Tensor | None = None) -> Tensor: + h = self.embed(self.bigram_hash(token_ids, has_leading_space)) + if self._trigram: + h = h + self.embed(self.trigram_hash(token_ids, has_leading_space)) + if self.proj is not None: + h = self.proj(h) + return h * self.scale.to(dtype=h.dtype) + +class ValueEmbedding(nn.Module): + """Reinject token identity into attention values at specific layers.""" + def __init__(self, vocab_size: int, ve_dim: int, model_dim: int): + super().__init__() + self.embed = nn.Embedding(vocab_size, ve_dim) + nn.init.normal_(self.embed.weight, std=0.01) + self.proj = CastedLinear(ve_dim, model_dim, bias=False) if ve_dim != model_dim else None + if self.proj is not None: + nn.init.zeros_(self.proj.weight) + self.scale = nn.Parameter(torch.tensor(0.1, dtype=torch.float32)) + def forward(self, token_ids: Tensor) -> Tensor: + h = self.embed(token_ids) + if self.proj is not None: + h = self.proj(h) + return h * self.scale.to(dtype=h.dtype) + +class MLP(nn.Module): + def __init__(self, dim: int, mlp_mult: int): + super().__init__() + # No CastedLinear -- weights come from banks + def forward(self, x: Tensor, up_w: Tensor, down_w: Tensor) -> Tensor: + x = F.leaky_relu(F.linear(x, up_w.to(x.dtype)), negative_slope=0.5) + return F.linear(x.square(), down_w.to(x.dtype)) + +class Block(nn.Module): + def __init__( + self, + dim: int, + num_heads: int, + num_kv_heads: int, + mlp_mult: int, + rope_base: float, + qk_gain_init: float, + layer_idx: int = 0, + ln_scale: bool = False, + dtg: bool = False, + gated_attention: bool = False, + value_residual: bool = False, + ): + super().__init__() + self.attn_norm = RMSNorm() + self.mlp_norm = RMSNorm() + self.attn = CausalSelfAttention(dim, num_heads, num_kv_heads, rope_base, qk_gain_init, + gated_attention=gated_attention, value_residual=value_residual) + self.mlp = MLP(dim, mlp_mult) + # Block 0: init attn_scale to 0.1 (analysis shows it converges to ~0.094 anyway) + attn_init = 0.1 if layer_idx == 0 else 1.0 + self.attn_scale = nn.Parameter(torch.full((dim,), attn_init, dtype=torch.float32)) + self.mlp_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + self.resid_mix = nn.Parameter(torch.stack((torch.ones(dim), torch.zeros(dim))).float()) + self.ln_scale_factor = 1.0 / math.sqrt(layer_idx + 1) if ln_scale else 1.0 + if dtg: + self.dtg_gate = nn.Linear(dim, 1, bias=True) + nn.init.zeros_(self.dtg_gate.weight) + nn.init.constant_(self.dtg_gate.bias, 2.0) + else: + self.dtg_gate = None + + def forward(self, x: Tensor, x0: Tensor, q_w: Tensor, k_w: Tensor, v_w: Tensor, out_w: Tensor, + up_w: Tensor, down_w: Tensor, v_embed: Tensor | None = None, v0: Tensor | None = None) -> tuple[Tensor, Tensor | None]: + mix = self.resid_mix.to(dtype=x.dtype) + x_in = mix[0][None, None, :] * x + mix[1][None, None, :] * x0 + attn_out, raw_v = self.attn(self.attn_norm(x_in) * self.ln_scale_factor, q_w, k_w, v_w, out_w, v_embed=v_embed, v0=v0) + x_out = x_in + self.attn_scale.to(dtype=x_in.dtype)[None, None, :] * attn_out + x_out = x_out + self.mlp_scale.to(dtype=x_out.dtype)[None, None, :] * self.mlp(self.mlp_norm(x_out) * self.ln_scale_factor, up_w, down_w) + if self.dtg_gate is not None: + gate = torch.sigmoid(self.dtg_gate(x_in.detach())) + x_out = x_in + gate * (x_out - x_in) + return x_out, raw_v + +class GPT(nn.Module): + def __init__( + self, + vocab_size: int, + num_layers: int, + model_dim: int, + num_heads: int, + num_kv_heads: int, + mlp_mult: int, + tie_embeddings: bool, + tied_embed_init_std: float, + logit_softcap: float, + rope_base: float, + qk_gain_init: float, + mtp_num_heads: int = 0, + mtp_loss_weight: float = 0.1, + bigram_vocab_size: int = 0, + bigram_dim: int = 128, + xsa_last_n: int = 0, + rope_dims: int = 0, + ln_scale: bool = False, + dtg: bool = False, + ve_enabled: bool = False, + ve_dim: int = 128, + ve_layers: str = "9,10", + gated_attention: bool = False, + value_residual: bool = False, + ): + super().__init__() + self._ve_target_dim = num_kv_heads * (model_dim // num_heads) + if logit_softcap <= 0.0: + raise ValueError(f"logit_softcap must be positive, got {logit_softcap}") + self.tie_embeddings = tie_embeddings + self.tied_embed_init_std = tied_embed_init_std + self.logit_softcap = logit_softcap + self.value_residual = value_residual + self.mtp_num_heads = mtp_num_heads + self.mtp_loss_weight = mtp_loss_weight + self.tok_emb = nn.Embedding(vocab_size, model_dim) + self.bigram = BigramHashEmbedding( + bigram_vocab_size, bigram_dim, model_dim, + trigram=bool(int(os.environ.get("TRIGRAM", "0"))), + pos_conditional=bool(int(os.environ.get("POS_CONDITIONAL_BIGRAM", "0"))), + ) if bigram_vocab_size > 0 else None + self.word_start_boost = nn.Parameter(torch.tensor(1.0, dtype=torch.float32)) if bigram_vocab_size > 0 else None + self._has_leading_space: Tensor | None = None + self.smear = SmearGate(model_dim) + self.num_encoder_layers = num_layers // 2 + self.num_decoder_layers = num_layers - self.num_encoder_layers + self.num_skip_weights = min(self.num_encoder_layers, self.num_decoder_layers) + self.skip_weights = nn.Parameter(torch.ones(self.num_skip_weights, model_dim, dtype=torch.float32)) + head_dim = model_dim // num_heads + kv_dim = num_kv_heads * head_dim + mlp_dim = int(mlp_mult * model_dim) + self.num_layers = num_layers + self.qo_bank = nn.Parameter(torch.empty(2 * num_layers, model_dim, model_dim)) + self.kv_bank = nn.Parameter(torch.empty(2 * num_layers, kv_dim, model_dim)) + self.mlp_up_bank = nn.Parameter(torch.empty(num_layers, mlp_dim, model_dim)) + self.mlp_down_bank = nn.Parameter(torch.empty(num_layers, model_dim, mlp_dim)) + # exp107: MetaSGD removed (converged to uniform 1.0 in exp106, +8.6 GB overhead). + # SAM inner loop in meta_ttt_step replaces it — no new nn.Parameters needed. + self.blocks = nn.ModuleList( + [ + Block( + model_dim, num_heads, num_kv_heads, mlp_mult, rope_base, qk_gain_init, + layer_idx=i, ln_scale=ln_scale, dtg=dtg, + gated_attention=gated_attention, value_residual=value_residual, + ) + for i in range(num_layers) + ] + ) + if rope_dims > 0: + head_dim = model_dim // num_heads + for block in self.blocks: + block.attn.rope_dims = rope_dims + block.attn.rotary = Rotary(head_dim, base=rope_base, train_seq_len=1024, rope_dims=rope_dims) + self.ve_layer_indices = [int(x) for x in ve_layers.split(",") if x.strip()] if ve_enabled else [] + kv_dim_ve = self._ve_target_dim + if self.ve_layer_indices: + self.ve_shared = ValueEmbedding(vocab_size, ve_dim, kv_dim_ve) + self.ve_layer_scales = nn.ParameterList( + [nn.Parameter(torch.ones(1, dtype=torch.float32)) for _ in self.ve_layer_indices] + ) + else: + self.ve_shared = None + self.ve_layer_scales = nn.ParameterList() + self.value_embeds = nn.ModuleList() + self.final_norm = RMSNorm() + self.lm_head = None if tie_embeddings else CastedLinear(model_dim, vocab_size, bias=False) + if self.lm_head is not None: + self.lm_head._zero_init = True + self.mtp_heads = nn.ModuleList( + [CastedLinear(model_dim, vocab_size, bias=False) for _ in range(mtp_num_heads)] + ) + for head in self.mtp_heads: + head._zero_init = True + if xsa_last_n > 0: + for i in range(max(0, num_layers - xsa_last_n), num_layers): + self.blocks[i].attn.use_xsa = True + self._init_weights() + + def _init_weights(self) -> None: + if self.tie_embeddings: + nn.init.normal_(self.tok_emb.weight, mean=0.0, std=self.tied_embed_init_std) + n = self.num_layers + proj_scale = 1.0 / math.sqrt(2 * n) + for i in range(n): + nn.init.orthogonal_(self.qo_bank.data[i], gain=1.0) + nn.init.zeros_(self.qo_bank.data[n + i]) + nn.init.orthogonal_(self.kv_bank.data[i], gain=1.0) + nn.init.orthogonal_(self.kv_bank.data[n + i], gain=1.0) + nn.init.orthogonal_(self.mlp_up_bank.data[i], gain=1.0) + nn.init.zeros_(self.mlp_down_bank.data[i]) + self.qo_bank.data[n + i].mul_(proj_scale) + self.mlp_down_bank.data[i].mul_(proj_scale) + for name, module in self.named_modules(): + if isinstance(module, nn.Linear): + if getattr(module, "_zero_init", False): + nn.init.zeros_(module.weight) + elif module.weight.ndim == 2 and module.weight.shape[0] >= 64 and module.weight.shape[1] >= 64: + nn.init.orthogonal_(module.weight, gain=1.0) + + def _get_ve(self, layer_idx: int, input_ids: Tensor, ve_cache: dict | None = None) -> Tensor | None: + if self.ve_shared is None or layer_idx not in self.ve_layer_indices: + return None + if ve_cache is not None and 've' not in ve_cache: + ve_cache['ve'] = self.ve_shared(input_ids) + ve_base = ve_cache['ve'] if ve_cache is not None else self.ve_shared(input_ids) + ve_idx = self.ve_layer_indices.index(layer_idx) + return ve_base * self.ve_layer_scales[ve_idx].to(dtype=ve_base.dtype) + + def forward(self, input_ids: Tensor, target_ids: Tensor) -> Tensor: + n = self.num_layers + x = self.tok_emb(input_ids) + if self.bigram is not None: + bg = self.bigram(input_ids, self._has_leading_space) + if self.word_start_boost is not None and self._has_leading_space is not None: + ws_mask = self._has_leading_space[input_ids].unsqueeze(-1).to(dtype=bg.dtype) + bg = bg * (1.0 + ws_mask * (self.word_start_boost.to(dtype=bg.dtype) - 1.0)) + x = x + bg + x = F.rms_norm(x, (x.size(-1),)) + x = self.smear(x) + x0 = x + v0 = None + skips: list[Tensor] = [] + ve_cache: dict = {} + for i in range(self.num_encoder_layers): + ve = self._get_ve(i, input_ids, ve_cache) + x, raw_v = self.blocks[i](x, x0, + self.qo_bank[i], self.kv_bank[i], self.kv_bank[n + i], + self.qo_bank[n + i], self.mlp_up_bank[i], self.mlp_down_bank[i], + v_embed=ve, v0=v0) + if v0 is None and raw_v is not None: + v0 = raw_v + skips.append(x) + for i in range(self.num_decoder_layers): + bi = self.num_encoder_layers + i + if skips: + x = x + self.skip_weights[i].to(dtype=x.dtype)[None, None, :] * skips.pop() + ve = self._get_ve(bi, input_ids, ve_cache) + x, _ = self.blocks[bi](x, x0, + self.qo_bank[bi], self.kv_bank[bi], self.kv_bank[n + bi], + self.qo_bank[n + bi], self.mlp_up_bank[bi], self.mlp_down_bank[bi], + v_embed=ve, v0=v0) + x = self.final_norm(x) + x_flat = x.reshape(-1, x.size(-1)) + targets = target_ids.reshape(-1) + if self.tie_embeddings: + logits_proj = F.linear(x_flat, self.tok_emb.weight) + else: + if self.lm_head is None: + raise RuntimeError("lm_head is required when tie_embeddings=False") + logits_proj = self.lm_head(x_flat) + logits = self.logit_softcap * torch.tanh(logits_proj / self.logit_softcap) + main_loss = F.cross_entropy(logits.float(), targets, reduction="mean") + if self.training and self.mtp_num_heads > 0 and self.mtp_loss_weight > 0.0: + _, seqlen, dim = x.shape + mtp_loss_sum = x.new_zeros(()) + mtp_loss_count = 0 + for k, mtp_head in enumerate(self.mtp_heads): + valid_t = seqlen - (k + 1) + if valid_t <= 0: + continue + mtp_hidden = x[:, :valid_t, :].reshape(-1, dim) + mtp_targets = target_ids[:, k + 1 :].reshape(-1) + mtp_logits_proj = mtp_head(mtp_hidden) + mtp_logits = self.logit_softcap * torch.tanh(mtp_logits_proj / self.logit_softcap) + mtp_loss_sum = mtp_loss_sum + F.cross_entropy(mtp_logits.float(), mtp_targets, reduction="mean") + mtp_loss_count += 1 + if mtp_loss_count > 0: + main_loss = main_loss + self.mtp_loss_weight * (mtp_loss_sum / mtp_loss_count) + return main_loss + + def forward_logits(self, input_ids: Tensor) -> Tensor: + n = self.num_layers + x = self.tok_emb(input_ids) + if self.bigram is not None: + bg = self.bigram(input_ids, self._has_leading_space) + if self.word_start_boost is not None and self._has_leading_space is not None: + ws_mask = self._has_leading_space[input_ids].unsqueeze(-1).to(dtype=bg.dtype) + bg = bg * (1.0 + ws_mask * (self.word_start_boost.to(dtype=bg.dtype) - 1.0)) + x = x + bg + x = F.rms_norm(x, (x.size(-1),)) + x = self.smear(x) + x0 = x + v0 = None + skips: list[Tensor] = [] + ve_cache: dict = {} + for i in range(self.num_encoder_layers): + ve = self._get_ve(i, input_ids, ve_cache) + x, raw_v = self.blocks[i](x, x0, + self.qo_bank[i], self.kv_bank[i], self.kv_bank[n + i], + self.qo_bank[n + i], self.mlp_up_bank[i], self.mlp_down_bank[i], + v_embed=ve, v0=v0) + if v0 is None and raw_v is not None: + v0 = raw_v + skips.append(x) + for i in range(self.num_decoder_layers): + bi = self.num_encoder_layers + i + if skips: + x = x + self.skip_weights[i].to(dtype=x.dtype)[None, None, :] * skips.pop() + ve = self._get_ve(bi, input_ids, ve_cache) + x, _ = self.blocks[bi](x, x0, + self.qo_bank[bi], self.kv_bank[bi], self.kv_bank[n + bi], + self.qo_bank[n + bi], self.mlp_up_bank[bi], self.mlp_down_bank[bi], + v_embed=ve, v0=v0) + x = self.final_norm(x) + if self.tie_embeddings: + logits_proj = F.linear(x, self.tok_emb.weight) + else: + logits_proj = self.lm_head(x) + return self.logit_softcap * torch.tanh(logits_proj / self.logit_softcap) + + def forward_with_banks(self, input_ids: Tensor, target_ids: Tensor, + qo_bank: Tensor, kv_bank: Tensor, + mlp_up_bank: Tensor, mlp_down_bank: Tensor) -> Tensor: + """Forward with external bank tensors (for meta-TTT). Pure next-token loss, no MTP.""" + n = self.num_layers + x = self.tok_emb(input_ids) + if self.bigram is not None: + bg = self.bigram(input_ids, self._has_leading_space) + if self.word_start_boost is not None and self._has_leading_space is not None: + ws_mask = self._has_leading_space[input_ids].unsqueeze(-1).to(dtype=bg.dtype) + bg = bg * (1.0 + ws_mask * (self.word_start_boost.to(dtype=bg.dtype) - 1.0)) + x = x + bg + x = F.rms_norm(x, (x.size(-1),)) + x = self.smear(x) + x0 = x + v0 = None + skips: list[Tensor] = [] + ve_cache: dict = {} + for i in range(self.num_encoder_layers): + ve = self._get_ve(i, input_ids, ve_cache) + x, raw_v = self.blocks[i](x, x0, + qo_bank[i], kv_bank[i], kv_bank[n + i], + qo_bank[n + i], mlp_up_bank[i], mlp_down_bank[i], + v_embed=ve, v0=v0) + if v0 is None and raw_v is not None: + v0 = raw_v + skips.append(x) + for i in range(self.num_decoder_layers): + bi = self.num_encoder_layers + i + if skips: + x = x + self.skip_weights[i].to(dtype=x.dtype)[None, None, :] * skips.pop() + ve = self._get_ve(bi, input_ids, ve_cache) + x, _ = self.blocks[bi](x, x0, + qo_bank[bi], kv_bank[bi], kv_bank[n + bi], + qo_bank[n + bi], mlp_up_bank[bi], mlp_down_bank[bi], + v_embed=ve, v0=v0) + x = self.final_norm(x).reshape(-1, x.size(-1)) + targets = target_ids.reshape(-1) + if self.tie_embeddings: + logits_proj = F.linear(x, self.tok_emb.weight) + else: + logits_proj = self.lm_head(x) + logits = self.logit_softcap * torch.tanh(logits_proj / self.logit_softcap) + return F.cross_entropy(logits.float(), targets, reduction="mean") + +# --- TTT (Test-Time Training) --- + +def eval_val_sliding_ttt( + args: Hyperparameters, base_model: nn.Module, rank: int, world_size: int, + device: torch.device, val_tokens: Tensor, base_bytes_lut: Tensor, + has_leading_space_lut: Tensor, is_boundary_token_lut: Tensor, + stride: int, log0=print, +) -> tuple[float, float]: + seq_len = args.train_seq_len + total_tokens = val_tokens.numel() - 1 + ttt_chunk = args.ttt_chunk_tokens + eval_batch = args.eval_batch_seqs + train_batch = args.ttt_batch_seqs + window_starts = [ws for ws in range(0, total_tokens, stride) + if min(ws + seq_len, total_tokens) - ws >= stride or ws == 0] + num_chunks = (total_tokens + ttt_chunk - 1) // ttt_chunk + chunk_windows: list[list[int]] = [[] for _ in range(num_chunks)] + for ws in window_starts: + end = min(ws + seq_len, total_tokens) + wlen = end - ws + s = 0 if ws == 0 else max(wlen - stride, 0) + scored_start = ws + s + ci = min(scored_start // ttt_chunk, num_chunks - 1) + chunk_windows[ci].append(ws) + frozen_ids = set(range(min(args.ttt_freeze_blocks, len(base_model.blocks)))) + ttt_params = [] + for name, p in base_model.named_parameters(): + freeze = any(f"blocks.{bi}." in name for bi in frozen_ids) + if freeze: + p.requires_grad_(False) + else: + p.requires_grad_(True) + ttt_params.append(p) + unfrozen_n = sum(p.numel() for p in ttt_params) + frozen_n = sum(p.numel() for p in base_model.parameters() if not p.requires_grad) + log0(f"ttt_sliding:start chunks={num_chunks} chunk_tokens={ttt_chunk} " + f"total_windows={len(window_starts)} stride={stride} " + f"ttt_lr={args.ttt_lr} ttt_epochs={args.ttt_epochs} " + f"freeze_blocks={args.ttt_freeze_blocks}", flush=True) + log0(f"ttt_sliding:params unfrozen={unfrozen_n} frozen={frozen_n}", flush=True) + optimizer = torch.optim.SGD(ttt_params, lr=args.ttt_lr, momentum=args.ttt_momentum) + loss_sum = torch.zeros((), device=device, dtype=torch.float64) + token_count = torch.zeros((), device=device, dtype=torch.float64) + byte_count = torch.zeros((), device=device, dtype=torch.float64) + t0 = time.perf_counter() + for ci in range(num_chunks): + windows = chunk_windows[ci] + if not windows: + continue + chunk_start = ci * ttt_chunk + chunk_end = min((ci + 1) * ttt_chunk, total_tokens) + my_s = (len(windows) * rank) // world_size + my_e = (len(windows) * (rank + 1)) // world_size + my_windows = windows[my_s:my_e] + base_model.eval() + with torch.inference_mode(): + for bi in range(0, len(my_windows), eval_batch): + batch_ws = my_windows[bi:bi + eval_batch] + bsz = len(batch_ws) + x_batch = torch.zeros(bsz, seq_len, dtype=torch.int64, device=device) + y_batch = torch.zeros(bsz, seq_len, dtype=torch.int64, device=device) + wlens: list[int] = [] + for i, ws in enumerate(batch_ws): + end = min(ws + seq_len, total_tokens) + wlen = end - ws + wlens.append(wlen) + chunk_tok = val_tokens[ws:end + 1].to(dtype=torch.int64, device=device) + x_batch[i, :wlen] = chunk_tok[:-1] + y_batch[i, :wlen] = chunk_tok[1:] + with torch.autocast(device_type="cuda", dtype=torch.bfloat16): + logits = base_model.forward_logits(x_batch) + nll = F.cross_entropy( + logits.reshape(-1, logits.size(-1)).float(), + y_batch.reshape(-1), reduction="none", + ).reshape(bsz, seq_len) + for i, ws in enumerate(batch_ws): + wlen = wlens[i] + s = 0 if ws == 0 else max(wlen - stride, 0) + scored_nll = nll[i, s:wlen].to(torch.float64) + loss_sum += scored_nll.sum() + token_count += float(wlen - s) + tgt, prev = y_batch[i, s:wlen], x_batch[i, s:wlen] + tb = base_bytes_lut[tgt].to(torch.float64) + tb += (has_leading_space_lut[tgt] & ~is_boundary_token_lut[prev]).to(torch.float64) + byte_count += tb.sum() + is_last = (ci == num_chunks - 1) + if not is_last and args.ttt_epochs > 0: + base_model.train() + chunk_seqs = (chunk_end - chunk_start) // seq_len + if chunk_seqs > 0: + cos_lr = args.ttt_lr * 0.5 * (1.0 + math.cos(math.pi * ci / max(num_chunks - 1, 1))) + for pg in optimizer.param_groups: + pg['lr'] = cos_lr + my_seq_s = (chunk_seqs * rank) // world_size + my_seq_e = (chunk_seqs * (rank + 1)) // world_size + my_chunk_seqs = my_seq_e - my_seq_s + for _ep in range(args.ttt_epochs): + for bs in range(0, my_chunk_seqs, train_batch): + be = min(bs + train_batch, my_chunk_seqs) + actual_bs = my_seq_s + bs + start_tok = chunk_start + actual_bs * seq_len + end_tok = chunk_start + (my_seq_s + be) * seq_len + 1 + if end_tok > val_tokens.numel(): + continue + local = val_tokens[start_tok:end_tok].to(device=device, dtype=torch.int64) + x = local[:-1].reshape(-1, seq_len) + y = local[1:].reshape(-1, seq_len) + optimizer.zero_grad(set_to_none=True) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16): + loss = base_model(x, y) + loss.backward() + if world_size > 1: + for p in ttt_params: + if p.grad is not None: + dist.all_reduce(p.grad, op=dist.ReduceOp.AVG) + torch.nn.utils.clip_grad_norm_(ttt_params, args.ttt_grad_clip) + optimizer.step() + if rank == 0 and (ci % 10 == 0 or ci == num_chunks - 1): + elapsed = time.perf_counter() - t0 + total_w = len(window_starts) + done_w = sum(len(chunk_windows[c]) for c in range(ci + 1)) + pct = done_w / total_w * 100 + rl = loss_sum.item() / max(token_count.item(), 1) + rbpb = rl / math.log(2.0) * (token_count.item() / max(byte_count.item(), 1)) if token_count.item() > 0 else 0.0 + eta = (elapsed / max(done_w, 1)) * (total_w - done_w) + bar_len = 30 + filled = int(bar_len * done_w / total_w) + bar = "\u2588" * filled + "\u2591" * (bar_len - filled) + log0(f" ttt [{bar}] {pct:5.1f}% chunk {ci+1}/{num_chunks} bpb={rbpb:.6f} ETA={eta:.0f}s", flush=True) + if rank == 0: + log0("", flush=True) + if dist.is_available() and dist.is_initialized(): + dist.all_reduce(loss_sum, op=dist.ReduceOp.SUM) + dist.all_reduce(token_count, op=dist.ReduceOp.SUM) + dist.all_reduce(byte_count, op=dist.ReduceOp.SUM) + val_loss = (loss_sum / token_count).item() + val_bpb = val_loss / math.log(2.0) * (token_count.item() / byte_count.item()) + for p in base_model.parameters(): + p.requires_grad_(True) + base_model.eval() + log0(f"ttt_sliding:done val_loss={val_loss:.6f} val_bpb={val_bpb:.6f} " + f"elapsed={time.perf_counter() - t0:.1f}s", flush=True) + return val_loss, val_bpb + +# --- GPTQ quantization pipeline --- + +def generate_autoregressive_calib(model, device, num_seqs=64, seq_len=2048, + vocab_size=1024, temperature=0.8, batch_size=8, seed=42): + model.eval() + rng = torch.Generator(device=device) + rng.manual_seed(seed) + all_tokens = [] + with torch.inference_mode(), torch.autocast(device_type="cuda", dtype=torch.bfloat16): + for batch_start in range(0, num_seqs, batch_size): + bs = min(batch_size, num_seqs - batch_start) + tokens = torch.randint(0, vocab_size, (bs, 1), device=device, generator=rng) + for pos in range(seq_len - 1): + logits = model.forward_logits(tokens) + next_logit = logits[:, -1, :] + probs = torch.softmax(next_logit / temperature, dim=-1) + next_tok = torch.multinomial(probs, 1, generator=rng) + tokens = torch.cat([tokens, next_tok], dim=1) + for i in range(bs): + all_tokens.append(tokens[i:i+1]) + return all_tokens + +def collect_hessians_from_tokens(hessian_model, token_seqs, device): + hessians = {} + hooks = [] + for name, module in hessian_model.named_modules(): + if isinstance(module, CastedLinear): + param_name = name + ".weight" + cols = module.weight.shape[1] + hessians[param_name] = torch.zeros(cols, cols, dtype=torch.float32, device='cpu') + def make_hook(pname): + def hook_fn(module, input, output): + x = input[0].detach().float() + if x.ndim == 3: + x = x.reshape(-1, x.shape[-1]) + hessians[pname] += (x.T @ x).cpu() + return hook_fn + h = module.register_forward_hook(make_hook(param_name)) + hooks.append(h) + hessian_model.eval() + with torch.inference_mode(), torch.autocast(device_type="cuda", dtype=torch.bfloat16): + for seq in token_seqs: + x = seq[:, :-1].to(device) + y = seq[:, 1:].to(device) + hessian_model(x, y) + for h in hooks: + h.remove() + num_batches = len(token_seqs) + for name in hessians: + H = hessians[name] + H /= num_batches + damp = 0.01 * torch.diag(H).mean().clamp_min(1e-6) + H += damp * torch.eye(H.shape[0]) + hessians[name] = H + return hessians + +def _classify_param(name: str) -> str: + if "tok_emb" in name or "lm_head" in name: + return "embed" + if ".mlp." in name: + return "mlp" + if ".attn." in name or (".proj." in name and ".mlp." not in name): + return "attn" + return "other" + +def quantize_int6_per_row(t: Tensor, clip_range: int = 31) -> tuple[Tensor, Tensor]: + t32 = t.float() + if t32.ndim == 2: + best_q, best_s, best_err = None, None, float('inf') + for pct in [0.9990, 0.9995, 0.9999, 0.99999, 1.0]: + if pct < 1.0: + row_clip = torch.quantile(t32.abs(), pct, dim=1) + else: + row_clip = t32.abs().amax(dim=1) + s = (row_clip / clip_range).clamp_min(1.0 / clip_range).to(torch.float16) + q = torch.clamp(torch.round(t32 / s.float()[:, None]), -clip_range, clip_range).to(torch.int8) + recon = q.float() * s.float()[:, None] + err = (t32 - recon).pow(2).mean().item() + if err < best_err: + best_q, best_s, best_err = q, s, err + return best_q, best_s + amax = t32.abs().max().item() + scale = torch.tensor(amax / clip_range if amax > 0 else 1.0, dtype=torch.float16) + q = torch.clamp(torch.round(t32 / scale.float()), -clip_range, clip_range).to(torch.int8) + return q, scale + +def quantize_int6_gptq(weight, hessian=None, clip_range=31, block_size=128): + t32 = weight.float() + if t32.ndim != 2 or hessian is None: + return _quantize_int6_percentile(t32, clip_range) + rows, cols = t32.shape + H = hessian.float().clone() + dead = torch.diag(H) == 0 + H[dead, dead] = 1 + damp = 0.01 * torch.mean(torch.diag(H)) + H[torch.arange(cols), torch.arange(cols)] += damp + perm = torch.argsort(torch.diag(H), descending=True) + inv_perm = torch.argsort(perm) + W = t32[:, perm].clone() + W[:, dead[perm]] = 0 + H = H[perm][:, perm] + Hinv = torch.linalg.cholesky(H) + Hinv = torch.cholesky_inverse(Hinv) + Hinv = torch.linalg.cholesky(Hinv, upper=True) + best_q = None; best_scale = None; best_err = float('inf') + for pct in [0.9990, 0.9995, 0.9999, 0.99999, 1.0]: + if pct < 1.0: + row_clip = torch.quantile(t32.abs(), pct, dim=1) + else: + row_clip = t32.abs().amax(dim=1) + s = (row_clip / clip_range).clamp_min(1.0 / clip_range).to(torch.float16) + sf = s.float() + Q = torch.zeros_like(W, dtype=torch.int8) + W_work = W.clone() + for i1 in range(0, cols, block_size): + i2 = min(i1 + block_size, cols) + count = i2 - i1 + W1 = W_work[:, i1:i2].clone() + Q1 = torch.zeros(rows, count, dtype=torch.int8) + Err1 = torch.zeros(rows, count) + Hinv1 = Hinv[i1:i2, i1:i2] + for i in range(count): + w = W1[:, i] + d = Hinv1[i, i] + q = torch.clamp(torch.round(w / sf), -clip_range, clip_range).to(torch.int8) + Q1[:, i] = q + err = (w - q.float() * sf) / d + W1[:, i:] -= err.unsqueeze(1) * Hinv1[i, i:].unsqueeze(0) + Err1[:, i] = err + Q[:, i1:i2] = Q1 + if i2 < cols: + W_work[:, i2:] -= Err1 @ Hinv[i1:i2, i2:] + recon = Q.float() * sf[:, None] + mse = (W - recon).pow(2).mean().item() + if mse < best_err: + best_q, best_scale, best_err = Q, s, mse + best_q = best_q[:, inv_perm] + return best_q, best_scale + +def _quantize_int6_percentile(t32, clip_range=31): + if t32.ndim == 2: + best_q, best_s, best_err = None, None, float('inf') + for pct in [0.9990, 0.9995, 0.9999, 0.99999, 1.0]: + if pct < 1.0: + row_clip = torch.quantile(t32.abs(), pct, dim=1) + else: + row_clip = t32.abs().amax(dim=1) + s = (row_clip / clip_range).clamp_min(1.0 / clip_range).to(torch.float16) + q = torch.clamp(torch.round(t32 / s.float()[:, None]), -clip_range, clip_range).to(torch.int8) + recon = q.float() * s.float()[:, None] + err = (t32 - recon).pow(2).mean().item() + if err < best_err: + best_q, best_s, best_err = q, s, err + return best_q, best_s + amax = t32.abs().max().item() + scale = torch.tensor(amax / clip_range if amax > 0 else 1.0, dtype=torch.float16) + q = torch.clamp(torch.round(t32 / scale.float()), -clip_range, clip_range).to(torch.int8) + return q, scale + +def _unbank_state_dict(sd: dict[str, Tensor], num_layers: int) -> dict[str, Tensor]: + out: dict[str, Tensor] = {} + n = num_layers + for name, tensor in sd.items(): + if name == "qo_bank": + for i in range(n): + out[f"blocks.{i}.attn.c_q.weight"] = tensor[i] + out[f"blocks.{i}.attn.proj.weight"] = tensor[n + i] + elif name == "kv_bank": + for i in range(n): + out[f"blocks.{i}.attn.c_k.weight"] = tensor[i] + out[f"blocks.{i}.attn.c_v.weight"] = tensor[n + i] + elif name == "mlp_up_bank": + for i in range(n): + out[f"blocks.{i}.mlp.fc.weight"] = tensor[i] + elif name == "mlp_down_bank": + for i in range(n): + out[f"blocks.{i}.mlp.proj.weight"] = tensor[i] + else: + out[name] = tensor + return out + +def _rebank_state_dict(sd: dict[str, Tensor], num_layers: int, template_sd: dict[str, Tensor]) -> dict[str, Tensor]: + out: dict[str, Tensor] = {} + n = num_layers + qo_slices = [None] * (2 * n) + kv_slices = [None] * (2 * n) + up_slices = [None] * n + down_slices = [None] * n + consumed = set() + for i in range(n): + qk = f"blocks.{i}.attn.c_q.weight" + if qk in sd: qo_slices[i] = sd[qk]; consumed.add(qk) + ok = f"blocks.{i}.attn.proj.weight" + if ok in sd: qo_slices[n + i] = sd[ok]; consumed.add(ok) + kk = f"blocks.{i}.attn.c_k.weight" + if kk in sd: kv_slices[i] = sd[kk]; consumed.add(kk) + vk = f"blocks.{i}.attn.c_v.weight" + if vk in sd: kv_slices[n + i] = sd[vk]; consumed.add(vk) + fk = f"blocks.{i}.mlp.fc.weight" + if fk in sd: up_slices[i] = sd[fk]; consumed.add(fk) + dk = f"blocks.{i}.mlp.proj.weight" + if dk in sd: down_slices[i] = sd[dk]; consumed.add(dk) + out["qo_bank"] = torch.stack(qo_slices).to(dtype=template_sd["qo_bank"].dtype) + out["kv_bank"] = torch.stack(kv_slices).to(dtype=template_sd["kv_bank"].dtype) + out["mlp_up_bank"] = torch.stack(up_slices).to(dtype=template_sd["mlp_up_bank"].dtype) + out["mlp_down_bank"] = torch.stack(down_slices).to(dtype=template_sd["mlp_down_bank"].dtype) + for name, tensor in sd.items(): + if name not in consumed: + out[name] = tensor + return out + +class _HessianAttn(nn.Module): + def __init__(self, dim, num_heads, num_kv_heads, rope_base, qk_gain_init): + super().__init__() + self.num_heads, self.num_kv_heads = num_heads, num_kv_heads + self.head_dim = dim // num_heads + kv_dim = num_kv_heads * self.head_dim + self.c_q = CastedLinear(dim, dim, bias=False) + self.c_k = CastedLinear(dim, kv_dim, bias=False) + self.c_v = CastedLinear(dim, kv_dim, bias=False) + self.proj = CastedLinear(dim, dim, bias=False) + self.q_gain = nn.Parameter(torch.full((num_heads,), qk_gain_init, dtype=torch.float32)) + self.rope_dims = 0 + self.rotary = Rotary(self.head_dim, base=rope_base, train_seq_len=1024) + self.use_xsa = False + def _xsa_efficient(self, y, v): + B, T, H, D = y.shape; Hkv = v.size(-2); group = H // Hkv + y_g = y.reshape(B, T, Hkv, group, D) + vn = F.normalize(v, dim=-1).unsqueeze(-2) + proj = (y_g * vn).sum(dim=-1, keepdim=True) * vn + return (y_g - proj).reshape(B, T, H, D) + def forward(self, x, v_embed=None): + bsz, seqlen, dim = x.shape + q = self.c_q(x).reshape(bsz, seqlen, self.num_heads, self.head_dim) + k = self.c_k(x).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim) + v = self.c_v(x) + if v_embed is not None: v = v + v_embed + v = v.reshape(bsz, seqlen, self.num_kv_heads, self.head_dim) + q = F.rms_norm(q, (q.size(-1),)); k = F.rms_norm(k, (k.size(-1),)) + cos, sin = self.rotary(seqlen, x.device, q.dtype) + q = apply_rotary_emb(q, cos, sin, self.rope_dims) + k = apply_rotary_emb(k, cos, sin, self.rope_dims) + q = q * self.q_gain.to(dtype=q.dtype)[None, None, :, None] + y = flash_attn_3_func(q, k, v, causal=True) + if self.use_xsa: y = self._xsa_efficient(y, v) + return self.proj(y.reshape(bsz, seqlen, dim)) + +class _HessianMLP(nn.Module): + def __init__(self, dim, mlp_mult): + super().__init__() + self.fc = CastedLinear(dim, int(mlp_mult * dim), bias=False) + self.proj = CastedLinear(int(mlp_mult * dim), dim, bias=False) + def forward(self, x): + return self.proj(F.leaky_relu(self.fc(x), negative_slope=0.5).square()) + +class _HessianBlock(nn.Module): + def __init__(self, dim, num_heads, num_kv_heads, mlp_mult, rope_base, qk_gain_init, layer_idx=0, ln_scale=False): + super().__init__() + self.attn_norm = RMSNorm(); self.mlp_norm = RMSNorm() + self.attn = _HessianAttn(dim, num_heads, num_kv_heads, rope_base, qk_gain_init) + self.mlp = _HessianMLP(dim, mlp_mult) + self.attn_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + self.mlp_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + self.resid_mix = nn.Parameter(torch.stack((torch.ones(dim), torch.zeros(dim))).float()) + self.ln_scale_factor = 1.0 / math.sqrt(layer_idx + 1) if ln_scale else 1.0 + def forward(self, x, x0, v_embed=None): + mix = self.resid_mix.to(dtype=x.dtype) + x_in = mix[0][None, None, :] * x + mix[1][None, None, :] * x0 + attn_out = self.attn(self.attn_norm(x_in) * self.ln_scale_factor, v_embed=v_embed) + x_out = x_in + self.attn_scale.to(dtype=x_in.dtype)[None, None, :] * attn_out + x_out = x_out + self.mlp_scale.to(dtype=x_out.dtype)[None, None, :] * self.mlp(self.mlp_norm(x_out) * self.ln_scale_factor) + return x_out + +class _HessianGPT(nn.Module): + def __init__(self, vocab_size, num_layers, model_dim, num_heads, num_kv_heads, + mlp_mult, tie_embeddings, logit_softcap, rope_base, qk_gain_init, + bigram_vocab_size=0, bigram_dim=128, xsa_last_n=0, + rope_dims=0, ln_scale=False, + ve_enabled=False, ve_dim=128, ve_layers="9,10"): + super().__init__() + self.tie_embeddings = tie_embeddings + self.logit_softcap = logit_softcap + self.num_layers = num_layers + self.tok_emb = nn.Embedding(vocab_size, model_dim) + self.bigram = BigramHashEmbedding( + bigram_vocab_size, bigram_dim, model_dim, + trigram=bool(int(os.environ.get("TRIGRAM", "0"))), + pos_conditional=bool(int(os.environ.get("POS_CONDITIONAL_BIGRAM", "0"))), + ) if bigram_vocab_size > 0 else None + self.word_start_boost = nn.Parameter(torch.tensor(1.0, dtype=torch.float32)) if bigram_vocab_size > 0 else None + self._has_leading_space: Tensor | None = None + self.smear = SmearGate(model_dim) + self.num_encoder_layers = num_layers // 2 + self.num_decoder_layers = num_layers - self.num_encoder_layers + self.num_skip_weights = min(self.num_encoder_layers, self.num_decoder_layers) + self.skip_weights = nn.Parameter(torch.ones(self.num_skip_weights, model_dim, dtype=torch.float32)) + self.blocks = nn.ModuleList([ + _HessianBlock(model_dim, num_heads, num_kv_heads, mlp_mult, rope_base, qk_gain_init, + layer_idx=i, ln_scale=ln_scale) + for i in range(num_layers) + ]) + if rope_dims > 0: + head_dim = model_dim // num_heads + for block in self.blocks: + block.attn.rope_dims = rope_dims + block.attn.rotary = Rotary(head_dim, base=rope_base, train_seq_len=1024, rope_dims=rope_dims) + if xsa_last_n > 0: + for i in range(max(0, num_layers - xsa_last_n), num_layers): + self.blocks[i].attn.use_xsa = True + kv_dim = num_kv_heads * (model_dim // num_heads) + self.ve_layer_indices = [int(x) for x in ve_layers.split(",") if x.strip()] if ve_enabled else [] + if self.ve_layer_indices: + self.ve_shared = ValueEmbedding(vocab_size, ve_dim, kv_dim) + self.ve_layer_scales = nn.ParameterList([nn.Parameter(torch.ones(1, dtype=torch.float32)) for _ in self.ve_layer_indices]) + else: + self.ve_shared = None + self.ve_layer_scales = nn.ParameterList() + self.final_norm = RMSNorm() + self.lm_head = None if tie_embeddings else CastedLinear(model_dim, vocab_size, bias=False) + def _get_ve(self, layer_idx, input_ids, ve_cache): + if self.ve_shared is None or layer_idx not in self.ve_layer_indices: return None + if 've' not in ve_cache: ve_cache['ve'] = self.ve_shared(input_ids) + ve_idx = self.ve_layer_indices.index(layer_idx) + return ve_cache['ve'] * self.ve_layer_scales[ve_idx].to(dtype=ve_cache['ve'].dtype) + def forward(self, input_ids, target_ids): + x = self.tok_emb(input_ids) + if self.bigram is not None: + bg = self.bigram(input_ids, self._has_leading_space) + if self.word_start_boost is not None and self._has_leading_space is not None: + ws_mask = self._has_leading_space[input_ids].unsqueeze(-1).to(dtype=bg.dtype) + bg = bg * (1.0 + ws_mask * (self.word_start_boost.to(dtype=bg.dtype) - 1.0)) + x = x + bg + x = F.rms_norm(x, (x.size(-1),)) + x = self.smear(x) + x0 = x; skips = []; ve_cache = {} + for i in range(self.num_encoder_layers): + ve = self._get_ve(i, input_ids, ve_cache) + x = self.blocks[i](x, x0, v_embed=ve); skips.append(x) + for i in range(self.num_decoder_layers): + bi = self.num_encoder_layers + i + if skips: x = x + self.skip_weights[i].to(dtype=x.dtype)[None, None, :] * skips.pop() + ve = self._get_ve(bi, input_ids, ve_cache) + x = self.blocks[bi](x, x0, v_embed=ve) + x = self.final_norm(x) + x_flat = x.reshape(-1, x.size(-1)); targets = target_ids.reshape(-1) + logits_proj = F.linear(x_flat, self.tok_emb.weight) if self.tie_embeddings else self.lm_head(x_flat) + logits = self.logit_softcap * torch.tanh(logits_proj / self.logit_softcap) + return F.cross_entropy(logits.float(), targets, reduction="mean") + +def collect_hessians(hessian_model, train_loader, args, device, grad_accum_steps, num_batches=256): + hessians = {}; hooks = [] + for name, module in hessian_model.named_modules(): + if isinstance(module, CastedLinear): + param_name = name + ".weight" + cols = module.weight.shape[1] + hessians[param_name] = torch.zeros(cols, cols, dtype=torch.float32, device='cpu') + def make_hook(pname): + def hook_fn(module, input, output): + x = input[0].detach().float() + if x.ndim == 3: x = x.reshape(-1, x.shape[-1]) + hessians[pname] += (x.T @ x).cpu() + return hook_fn + h = module.register_forward_hook(make_hook(param_name)); hooks.append(h) + hessian_model.eval() + with torch.inference_mode(), torch.autocast(device_type="cuda", dtype=torch.bfloat16): + for _ in range(num_batches): + x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) + hessian_model(x, y) + for h in hooks: h.remove() + for name in hessians: + H = hessians[name]; H /= num_batches + damp = 0.01 * torch.diag(H).mean().clamp_min(1e-6) + H += damp * torch.eye(H.shape[0]); hessians[name] = H + hessian_model.train() + return hessians + +def mixed_quantize_int6(state_dict: dict[str, Tensor], int6_cats: set[str], hessians: dict[str, Tensor] | None = None): + num_layers_total = max( + (int(k.split(".")[1]) for k in state_dict if k.startswith("blocks.")), + default=0, + ) + 1 + late_k_layers = set(range(num_layers_total - 2, num_layers_total)) + result: dict[str, Tensor] = {} + meta: dict[str, object] = {} + for name, tensor in state_dict.items(): + t = tensor.detach().cpu().contiguous() + cat = _classify_param(name) + if not t.is_floating_point() or t.numel() <= 65536: + result[name] = t.to(torch.float16) if t.is_floating_point() else t + meta[name] = "passthrough"; continue + if any(p in name for p in CONTROL_TENSOR_NAME_PATTERNS): + result[name] = t.float(); meta[name] = "passthrough_ctrl"; continue + if cat in int6_cats and t.ndim >= 1: + cr = 31 + H = hessians.get(name) if hessians else None + if H is not None: + q, s = quantize_int6_gptq(t, hessian=H, clip_range=cr) + else: + q, s = quantize_int6_per_row(t, clip_range=cr) + result[name + ".q"] = q; result[name + ".scale"] = s + meta[name] = {"type": "int6"} + else: + q, s = quantize_float_tensor(t) + result[name + ".q"] = q; result[name + ".scale"] = s + meta[name] = {"type": "int8"} + return result, meta + +def dequantize_mixed_int6(result: dict[str, Tensor], meta: dict[str, object], + template_sd: dict[str, Tensor]) -> dict[str, Tensor]: + out: dict[str, Tensor] = {} + for name, orig in template_sd.items(): + info = meta.get(name) + if info is None: continue + orig_dtype = orig.dtype + if info in ("passthrough", "passthrough_ctrl", "passthrough_fp16"): + t = result[name] + if t.dtype == torch.float16 and orig_dtype in (torch.float32, torch.bfloat16): + t = t.to(orig_dtype) + out[name] = t; continue + q, s = result[name + ".q"], result[name + ".scale"] + if s.ndim > 0: + out[name] = (q.float() * s.float().view(q.shape[0], *([1] * (q.ndim - 1)))).to(orig_dtype) + else: + out[name] = (q.float() * float(s.item())).to(orig_dtype) + return out + +# --- Meta-TTT (FOMAML + exp106 extensions) --- + +def _meta_ttt_split(x: Tensor, y: Tensor, mode: str) -> tuple[Tensor, Tensor, Tensor, Tensor]: + """exp106 (A): pick the inner/outer split for meta-TTT. + + mode="batch": cross-sample split along the batch dim. Inner and outer are drawn from + DIFFERENT training sequences, which in fineweb10B means different documents. + This matches deployment-time TTT (adapt on past text, predict upcoming text from + a likely different distributional regime) far better than the legacy same-doc split. + Falls back to "seq" if batch size is <2. + mode="seq" (legacy): first half / second half of the same sequence. High inner/outer + correlation because both halves are from the same document. + """ + b = x.shape[0] + if mode == "batch" and b >= 2: + half = b // 2 + return x[:half], y[:half], x[half:], y[half:] + # Fallback or explicit seq mode + seq_len = x.shape[1] + half = seq_len // 2 + return x[:, :half], y[:, :half], x[:, half:], y[:, half:] + +def meta_ttt_step(base_model: nn.Module, x: Tensor, y: Tensor, + args, grad_scale: float = 1.0) -> Tensor: + """Meta-TTT step with exp106's three changes on top of the exp101 FOMAML baseline: + + (A) Cross-chunk inner/outer split (from exp106): split along batch dim so inner and + outer are different documents. Matches deployment-time TTT statistical regime. + + (B) Delta loss (from exp106): outer loss = post_weight * loss_post + delta_weight * (loss_post - loss_pre). + Explicitly rewards the backbone for developing features where SGD-on-banks improves. + + (D) SAM inner loop (exp107, replaces exp106's MetaSGD which converged to uniform 1.0): + Instead of vanilla SGD, the inner step uses Sharpness-Aware Minimization: + 1. Compute gradient g at current banks (vanilla forward+backward) + 2. Perturb banks in the ascent direction: banks_pert = banks + rho * g / ||g|| + 3. Compute gradient g_sam at the perturbed point (second forward+backward) + 4. Use g_sam (not g) for the adaptation step: banks' = banks - lr * g_sam + This finds adapted banks in flatter regions of the loss landscape. If the TTT + ceiling is determined by local curvature, SAM's flatness-seeking may break it. + + Runs uncompiled (forward_with_banks) to avoid recompilation from new bank tensors. + """ + n = base_model.num_layers + freeze_n = args.meta_ttt_freeze_blocks + lr = args.meta_ttt_inner_lr + post_weight = args.meta_ttt_loss_weight + delta_weight = args.meta_ttt_delta_weight + sam_on = args.meta_ttt_sam_enabled + + # (A) Cross-chunk split + x_inner, y_inner, x_outer, y_outer = _meta_ttt_split(x, y, args.meta_ttt_split) + + # --- Inner loop: detached banks as leaves, compute grads on chunk_A --- + qo_in = base_model.qo_bank.detach().clone().requires_grad_(True) + kv_in = base_model.kv_bank.detach().clone().requires_grad_(True) + up_in = base_model.mlp_up_bank.detach().clone().requires_grad_(True) + down_in = base_model.mlp_down_bank.detach().clone().requires_grad_(True) + + with torch.autocast(device_type="cuda", dtype=torch.bfloat16): + inner_loss = base_model.forward_with_banks(x_inner, y_inner, qo_in, kv_in, up_in, down_in) + + g_qo, g_kv, g_up, g_down = torch.autograd.grad(inner_loss, [qo_in, kv_in, up_in, down_in]) + + # Detach gradients (first-order FOMAML). + g_qo = g_qo.detach(); g_kv = g_kv.detach() + g_up = g_up.detach(); g_down = g_down.detach() + + # Freeze mask: zero gradients for frozen blocks + if freeze_n > 0: + g_qo = g_qo.clone(); g_kv = g_kv.clone(); g_up = g_up.clone(); g_down = g_down.clone() + with torch.no_grad(): + for bi in range(freeze_n): + g_qo[bi].zero_(); g_qo[n + bi].zero_() + g_kv[bi].zero_(); g_kv[n + bi].zero_() + g_up[bi].zero_() + g_down[bi].zero_() + + # (D) SAM: compute sharpness-aware gradient at the ascent-perturbed point. + # The gradient g from step above is the "vanilla" direction. SAM perturbs banks + # in the direction of steepest ascent (g / ||g||), then recomputes the gradient + # at that perturbed point. The result g_sam points toward flatter minima. + if sam_on: + rho = args.meta_ttt_sam_rho + + # Joint gradient norm across all 4 banks (float32 for precision) + grad_norm = (g_qo.float().norm().square() + g_kv.float().norm().square() + + g_up.float().norm().square() + g_down.float().norm().square() + ).sqrt().clamp(min=1e-12) + + # Ascent perturbation: epsilon = rho * g / ||g|| + with torch.no_grad(): + if args.meta_ttt_sam_adaptive: + # Adaptive SAM: scale by |param|, giving each bank its own effective radius + eps_qo = rho * (g_qo / grad_norm) * qo_in.abs().clamp(min=1e-7) + eps_kv = rho * (g_kv / grad_norm) * kv_in.abs().clamp(min=1e-7) + eps_up = rho * (g_up / grad_norm) * up_in.abs().clamp(min=1e-7) + eps_dn = rho * (g_down / grad_norm) * down_in.abs().clamp(min=1e-7) + else: + # Vanilla SAM: uniform perturbation radius + scale = rho / grad_norm + eps_qo = scale * g_qo + eps_kv = scale * g_kv + eps_up = scale * g_up + eps_dn = scale * g_down + + # Perturbed banks (ascent point — loss should be higher here) + qo_pert = (qo_in.detach() + eps_qo).requires_grad_(True) + kv_pert = (kv_in.detach() + eps_kv).requires_grad_(True) + up_pert = (up_in.detach() + eps_up).requires_grad_(True) + down_pert = (down_in.detach() + eps_dn).requires_grad_(True) + + # Second forward+backward at the perturbed point → SAM gradient + with torch.autocast(device_type="cuda", dtype=torch.bfloat16): + loss_pert = base_model.forward_with_banks( + x_inner, y_inner, qo_pert, kv_pert, up_pert, down_pert) + + g_sam_qo, g_sam_kv, g_sam_up, g_sam_down = torch.autograd.grad( + loss_pert, [qo_pert, kv_pert, up_pert, down_pert]) + + # Detach SAM gradients (first-order FOMAML) + g_sam_qo = g_sam_qo.detach(); g_sam_kv = g_sam_kv.detach() + g_sam_up = g_sam_up.detach(); g_sam_down = g_sam_down.detach() + + # Re-apply freeze mask (SAM backward doesn't know about frozen blocks) + if freeze_n > 0: + g_sam_qo = g_sam_qo.clone(); g_sam_kv = g_sam_kv.clone() + g_sam_up = g_sam_up.clone(); g_sam_down = g_sam_down.clone() + with torch.no_grad(): + for bi in range(freeze_n): + g_sam_qo[bi].zero_(); g_sam_qo[n + bi].zero_() + g_sam_kv[bi].zero_(); g_sam_kv[n + bi].zero_() + g_sam_up[bi].zero_() + g_sam_down[bi].zero_() + + # Use SAM gradient for the adaptation step + use_g_qo, use_g_kv = g_sam_qo, g_sam_kv + use_g_up, use_g_down = g_sam_up, g_sam_down + else: + # Vanilla SGD (exp101 fallback when SAM disabled) + use_g_qo, use_g_kv = g_qo, g_kv + use_g_up, use_g_down = g_up, g_down + + # Build adapted banks as leaves with requires_grad=True for the outer backward. + qo_bank_det = base_model.qo_bank.detach() + kv_bank_det = base_model.kv_bank.detach() + up_bank_det = base_model.mlp_up_bank.detach() + down_bank_det = base_model.mlp_down_bank.detach() + with torch.no_grad(): + qo_upd = (qo_bank_det - lr * use_g_qo).requires_grad_(True) + kv_upd = (kv_bank_det - lr * use_g_kv).requires_grad_(True) + up_upd = (up_bank_det - lr * use_g_up).requires_grad_(True) + down_upd = (down_bank_det - lr * use_g_down).requires_grad_(True) + + # --- Outer loop (unchanged from exp106) --- + # (B) loss_pre: forward on OUTER chunk with ORIGINAL banks (LIVE, so grads flow to + # backbone non-bank params AND to the banks directly). + loss_pre: Tensor | None = None + if delta_weight != 0.0: + with torch.autocast(device_type="cuda", dtype=torch.bfloat16): + loss_pre = base_model.forward_with_banks( + x_outer, y_outer, + base_model.qo_bank, base_model.kv_bank, + base_model.mlp_up_bank, base_model.mlp_down_bank) + + # loss_post: forward with adapted banks. Non-bank params LIVE → grads flow directly. + with torch.autocast(device_type="cuda", dtype=torch.bfloat16): + loss_post = base_model.forward_with_banks(x_outer, y_outer, qo_upd, kv_upd, up_upd, down_upd) + + # Outer loss combines post + delta. + if loss_pre is not None: + outer_loss = (post_weight + delta_weight) * loss_post - delta_weight * loss_pre + else: + outer_loss = post_weight * loss_post + + scaled = outer_loss * grad_scale + scaled.backward() + + # FOMAML: copy adapted-point gradients onto original bank params. + with torch.no_grad(): + for bank, upd in [(base_model.qo_bank, qo_upd), + (base_model.kv_bank, kv_upd), + (base_model.mlp_up_bank, up_upd), + (base_model.mlp_down_bank, down_upd)]: + if upd.grad is not None: + if bank.grad is None: + bank.grad = upd.grad.to(bank.dtype).clone() + else: + bank.grad.add_(upd.grad.to(bank.dtype)) + + return loss_post.detach() + +# --- Training --- + +def main() -> None: + code = Path(__file__).read_text(encoding="utf-8") + args = Hyperparameters() + distributed = "RANK" in os.environ and "WORLD_SIZE" in os.environ + rank = int(os.environ.get("RANK", "0")) + world_size = int(os.environ.get("WORLD_SIZE", "1")) + local_rank = int(os.environ.get("LOCAL_RANK", "0")) + if world_size <= 0: + raise ValueError(f"WORLD_SIZE must be positive, got {world_size}") + total_accum = int(os.environ.get("GRAD_ACCUM_TOTAL", "8")) + if total_accum % world_size != 0: + raise ValueError(f"WORLD_SIZE={world_size} must divide GRAD_ACCUM_TOTAL={total_accum}") + grad_accum_steps = total_accum // world_size + grad_scale = 1.0 / grad_accum_steps + if not torch.cuda.is_available(): + raise RuntimeError("CUDA is required") + device = torch.device("cuda", local_rank) + torch.cuda.set_device(device) + if distributed: + dist.init_process_group(backend="nccl", device_id=device) + dist.barrier() + master_process = rank == 0 + torch.backends.cuda.matmul.allow_tf32 = True + torch.backends.cudnn.allow_tf32 = True + from torch.backends.cuda import enable_cudnn_sdp, enable_flash_sdp, enable_math_sdp, enable_mem_efficient_sdp + enable_cudnn_sdp(False); enable_flash_sdp(True); enable_mem_efficient_sdp(False); enable_math_sdp(False) + logfile = None + if master_process: + os.makedirs("logs", exist_ok=True) + logfile = f"logs/{args.run_id}.txt" + print(logfile) + def log0(msg: str, console: bool = True, flush: bool = False) -> None: + if not master_process: return + if console: print(msg, flush=flush) + if logfile is not None: + with open(logfile, "a", encoding="utf-8") as f: + print(msg, file=f, flush=True) + log0(code, console=False) + log0("=" * 100, console=False) + log0(f"Running Python {sys.version}", console=False) + log0(f"Running PyTorch {torch.__version__}", console=False) + log0(subprocess.run(["nvidia-smi"], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True, check=False).stdout, console=False) + log0("=" * 100, console=False) + random.seed(args.seed); np.random.seed(args.seed); torch.manual_seed(args.seed); torch.cuda.manual_seed_all(args.seed) + if not args.tokenizer_path.endswith(".model"): + raise ValueError(f"Script only setup for SentencePiece .model file: {args.tokenizer_path}") + sp = spm.SentencePieceProcessor(model_file=args.tokenizer_path) + if int(sp.vocab_size()) != args.vocab_size: + raise ValueError(f"VOCAB_SIZE={args.vocab_size} does not match tokenizer vocab_size={int(sp.vocab_size())}") + dataset_dir = Path(args.data_path).resolve() + actual_train_files = len(list(dataset_dir.glob("fineweb_train_*.bin"))) + effective_eval_seq_len = args.eval_seq_len if args.eval_seq_len > 0 else args.train_seq_len + val_seq_len = max(args.train_seq_len, effective_eval_seq_len) + val_tokens = load_validation_tokens(args.val_files, val_seq_len) + base_bytes_lut, has_leading_space_lut, is_boundary_token_lut = build_sentencepiece_luts(sp, args.vocab_size, device) + log0(f"val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path={args.tokenizer_path}") + log0(f"train_loader:dataset:{dataset_dir.name} train_shards:{actual_train_files}") + log0(f"val_loader:shards pattern={args.val_files} tokens:{val_tokens.numel() - 1}") + CastedLinear._qat_enabled = args.qat_enabled + base_model = GPT( + vocab_size=args.vocab_size, num_layers=args.num_layers, model_dim=args.model_dim, + num_heads=args.num_heads, num_kv_heads=args.num_kv_heads, mlp_mult=args.mlp_mult, + tie_embeddings=args.tie_embeddings, tied_embed_init_std=args.tied_embed_init_std, + logit_softcap=args.logit_softcap, rope_base=args.rope_base, qk_gain_init=args.qk_gain_init, + mtp_num_heads=args.mtp_num_heads, mtp_loss_weight=args.mtp_loss_weight, + bigram_vocab_size=args.bigram_vocab_size, bigram_dim=args.bigram_dim, + xsa_last_n=args.xsa_last_n, rope_dims=args.rope_dims, ln_scale=args.ln_scale, + dtg=args.dtg_enabled, ve_enabled=args.ve_enabled, ve_dim=args.ve_dim, ve_layers=args.ve_layers, + gated_attention=args.gated_attention, value_residual=args.value_residual, + ).to(device).bfloat16() + base_model._has_leading_space = has_leading_space_lut + base_model.qo_bank.data = base_model.qo_bank.data.float() + base_model.kv_bank.data = base_model.kv_bank.data.float() + base_model.mlp_up_bank.data = base_model.mlp_up_bank.data.float() + base_model.mlp_down_bank.data = base_model.mlp_down_bank.data.float() + for module in base_model.modules(): + if isinstance(module, CastedLinear): module.float() + restore_low_dim_params_to_fp32(base_model) + compiled_model = torch.compile(base_model, dynamic=False, fullgraph=True) + model = compiled_model + # Freeze dead skip connections: B0→B9 (row 4) and B1→B8 (row 3) are near-zero + with torch.no_grad(): + if base_model.skip_weights.shape[0] >= 5: + base_model.skip_weights.data[3].zero_() # B1→B8 + base_model.skip_weights.data[4].zero_() # B0→B9 + _skip_freeze_mask = torch.ones_like(base_model.skip_weights.data) + if base_model.skip_weights.shape[0] >= 5: + _skip_freeze_mask[3].zero_() + _skip_freeze_mask[4].zero_() + base_model.skip_weights.register_hook(lambda grad: grad * _skip_freeze_mask) + matrix_params = [base_model.qo_bank, base_model.kv_bank, base_model.mlp_up_bank, base_model.mlp_down_bank] + block_named_params = list(base_model.blocks.named_parameters()) + scalar_params = [p for name, p in block_named_params + if p.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS)] + if base_model.skip_weights.numel() > 0: scalar_params.append(base_model.skip_weights) + scalar_params.append(base_model.smear.gate) + if base_model.bigram is not None: scalar_params.append(base_model.bigram.scale) + if base_model.word_start_boost is not None: scalar_params.append(base_model.word_start_boost) + if base_model.ve_shared is not None: + if base_model.ve_shared.proj is not None: scalar_params.append(base_model.ve_shared.proj.weight) + scalar_params.append(base_model.ve_shared.scale) + for s in base_model.ve_layer_scales: scalar_params.append(s) + # exp107: MetaSGD removed (converged to uniform 1.0, +8.6 GB overhead). SAM has no + # learned params — rho is a fixed hyperparameter configured via META_TTT_SAM_RHO. + token_lr = args.tied_embed_lr if args.tie_embeddings else args.embed_lr + tok_params = [{"params": [base_model.tok_emb.weight], "lr": token_lr, "base_lr": token_lr}] + if base_model.bigram is not None: + tok_params.append({"params": [base_model.bigram.embed.weight], "lr": token_lr, "base_lr": token_lr}) + if base_model.bigram.proj is not None: scalar_params.append(base_model.bigram.proj.weight) + if base_model.ve_shared is not None: + tok_params.append({"params": [base_model.ve_shared.embed.weight], "lr": token_lr, "base_lr": token_lr}) + optimizer_tok = torch.optim.AdamW(tok_params, betas=(args.beta1, args.beta2), eps=args.adam_eps, weight_decay=args.adam_wd, fused=True) + optimizer_muon = Muon(matrix_params, lr=args.matrix_lr, momentum=args.muon_momentum, backend_steps=args.muon_backend_steps, weight_decay=args.muon_wd) + for group in optimizer_muon.param_groups: group["base_lr"] = args.matrix_lr + _scalar_groups: list[dict] = [{"params": scalar_params, "lr": args.scalar_lr, "base_lr": args.scalar_lr}] + optimizer_scalar = torch.optim.AdamW(_scalar_groups, + betas=(args.beta1, args.beta2), eps=args.adam_eps, weight_decay=args.adam_wd, fused=True) + replicated_params = list(optimizer_tok.param_groups[0]["params"]) + for pg in optimizer_tok.param_groups[1:]: replicated_params.extend(pg["params"]) + replicated_params.extend(scalar_params) + optimizer_head = None + if base_model.lm_head is not None: + optimizer_head = torch.optim.Adam([{"params": [base_model.lm_head.weight], "lr": args.head_lr, "base_lr": args.head_lr}], + betas=(args.beta1, args.beta2), eps=args.adam_eps, fused=True) + replicated_params.append(base_model.lm_head.weight) + optimizers: list[torch.optim.Optimizer] = [optimizer_tok, optimizer_muon, optimizer_scalar] + if optimizer_head is not None: optimizers.append(optimizer_head) + n_params = sum(p.numel() for p in base_model.parameters()) + mtp_params = sum(p.numel() for p in base_model.mtp_heads.parameters()) + log0(f"model_params:{n_params}") + log0(f"mtp_num_heads:{args.mtp_num_heads} mtp_loss_weight:{args.mtp_loss_weight} mtp_params:{mtp_params}") + xsa_layers = [i for i, b in enumerate(base_model.blocks) if b.attn.use_xsa] + log0(f"XSA:last_{args.xsa_last_n} active_layers:{xsa_layers}") + log0(f"world_size:{world_size} grad_accum_steps:{grad_accum_steps}") + log0(f"seed:{args.seed}") + train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) + def zero_grad_all() -> None: + for opt in optimizers: opt.zero_grad(set_to_none=True) + max_wallclock_ms = 1000.0 * args.max_wallclock_seconds if args.max_wallclock_seconds > 0 else None + _warmdown_triggered_ms: float | None = None + _loss_ema: float | None = None + _loss_ema_prev: float | None = None + def lr_mul(step: int, elapsed_ms: float) -> float: + nonlocal _warmdown_triggered_ms + if args.warmdown_iters <= 0: return 1.0 + # Guard: don't compute warmdown in early steps where step_ms estimate is unreliable + if step < args.warmup_steps + 5: return 1.0 + step_ms = elapsed_ms / max(step, 1) + warmdown_duration_ms = args.warmdown_iters * step_ms + if _warmdown_triggered_ms is not None: + warmdown_start_ms = _warmdown_triggered_ms + elif max_wallclock_ms is not None: + warmdown_start_ms = max(max_wallclock_ms - warmdown_duration_ms, 0.0) + else: + warmdown_start_ms = max(args.iterations - args.warmdown_iters, 0) * step_ms + if elapsed_ms >= warmdown_start_ms: + total_warmdown_ms = (max_wallclock_ms - warmdown_start_ms) if max_wallclock_ms else warmdown_duration_ms + progress_ms = elapsed_ms - warmdown_start_ms + progress = min(progress_ms / max(total_warmdown_ms, 1e-9), 1.0) + return 0.5 * (1.0 + math.cos(math.pi * progress)) + return 1.0 + if args.warmup_steps > 0: + initial_model_state = {name: tensor.detach().cpu().clone() for name, tensor in base_model.state_dict().items()} + initial_optimizer_states = [copy.deepcopy(opt.state_dict()) for opt in optimizers] + model.train() + for warmup_step in range(args.warmup_steps): + zero_grad_all() + for micro_step in range(grad_accum_steps): + x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + warmup_loss = model(x, y) + (warmup_loss * grad_scale).backward() + if distributed: + for p in base_model.parameters(): + if p.grad is not None: dist.all_reduce(p.grad, op=dist.ReduceOp.AVG) + for opt in optimizers: opt.step() + zero_grad_all() + if args.warmup_steps <= 20 or (warmup_step + 1) % 10 == 0 or warmup_step + 1 == args.warmup_steps: + log0(f"warmup_step:{warmup_step + 1}/{args.warmup_steps}") + base_model.load_state_dict(initial_model_state, strict=True) + for opt, state in zip(optimizers, initial_optimizer_states, strict=True): opt.load_state_dict(state) + zero_grad_all() + train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) + swa_state: dict[str, Tensor] | None = None + swa_count = 0 + from collections import deque + lawa_queue: deque[dict[str, Tensor]] = deque(maxlen=args.lawa_k) + ema_pairs: list[tuple[Tensor, Tensor]] | None = None + ema_state: dict[str, Tensor] | None = None + ema_update_every = int(os.environ.get("EMA_UPDATE_EVERY", "10")) + if args.ema_enabled: + ema_state = {name: t.detach().float().clone() for name, t in base_model.state_dict().items()} + ema_pairs = [(ema_state[name], param) for name, param in base_model.named_parameters()] + for name, buf in base_model.named_buffers(): + if name in ema_state: ema_pairs.append((ema_state[name], buf)) + ema_decay_eff = args.ema_decay ** ema_update_every + log0(f"ema:initialized decay={args.ema_decay} update_every={ema_update_every} decay_eff={ema_decay_eff:.6f}") + ema_stream = torch.cuda.Stream() if ema_pairs is not None else None + training_time_ms = 0.0 + stop_after_step: int | None = None + torch.cuda.synchronize() + t0 = time.perf_counter() + step = 0 + while True: + last_step = step == args.iterations or (stop_after_step is not None and step >= stop_after_step) + should_validate = last_step or (args.val_loss_every > 0 and step % args.val_loss_every == 0) + if should_validate: + torch.cuda.synchronize() + training_time_ms += 1000.0 * (time.perf_counter() - t0) + val_loss, val_bpb = eval_val(args, model, rank, world_size, device, grad_accum_steps, + val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut) + log0(f"step:{step}/{args.iterations} val_loss:{val_loss:.4f} val_bpb:{val_bpb:.4f} " + f"train_time:{training_time_ms:.0f}ms step_avg:{training_time_ms / max(step, 1):.2f}ms") + torch.cuda.synchronize(); t0 = time.perf_counter() + if last_step: + if stop_after_step is not None and step < args.iterations: + log0(f"stopping_early: wallclock_cap train_time:{training_time_ms:.0f}ms step:{step}/{args.iterations}") + break + elapsed_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) + scale = lr_mul(step, elapsed_ms) + if args.late_qat_threshold > 0 and scale < args.late_qat_threshold and not CastedLinear._qat_enabled: + CastedLinear._qat_enabled = True + log0(f"late_qat:enabled step:{step} scale:{scale:.4f}") + zero_grad_all() + train_loss = torch.zeros((), device=device) + next_xy = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) + for micro_step in range(grad_accum_steps): + x, y = next_xy + if micro_step < grad_accum_steps - 1: + next_xy = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + loss = model(x, y) + train_loss += loss.detach() + (loss * grad_scale).backward() + train_loss /= grad_accum_steps + # Meta-TTT: FOMAML inner/outer loop on the last micro-batch + # Disabled during warmdown (scale < 0.5) — model should converge, not explore + if args.meta_ttt_enabled and step % args.meta_ttt_every == 0 and scale > 0.5: + meta_outer = meta_ttt_step(base_model, x, y, args, grad_scale=grad_scale) + if args.adaptive_warmdown and _warmdown_triggered_ms is None and step >= args.adaptive_warmdown_min_steps and step % 100 == 0: + tl = train_loss.item() + if _loss_ema is None: + _loss_ema = tl; _loss_ema_prev = tl + else: + _loss_ema = args.adaptive_warmdown_ema * _loss_ema + (1.0 - args.adaptive_warmdown_ema) * tl + improvement = (_loss_ema_prev - _loss_ema) if _loss_ema_prev is not None else 1.0 + if improvement < args.adaptive_warmdown_threshold: + elapsed_now = training_time_ms + 1000.0 * (time.perf_counter() - t0) + _warmdown_triggered_ms = elapsed_now + log0(f"adaptive_warmdown:triggered step:{step} loss_ema:{_loss_ema:.6f} improvement:{improvement:.6f}") + _loss_ema_prev = _loss_ema + if step < args.muon_momentum_warmup_steps: + frac = step / args.muon_momentum_warmup_steps if args.muon_momentum_warmup_steps > 0 else 1.0 + muon_momentum = (1 - frac) * args.muon_momentum_warmup_start + frac * args.muon_momentum + elif args.momentum_cyclic: + period = args.momentum_cycle_period * 2 + pos = (step % period) / period + if pos < 0.5: + muon_momentum = args.momentum_min + (args.momentum_max - args.momentum_min) * (pos * 2) + else: + muon_momentum = args.momentum_max - (args.momentum_max - args.momentum_min) * ((pos - 0.5) * 2) + else: + muon_momentum = args.muon_momentum + for group in optimizer_muon.param_groups: group["momentum"] = muon_momentum + for opt in optimizers: + for group in opt.param_groups: group["lr"] = group["base_lr"] * scale + if args.grad_clip_norm > 0: + torch.nn.utils.clip_grad_norm_(base_model.parameters(), args.grad_clip_norm) + optimizer_muon.launch_reduce_scatters() + if distributed: + for p in replicated_params: + if p.grad is not None: dist.all_reduce(p.grad, op=dist.ReduceOp.AVG) + optimizer_tok.step(); optimizer_scalar.step() + if optimizer_head is not None: optimizer_head.step() + optimizer_muon.step() + zero_grad_all() + step += 1 + approx_training_time_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) + if ema_pairs is not None and step % ema_update_every == 0: + ema_stream.wait_stream(torch.cuda.current_stream()) + with torch.cuda.stream(ema_stream), torch.no_grad(): + for ema_t, param_t in ema_pairs: + ema_t.mul_(ema_decay_eff).add_(param_t.detach().float(), alpha=1.0 - ema_decay_eff) + if args.swa_enabled and scale < 0.2 and step % args.swa_every == 0: + if swa_state is None: + swa_state = {name: t.detach().cpu().clone() for name, t in base_model.state_dict().items()} + swa_count = 1; log0(f"swa:start step:{step}") + else: + for name, t in base_model.state_dict().items(): swa_state[name] += t.detach().cpu() + swa_count += 1 + if args.lawa_enabled and step % args.lawa_freq == 0: + lawa_queue.append({name: t.detach().cpu().clone() for name, t in base_model.state_dict().items()}) + should_log_train = args.train_log_every > 0 and (step <= 10 or step % args.train_log_every == 0 or stop_after_step is not None) + if should_log_train: + log0(f"step:{step}/{args.iterations} train_loss:{train_loss.item():.4f} " + f"train_time:{approx_training_time_ms:.0f}ms step_avg:{approx_training_time_ms / step:.2f}ms") + reached_cap = max_wallclock_ms is not None and approx_training_time_ms >= max_wallclock_ms + if distributed and max_wallclock_ms is not None: + reached_cap_tensor = torch.tensor(int(reached_cap), device=device) + dist.all_reduce(reached_cap_tensor, op=dist.ReduceOp.MAX) + reached_cap = bool(reached_cap_tensor.item()) + if stop_after_step is None and reached_cap: stop_after_step = step + log0(f"peak memory allocated: {torch.cuda.max_memory_allocated() // 1024 // 1024} MiB " + f"reserved: {torch.cuda.max_memory_reserved() // 1024 // 1024} MiB") + if ema_stream is not None: ema_stream.synchronize() + if args.lawa_enabled and len(lawa_queue) > 1: + log0(f"lawa:applying LAWA averaging k={len(lawa_queue)}") + current_state = base_model.state_dict() + avg_state = {name: torch.zeros(t.shape, dtype=torch.float32, device='cpu') for name, t in current_state.items()} + for snap in lawa_queue: + for name in avg_state: avg_state[name] += snap[name].float() + for name in avg_state: + avg_state[name] /= len(lawa_queue) + avg_state[name] = avg_state[name].to(dtype=current_state[name].dtype) + base_model.load_state_dict(avg_state, strict=True) + elif ema_state is not None: + log0("ema:applying EMA weights") + current_state = base_model.state_dict() + avg_state = {name: t.to(dtype=current_state[name].dtype) for name, t in ema_state.items()} + base_model.load_state_dict(avg_state, strict=True) + torch.cuda.synchronize(); t_diag = time.perf_counter() + diag_val_loss, diag_val_bpb = eval_val(args, compiled_model, rank, world_size, device, grad_accum_steps, + val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut) + torch.cuda.synchronize() + log0(f"DIAGNOSTIC post_ema val_loss:{diag_val_loss:.4f} val_bpb:{diag_val_bpb:.4f} " + f"eval_time:{1000.0 * (time.perf_counter() - t_diag):.0f}ms") + full_state_dict = base_model.state_dict() + # Drop MTP heads from the exported checkpoint (not used at inference/TTT time). + def _drop_from_export(k: str) -> bool: + return "mtp_heads" in k + export_sd = {k: v for k, v in full_state_dict.items() if not _drop_from_export(k)} + excluded_mtp = sum(int(t.numel()) for k, t in full_state_dict.items() if "mtp_heads" in k) + if excluded_mtp > 0: log0(f"export_excluding_mtp_params:{excluded_mtp}") + if master_process: + exp_name = args.run_id.rsplit("_seed", 1)[0] if "_seed" in args.run_id else args.run_id + torch.save(export_sd, "final_model.pt") + model_bytes = os.path.getsize("final_model.pt") + code_bytes = len(code.encode("utf-8")) + log0(f"Serialized model: {model_bytes} bytes") + log0(f"Code size: {code_bytes} bytes") + sd_cpu = {k: v.detach().cpu() for k, v in export_sd.items()} + unbanked_sd = _unbank_state_dict(sd_cpu, args.num_layers) + log0(f"gptq:building non-banked model for Hessian collection...") + hessian_model = _HessianGPT( + vocab_size=args.vocab_size, num_layers=args.num_layers, model_dim=args.model_dim, + num_heads=args.num_heads, num_kv_heads=args.num_kv_heads, mlp_mult=args.mlp_mult, + tie_embeddings=args.tie_embeddings, logit_softcap=args.logit_softcap, + rope_base=args.rope_base, qk_gain_init=args.qk_gain_init, + bigram_vocab_size=args.bigram_vocab_size, bigram_dim=args.bigram_dim, + xsa_last_n=args.xsa_last_n, rope_dims=args.rope_dims, ln_scale=args.ln_scale, + ve_enabled=args.ve_enabled, ve_dim=args.ve_dim, ve_layers=args.ve_layers, + ).to(device).bfloat16() + hessian_model._has_leading_space = has_leading_space_lut + for m in hessian_model.modules(): + if isinstance(m, CastedLinear): m.float() + restore_low_dim_params_to_fp32(hessian_model) + hessian_model.load_state_dict({k: v.to(device) for k, v in unbanked_sd.items() if k in hessian_model.state_dict()}, strict=False) + log0("gptq:generating autoregressive calibration data (64 seqs x 2048 tokens, temp=0.8)...") + base_model.load_state_dict(export_sd, strict=False) + t_gen = time.perf_counter() + ar_tokens = generate_autoregressive_calib(base_model, device, num_seqs=64, seq_len=args.train_seq_len, + vocab_size=args.vocab_size, temperature=0.8, batch_size=8, seed=args.seed) + log0(f"gptq:generated {len(ar_tokens)} sequences in {time.perf_counter()-t_gen:.1f}s") + log0("gptq:collecting hessians from autoregressive data...") + hessians = collect_hessians_from_tokens(hessian_model, ar_tokens, device) + log0(f"gptq:collected hessians for {len(hessians)} layers (AR self-gen)") + del ar_tokens; del hessian_model; torch.cuda.empty_cache() + quant_result, quant_meta = mixed_quantize_int6(unbanked_sd, {"mlp", "attn"}, hessians=hessians) + target_mb = float(os.environ.get("TARGET_MB", "15.9")) + code_bytes_est = len(code.encode("utf-8")) + ones_info = [] + for name, info in quant_meta.items(): + if not (isinstance(info, dict) and info.get("type") == "int6"): continue + qk, sk = name + ".q", name + ".scale" + if qk not in quant_result or sk not in quant_result: continue + q, s = quant_result[qk], quant_result[sk] + if s.ndim > 0: + ones_mask = (q.abs() == 1) + if ones_mask.any(): + row_idx = torch.arange(q.shape[0]).unsqueeze(1).expand_as(q)[ones_mask] + flat_idx = torch.arange(q.numel()).reshape(q.shape)[ones_mask] + errors = s.float()[row_idx].pow(2) + for fi, err in zip(flat_idx.tolist(), errors.tolist()): + ones_info.append((qk, fi, err)) + if ones_info: + ones_info.sort(key=lambda x: x[2]) + def _try_prune(n): + tmp = {k: v.clone() for k, v in quant_result.items()} + for i in range(min(n, len(ones_info))): + tmp[ones_info[i][0]].view(-1)[ones_info[i][1]] = 0 + buf = io.BytesIO(); torch.save({"w": tmp, "m": quant_meta}, buf) + return len(lzma.compress(buf.getvalue(), preset=9)) + code_bytes_est, tmp + no_sz, _ = _try_prune(0) + target_bytes = int(target_mb * 1024 * 1024) + log0(f"selective_prune: {len(ones_info)} +/-1 candidates, unpruned={no_sz/(1024*1024):.2f}MB target={target_mb}MB") + if no_sz <= target_bytes: + log0("selective_prune: already fits, no pruning needed") + else: + full_sz, _ = _try_prune(len(ones_info)) + log0(f"selective_prune: full +/-1 prune={full_sz/(1024*1024):.2f}MB") + if full_sz > target_bytes: + log0("selective_prune: even full prune not enough, applying all") + _, quant_result = _try_prune(len(ones_info)) + else: + lo, hi = 0, len(ones_info) + while lo < hi: + mid = (lo + hi) // 2 + sz, _ = _try_prune(mid) + if sz <= target_bytes: hi = mid + else: lo = mid + 1 + log0(f"selective_prune: pruning {lo}/{len(ones_info)} +/-1 values ({100*lo/len(ones_info):.1f}%) to fit {target_mb}MB") + _, quant_result = _try_prune(lo) + quant_buf = io.BytesIO() + torch.save({"w": quant_result, "m": quant_meta}, quant_buf) + quant_raw = quant_buf.getvalue() + quant_blob = lzma.compress(quant_raw, preset=9) + if master_process: + with open("final_model.int6.ptz", "wb") as f: f.write(quant_blob) + quant_file_bytes = len(quant_blob) + code_bytes = len(code.encode("utf-8")) + log0(f"Serialized model int6+lzma: {quant_file_bytes} bytes") + log0(f"Total submission size int6+lzma: {quant_file_bytes + code_bytes} bytes") + if distributed: dist.barrier() + with open("final_model.int6.ptz", "rb") as f: quant_blob_disk = f.read() + quant_state = torch.load(io.BytesIO(lzma.decompress(quant_blob_disk)), map_location="cpu") + deq_unbanked = dequantize_mixed_int6(quant_state["w"], quant_state["m"], unbanked_sd) + deq_state = _rebank_state_dict(deq_unbanked, args.num_layers, sd_cpu) + # They are training-only params (only used in meta_ttt_step's inner-SGD update) + # so they never influence the eval forward — but eval_model has them as + # nn.Parameters from GPT.__init__, so strict=True load below requires them + # exp107: MetaSGD removed — no meta_sgd_* keys to re-inject. + eval_model = GPT( + vocab_size=args.vocab_size, num_layers=args.num_layers, model_dim=args.model_dim, + num_heads=args.num_heads, num_kv_heads=args.num_kv_heads, mlp_mult=args.mlp_mult, + tie_embeddings=args.tie_embeddings, tied_embed_init_std=args.tied_embed_init_std, + logit_softcap=args.logit_softcap, rope_base=args.rope_base, qk_gain_init=args.qk_gain_init, + mtp_num_heads=0, mtp_loss_weight=0.0, + bigram_vocab_size=args.bigram_vocab_size, bigram_dim=args.bigram_dim, + xsa_last_n=args.xsa_last_n, rope_dims=args.rope_dims, ln_scale=args.ln_scale, dtg=args.dtg_enabled, + ve_enabled=args.ve_enabled, ve_dim=args.ve_dim, ve_layers=args.ve_layers, + gated_attention=args.gated_attention, value_residual=args.value_residual, + ).to(device).bfloat16() + eval_model._has_leading_space = has_leading_space_lut + eval_model.qo_bank.data = eval_model.qo_bank.data.float() + eval_model.kv_bank.data = eval_model.kv_bank.data.float() + eval_model.mlp_up_bank.data = eval_model.mlp_up_bank.data.float() + eval_model.mlp_down_bank.data = eval_model.mlp_down_bank.data.float() + for m in eval_model.modules(): + if isinstance(m, CastedLinear): m.float() + restore_low_dim_params_to_fp32(eval_model) + eval_model.load_state_dict(deq_state, strict=True) + compiled_eval = torch.compile(eval_model, dynamic=False, fullgraph=True) + torch.cuda.synchronize(); t_qeval = time.perf_counter() + q_val_loss, q_val_bpb = eval_val(args, compiled_eval, rank, world_size, device, grad_accum_steps, + val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, + eval_seq_len=effective_eval_seq_len) + torch.cuda.synchronize() + log0(f"final_int6_roundtrip val_loss:{q_val_loss:.4f} val_bpb:{q_val_bpb:.4f} " + f"eval_time:{1000.0 * (time.perf_counter() - t_qeval):.0f}ms") + log0(f"final_int6_roundtrip_exact val_loss:{q_val_loss:.8f} val_bpb:{q_val_bpb:.8f}") + # TTT-only evaluation (no standalone sliding window — meta-TTT model is designed for TTT) + if args.ttt_enabled and args.eval_stride > 0: + log0(f"\n{'='*60}", flush=True) + log0(f"STARTING TTT (Test-Time Training)", flush=True) + log0(f"{'='*60}", flush=True) + eval_model.load_state_dict(deq_state, strict=True) + ttt_loss, ttt_bpb = eval_val_sliding_ttt(args, eval_model, rank, world_size, device, + val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, + stride=args.eval_stride, log0=log0) + log0(f"legal_ttt val_loss:{ttt_loss:.4f} val_bpb:{ttt_bpb:.4f}", flush=True) + log0(f"legal_ttt_exact val_loss:{ttt_loss:.8f} val_bpb:{ttt_bpb:.8f}", flush=True) + if distributed: dist.destroy_process_group() + +if __name__ == "__main__": + main()