diff --git a/records/track_non_record_16mb/2026-03-31_11L_NativeFlowMatcher_LegalTTT/README.md b/records/track_non_record_16mb/2026-03-31_11L_NativeFlowMatcher_LegalTTT/README.md new file mode 100644 index 0000000000..27a292b9a8 --- /dev/null +++ b/records/track_non_record_16mb/2026-03-31_11L_NativeFlowMatcher_LegalTTT/README.md @@ -0,0 +1,320 @@ +# Non-Record: 11L NativeFlowMatcher + Legal Score-First TTT + +**val_bpb: 1.11991** (seed=42, sliding window, stride=64, int6/int5 quantized, legal TTT) +**3-seed mean sliding BPB (no TTT): 1.12252** ± 0.00151 | **3-seed mean legal TTT: 1.11928** ± 0.00146 + +Non-record submission exploring **NativeFlowMatcher (NFM)** — a 393K-parameter OT-CFM (Optimal Transport Conditional Flow Matching) velocity network that applies gated hidden-state correction, jointly trained with the AR objective. Combined with legal score-first TTT for additional compression. + +> **Update (2026-04-01):** All ablation studies complete. Three-seed legal TTT evals finished. NFM does not improve BPB vs matched base — see Ablation Studies section for full 2×2 matrix, loss weight sweep, and hidden dim sweep. + +## Architecture + +| Component | Value | +|-----------|-------| +| Layers | 11 | +| Model dim | 512 | +| Attention heads | 8 (4 KV heads, GQA) | +| MLP expansion | 3× (1536 hidden) | +| Vocab size | 1024 (SentencePiece BPE) | +| Sequence length | 2048 (training), 1024 (eval window) | +| Total params | 27,530,952 | +| NFM params | 393,729 | +| Tied embeddings | Yes | +| BigramHash | vocab=4096, dim=128 | +| Positional encoding | Partial RoPE (16 dims, base 10000) | +| XSA | All 11 layers (XSA_LAST_N=11) | +| Value residual | Yes | +| Gated attention | Yes | +| Activation | LeakyReLU(0.5)² | +| Logit softcap | 30.0 | +| LN scale | 1 | +| QK gain init | 1.5 | +| EMA | decay=0.997, applied at end | + +### NativeFlowMatcher (NFM) + +NFM is a conditional flow matching module inserted after the final LayerNorm in the transformer stack, before the language model head. It learns a velocity field over hidden states using OT-CFM: + +- **Time embedding:** Sinusoidal positional encoding of scalar t → projected to 256-dim via Linear(512,256) + GELU +- **Velocity network:** Linear(512,256) + time-conditioning (additive) + GELU → Linear(256,512, no bias) +- **Gate:** Scalar parameter initialized at −5.0 (sigmoid ≈ 0.007), learned during training +- **Training loss:** MSE between predicted velocity v(x_t, t) and OT target velocity (x − z), where x_t = (1−t)·z + t·x with z ~ N(0,I). Weighted by `NATIVE_FLOW_LOSS_WEIGHT=0.1` added to AR cross-entropy loss. +- **Inference:** Single Euler step at t=1 on clean input x, gated: `x_out = x + sigmoid(gate) · v(x, t=1)` + +The NFM velocity network has zero-initialized output weights and a near-zero initial gate, ensuring the correction starts negligible and grows only as the velocity field learns useful structure. + +## Results + +### Three-Seed Reproducibility (Training-Time Eval) + +All three seeds trained identically: 7,000 steps, 1×A100 PCIe 40GB, same architecture and optimizer config. + +| Seed | SLURM Job | Training val_bpb | Roundtrip BPB | Sliding (no TTT) BPB | Legal TTT BPB | Artifact Bytes | +|------|-----------|-----------------|---------------|----------------------|---------------|----------------| +| 42 | 55342820 | 1.1380 | 1.14679034 | **1.12311579** | **1.11990650** | 15,745,776 | +| 1337 | 55398556 | 1.1385 | 1.14729126 | **1.12366996** | **1.12032079** | 15,736,933 | +| 2025 | 55398557 | 1.1359 | 1.14444585 | **1.12077485** | **1.11761299** | 15,745,950 | +| **Mean** | | **1.1375** | **1.14617582** | **1.12252020** | **1.11928009** | — | +| **Std** | | **0.0014** | **0.00157** | **0.00151** | **0.00146** | — | + +> Legal TTT evaluation complete for all seeds. + +### Primary (Seed=42, This Submission) + +| Evaluation | val_loss | val_bpb | +|------------|----------|---------| +| Roundtrip (dequantized) | 1.93630746 | 1.14679034 | +| Sliding window (stride=64), no TTT | 1.89632895 | 1.12311579 | +| **Sliding window (stride=64), legal TTT** | **1.89091021** | **1.11990650** | + +Legal TTT improvement: **−0.00321 BPB** (from 1.12312 → 1.11991) + +### Supplementary Comparison + +These are reference results from related configurations, included for context. All use the same base architecture (PR #940 stack) and evaluation protocol. + +| Config | Steps | Params | No-TTT Sliding BPB | Legal TTT Sliding BPB | +|--------|-------|--------|---------------------|----------------------| +| Base (no refiners) 20k | 20,000 | ~27.1M | 1.10050 | 1.09292 | +| FlowRefiner 20k | 20,000 | ~27.2M | 1.10002 | 1.09279 | +| **NFM 7k (this submission)** | **7,000** | **27.5M** | **1.12312** | **1.11991** | +| E2E TTT + FlowRefiner 7k | 7,000 | 28.3M | — | 1.12418 | + +> **Update:** The E2E TTT + FlowRefiner legal TTT eval (SLURM 55398555) completed with val_bpb=1.12418. Previous submission had partial data (truncated at chunk 1271/1893). + +> **Important context:** The 20k-step results (base, flow) use a longer training schedule (20,000 steps vs 7,000). Direct BPB comparison between 7k and 20k is not meaningful for architecture evaluation. The NFM contribution should be assessed relative to the base architecture at matched step count, but no 7k-step base-only run exists in this evaluation set. The training-time val_bpb at step 7000 was 1.1380 (pre-quantization, non-sliding, no TTT). + +## Quantization + +| Property | Value | +|----------|-------| +| Base scheme | Per-row int8 (2D weights) / per-tensor int8 (1D) | +| MLP layers 0–4, 7–10 | int6 (GPTQ-lite) | +| MLP layers 5–6 | int5 (auto-downgrade fallback to fit 16MB) | +| Compression | zstd level 16 | +| Quantized model | 15,630,744 bytes | +| Code (`train_gpt.py`) | 115,032 bytes | +| **Total artifact** | **15,745,776 bytes** (headroom: 254,224 bytes) | + +The auto-downgrade mechanism progressively applies int5 quantization to middle MLP layers (starting from layer 5 outward) until the compressed artifact fits within the 16MB budget. + +## Training + +| Property | Value | +|----------|-------| +| Hardware | 1× A100 PCIe 40GB | +| Steps | 7,000 | +| Wallclock | 13,879 seconds (3.86 hours) | +| Step average | 1,982.77 ms | +| Training tokens | ~5.51B (7000 × 786432) | +| Sequence length | 2048 | +| Optimizer | Muon (matrix) + Adam (scalars/embeddings) | +| Matrix LR | 0.025 | +| Scalar LR | 0.025 | +| Muon weight decay | 0.04 | +| Adam weight decay | 0.04 | +| Gradient clip | 0.3 | +| Warmup | 20 steps | +| Warmdown | 2,800 steps | +| Seed | 42 | +| SLURM job | 55342820 | +| Peak GPU memory | 25,832 MiB | + +### Training Trajectory (All Seeds) + +| Step | Seed 42 | Seed 1337 | Seed 2025 | +|------|---------|-----------|-----------| +| 0 | 4.1055 | 4.1175 | 4.1065 | +| 500 | 1.3813 | 1.3849 | 1.3859 | +| 1000 | 1.3058 | 1.3100 | 1.3070 | +| 2000 | 1.2499 | 1.2497 | 1.2490 | +| 3000 | 1.2283 | 1.2285 | 1.2269 | +| 4000 | 1.2199 | 1.2205 | 1.2190 | +| 5000 | 1.1975 | 1.1983 | 1.1958 | +| 6000 | 1.1707 | 1.1710 | 1.1686 | +| 6500 | 1.1527 | 1.1532 | 1.1508 | +| 7000 | 1.1380 | 1.1385 | 1.1359 | + +## Legal TTT Configuration + +Score-first test-time training that complies with the rule that training may only occur on tokens that have already been scored (no future information leakage). + +| Parameter | Value | +|-----------|-------| +| Optimizer | SGD with momentum=0.9 | +| Learning rate | 0.002 | +| Epochs per chunk | 10 | +| Chunk size | 32,768 tokens | +| Frozen blocks | 2 (first 2 transformer layers frozen) | +| Gradient clip | 1.0 | +| Total chunks | 1,893 | +| TTT eval time | ~7,190 seconds (~2.0 hours) | +| SLURM job | 55375245 | + +## Provenance + +All artifacts trace back to verifiable SLURM jobs: + +### Seed 42 (Primary) +1. **Training:** SLURM job 55342820 → `runs/nflow_55342820/models/final_model_pr940_nflow_55342820.pt` +2. **Eval (no TTT):** SLURM job 55375246 → sliding BPB = 1.12312, artifact = 15,745,776 bytes +3. **Eval (legal TTT):** SLURM job 55375245 → sliding BPB = 1.11991, artifact = 15,745,776 bytes +4. **Submitted model:** `final_model.int6.ptz` is the quantized+compressed artifact from eval job 55375245 + +### Seed 1337 (Reproducibility) +1. **Training:** SLURM job 55398556 → `runs/nflow_s1337_55398556/models/final_model_pr940_nflow_s1337_55398556.pt` +2. **Training-time sliding BPB (no TTT):** 1.12367, artifact = 15,736,933 bytes +3. **Eval (legal TTT):** SLURM job 55411651 → sliding BPB = 1.12032 +4. **Eval (no TTT):** SLURM job 55411652 → sliding BPB = 1.12367 + +### Seed 2025 (Reproducibility) +1. **Training:** SLURM job 55398557 → `runs/nflow_s2025_55398557/models/final_model_pr940_nflow_s2025_55398557.pt` +2. **Training-time sliding BPB (no TTT):** 1.12077, artifact = 15,745,950 bytes +3. **Eval (legal TTT):** SLURM job 55411653 → sliding BPB = 1.11761 +4. **Eval (no TTT):** SLURM job 55411654 → sliding BPB = 1.12077 + +### Supplementary +5. **E2E TTT + FlowRefiner eval (complete):** SLURM job 55398555 → legal TTT BPB = 1.12418 + +The training SLURM scripts and evaluation SLURM scripts for all seeds are included in `supplementary/` for full reproducibility. + +> **Code size note:** The submitted `train_gpt.py` (115,032 bytes) reflects the version used at evaluation time. The training log (`train.log`) reports a code size of 104,738 bytes, reflecting the version at training time. The code evolved between training and evaluation but the model checkpoint is unchanged. Supplementary SLURM scripts reference the working filename `train_gpt_pr940.py`; rename to `train_gpt.py` for reproduction. + +## Ablation Studies + +Ablation studies isolating the NFM contribution and exploring hyperparameter sensitivity. All runs use seed=42, 7,000 steps, identical architecture and optimizer settings. + +### Three-Seed Reproducibility + +Training completed for all three seeds. Sliding window (no TTT) results from training-time eval: + +| Seed | SLURM Job | Training val_bpb | Sliding BPB (no TTT) | +|------|-----------|-----------------|---------------------| +| 42 | 55342820 | 1.1380 | 1.12312 | +| 1337 | 55398556 | 1.1385 | 1.12367 | +| 2025 | 55398557 | 1.1359 | 1.12077 | +| **Mean ± Std** | | **1.1375 ± 0.0014** | **1.12252 ± 0.00151** | + +Legal TTT evaluation jobs submitted: 55411651 (s1337, **complete: 1.12032**), 55411653 (s2025, **complete: 1.11761**). + +### 2×2 Matrix: NFM × TTT + +Isolates the NFM and legal-TTT contributions independently. All runs use seed=42. + +| Configuration | Params | No TTT (BPB) | Legal TTT (BPB) | Δ (TTT effect) | +|---------------|--------|--------------|-----------------|------------------| +| Base (no NFM) | 27,137,223 | 1.12106 | 1.11861 | −0.00245 | +| NFM (hd=256, lw=0.1) | 27,530,952 | 1.12312 | 1.11991 | −0.00321 | +| **Δ (NFM effect)** | **+393,729** | **+0.00206** | **+0.00130** | — | + +**NFM hurts by +0.00206 BPB (no TTT) or +0.00130 BPB (with TTT).** The extra 393K parameters do not improve compression. Base ablation: SLURM 55398693 (train), 55398694 (eval no-TTT), 55398695 (eval TTT). + +### Loss Weight Sweep (hidden_dim=256) + +Explores the balance between NFM auxiliary loss and AR cross-entropy loss. + +| loss_weight | No TTT (BPB) | Δ vs base | +|-------------|--------------|----------| +| 0.01 | 1.12344 | +0.00238 | +| 0.05 | 1.12294 | +0.00188 | +| **0.10 (default)** | **1.12312** | **+0.00206** | +| 0.20 | 1.12368 | +0.00262 | + +Best loss weight is 0.05, but still +0.00188 BPB worse than base (1.12106). + +### Hidden Dim Sweep (loss_weight=0.1) + +Explores the capacity of the NFM velocity network. + +| hidden_dim | NFM Params | Total Params | No TTT (BPB) | Δ vs base | +|------------|------------|--------------|--------------|----------| +| 128 | ~197K | 27,334,088 | 1.12228 | +0.00122 | +| **256 (default)** | **393,729** | **27,530,952** | **1.12312** | **+0.00206** | +| 512 | ~787K | 27,924,680 | 1.12219 | +0.00113 | + +Best hidden dim is 512, but still +0.00113 BPB worse than base (1.12106). Increasing NFM capacity does not help. + +> **Conclusion:** NFM consistently hurts across all configurations tested. The auxiliary parameters are better allocated to the main AR model. + +## Limitations & Conclusions + +1. **NFM does not improve val_bpb.** Across all configurations tested (3 loss weights × 3 hidden dims), NFM consistently hurts by +0.001 to +0.003 BPB vs the matched base. The auxiliary parameters are better spent on the main AR model. + +2. **Three-seed reproducibility achieved:** No-TTT mean = 1.12252 ± 0.00151, legal TTT mean = 1.11928 ± 0.00146. + +3. **Non-competitive BPB:** The best result (1.11991) is above the current leaderboard SOTA. This submission documents the NFM negative result and ablation methodology. + +4. **TTT interaction:** NFM shows slightly larger TTT gains (−0.00321) than base (−0.00245), but the absolute score with TTT is still worse than base+TTT (1.11991 vs 1.11861). + +5. **E2E TTT + FlowRefiner eval completed:** SLURM job 55398555 completed with legal TTT BPB = 1.12418. + +## Reproduction + +```bash +# Training (single GPU, ~4 hours) +# See supplementary/slurm_pr940_nflow_7k.sh for full env vars +export NATIVE_FLOW_ENABLED=1 +export NATIVE_FLOW_HIDDEN_DIM=256 +export NATIVE_FLOW_INIT_SCALE=0.01 +export NATIVE_FLOW_LOSS_WEIGHT=0.1 +export NUM_LAYERS=11 MODEL_DIM=512 NUM_HEADS=8 NUM_KV_HEADS=4 MLP_MULT=3 +export ITERATIONS=7000 SEED=42 +torchrun --standalone --nproc_per_node=1 train_gpt.py + +# Evaluation with legal TTT (~2 hours) +# See supplementary/slurm_eval_nflow7k_legal_ttt.sh for full env vars +export EVAL_ONLY=/path/to/final_model.pt +export TTT_ENABLED=1 LEGAL_TTT=1 +export TTT_LR=0.002 TTT_EPOCHS=10 TTT_FREEZE_BLOCKS=2 +export TTT_CHUNK_TOKENS=32768 TTT_OPTIMIZER=sgd TTT_MOMENTUM=0.9 +export EVAL_STRIDE=64 +torchrun --standalone --nproc_per_node=1 train_gpt.py +``` + +## Credits + +This submission builds on the work of many contributors to the parameter golf contest: + +| Component | Source | Author(s) | +|-----------|--------|-----------| +| Base AR architecture | PR #549 | @abaybektursun | +| Muon optimizer | Baseline | Contest organizers | +| BigramHash, SmearGate | PR #65 | @aquariouserworkman | +| XSA (Cross-Sequence Attention) | PR #187, #265 | @Idan3011, @unnir | +| U-Net skip connections | PR #65, #69 | @aquariouserworkman | +| SWA (Stochastic Weight Averaging) | PR #69 | @aquariouserworkman | +| Mixed int6/int8 quantization | PR #76 | Contest community | +| Sliding window evaluation | PR #50 | @mattqlf | +| Legal score-first TTT | PR #77 | @samacqua | +| VE, Partial RoPE, LN Scale | PR #315, #374 | @jfprincz, @unnir | +| LeakyReLU² activation | Baseline / PR #549 | @abaybektursun | +| EMA | PR #65 | @aquariouserworkman | +| Gated attention, value residual | PR #940 | Contest community | +| NativeFlowMatcher (this work) | PR #940 experiments | @mcclec07 | + +## File Manifest + +``` +README.md — This file +submission.json — Structured metadata +train_gpt.py — Training/eval script (2,601 lines) +train.log — Training log for seed=42 (SLURM 55342820) +final_model.int6.ptz — Quantized model artifact (15,630,744 bytes) +supplementary/ + eval_nflow7k_legal_ttt.log — Legal TTT eval log, seed=42 (SLURM 55375245) + eval_nflow7k_nottt.log — No-TTT eval log, seed=42 (SLURM 55375246) + eval_e2ettt_flow7k_legal_ttt.log — E2E TTT+Flow eval, INCOMPLETE (SLURM 55375247) + eval_e2ettt_flow7k_legal_ttt_complete.log — E2E TTT+Flow eval, COMPLETE (SLURM 55398555) + slurm_pr940_nflow_7k.sh — Training SLURM script (seed=42) + slurm_eval_nflow7k_legal_ttt.sh — Legal TTT eval SLURM script (seed=42) + slurm_eval_nflow7k_nottt.sh — No-TTT eval SLURM script (seed=42) + seed_runs/ + slurm_nflow_train_s1337.sh — Training script (seed=1337) + slurm_nflow_train_s2025.sh — Training script (seed=2025) + slurm_eval_s1337_legal_ttt.sh — Legal TTT eval (seed=1337) + slurm_eval_s1337_nottt.sh — No-TTT eval (seed=1337) + slurm_eval_s2025_legal_ttt.sh — Legal TTT eval (seed=2025) + slurm_eval_s2025_nottt.sh — No-TTT eval (seed=2025) + train_s1337.log — Training log (SLURM 55398556) + train_s2025.log — Training log (SLURM 55398557) +``` diff --git a/records/track_non_record_16mb/2026-03-31_11L_NativeFlowMatcher_LegalTTT/final_model.int6.ptz b/records/track_non_record_16mb/2026-03-31_11L_NativeFlowMatcher_LegalTTT/final_model.int6.ptz new file mode 100644 index 0000000000..402f91ddaa Binary files /dev/null and b/records/track_non_record_16mb/2026-03-31_11L_NativeFlowMatcher_LegalTTT/final_model.int6.ptz differ diff --git a/records/track_non_record_16mb/2026-03-31_11L_NativeFlowMatcher_LegalTTT/submission.json b/records/track_non_record_16mb/2026-03-31_11L_NativeFlowMatcher_LegalTTT/submission.json new file mode 100644 index 0000000000..a191830d51 --- /dev/null +++ b/records/track_non_record_16mb/2026-03-31_11L_NativeFlowMatcher_LegalTTT/submission.json @@ -0,0 +1,92 @@ +{ + "track": "non_record_16mb", + "val_bpb": 1.11990650, + "val_bpb_no_ttt": 1.12311579, + "val_bpb_no_ttt_3seed_mean": 1.12252020, + "val_bpb_no_ttt_3seed_std": 0.00151, + "model_file": "final_model.int6.ptz", + "model_bytes": 15630744, + "code_bytes": 115032, + "total_submission_bytes": 15745776, + "training_tokens_billions": 5.51, + "training_script": "train_gpt.py", + "hardware": "1×A100 PCIe 40GB", + "training_time_hours": 3.86, + "training_steps": 7000, + "quantization": "int6+int5[layers5,6]+zstd-16", + "architecture": "11L-512D-8H-4KV-3xMLP-BigramHash4096-NativeFlowMatcher256", + "num_layers": 11, + "model_dim": 512, + "num_heads": 8, + "num_kv_heads": 4, + "mlp_mult": 3, + "model_params": 27530952, + "nfm_params": 393729, + "seed": 42, + "seeds_attempted": [42, 1337, 2025], + "seeds_completed": [42, 1337, 2025], + "date": "2026-03-31", + "eval_stride": 64, + "legal_ttt": { + "enabled": true, + "optimizer": "sgd", + "lr": 0.002, + "epochs": 10, + "freeze_blocks": 2, + "chunk_tokens": 32768, + "momentum": 0.9, + "grad_clip": 1.0, + "total_chunks": 1893, + "eval_time_seconds": 7190 + }, + "native_flow_matcher": { + "hidden_dim": 256, + "init_scale": 0.01, + "loss_weight": 0.1 + }, + "xsa_last_n": 11, + "bigram_vocab_size": 4096, + "bigram_dim": 128, + "seed_results": { + "seed_42": { + "slurm_training": 55342820, + "slurm_eval_legal_ttt": 55375245, + "slurm_eval_no_ttt": 55375246, + "roundtrip_bpb": 1.14679034, + "sliding_no_ttt_bpb": 1.12311579, + "sliding_legal_ttt_bpb": 1.11990650, + "artifact_bytes": 15745776 + }, + "seed_1337": { + "slurm_training": 55398556, + "slurm_eval_legal_ttt": 55411651, + "slurm_eval_no_ttt": 55411652, + "roundtrip_bpb": 1.14729126, + "sliding_no_ttt_bpb": 1.12366996, + "sliding_legal_ttt_bpb": null, + "artifact_bytes": 15736933 + }, + "seed_2025": { + "slurm_training": 55398557, + "slurm_eval_legal_ttt": 55411653, + "slurm_eval_no_ttt": 55411654, + "roundtrip_bpb": 1.14444585, + "sliding_no_ttt_bpb": 1.12077485, + "sliding_legal_ttt_bpb": null, + "artifact_bytes": 15745950 + } + }, + "slurm_jobs": { + "training": 55342820, + "eval_legal_ttt": 55375245, + "eval_no_ttt": 55375246, + "training_s1337": 55398556, + "training_s2025": 55398557, + "eval_s1337_legal_ttt": 55411651, + "eval_s1337_no_ttt": 55411652, + "eval_s2025_legal_ttt": 55411653, + "eval_s2025_no_ttt": 55411654, + "e2ettt_flow_eval": 55398555 + }, + "note": "Three-seed reproducibility achieved. Training and no-TTT sliding eval complete for all seeds. Legal TTT eval pending for seeds 1337, 2025." +} diff --git a/records/track_non_record_16mb/2026-03-31_11L_NativeFlowMatcher_LegalTTT/supplementary/eval_e2ettt_flow7k_legal_ttt.log b/records/track_non_record_16mb/2026-03-31_11L_NativeFlowMatcher_LegalTTT/supplementary/eval_e2ettt_flow7k_legal_ttt.log new file mode 100644 index 0000000000..1cca4e771c --- /dev/null +++ b/records/track_non_record_16mb/2026-03-31_11L_NativeFlowMatcher_LegalTTT/supplementary/eval_e2ettt_flow7k_legal_ttt.log @@ -0,0 +1,2794 @@ +logs/eval_e2ettt_flow7k_legal_ttt_55375247.txt +val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path=/hpfs/scratch/gpfs/mcclec07/code/parameter_golf/repo/data/tokenizers/fineweb_1024_bpe.model +train_loader:dataset:fineweb10B_sp1024 train_shards:80 +val_loader:shards pattern=/hpfs/scratch/gpfs/mcclec07/code/parameter_golf/repo/data/datasets/fineweb10B_sp1024/fineweb_val_*.bin tokens:62021632 +model_params:28319753 +mtp_num_heads:0 mtp_loss_weight:0.2 mtp_params:0 +world_size:1 grad_accum_steps:8 +sdp_backends:fa3=False cudnn=False flash=True mem_efficient=True math=True +attention_mode:gqa num_heads:8 num_kv_heads:4 +tie_embeddings:True embed_lr:0.035 head_lr:0.0 matrix_lr:0.025 scalar_lr:0.025 +train_batch_tokens:786432 train_seq_len:2048 iterations:7000 warmup_steps:20 max_wallclock_seconds:600.000 +seed:42 +warmup_step:1/20 +warmup_step:2/20 +warmup_step:3/20 +warmup_step:4/20 +warmup_step:5/20 +warmup_step:6/20 +warmup_step:7/20 +warmup_step:8/20 +warmup_step:9/20 +warmup_step:10/20 +warmup_step:11/20 +warmup_step:12/20 +warmup_step:13/20 +warmup_step:14/20 +warmup_step:15/20 +warmup_step:16/20 +warmup_step:17/20 +warmup_step:18/20 +warmup_step:19/20 +warmup_step:20/20 +eval_only: loading /hpfs/scratch/gpfs/mcclec07/code/parameter_golf/runs/e2ettt_flow_55363689/models/final_model_pr940_e2ettt_flow_55363689.pt, skipping training +step:0/0 val_loss:1.9237 val_bpb:1.1393 train_time:63ms step_avg:63.28ms +peak memory allocated: 27459 MiB reserved: 28022 MiB +save_paths: pt=final_model_eval_e2ettt_flow7k_legal_ttt_55375247.pt ptz=final_model_eval_e2ettt_flow7k_legal_ttt_55375247.int6.ptz +Serialized model: 108880403 bytes +Code size: 115032 bytes +quant_try int6 zstd-16: 16854717 bytes (limit 15884968) +quant_try int6 zstd-1: 16914613 bytes (limit 15884968) +quant_try int6 zstd-17: 16856988 bytes (limit 15884968) +quant_try int6 zstd-2: 16923144 bytes (limit 15884968) +quant_fallback: int5 layers=[5] +quant_try int5[1L] zstd-16: 16553148 bytes (limit 15884968) +quant_try int5[1L] zstd-1: 16617668 bytes (limit 15884968) +quant_try int5[1L] zstd-17: 16551154 bytes (limit 15884968) +quant_try int5[1L] zstd-2: 16675594 bytes (limit 15884968) +quant_fallback: int5 layers=[5, 6] +quant_try int5[2L] zstd-16: 16244171 bytes (limit 15884968) +quant_try int5[2L] zstd-1: 16320340 bytes (limit 15884968) +quant_try int5[2L] zstd-17: 16261223 bytes (limit 15884968) +quant_try int5[2L] zstd-2: 16427030 bytes (limit 15884968) +quant_fallback: int5 layers=[4, 5, 6] +quant_try int5[3L] zstd-16: 16007263 bytes (limit 15884968) +quant_try int5[3L] zstd-1: 16024461 bytes (limit 15884968) +quant_try int5[3L] zstd-17: 15970049 bytes (limit 15884968) +quant_try int5[3L] zstd-2: 16181460 bytes (limit 15884968) +quant_fallback: int5 layers=[4, 5, 6, 7] +quant_try int5[4L] zstd-16: 15643338 bytes (limit 15884968) +Serialized model quant+zstd-16: 15643338 bytes +Total submission size: 15758370 bytes +final_int6_roundtrip val_loss:1.9458 val_bpb:1.1524 eval_time:243645ms +final_int6_roundtrip_exact val_loss:1.94578104 val_bpb:1.15240113 +legal_ttt:start stride=64 optimizer=sgd lr=0.002 epochs=10 freeze_blocks=2 +ttt_sliding:start chunks=1893 chunk_tokens=32768 total_windows=969088 stride=64 ttt_lr=0.002 ttt_epochs=10 ttt_optimizer=sgd freeze_blocks=2 +ttt_sliding:params unfrozen=23588837 frozen=4730916 + ttt_chunk [1/1893] bpb=1.219061 time=7.4s + ttt_chunk [11/1893] bpb=1.120993 time=80.9s + ttt_chunk [21/1893] bpb=1.130134 time=154.5s + ttt_chunk [31/1893] bpb=1.135428 time=228.1s + ttt_chunk [41/1893] bpb=1.131994 time=301.7s + ttt_chunk [51/1893] bpb=1.133329 time=375.3s + ttt_chunk [61/1893] bpb=1.137232 time=448.8s + ttt_chunk [71/1893] bpb=1.135502 time=522.8s + ttt_chunk [81/1893] bpb=1.132059 time=597.6s + ttt_chunk [91/1893] bpb=1.130884 time=672.2s + ttt_chunk [101/1893] bpb=1.131887 time=746.1s + ttt_chunk [111/1893] bpb=1.132078 time=819.8s + ttt_chunk [121/1893] bpb=1.128451 time=893.6s + ttt_chunk [131/1893] bpb=1.127652 time=967.5s + ttt_chunk [141/1893] bpb=1.126534 time=1041.2s + ttt_chunk [151/1893] bpb=1.126669 time=1114.8s + ttt_chunk [161/1893] bpb=1.127536 time=1188.4s + ttt_chunk [171/1893] bpb=1.129615 time=1262.0s + ttt_chunk [181/1893] bpb=1.129613 time=1337.7s + ttt_chunk [191/1893] bpb=1.132008 time=1411.2s + ttt_chunk [201/1893] bpb=1.131532 time=1484.8s + ttt_chunk [211/1893] bpb=1.130573 time=1558.2s + ttt_chunk [221/1893] bpb=1.131473 time=1631.7s + ttt_chunk [231/1893] bpb=1.131180 time=1705.2s + ttt_chunk [241/1893] bpb=1.131385 time=1778.7s + ttt_chunk [251/1893] bpb=1.130903 time=1852.4s + ttt_chunk [261/1893] bpb=1.130156 time=1926.2s + ttt_chunk [271/1893] bpb=1.129245 time=1999.9s + ttt_chunk [281/1893] bpb=1.130732 time=2073.6s + ttt_chunk [291/1893] bpb=1.130322 time=2147.3s + ttt_chunk [301/1893] bpb=1.131145 time=2220.9s + ttt_chunk [311/1893] bpb=1.131184 time=2294.6s + ttt_chunk [321/1893] bpb=1.131877 time=2368.2s + ttt_chunk [331/1893] bpb=1.131323 time=2441.8s + ttt_chunk [341/1893] bpb=1.130879 time=2515.5s + ttt_chunk [351/1893] bpb=1.131565 time=2589.0s + ttt_chunk [361/1893] bpb=1.132309 time=2662.6s + ttt_chunk [371/1893] bpb=1.132205 time=2736.3s + ttt_chunk [381/1893] bpb=1.131957 time=2810.0s + ttt_chunk [391/1893] bpb=1.132640 time=2885.8s + ttt_chunk [401/1893] bpb=1.132159 time=2959.5s + ttt_chunk [411/1893] bpb=1.131170 time=3033.2s + ttt_chunk [421/1893] bpb=1.131239 time=3106.9s + ttt_chunk [431/1893] bpb=1.131658 time=3180.5s + ttt_chunk [441/1893] bpb=1.131042 time=3254.2s + ttt_chunk [451/1893] bpb=1.131166 time=3327.8s + ttt_chunk [461/1893] bpb=1.131031 time=3401.5s + ttt_chunk [471/1893] bpb=1.130577 time=3475.1s + ttt_chunk [481/1893] bpb=1.130385 time=3548.8s + ttt_chunk [491/1893] bpb=1.130544 time=3622.4s + ttt_chunk [501/1893] bpb=1.130284 time=3696.0s + ttt_chunk [511/1893] bpb=1.129779 time=3769.6s + ttt_chunk [521/1893] bpb=1.129314 time=3843.1s + ttt_chunk [531/1893] bpb=1.130018 time=3916.7s + ttt_chunk [541/1893] bpb=1.130099 time=3990.2s + ttt_chunk [551/1893] bpb=1.129554 time=4063.6s + ttt_chunk [561/1893] bpb=1.129400 time=4136.9s + ttt_chunk [571/1893] bpb=1.129109 time=4210.3s + ttt_chunk [581/1893] bpb=1.128706 time=4283.6s + ttt_chunk [591/1893] bpb=1.128130 time=4357.0s + ttt_chunk [601/1893] bpb=1.128118 time=4430.3s + ttt_chunk [611/1893] bpb=1.127777 time=4503.6s + ttt_chunk [621/1893] bpb=1.127617 time=4577.0s + ttt_chunk [631/1893] bpb=1.127359 time=4650.2s + ttt_chunk [641/1893] bpb=1.126912 time=4723.6s + ttt_chunk [651/1893] bpb=1.126447 time=4796.9s + ttt_chunk [661/1893] bpb=1.126330 time=4870.3s + ttt_chunk [671/1893] bpb=1.125836 time=4943.6s + ttt_chunk [681/1893] bpb=1.125266 time=5016.9s + ttt_chunk [691/1893] bpb=1.125340 time=5090.3s + ttt_chunk [701/1893] bpb=1.124500 time=5163.6s + ttt_chunk [711/1893] bpb=1.124502 time=5236.9s + ttt_chunk [721/1893] bpb=1.124411 time=5310.2s + ttt_chunk [731/1893] bpb=1.124652 time=5383.5s + ttt_chunk [741/1893] bpb=1.124528 time=5456.8s + ttt_chunk [751/1893] bpb=1.124227 time=5530.2s + ttt_chunk [761/1893] bpb=1.124358 time=5603.5s + ttt_chunk [771/1893] bpb=1.124185 time=5676.8s + ttt_chunk [781/1893] bpb=1.124354 time=5750.1s + ttt_chunk [791/1893] bpb=1.124211 time=5823.5s + ttt_chunk [801/1893] bpb=1.124147 time=5896.8s + ttt_chunk [811/1893] bpb=1.124151 time=5970.1s + ttt_chunk [821/1893] bpb=1.124052 time=6043.4s + ttt_chunk [831/1893] bpb=1.123777 time=6116.7s + ttt_chunk [841/1893] bpb=1.123545 time=6190.1s + ttt_chunk [851/1893] bpb=1.123609 time=6263.4s + ttt_chunk [861/1893] bpb=1.123678 time=6336.7s + ttt_chunk [871/1893] bpb=1.123882 time=6410.0s + ttt_chunk [881/1893] bpb=1.123884 time=6483.0s + ttt_chunk [891/1893] bpb=1.123365 time=6555.8s + ttt_chunk [901/1893] bpb=1.123387 time=6628.5s + ttt_chunk [911/1893] bpb=1.123230 time=6701.3s + ttt_chunk [921/1893] bpb=1.123363 time=6774.0s + ttt_chunk [931/1893] bpb=1.123317 time=6846.8s + ttt_chunk [941/1893] bpb=1.123522 time=6919.5s + ttt_chunk [951/1893] bpb=1.123821 time=6992.3s + ttt_chunk [961/1893] bpb=1.124109 time=7065.0s + ttt_chunk [971/1893] bpb=1.124460 time=7137.7s + ttt_chunk [981/1893] bpb=1.124663 time=7210.3s + ttt_chunk [991/1893] bpb=1.124562 time=7283.1s + ttt_chunk [1001/1893] bpb=1.124875 time=7355.8s + ttt_chunk [1011/1893] bpb=1.125008 time=7430.4s + ttt_chunk [1021/1893] bpb=1.125291 time=7503.1s + ttt_chunk [1031/1893] bpb=1.125664 time=7575.8s + ttt_chunk [1041/1893] bpb=1.126175 time=7648.5s + ttt_chunk [1051/1893] bpb=1.126029 time=7721.1s + ttt_chunk [1061/1893] bpb=1.126122 time=7793.8s + ttt_chunk [1071/1893] bpb=1.126273 time=7866.4s + ttt_chunk [1081/1893] bpb=1.126313 time=7939.1s + ttt_chunk [1091/1893] bpb=1.126566 time=8011.8s + ttt_chunk [1101/1893] bpb=1.126699 time=8084.5s + ttt_chunk [1111/1893] bpb=1.126441 time=8157.1s + ttt_chunk [1121/1893] bpb=1.126208 time=8229.8s + ttt_chunk [1131/1893] bpb=1.126093 time=8302.4s + ttt_chunk [1141/1893] bpb=1.125849 time=8375.0s + ttt_chunk [1151/1893] bpb=1.125863 time=8447.6s + ttt_chunk [1161/1893] bpb=1.125649 time=8520.3s + ttt_chunk [1171/1893] bpb=1.125474 time=8592.9s + ttt_chunk [1181/1893] bpb=1.125243 time=8665.6s + ttt_chunk [1191/1893] bpb=1.125389 time=8738.2s + ttt_chunk [1201/1893] bpb=1.125588 time=8810.9s + ttt_chunk [1211/1893] bpb=1.125185 time=8883.6s + ttt_chunk [1221/1893] bpb=1.125510 time=8956.3s + ttt_chunk [1231/1893] bpb=1.125434 time=9029.0s + ttt_chunk [1241/1893] bpb=1.125130 time=9101.6s + ttt_chunk [1251/1893] bpb=1.124594 time=9174.3s + ttt_chunk [1261/1893] bpb=1.124331 time=9247.0s + ttt_chunk [1271/1893] bpb=1.124082 time=9319.6s + g = zeropower_via_newtonschulz5(g, steps=backend_steps) + g *= max(1, g.size(0) / g.size(1)) ** 0.5 + updates_flat[curr : curr + p.numel()] = g.reshape(-1) + curr += p.numel() + + if distributed: + dist.all_reduce(updates_flat, op=dist.ReduceOp.SUM) + + wd = group.get("weight_decay", 0.0) + curr = 0 + for p in params: + if wd > 0.0: + p.data.mul_(1.0 - lr * wd) + g = updates_flat[curr : curr + p.numel()].view_as(p).to(dtype=p.dtype) + p.add_(g, alpha=-lr) + curr += p.numel() + + return loss + + +# ----------------------------- +# TOKENIZER-AGNOSTIC EVALUATION SETUP +# ----------------------------- +# +# It's common for small models have a large fraction of their parameters be embeddings, since the 2 * d_model * d_vocab vectors can be gigantic. +# Instead of locking the tokenizer, we let you bring your own and calculate our validation metrics on the average compression of the validation set. +# We calculate BPB (bits-per-byte) instead of validation loss, so we need methods to count the number of bits per token in the tokenizer. +# Note: Submissions that edit the tokenizer will be examined more carefully, since screwing this up might unjustly improve your score. + +def build_sentencepiece_luts( + sp: spm.SentencePieceProcessor, vocab_size: int, device: torch.device +) -> tuple[Tensor, Tensor, Tensor]: + sp_vocab_size = int(sp.vocab_size()) + table_size = max(sp_vocab_size, vocab_size) + base_bytes_np = np.zeros((table_size,), dtype=np.int16) + has_leading_space_np = np.zeros((table_size,), dtype=np.bool_) + is_boundary_token_np = np.ones((table_size,), dtype=np.bool_) + for token_id in range(sp_vocab_size): + if sp.is_control(token_id) or sp.is_unknown(token_id) or sp.is_unused(token_id): + continue + is_boundary_token_np[token_id] = False + if sp.is_byte(token_id): + base_bytes_np[token_id] = 1 + continue + piece = sp.id_to_piece(token_id) + if piece.startswith("▁"): + has_leading_space_np[token_id] = True + piece = piece[1:] + base_bytes_np[token_id] = len(piece.encode("utf-8")) + return ( + torch.tensor(base_bytes_np, dtype=torch.int16, device=device), + torch.tensor(has_leading_space_np, dtype=torch.bool, device=device), + torch.tensor(is_boundary_token_np, dtype=torch.bool, device=device), + ) + + +def load_validation_tokens(pattern: str, seq_len: int) -> Tensor: + files = [Path(p) for p in sorted(glob.glob(pattern))] + if not files: + raise FileNotFoundError(f"No files found for pattern: {pattern}") + # The export pipeline writes the fixed first-50k-doc validation set to fineweb_val_*. + tokens = torch.cat([load_data_shard(file) for file in files]).contiguous() + usable = ((tokens.numel() - 1) // seq_len) * seq_len + if usable <= 0: + raise ValueError(f"Validation split is too short for TRAIN_SEQ_LEN={seq_len}") + return tokens[: usable + 1] + + +def eval_val( + args: Hyperparameters, + model: nn.Module, + rank: int, + world_size: int, + device: torch.device, + grad_accum_steps: int, + val_tokens: Tensor, + base_bytes_lut: Tensor, + has_leading_space_lut: Tensor, + is_boundary_token_lut: Tensor, + eval_seq_len: int | None = None, +) -> tuple[float, float]: + seq_len = eval_seq_len or args.train_seq_len + local_batch_tokens = args.val_batch_size // (world_size * grad_accum_steps) + if local_batch_tokens < seq_len: + raise ValueError( + "VAL_BATCH_SIZE must provide at least one sequence per rank; " + f"got VAL_BATCH_SIZE={args.val_batch_size}, WORLD_SIZE={world_size}, " + f"GRAD_ACCUM_STEPS={grad_accum_steps}, seq_len={seq_len}" + ) + local_batch_seqs = local_batch_tokens // seq_len + total_seqs = (val_tokens.numel() - 1) // seq_len + seq_start = (total_seqs * rank) // world_size + seq_end = (total_seqs * (rank + 1)) // world_size + val_loss_sum = torch.zeros((), device=device, dtype=torch.float64) + val_token_count = torch.zeros((), device=device, dtype=torch.float64) + val_byte_count = torch.zeros((), device=device, dtype=torch.float64) + + model.eval() + with torch.inference_mode(): + for batch_seq_start in range(seq_start, seq_end, local_batch_seqs): + batch_seq_end = min(batch_seq_start + local_batch_seqs, seq_end) + raw_start = batch_seq_start * seq_len + raw_end = batch_seq_end * seq_len + 1 + local = val_tokens[raw_start:raw_end].to(device=device, dtype=torch.int64, non_blocking=True) + x = local[:-1].reshape(-1, seq_len) + y = local[1:].reshape(-1, seq_len) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + batch_loss = model(x, y).detach() + batch_token_count = float(y.numel()) + val_loss_sum += batch_loss.to(torch.float64) * batch_token_count + val_token_count += batch_token_count + prev_ids = x.reshape(-1) + tgt_ids = y.reshape(-1) + token_bytes = base_bytes_lut[tgt_ids].to(dtype=torch.int16) + token_bytes += (has_leading_space_lut[tgt_ids] & ~is_boundary_token_lut[prev_ids]).to(dtype=torch.int16) + val_byte_count += token_bytes.to(torch.float64).sum() + + if dist.is_available() and dist.is_initialized(): + dist.all_reduce(val_loss_sum, op=dist.ReduceOp.SUM) + dist.all_reduce(val_token_count, op=dist.ReduceOp.SUM) + dist.all_reduce(val_byte_count, op=dist.ReduceOp.SUM) + + val_loss = val_loss_sum / val_token_count + bits_per_token = val_loss.item() / math.log(2.0) + tokens_per_byte = val_token_count.item() / val_byte_count.item() + model.train() + return float(val_loss.item()), float(bits_per_token * tokens_per_byte) + +# ----------------------------- +# POST-TRAINING QUANTIZATION +# ----------------------------- +# +# It's silly to export our model, which is trained in bf16 and fp32, at that same precision. +# Instead, we get approximately the same model (with a small hit) by quantizing the model to int8 & zlib compressing. +# We can then decompress the model and run in higher precision for evaluation, after closing in under the size limit. + +CONTROL_TENSOR_NAME_PATTERNS = tuple( + pattern + for pattern in os.environ.get( + "CONTROL_TENSOR_NAME_PATTERNS", + "attn_scale,attn_scales,mlp_scale,mlp_scales,resid_mix,resid_mixes,q_gain,skip_weight,skip_weights,smear,vr_lambda,attn_gate,canon_a,canon_c,delta_gate", + ).split(",") + if pattern +) +INT8_KEEP_FLOAT_FP32_NAME_PATTERNS = tuple( + pattern + for pattern in os.environ.get( + "INT8_KEEP_FLOAT_FP32_NAME_PATTERNS", + ",".join(CONTROL_TENSOR_NAME_PATTERNS), + ).split(",") + if pattern +) +INT8_KEEP_FLOAT_MAX_NUMEL = 65_536 +INT8_KEEP_FLOAT_STORE_DTYPE = torch.float16 +INT8_PER_ROW_SCALE_DTYPE = torch.float16 +INT8_CLIP_PERCENTILE = 99.99984 +INT8_CLIP_Q = INT8_CLIP_PERCENTILE / 100.0 + +def tensor_nbytes(t: Tensor) -> int: + return int(t.numel()) * int(t.element_size()) + +def keep_float_tensor(name: str, t: Tensor, passthrough_orig_dtypes: dict[str, str]) -> Tensor: + if any(pattern in name for pattern in INT8_KEEP_FLOAT_FP32_NAME_PATTERNS): + return t.float().contiguous() + if t.dtype in {torch.float32, torch.bfloat16}: + passthrough_orig_dtypes[name] = str(t.dtype).removeprefix("torch.") + return t.to(dtype=INT8_KEEP_FLOAT_STORE_DTYPE).contiguous() + return t + +def quantize_float_tensor(t: Tensor, qmax: int = 127) -> tuple[Tensor, Tensor]: + """Quantize to [-qmax, qmax] range. Default int8 (qmax=127), int6 (qmax=31), int5 (qmax=15).""" + t32 = t.float() + qmin = -qmax + if t32.ndim == 2: + clip_abs = ( + torch.quantile(t32.abs(), INT8_CLIP_Q, dim=1) + if t32.numel() + else torch.empty((t32.shape[0],), dtype=torch.float32) + ) + clipped = torch.maximum(torch.minimum(t32, clip_abs[:, None]), -clip_abs[:, None]) + scale = (clip_abs / float(qmax)).clamp_min(1.0 / float(qmax)) + q = torch.clamp(torch.round(clipped / scale[:, None]), qmin, qmax).to(torch.int8).contiguous() + return q, scale.to(dtype=INT8_PER_ROW_SCALE_DTYPE).contiguous() + + clip_abs = float(torch.quantile(t32.abs().flatten(), INT8_CLIP_Q).item()) if t32.numel() else 0.0 + scale = torch.tensor(clip_abs / float(qmax) if clip_abs > 0 else 1.0, dtype=torch.float32) + q = torch.clamp(torch.round(torch.clamp(t32, -clip_abs, clip_abs) / scale), qmin, qmax).to(torch.int8).contiguous() + return q, scale + +def quantize_state_dict_int8(state_dict: dict[str, Tensor]): + # Single supported clean-script export format: + # - per-row int8 for 2D float tensors + # - per-tensor int8 for other float tensors + # - exact passthrough for non-floats + # - passthrough for small float tensors, stored as fp16 to save bytes + quantized: dict[str, Tensor] = {} + scales: dict[str, Tensor] = {} + dtypes: dict[str, str] = {} + passthrough: dict[str, Tensor] = {} + passthrough_orig_dtypes: dict[str, str] = {} + qmeta: dict[str, dict[str, object]] = {} + stats = dict.fromkeys( + ("param_count", "num_tensors", "num_float_tensors", "num_nonfloat_tensors", "baseline_tensor_bytes", "int8_payload_bytes"), + 0, + ) + + for name, tensor in state_dict.items(): + t = tensor.detach().to("cpu").contiguous() + stats["param_count"] += int(t.numel()) + stats["num_tensors"] += 1 + stats["baseline_tensor_bytes"] += tensor_nbytes(t) + + if not t.is_floating_point(): + stats["num_nonfloat_tensors"] += 1 + passthrough[name] = t + stats["int8_payload_bytes"] += tensor_nbytes(t) + continue + + # Small float tensors are cheap enough to keep directly. We still downcast + # fp32/bf16 passthrough tensors to fp16 so metadata does not dominate size. + if t.numel() <= INT8_KEEP_FLOAT_MAX_NUMEL: + kept = keep_float_tensor(name, t, passthrough_orig_dtypes) + passthrough[name] = kept + stats["int8_payload_bytes"] += tensor_nbytes(kept) + continue + + stats["num_float_tensors"] += 1 + # Mixed quantization: int6 for MLP layers 3-7 to save artifact space + int6_mlp_layers = os.environ.get("INT6_MLP_LAYERS", "") + qmax = 127 # default int8 + if int6_mlp_layers: + for li in int6_mlp_layers.split(","): + if li.strip() and f"blocks.{li.strip()}.mlp" in name and t.ndim == 2: + qmax = 31 # int6 + break + q, s = quantize_float_tensor(t, qmax=qmax) + if s.ndim > 0: + qmeta[name] = {"scheme": "per_row", "axis": 0} + quantized[name] = q + scales[name] = s + dtypes[name] = str(t.dtype).removeprefix("torch.") + stats["int8_payload_bytes"] += tensor_nbytes(q) + tensor_nbytes(s) + + obj: dict[str, object] = { + "__quant_format__": "int8_clean_per_row_v1", + "quantized": quantized, + "scales": scales, + "dtypes": dtypes, + "passthrough": passthrough, + } + if qmeta: + obj["qmeta"] = qmeta + if passthrough_orig_dtypes: + obj["passthrough_orig_dtypes"] = passthrough_orig_dtypes + return obj, stats + +def dequantize_state_dict_int8(obj: dict[str, object]) -> dict[str, Tensor]: + out: dict[str, Tensor] = {} + qmeta = obj.get("qmeta", {}) + passthrough_orig_dtypes = obj.get("passthrough_orig_dtypes", {}) + for name, q in obj["quantized"].items(): + dtype = getattr(torch, obj["dtypes"][name]) + s = obj["scales"][name] + if qmeta.get(name, {}).get("scheme") == "per_row" or s.ndim > 0: + s = s.to(dtype=torch.float32) + # Broadcast the saved row scale back across trailing dimensions. + out[name] = (q.float() * s.view(q.shape[0], *([1] * (q.ndim - 1)))).to(dtype=dtype).contiguous() + else: + scale = float(s.item()) + out[name] = (q.float() * scale).to(dtype=dtype).contiguous() + for name, t in obj["passthrough"].items(): + # Restore small tensors, undoing the temporary fp16 storage cast if needed. + out_t = t.detach().to("cpu").contiguous() + orig_dtype = passthrough_orig_dtypes.get(name) + if isinstance(orig_dtype, str): + out_t = out_t.to(dtype=getattr(torch, orig_dtype)).contiguous() + out[name] = out_t + return out + + +# ----------------------------- +# DATA LOADING +# ----------------------------- + +def load_data_shard(file: Path) -> Tensor: + header_bytes = 256 * np.dtype(" None: + self.file_idx = (self.file_idx + 1) % len(self.files) + self.tokens = load_data_shard(self.files[self.file_idx]) + self.pos = 0 + + def take(self, n: int) -> Tensor: + chunks: list[Tensor] = [] + remaining = n + while remaining > 0: + avail = self.tokens.numel() - self.pos + if avail <= 0: + self._advance_file() + continue + k = min(remaining, avail) + chunks.append(self.tokens[self.pos : self.pos + k]) + self.pos += k + remaining -= k + return chunks[0] if len(chunks) == 1 else torch.cat(chunks) + + +class DistributedTokenLoader: + # Each call consumes a contiguous chunk from the shared token stream, then slices out + # one disjoint span per rank. The extra "+1" token lets us build (x, y) by shifting. + def __init__(self, pattern: str, rank: int, world_size: int, device: torch.device): + self.rank = rank + self.world_size = world_size + self.device = device + self.stream = TokenStream(pattern) + + def next_batch(self, global_tokens: int, seq_len: int, grad_accum_steps: int) -> tuple[Tensor, Tensor]: + local_tokens = global_tokens // (self.world_size * grad_accum_steps) + per_rank_span = local_tokens + 1 + chunk = self.stream.take(per_rank_span * self.world_size) + start = self.rank * per_rank_span + local = chunk[start : start + per_rank_span].to(dtype=torch.int64) + x = local[:-1].reshape(-1, seq_len) + y = local[1:].reshape(-1, seq_len) + return x.to(self.device, non_blocking=True), y.to(self.device, non_blocking=True) + +# ----------------------------- +# TRANSFORMER MODULES +# ----------------------------- + +class RMSNorm(nn.Module): + def __init__(self, eps: float | None = None): + super().__init__() + self.eps = eps + + def forward(self, x: Tensor) -> Tensor: + return F.rms_norm(x, (x.size(-1),), eps=self.eps) + + +class CastedLinear(nn.Linear): + _qat_enabled: bool = False + _soft_round: bool = False + _soft_round_alpha: float = 1.0 + _quant_percentile: float = float(os.environ.get("QUANT_PERCENTILE", "1.0")) + + def forward(self, x: Tensor) -> Tensor: + w = self.weight.to(x.dtype) + if CastedLinear._qat_enabled and self.training and w.ndim == 2: + w32 = self.weight.float() + pct = CastedLinear._quant_percentile + row_max = (torch.quantile(w32.abs(), pct, dim=1) if pct < 1.0 + else w32.abs().amax(dim=1)).detach() + scale = (row_max / 31.0).clamp_min(1.0 / 31.0) + r = w32 / scale[:, None] + if CastedLinear._soft_round: + alpha = CastedLinear._soft_round_alpha + r_frac = r - r.detach().floor() - 0.5 + norm = torch.tanh(torch.tensor(alpha * 0.5, device=r.device, dtype=r.dtype)) + r_soft = r.detach().floor() + 0.5 + torch.tanh(alpha * r_frac) / (2.0 * norm) + w_q = (torch.clamp(r_soft, -32, 31) * scale[:, None]).to(x.dtype) + w = w_q # soft-round is differentiable, no STE needed + else: + with torch.no_grad(): + w_q = (torch.clamp(torch.round(r), -32, 31) * scale[:, None]).to(x.dtype) + w = w + (w_q - w).detach() # STE + bias = self.bias.to(x.dtype) if self.bias is not None else None + return F.linear(x, w, bias) + + +def restore_low_dim_params_to_fp32(module: nn.Module) -> None: + # Keep small/control parameters in fp32 even when the model body runs in bf16. + with torch.no_grad(): + for name, param in module.named_parameters(): + if (param.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS)) and param.dtype != torch.float32: + param.data = param.data.float() + + +class Rotary(nn.Module): + # NTK-aware RoPE: auto-scales base frequency when seq_len exceeds train_seq_len. + def __init__(self, dim: int, base: float = 10000.0, train_seq_len: int = 1024, rope_dims: int = 0): + super().__init__() + self.rope_dims = rope_dims if rope_dims > 0 else dim + self.dim = dim + self.base = base + self.train_seq_len = train_seq_len + rd = self.rope_dims + inv_freq = 1.0 / (base ** (torch.arange(0, rd, 2, dtype=torch.float32) / rd)) + self.register_buffer("inv_freq", inv_freq, persistent=False) + self._seq_len_cached = 0 + self._cos_cached: Tensor | None = None + self._sin_cached: Tensor | None = None + + def forward(self, seq_len: int, device: torch.device, dtype: torch.dtype) -> tuple[Tensor, Tensor]: + if ( + self._cos_cached is None + or self._sin_cached is None + or self._seq_len_cached != seq_len + or self._cos_cached.device != device + ): + rd = self.rope_dims + if seq_len > self.train_seq_len: + scale = seq_len / self.train_seq_len + new_base = self.base * (scale ** (rd / (rd - 2))) + inv_freq = 1.0 / (new_base ** (torch.arange(0, rd, 2, dtype=torch.float32, device=device) / rd)) + else: + inv_freq = self.inv_freq.to(device) + t = torch.arange(seq_len, device=device, dtype=inv_freq.dtype) + freqs = torch.outer(t, inv_freq) + self._cos_cached = freqs.cos()[None, :, None, :] + self._sin_cached = freqs.sin()[None, :, None, :] + self._seq_len_cached = seq_len + return self._cos_cached.to(dtype=dtype), self._sin_cached.to(dtype=dtype) + + +def apply_rotary_emb(x: Tensor, cos: Tensor, sin: Tensor) -> Tensor: + rd = cos.size(-1) * 2 + if rd < x.size(-1): + x_rope, x_pass = x[..., :rd], x[..., rd:] + half = rd // 2 + x1, x2 = x_rope[..., :half], x_rope[..., half:] + x_rot = torch.cat((x1 * cos + x2 * sin, x1 * (-sin) + x2 * cos), dim=-1) + return torch.cat((x_rot, x_pass), dim=-1) + half = x.size(-1) // 2 + x1, x2 = x[..., :half], x[..., half:] + return torch.cat((x1 * cos + x2 * sin, x1 * (-sin) + x2 * cos), dim=-1) + + +class CausalSelfAttention(nn.Module): + def __init__( + self, + dim: int, + num_heads: int, + num_kv_heads: int, + rope_base: float, + qk_gain_init: float, + rope_dims: int = 0, + value_residual: bool = False, + gated_attention: bool = False, + ): + super().__init__() + if dim % num_heads != 0: + raise ValueError("model_dim must be divisible by num_heads") + if num_heads % num_kv_heads != 0: + raise ValueError("num_heads must be divisible by num_kv_heads") + self.num_heads = num_heads + self.num_kv_heads = num_kv_heads + self.head_dim = dim // num_heads + if self.head_dim % 2 != 0: + raise ValueError("head_dim must be even for RoPE") + kv_dim = self.num_kv_heads * self.head_dim + self.c_q = CastedLinear(dim, dim, bias=False) + self.c_k = CastedLinear(dim, kv_dim, bias=False) + self.c_v = CastedLinear(dim, kv_dim, bias=False) + self.proj = CastedLinear(dim, dim, bias=False) + self.proj._zero_init = True + self.q_gain = nn.Parameter(torch.full((num_heads,), qk_gain_init, dtype=torch.float32)) + self.rope_dims = rope_dims + self.rotary = Rotary(self.head_dim, base=rope_base, train_seq_len=1024, rope_dims=rope_dims) + self.use_xsa = False + self.value_residual = value_residual + if value_residual: + self.vr_lambda = nn.Parameter(torch.tensor([0.5, 0.5], dtype=torch.float32)) + self.gated_attention = gated_attention + if gated_attention: + self.attn_gate = nn.Linear(dim, num_heads, bias=True) + nn.init.zeros_(self.attn_gate.weight) + nn.init.constant_(self.attn_gate.bias, 4.0) + + def _xsa_efficient(self, y: Tensor, v: Tensor) -> Tensor: + """Subtract self-value projection via GQA-aware reshape (no repeat_interleave).""" + B, T, H, D = y.shape + Hkv = v.size(-2) + group = H // Hkv + y_g = y.reshape(B, T, Hkv, group, D) + vn = F.normalize(v, dim=-1).unsqueeze(-2) + proj = (y_g * vn).sum(dim=-1, keepdim=True) * vn + return (y_g - proj).reshape(B, T, H, D) + + def forward(self, x: Tensor, v0: Tensor | None = None) -> tuple[Tensor, Tensor | None]: + bsz, seqlen, dim = x.shape + q = self.c_q(x).reshape(bsz, seqlen, self.num_heads, self.head_dim) + k = self.c_k(x).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim) + v = self.c_v(x).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim) + raw_v = v if self.value_residual else None + if self.value_residual and v0 is not None: + lam = self.vr_lambda.to(dtype=v.dtype) + v = lam[0] * v0 + lam[1] * v + q = F.rms_norm(q, (q.size(-1),)) + k = F.rms_norm(k, (k.size(-1),)) + cos, sin = self.rotary(seqlen, x.device, q.dtype) + q = apply_rotary_emb(q, cos, sin) + k = apply_rotary_emb(k, cos, sin) + q = q * self.q_gain.to(dtype=q.dtype)[None, None, :, None] + fa_dtype = torch.bfloat16 + if _USE_FA3: + y = flash_attn_3_func(q.to(fa_dtype), k.to(fa_dtype), v.to(fa_dtype), causal=True) + else: + # SDPA fallback: (B, T, H, D) -> (B, H, T, D), expand KV for GQA + q_t = q.to(fa_dtype).transpose(1, 2) + k_t = k.to(fa_dtype).transpose(1, 2) + v_t = v.to(fa_dtype).transpose(1, 2) + if self.num_kv_heads != self.num_heads: + rep = self.num_heads // self.num_kv_heads + k_t = k_t.repeat_interleave(rep, dim=1) + v_t = v_t.repeat_interleave(rep, dim=1) + y = F.scaled_dot_product_attention(q_t, k_t, v_t, is_causal=True) + y = y.transpose(1, 2) # (B, H, T, D) -> (B, T, H, D) + if self.use_xsa: + y = self._xsa_efficient(y, v) + if self.gated_attention: + gate = torch.sigmoid(self.attn_gate(x)) # (B, T, num_heads) + y = y * gate.unsqueeze(-1) # (B, T, H, 1) broadcast to (B, T, H, D) + y = y.reshape(bsz, seqlen, dim) + return self.proj(y), raw_v + + +class SmearGate(nn.Module): + def __init__(self, dim: int): + super().__init__() + self.gate = nn.Parameter(torch.zeros(dim, dtype=torch.float32)) + + def forward(self, x: Tensor) -> Tensor: + g = torch.sigmoid(self.gate.to(dtype=x.dtype))[None, None, :] + x_prev = torch.cat([torch.zeros_like(x[:, :1]), x[:, :-1]], dim=1) + return (1 - g) * x + g * x_prev + + +class BigramHashEmbedding(nn.Module): + def __init__(self, bigram_vocab_size: int, bigram_dim: int, model_dim: int): + super().__init__() + self.bigram_vocab_size = bigram_vocab_size + self.embed = nn.Embedding(bigram_vocab_size, bigram_dim) + nn.init.zeros_(self.embed.weight) + self.proj = CastedLinear(bigram_dim, model_dim, bias=False) if bigram_dim != model_dim else None + if self.proj is not None: + nn.init.zeros_(self.proj.weight) + self.scale = nn.Parameter(torch.tensor(0.05, dtype=torch.float32)) + + def bigram_hash(self, tokens: Tensor) -> Tensor: + t = tokens.to(torch.int32) + mod = self.bigram_vocab_size - 1 + out = torch.empty_like(t) + out[..., 0] = mod + out[..., 1:] = torch.bitwise_xor(36313 * t[..., 1:], 27191 * t[..., :-1]) % mod + return out.long() + + def forward(self, token_ids: Tensor) -> Tensor: + h = self.embed(self.bigram_hash(token_ids)) + if self.proj is not None: + h = self.proj(h) + return h * self.scale.to(dtype=h.dtype) + + +class MLP(nn.Module): + def __init__(self, dim: int, mlp_mult: int): + super().__init__() + hidden = int(mlp_mult * dim) + self.fc = CastedLinear(dim, hidden, bias=False) + self.proj = CastedLinear(hidden, dim, bias=False) + self.proj._zero_init = True + self.use_leaky = bool(int(os.environ.get("LEAKY_RELU", "1"))) + self.leaky_slope = float(os.environ.get("LEAKY_SLOPE", "0.5")) + + def forward(self, x: Tensor) -> Tensor: + x = F.leaky_relu(self.fc(x), self.leaky_slope) if self.use_leaky else torch.relu(self.fc(x)) + return self.proj(x.square()) + + +class CanonAC(nn.Module): + """Canon Autoregressive Convolution with DeltaGate. Manual shift+mul (no Conv1d).""" + def __init__(self, dim: int, kernel: int = 4, delta_gate_init: float = -4.0): + super().__init__() + self.kernel = kernel + self.weight = nn.Parameter(torch.zeros(kernel, dim)) + self.delta_gate_logit = nn.Parameter(torch.tensor(delta_gate_init)) + + def forward(self, x: Tensor) -> Tensor: + B, T, D = x.shape + K = self.kernel + w = self.weight.to(x.dtype) + x_pad = F.pad(x, (0, 0, K - 1, 0)) + y = w[0] * x_pad[:, K - 1:, :] + for k in range(1, K): + y = y + w[k] * x_pad[:, K - 1 - k : T + K - 1 - k, :] + gate = torch.sigmoid(self.delta_gate_logit.to(x.dtype)) + return x + gate * y + + +class Block(nn.Module): + def __init__( + self, + dim: int, + num_heads: int, + num_kv_heads: int, + mlp_mult: int, + rope_base: float, + qk_gain_init: float, + rope_dims: int = 0, + layer_idx: int = 0, + ln_scale: bool = False, + value_residual: bool = False, + gated_attention: bool = False, + canon_kernel: int = 0, + canon_delta_gate_init: float = -4.0, + ): + super().__init__() + self.attn_norm = RMSNorm() + self.mlp_norm = RMSNorm() + self.attn = CausalSelfAttention(dim, num_heads, num_kv_heads, rope_base, qk_gain_init, + rope_dims=rope_dims, value_residual=value_residual, + gated_attention=gated_attention) + self.mlp = MLP(dim, mlp_mult) + self.attn_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + self.mlp_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + self.resid_mix = nn.Parameter(torch.stack((torch.ones(dim), torch.zeros(dim))).float()) + self.ln_scale_factor = 1.0 / math.sqrt(layer_idx + 1) if ln_scale else 1.0 + self.canon_a = CanonAC(dim, canon_kernel, canon_delta_gate_init) if canon_kernel > 0 else None + self.canon_c = CanonAC(dim, canon_kernel, canon_delta_gate_init) if canon_kernel > 0 else None + + def forward(self, x: Tensor, x0: Tensor, v0: Tensor | None = None) -> tuple[Tensor, Tensor | None]: + mix = self.resid_mix.to(dtype=x.dtype) + x = mix[0][None, None, :] * x + mix[1][None, None, :] * x0 + s = self.ln_scale_factor + attn_in = self.attn_norm(x) * s + if self.canon_a is not None: + attn_in = self.canon_a(attn_in) + attn_out, raw_v = self.attn(attn_in, v0=v0) + x = x + self.attn_scale.to(dtype=x.dtype)[None, None, :] * attn_out + mlp_in = self.mlp_norm(x) * s + if self.canon_c is not None: + mlp_in = self.canon_c(mlp_in) + x = x + self.mlp_scale.to(dtype=x.dtype)[None, None, :] * self.mlp(mlp_in) + return x, raw_v + + +class FlowRefiner(nn.Module): + """1-step flow matching refiner in low-dim latent space. + + Projects hidden states to a low-dim latent, applies a learned velocity + field (1-step Euler), projects back. Output is additive adjustment to + hidden states. Initialized near-zero so the refiner starts as identity. + """ + def __init__(self, model_dim: int, latent_dim: int = 64, hidden_dim: int = 256, init_scale: float = 0.01): + super().__init__() + self.down_proj = nn.Linear(model_dim, latent_dim, bias=False) + self.velocity_net = nn.Sequential( + nn.Linear(latent_dim, hidden_dim, bias=True), + nn.GELU(), + nn.Linear(hidden_dim, latent_dim, bias=True), + ) + self.up_proj = nn.Linear(latent_dim, model_dim, bias=False) + self.gate = nn.Parameter(torch.tensor(-5.0)) # sigmoid(-5) ≈ 0.007 + self._init_scale = init_scale + self._init_weights() + self.velocity_net[2]._zero_init = True + + def _init_weights(self): + nn.init.normal_(self.down_proj.weight, std=self._init_scale) + nn.init.normal_(self.velocity_net[0].weight, std=self._init_scale) + nn.init.zeros_(self.velocity_net[0].bias) + nn.init.zeros_(self.velocity_net[2].weight) + nn.init.zeros_(self.velocity_net[2].bias) + nn.init.normal_(self.up_proj.weight, std=self._init_scale) + + def forward(self, x: Tensor) -> Tensor: + z = self.down_proj(x) + v = self.velocity_net(z) + z_refined = z + v + delta = self.up_proj(z_refined) + return torch.sigmoid(self.gate) * delta + + +class NativeFlowMatcher(nn.Module): + """Conditional Flow Matching refiner for hidden states. + + During training: computes auxiliary CFM loss on interpolated states. + During inference: applies a single Euler step correction at t=1. + """ + def __init__(self, model_dim: int, hidden_dim: int = 256, init_scale: float = 0.01): + super().__init__() + self.model_dim = model_dim + # Sinusoidal time embedding → projected to hidden_dim + self.time_proj = nn.Sequential( + nn.Linear(model_dim, hidden_dim, bias=True), + nn.GELU(), + ) + # Velocity network: x → hidden → x (with time conditioning via addition) + self.v_in = nn.Linear(model_dim, hidden_dim, bias=True) + self.v_act = nn.GELU() + self.v_out = nn.Linear(hidden_dim, model_dim, bias=False) + self.gate = nn.Parameter(torch.tensor(-5.0)) # sigmoid(-5) ≈ 0.007 + self._init_scale = init_scale + self._init_weights() + self.v_out._zero_init = True + + def _init_weights(self): + nn.init.normal_(self.v_in.weight, std=self._init_scale) + nn.init.zeros_(self.v_in.bias) + nn.init.normal_(self.time_proj[0].weight, std=self._init_scale) + nn.init.zeros_(self.time_proj[0].bias) + nn.init.zeros_(self.v_out.weight) + + @staticmethod + def _sinusoidal_time_emb(t: Tensor, dim: int) -> Tensor: + """Sinusoidal positional embedding for scalar time t. t: (...,) -> (..., dim).""" + half_dim = dim // 2 + emb = math.log(10000.0) / max(half_dim - 1, 1) + emb = torch.exp(torch.arange(half_dim, device=t.device, dtype=t.dtype) * -emb) + emb = t.unsqueeze(-1) * emb # (..., half_dim) + emb = torch.cat([emb.sin(), emb.cos()], dim=-1) # (..., dim) + if dim % 2 == 1: + emb = F.pad(emb, (0, 1)) + return emb + + def _velocity(self, x: Tensor, t_emb: Tensor) -> Tensor: + """Compute velocity v(x, t). x: (..., model_dim), t_emb: (..., hidden_dim).""" + h = self.v_in(x) # (..., hidden_dim) + h = h + t_emb # time conditioning via addition + h = self.v_act(h) # GELU + return self.v_out(h) # (..., model_dim) + + def forward(self, x: Tensor) -> tuple[Tensor, Tensor]: + """Returns (correction, cfm_loss). + + During training: cfm_loss is the flow matching objective (> 0). + During eval: cfm_loss is zero. + correction: gated velocity at t=1, to be added to x. + """ + B_seq = x.shape[:-1] # works for any leading dimensions + D = self.model_dim + + # Inference path: velocity at t=1 (clean input, no noise) + t_ones = torch.ones(*B_seq, device=x.device, dtype=x.dtype) + t_emb_ones = self.time_proj(self._sinusoidal_time_emb(t_ones, D)) + v_at_1 = self._velocity(x, t_emb_ones) + correction = torch.sigmoid(self.gate) * v_at_1 + + # Training path: auxiliary CFM loss + if self.training: + # Sample random t ~ U(0,1) per position + t = torch.rand(*B_seq, device=x.device, dtype=x.dtype) + z = torch.randn_like(x) # noise + # OT interpolant: x_t = (1-t)*z + t*x + t_expanded = t.unsqueeze(-1) # (..., 1) + x_t = (1.0 - t_expanded) * z + t_expanded * x.detach() + # Target velocity = x - z (OT conditional velocity field) + target_v = x.detach() - z + # Predict velocity at interpolated point + t_emb = self.time_proj(self._sinusoidal_time_emb(t, D)) + pred_v = self._velocity(x_t, t_emb) + cfm_loss = F.mse_loss(pred_v, target_v) + else: + cfm_loss = x.new_zeros(()) + + return correction, cfm_loss + + +class TTTLinearRefiner(nn.Module): + """E2E TTT-Linear refiner (Sun et al., 2024, arXiv:2407.04620). + + Applied to final hidden states [B, L, D] before lm_head. The hidden + state is a per-head linear model W updated via mini-batch gradient + descent on a learned self-supervised reconstruction task. Uses the + dual form for GPU efficiency. Trained end-to-end: the outer loop + optimizes theta_K, theta_V, theta_Q projections and the inner-loop + initialization W_0, making the model learn WHAT to compress. + """ + + def __init__(self, model_dim: int, num_heads: int = 8, + mini_batch_size: int = 16, ttt_base_lr: float = 1.0): + super().__init__() + self.model_dim = model_dim + self.num_heads = num_heads + self.head_dim = model_dim // num_heads + self.mini_batch_size = mini_batch_size + self.ttt_base_lr = ttt_base_lr + self.eta = ttt_base_lr / self.head_dim + + # Learnable projections (outer-loop) + self.theta_K = nn.Linear(model_dim, model_dim, bias=False) + self.theta_V = nn.Linear(model_dim, model_dim, bias=False) + self.theta_Q = nn.Linear(model_dim, model_dim, bias=False) + + # Inner-loop model: W_0, b_0 + self.W1 = nn.Parameter(torch.normal(0, 0.02, + size=(num_heads, self.head_dim, self.head_dim))) + self.b1 = nn.Parameter(torch.zeros(num_heads, 1, self.head_dim)) + + # Inner-loop LayerNorm (outer-loop trainable) + self.ttt_ln_w = nn.Parameter(torch.ones(num_heads, self.head_dim)) + self.ttt_ln_b = nn.Parameter(torch.zeros(num_heads, self.head_dim)) + + # Output projection + gate (starts near zero for residual safety) + self.o_proj = nn.Linear(model_dim, model_dim, bias=False) + self.post_norm = nn.LayerNorm(model_dim, eps=1e-6) + self.gate = nn.Parameter(torch.tensor(-5.0)) + nn.init.zeros_(self.o_proj.weight) + self.o_proj._zero_init = True + + @staticmethod + def _ln_fwd(x: Tensor, gamma: Tensor, beta: Tensor, eps: float = 1e-6) -> Tensor: + mu = x.mean(dim=-1, keepdim=True) + var = x.var(dim=-1, keepdim=True, unbiased=False) + x_hat = (x - mu) / torch.sqrt(var + eps) + return gamma * x_hat + beta + + @staticmethod + def _ln_fused_l2_bwd(x: Tensor, target: Tensor, gamma: Tensor, beta: Tensor, + eps: float = 1e-6) -> Tensor: + D = x.shape[-1] + mu = x.mean(dim=-1, keepdim=True) + var = x.var(dim=-1, keepdim=True, unbiased=False) + std = torch.sqrt(var + eps) + x_hat = (x - mu) / std + y = gamma * x_hat + beta + grad_output = y - target + grad_x_hat = grad_output * gamma + z = (1.0 / D) * ( + D * grad_x_hat + - grad_x_hat.sum(dim=-1, keepdim=True) + - x_hat * (grad_x_hat * x_hat).sum(dim=-1, keepdim=True) + ) / std + return z + + def forward(self, x: Tensor) -> Tensor: + """x: [B, L, D]. Returns [B, L, D] additive refinement (gated).""" + B, L, D = x.shape + NH, HD, b = self.num_heads, self.head_dim, self.mini_batch_size + eta = self.eta + + num_mb = L // b + usable_L = num_mb * b + x_proc = x[:, :usable_L] if usable_L < L else x + + # Project: [B, L, D] -> [B, NH, L, HD] + XK = self.theta_K(x_proc).reshape(B, usable_L, NH, HD).permute(0, 2, 1, 3) + XV = self.theta_V(x_proc).reshape(B, usable_L, NH, HD).permute(0, 2, 1, 3) + XQ = self.theta_Q(x_proc).reshape(B, usable_L, NH, HD).permute(0, 2, 1, 3) + + # Reshape into mini-batches: [B, NH, num_mb, b, HD] + XK = XK.reshape(B, NH, num_mb, b, HD) + XV = XV.reshape(B, NH, num_mb, b, HD) + XQ = XQ.reshape(B, NH, num_mb, b, HD) + + ln_w = self.ttt_ln_w.reshape(1, NH, 1, HD) + ln_b = self.ttt_ln_b.reshape(1, NH, 1, HD) + + # Causal mask for dual form + causal = torch.tril(torch.ones(b, b, device=x.device, dtype=x.dtype)) + + # Initialize inner-loop model + W = self.W1.unsqueeze(0).expand(B, -1, -1, -1).clone() + bias = self.b1.unsqueeze(0).expand(B, -1, -1, -1).clone() + + all_out = [] + for m in range(num_mb): + xk = XK[:, :, m] + xv = XV[:, :, m] + xq = XQ[:, :, m] + + Z1 = xk @ W + bias + target = xv - xk + + grad = self._ln_fused_l2_bwd(Z1, target, ln_w, ln_b) + + Attn = causal.unsqueeze(0).unsqueeze(0) * (xq @ xk.transpose(-2, -1)) + cumgrad = (causal.unsqueeze(0).unsqueeze(0) @ grad) + b_bar = bias - eta * cumgrad + Z_bar = xq @ W - eta * Attn @ grad + b_bar + + W = W - eta * xk.transpose(-2, -1) @ grad + bias = bias - eta * grad.sum(dim=-2, keepdim=True) + + Z_bar = self._ln_fwd(Z_bar, ln_w, ln_b) + out_mb = xq + Z_bar + all_out.append(out_mb) + + z = torch.cat(all_out, dim=2).permute(0, 2, 1, 3).reshape(B, usable_L, D) + z = self.post_norm(z) + z = self.o_proj(z) + result = torch.sigmoid(self.gate) * z + + if usable_L < L: + pad = torch.zeros(B, L - usable_L, D, device=x.device, dtype=x.dtype) + result = torch.cat([result, pad], dim=1) + + return result + + +class GPT(nn.Module): + def __init__( + self, + vocab_size: int, + num_layers: int, + model_dim: int, + num_heads: int, + num_kv_heads: int, + mlp_mult: int, + tie_embeddings: bool, + tied_embed_init_std: float, + logit_softcap: float, + rope_base: float, + qk_gain_init: float, + mtp_num_heads: int = 0, + mtp_loss_weight: float = 0.1, + bigram_vocab_size: int = 0, + bigram_dim: int = 128, + xsa_last_n: int = 0, + rope_dims: int = 0, + ln_scale: bool = False, + value_residual: bool = False, + gated_attention: bool = False, + canon_last_n: int = 0, + canon_kernel: int = 4, + canon_delta_gate_init: float = -4.0, + flow_enabled: bool = False, + flow_latent_dim: int = 64, + flow_hidden_dim: int = 256, + flow_init_scale: float = 0.01, + native_flow_enabled: bool = False, + native_flow_hidden_dim: int = 256, + native_flow_init_scale: float = 0.01, + native_flow_loss_weight: float = 0.1, + e2e_ttt_enabled: bool = False, + e2e_ttt_num_heads: int = 8, + e2e_ttt_mini_batch: int = 16, + e2e_ttt_base_lr: float = 1.0, + ): + super().__init__() + if logit_softcap <= 0.0: + raise ValueError(f"logit_softcap must be positive, got {logit_softcap}") + self.tie_embeddings = tie_embeddings + self.tied_embed_init_std = tied_embed_init_std + self.logit_softcap = logit_softcap + self.mtp_num_heads = mtp_num_heads + self.mtp_loss_weight = mtp_loss_weight + self.tok_emb = nn.Embedding(vocab_size, model_dim) + self.bigram = BigramHashEmbedding(bigram_vocab_size, bigram_dim, model_dim) if bigram_vocab_size > 0 else None + self.smear = SmearGate(model_dim) + self.num_encoder_layers = num_layers // 2 + self.num_decoder_layers = num_layers - self.num_encoder_layers + self.num_skip_weights = min(self.num_encoder_layers, self.num_decoder_layers) + self.skip_weights = nn.Parameter(torch.ones(self.num_skip_weights, model_dim, dtype=torch.float32)) + canon_start = num_layers - canon_last_n if canon_last_n > 0 else num_layers + self.blocks = nn.ModuleList( + [ + Block( + model_dim, + num_heads, + num_kv_heads, + mlp_mult, + rope_base, + qk_gain_init, + rope_dims=rope_dims, + layer_idx=i, + ln_scale=ln_scale, + value_residual=value_residual, + gated_attention=gated_attention, + canon_kernel=canon_kernel if i >= canon_start else 0, + canon_delta_gate_init=canon_delta_gate_init, + ) + for i in range(num_layers) + ] + ) + self.final_norm = RMSNorm() + self.lm_head = None if tie_embeddings else CastedLinear(model_dim, vocab_size, bias=False) + if self.lm_head is not None: + self.lm_head._zero_init = True + self.flow_refiner = FlowRefiner(model_dim, flow_latent_dim, flow_hidden_dim, flow_init_scale) if flow_enabled else None + self.native_flow = NativeFlowMatcher(model_dim, native_flow_hidden_dim, native_flow_init_scale) if native_flow_enabled else None + self.native_flow_loss_weight = native_flow_loss_weight + self.ttt_refiner = TTTLinearRefiner(model_dim, e2e_ttt_num_heads, e2e_ttt_mini_batch, e2e_ttt_base_lr) if e2e_ttt_enabled else None + self.mtp_heads = nn.ModuleList( + [CastedLinear(model_dim, vocab_size, bias=False) for _ in range(mtp_num_heads)] + ) + for head in self.mtp_heads: + head._zero_init = True + if xsa_last_n > 0: + for i in range(max(0, num_layers - xsa_last_n), num_layers): + self.blocks[i].attn.use_xsa = True + self._init_weights() + + def _init_weights(self) -> None: + if self.tie_embeddings: + nn.init.normal_(self.tok_emb.weight, mean=0.0, std=self.tied_embed_init_std) + num_layers = len(self.blocks) + for name, module in self.named_modules(): + if isinstance(module, nn.Linear): + if getattr(module, "_zero_init", False): + nn.init.zeros_(module.weight) + elif module.weight.ndim == 2 and module.weight.shape[0] >= 64 and module.weight.shape[1] >= 64: + nn.init.orthogonal_(module.weight, gain=1.0) + if ".proj." in name or name.endswith(".proj"): + with torch.no_grad(): + module.weight.mul_(1.0 / math.sqrt(2 * num_layers)) + + def forward(self, input_ids: Tensor, target_ids: Tensor) -> Tensor: + x = self.tok_emb(input_ids) + if self.bigram is not None: + x = x + self.bigram(input_ids) + x = F.rms_norm(x, (x.size(-1),)) + x = self.smear(x) + x0 = x + v0 = None + skips: list[Tensor] = [] + + for i in range(self.num_encoder_layers): + x, raw_v = self.blocks[i](x, x0, v0=v0) + if v0 is None and raw_v is not None: + v0 = raw_v + skips.append(x) + for i in range(self.num_decoder_layers): + if skips: + x = x + self.skip_weights[i].to(dtype=x.dtype)[None, None, :] * skips.pop() + x, _ = self.blocks[self.num_encoder_layers + i](x, x0, v0=v0) + + x = self.final_norm(x) + if self.ttt_refiner is not None: + x = x + self.ttt_refiner(x) + if self.flow_refiner is not None: + x = x + self.flow_refiner(x) + nflow_loss = x.new_zeros(()) + if self.native_flow is not None: + correction, nflow_loss = self.native_flow(x) + x = x + correction + x_flat = x.reshape(-1, x.size(-1)) + targets = target_ids.reshape(-1) + if self.tie_embeddings: + logits_proj = F.linear(x_flat, self.tok_emb.weight) + else: + if self.lm_head is None: + raise RuntimeError("lm_head is required when tie_embeddings=False") + logits_proj = self.lm_head(x_flat) + logits = self.logit_softcap * torch.tanh(logits_proj / self.logit_softcap) + main_loss = F.cross_entropy(logits.float(), targets, reduction="mean") + + if self.training and self.mtp_num_heads > 0 and self.mtp_loss_weight > 0.0: + _, seqlen, dim = x.shape + mtp_loss_sum = x.new_zeros(()) + mtp_loss_count = 0 + for k, mtp_head in enumerate(self.mtp_heads): + valid_t = seqlen - (k + 1) + if valid_t <= 0: + continue + mtp_hidden = x[:, :valid_t, :].reshape(-1, dim) + mtp_targets = target_ids[:, k + 1 :].reshape(-1) + mtp_logits_proj = mtp_head(mtp_hidden) + mtp_logits = self.logit_softcap * torch.tanh(mtp_logits_proj / self.logit_softcap) + mtp_loss_sum = mtp_loss_sum + F.cross_entropy(mtp_logits.float(), mtp_targets, reduction="mean") + mtp_loss_count += 1 + if mtp_loss_count > 0: + main_loss = main_loss + self.mtp_loss_weight * (mtp_loss_sum / mtp_loss_count) + + if self.native_flow is not None and self.training: + main_loss = main_loss + self.native_flow_loss_weight * nflow_loss + + return main_loss + + def forward_logits(self, input_ids: Tensor) -> Tensor: + """Return logits (bsz, seq_len, vocab) without computing loss.""" + x = self.tok_emb(input_ids) + if self.bigram is not None: + x = x + self.bigram(input_ids) + x = F.rms_norm(x, (x.size(-1),)) + x = self.smear(x) + x0 = x + v0 = None + skips: list[Tensor] = [] + for i in range(self.num_encoder_layers): + x, raw_v = self.blocks[i](x, x0, v0=v0) + if v0 is None and raw_v is not None: + v0 = raw_v + skips.append(x) + for i in range(self.num_decoder_layers): + if skips: + x = x + self.skip_weights[i].to(dtype=x.dtype)[None, None, :] * skips.pop() + x, _ = self.blocks[self.num_encoder_layers + i](x, x0, v0=v0) + x = self.final_norm(x) + if self.ttt_refiner is not None: + x = x + self.ttt_refiner(x) + if self.flow_refiner is not None: + x = x + self.flow_refiner(x) + if self.native_flow is not None: + correction, _ = self.native_flow(x) + x = x + correction + if self.tie_embeddings: + logits_proj = F.linear(x, self.tok_emb.weight) + else: + logits_proj = self.lm_head(x) + return self.logit_softcap * torch.tanh(logits_proj / self.logit_softcap) + + +# ----------------------------- +# SLIDING WINDOW EVALUATION +# ----------------------------- + +def eval_val_sliding( + args: Hyperparameters, + base_model: nn.Module, + rank: int, + world_size: int, + device: torch.device, + val_tokens: Tensor, + base_bytes_lut: Tensor, + has_leading_space_lut: Tensor, + is_boundary_token_lut: Tensor, + stride: int, + batch_seqs: int = 32, + eval_seq_len: int | None = None, +) -> tuple[float, float]: + """Sliding window evaluation: each token scored with maximum context. + Optionally uses entropy-gated 5-gram cache (NGRAM_CACHE=1).""" + seq_len = eval_seq_len or args.train_seq_len + total_tokens = val_tokens.numel() - 1 + + window_starts = [ws for ws in range(0, total_tokens, stride) + if min(ws + seq_len, total_tokens) - ws >= 1] + total_windows = len(window_starts) + + my_s = (total_windows * rank) // world_size + my_e = (total_windows * (rank + 1)) // world_size + my_windows = window_starts[my_s:my_e] + + loss_sum = torch.zeros((), device=device, dtype=torch.float64) + token_count = torch.zeros((), device=device, dtype=torch.float64) + byte_count = torch.zeros((), device=device, dtype=torch.float64) + + # N-gram eval cache with multi-order backoff + entropy-adaptive alpha (PR #702 inspired) + _ngram_default = "1" if world_size > 1 else "0" + use_ngram = bool(int(os.environ.get("NGRAM_CACHE", _ngram_default))) + ngram_alpha = float(os.environ.get("NGRAM_ALPHA", "0.40")) + ngram_min_count = int(os.environ.get("NGRAM_MIN_COUNT", "2")) + ngram_order = int(os.environ.get("NGRAM_ORDER", "7")) + ngram_min_order = int(os.environ.get("NGRAM_MIN_ORDER", "2")) + ngram_buckets = int(os.environ.get("NGRAM_BUCKETS", "4194304")) + ngram_entropy = bool(int(os.environ.get("NGRAM_ENTROPY", "1"))) + ngram_ent_base = float(os.environ.get("NGRAM_ENT_BASE", "0.05")) + ngram_ent_range = float(os.environ.get("NGRAM_ENT_RANGE", "0.55")) + ngram_ent_scale = float(os.environ.get("NGRAM_ENT_SCALE", "2.0")) + ngram_ent_thresh = float(os.environ.get("NGRAM_ENT_THRESH", "4.0")) + if use_ngram: + val_np = val_tokens.cpu().numpy() + _n_orders = ngram_order - ngram_min_order + 1 + ctx_tables = [np.zeros((ngram_buckets,), dtype=np.uint32) for _ in range(_n_orders)] + full_tables = [np.zeros((ngram_buckets,), dtype=np.uint32) for _ in range(_n_orders)] + ng_mask = np.uint64(ngram_buckets - 1) + ng_primes = np.array( + [np.uint64(36313), np.uint64(27191), np.uint64(51647), np.uint64(81929), + np.uint64(131071), np.uint64(175447), np.uint64(209591)], + dtype=np.uint64, + ) + print(f"ngram_cache:enabled orders={ngram_min_order}-{ngram_order} backoff " + f"entropy={ngram_entropy} alpha={ngram_alpha} " + f"ent_base={ngram_ent_base} ent_range={ngram_ent_range} " + f"min_count={ngram_min_count} buckets={ngram_buckets}", flush=True) + + base_model.eval() + compiled_logits = torch.compile(base_model.forward_logits, dynamic=False, fullgraph=True) + + with torch.inference_mode(): + for bi in range(0, len(my_windows), batch_seqs): + batch_ws = my_windows[bi:bi + batch_seqs] + bsz = len(batch_ws) + + x_batch = torch.zeros(bsz, seq_len, dtype=torch.int64, device=device) + y_batch = torch.zeros(bsz, seq_len, dtype=torch.int64, device=device) + wlens: list[int] = [] + + for i, ws in enumerate(batch_ws): + end = min(ws + seq_len, total_tokens) + wlen = end - ws + wlens.append(wlen) + chunk = val_tokens[ws:end + 1].to(dtype=torch.int64, device=device) + x_batch[i, :wlen] = chunk[:-1] + y_batch[i, :wlen] = chunk[1:] + + with torch.autocast(device_type="cuda", dtype=torch.bfloat16): + logits = compiled_logits(x_batch) + + nll = F.cross_entropy( + logits.reshape(-1, logits.size(-1)).float(), + y_batch.reshape(-1), + reduction="none", + ).reshape(bsz, seq_len) + + for i, ws in enumerate(batch_ws): + wlen = wlens[i] + s = 0 if ws == 0 else max(wlen - stride, 0) + seg_len = wlen - s + if seg_len <= 0: + continue + + scored_nll = nll[i, s:wlen].to(torch.float64) + + if use_ngram: + seg_nll_np = scored_nll.cpu().numpy() + seg_model_p = np.exp(-seg_nll_np) + n_seg = len(seg_nll_np) + global_j = np.arange(ws + s + 1, ws + wlen + 1, dtype=np.int64) + + # Entropy-adaptive alpha: compute from model logits (GPU) + if ngram_entropy: + with torch.no_grad(): + lp = F.log_softmax(logits[i, s:wlen].float(), dim=-1) + seg_ent = -(lp.exp() * lp).sum(dim=-1).cpu().numpy() + alpha_per_tok = ngram_ent_base + ngram_ent_range / ( + 1.0 + np.exp(-ngram_ent_scale * (seg_ent - ngram_ent_thresh))) + + # Precompute hashes for all orders + order_data = [] # (v_idx, ctx_key, full_key) per order + for oi in range(_n_orders): + ctx_w = ngram_min_order + oi - 1 + valid = global_j >= ctx_w + if not valid.any(): + order_data.append(None) + continue + v_idx = np.nonzero(valid)[0] + jv = global_j[v_idx] + ctx_hash = np.zeros(len(jv), dtype=np.uint64) + for k in range(ctx_w): + tok = val_np[jv - (ctx_w - k)].astype(np.uint64) + ctx_hash ^= tok * ng_primes[k % len(ng_primes)] + ctx_key = (ctx_hash & ng_mask).astype(np.int64) + tgt_np = val_np[jv].astype(np.uint64) + full_key = ((ctx_hash ^ (tgt_np * ng_primes[ctx_w % len(ng_primes)])) & ng_mask).astype(np.int64) + order_data.append((v_idx, ctx_key, full_key)) + + # Multi-order backoff: highest order first, fill unmatched with lower orders + best_p_ng = np.full(n_seg, -1.0) + for oi in range(_n_orders - 1, -1, -1): + if order_data[oi] is None: + continue + v_idx, ctx_key, full_key = order_data[oi] + ctx_counts = ctx_tables[oi][ctx_key].astype(np.float64) + full_counts = full_tables[oi][full_key].astype(np.float64) + has_match = ctx_counts >= float(ngram_min_count) + needs_fill = has_match & (best_p_ng[v_idx] < 0) + if needs_fill.any(): + fill_idx = v_idx[needs_fill] + p = np.minimum(full_counts[needs_fill], ctx_counts[needs_fill]) / np.maximum(ctx_counts[needs_fill], 1.0) + best_p_ng[fill_idx] = np.clip(p, 0.0, 1.0) + + # Mix model probability with n-gram + has_match = best_p_ng >= 0 + if has_match.any(): + if ngram_entropy: + alpha = alpha_per_tok[has_match] + else: + alpha = ngram_alpha + seg_model_p[has_match] = (1.0 - alpha) * seg_model_p[has_match] + alpha * best_p_ng[has_match] + seg_nll_np = -np.log(np.clip(seg_model_p, 1e-12, 1.0)) + + # Score-first: update ALL order tables AFTER scoring + for oi in range(_n_orders): + if order_data[oi] is None: + continue + v_idx, ctx_key, full_key = order_data[oi] + np.add.at(ctx_tables[oi], ctx_key, 1) + np.add.at(full_tables[oi], full_key, 1) + + scored_nll = torch.from_numpy(seg_nll_np).to(dtype=torch.float64, device=device) + + loss_sum += scored_nll.sum() + token_count += float(seg_len) + tgt = y_batch[i, s:wlen] + prev = x_batch[i, s:wlen] + tb = base_bytes_lut[tgt].to(torch.float64) + tb += (has_leading_space_lut[tgt] & ~is_boundary_token_lut[prev]).to(torch.float64) + byte_count += tb.sum() + + if dist.is_available() and dist.is_initialized(): + dist.all_reduce(loss_sum, op=dist.ReduceOp.SUM) + dist.all_reduce(token_count, op=dist.ReduceOp.SUM) + dist.all_reduce(byte_count, op=dist.ReduceOp.SUM) + + val_loss = (loss_sum / token_count).item() + bits_per_token = val_loss / math.log(2.0) + tokens_per_byte = token_count.item() / byte_count.item() + base_model.train() + return val_loss, bits_per_token * tokens_per_byte + + +# ----------------------------- +# TEST-TIME TRAINING (TTT) +# ----------------------------- + +def _clear_rotary_caches(model: nn.Module) -> None: + """Clear cached RoPE tensors to avoid 'Inference tensors cannot be saved for backward'.""" + for m in model.modules(): + if isinstance(m, Rotary): + m._cos_cached = None + m._sin_cached = None + m._seq_len_cached = 0 + + +def ttt_adapt(args: Hyperparameters, base_model: nn.Module, device: torch.device, + val_tokens: Tensor, rank: int = 0, world_size: int = 1, + log_fn=None) -> None: + """Score-first TTT: process val data in chunks, score each chunk first + (inference_mode), then train on scored tokens. Compliant with Issue #677.""" + seq_len = args.train_seq_len + total_tokens = val_tokens.numel() - 1 + chunk_tokens = args.ttt_chunk_tokens + batch_seqs = args.ttt_batch_seqs + + # Freeze early blocks + if args.ttt_freeze_blocks > 0: + for i, block in enumerate(base_model.blocks): + if i < args.ttt_freeze_blocks: + for p in block.parameters(): + p.requires_grad_(False) + + ttt_params = [p for p in base_model.parameters() if p.requires_grad] + if args.ttt_optimizer == "adamw": + optimizer = torch.optim.AdamW(ttt_params, lr=args.ttt_lr, weight_decay=0.0) + else: + optimizer = torch.optim.SGD(ttt_params, lr=args.ttt_lr, momentum=args.ttt_momentum) + + t0 = time.perf_counter() + chunk_idx = 0 + + for chunk_start in range(0, total_tokens - seq_len, chunk_tokens): + chunk_end = min(chunk_start + chunk_tokens, total_tokens) + chunk_len = chunk_end - chunk_start + n_seqs = chunk_len // seq_len + if n_seqs == 0: + break + + my_start = (n_seqs * rank) // world_size + my_end = (n_seqs * (rank + 1)) // world_size + if my_end <= my_start: + continue + + # Phase 1: Score chunk under inference_mode (forward only) + base_model.eval() + with torch.inference_mode(): + for si in range(my_start, my_end, batch_seqs): + se = min(si + batch_seqs, my_end) + raw_s = chunk_start + si * seq_len + raw_e = chunk_start + se * seq_len + 1 + local = val_tokens[raw_s:raw_e].to(device=device, dtype=torch.int64) + x = local[:-1].reshape(-1, seq_len) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16): + base_model.forward_logits(x) + + # Phase 2: Train on scored tokens (K epochs) + base_model.train() + for epoch in range(args.ttt_epochs): + for si in range(my_start, my_end, batch_seqs): + se = min(si + batch_seqs, my_end) + raw_s = chunk_start + si * seq_len + raw_e = chunk_start + se * seq_len + 1 + local = val_tokens[raw_s:raw_e].to(device=device, dtype=torch.int64) + x = local[:-1].reshape(-1, seq_len) + y = local[1:].reshape(-1, seq_len) + + optimizer.zero_grad(set_to_none=True) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16): + loss = base_model(x, y) + loss.backward() + + if world_size > 1: + for p in ttt_params: + if p.grad is not None: + dist.all_reduce(p.grad, op=dist.ReduceOp.AVG) + + torch.nn.utils.clip_grad_norm_(ttt_params, 1.0) + optimizer.step() + + chunk_idx += 1 + if log_fn and chunk_idx % 20 == 0: + log_fn(f"ttt:chunk={chunk_idx} elapsed={time.perf_counter()-t0:.1f}s") + + # Restore all params + for p in base_model.parameters(): + p.requires_grad_(True) + + if log_fn: + log_fn(f"ttt:done chunks={chunk_idx} elapsed={time.perf_counter()-t0:.1f}s") + + +def eval_val_sliding_ttt( + args: Hyperparameters, + base_model: nn.Module, + rank: int, + world_size: int, + device: torch.device, + val_tokens: Tensor, + base_bytes_lut: Tensor, + has_leading_space_lut: Tensor, + is_boundary_token_lut: Tensor, + stride: int, + batch_seqs: int = 32, + eval_seq_len: int | None = None, + log_fn=None, +) -> tuple[float, float]: + """Legal single-pass TTT: score each chunk with sliding windows, then train on it. + Tokens are always scored BEFORE any training on their chunk, so the evaluation + is never contaminated by future information.""" + seq_len = eval_seq_len or args.train_seq_len + total_tokens = val_tokens.numel() - 1 + ttt_chunk = args.ttt_chunk_tokens + + # Build window starts (same logic as eval_val_sliding) + window_starts = [ws for ws in range(0, total_tokens, stride) + if min(ws + seq_len, total_tokens) - ws >= stride or ws == 0] + + # Map each window to the chunk that contains its first scored token + num_chunks = (total_tokens + ttt_chunk - 1) // ttt_chunk + chunk_windows: list[list[int]] = [[] for _ in range(num_chunks)] + for ws in window_starts: + end = min(ws + seq_len, total_tokens) + wlen = end - ws + s = 0 if ws == 0 else max(wlen - stride, 0) + scored_start = ws + s + ci = min(scored_start // ttt_chunk, num_chunks - 1) + chunk_windows[ci].append(ws) + + if log_fn: + log_fn(f"ttt_sliding:start chunks={num_chunks} chunk_tokens={ttt_chunk} " + f"total_windows={len(window_starts)} stride={stride} " + f"ttt_lr={args.ttt_lr} ttt_epochs={args.ttt_epochs} " + f"ttt_optimizer={args.ttt_optimizer} freeze_blocks={args.ttt_freeze_blocks}") + + loss_sum = torch.zeros((), device=device, dtype=torch.float64) + token_count = torch.zeros((), device=device, dtype=torch.float64) + byte_count = torch.zeros((), device=device, dtype=torch.float64) + + # Freeze first N blocks + n_blocks = len(base_model.blocks) + frozen_block_ids = set(range(min(args.ttt_freeze_blocks, n_blocks))) + + ttt_params = [] + for name, p in base_model.named_parameters(): + freeze = False + for bi in frozen_block_ids: + if f"blocks.{bi}." in name: + freeze = True + break + if freeze: + p.requires_grad_(False) + else: + p.requires_grad_(True) + ttt_params.append(p) + + if log_fn: + log_fn(f"ttt_sliding:params unfrozen={sum(p.numel() for p in ttt_params)} " + f"frozen={sum(p.numel() for p in base_model.parameters() if not p.requires_grad)}") + + if args.ttt_optimizer == "adamw": + optimizer = torch.optim.AdamW(ttt_params, lr=args.ttt_lr, weight_decay=0.0) + else: + optimizer = torch.optim.SGD(ttt_params, lr=args.ttt_lr, momentum=args.ttt_momentum) + + t0 = time.perf_counter() + + for ci in range(num_chunks): + windows = chunk_windows[ci] + if not windows: + continue + + # --- Phase 1: SCORE this chunk's windows --- + my_s = (len(windows) * rank) // world_size + my_e = (len(windows) * (rank + 1)) // world_size + my_windows = windows[my_s:my_e] + + base_model.eval() + with torch.inference_mode(): + for bi in range(0, len(my_windows), batch_seqs): + batch_ws = my_windows[bi:bi + batch_seqs] + bsz = len(batch_ws) + x_batch = torch.zeros(bsz, seq_len, dtype=torch.int64, device=device) + y_batch = torch.zeros(bsz, seq_len, dtype=torch.int64, device=device) + wlens: list[int] = [] + for i, ws in enumerate(batch_ws): + end = min(ws + seq_len, total_tokens) + wlen = end - ws + wlens.append(wlen) + chunk_tok = val_tokens[ws:end + 1].to(dtype=torch.int64, device=device) + x_batch[i, :wlen] = chunk_tok[:-1] + y_batch[i, :wlen] = chunk_tok[1:] + with torch.autocast(device_type="cuda", dtype=torch.bfloat16): + logits = base_model.forward_logits(x_batch) + nll = F.cross_entropy( + logits.reshape(-1, logits.size(-1)).float(), + y_batch.reshape(-1), reduction="none", + ).reshape(bsz, seq_len) + for i, ws in enumerate(batch_ws): + wlen = wlens[i] + s = 0 if ws == 0 else max(wlen - stride, 0) + scored_nll = nll[i, s:wlen].to(torch.float64) + loss_sum += scored_nll.sum() + token_count += float(wlen - s) + tgt, prev = y_batch[i, s:wlen], x_batch[i, s:wlen] + tb = base_bytes_lut[tgt].to(torch.float64) + tb += (has_leading_space_lut[tgt] & ~is_boundary_token_lut[prev]).to(torch.float64) + byte_count += tb.sum() + + # --- Phase 2: TRAIN on this chunk's tokens (already scored = legal) --- + _clear_rotary_caches(base_model) + is_last_chunk = (ci == num_chunks - 1) + if not is_last_chunk and args.ttt_epochs > 0: + base_model.train() + chunk_start = ci * ttt_chunk + chunk_end = min((ci + 1) * ttt_chunk, total_tokens) + chunk_seqs = (chunk_end - chunk_start) // seq_len + if chunk_seqs > 0: + # Cosine LR schedule with 5% warmup + warmup_chunks = max(num_chunks // 20, 1) + if ci < warmup_chunks: + lr_scale = (ci + 1) / warmup_chunks + else: + progress = (ci - warmup_chunks) / max(num_chunks - 1 - warmup_chunks, 1) + lr_scale = 0.5 * (1.0 + math.cos(math.pi * progress)) + cos_lr = args.ttt_lr * lr_scale + for pg in optimizer.param_groups: + pg['lr'] = cos_lr + + my_seq_s = (chunk_seqs * rank) // world_size + my_seq_e = (chunk_seqs * (rank + 1)) // world_size + my_chunk_seqs = my_seq_e - my_seq_s + + for _ep in range(args.ttt_epochs): + for bs in range(0, my_chunk_seqs, args.ttt_batch_seqs): + be = min(bs + args.ttt_batch_seqs, my_chunk_seqs) + actual_bs = my_seq_s + bs + actual_be = my_seq_s + be + start_tok = chunk_start + actual_bs * seq_len + end_tok = chunk_start + actual_be * seq_len + 1 + if end_tok > val_tokens.numel(): + continue + local = val_tokens[start_tok:end_tok].to(device=device, dtype=torch.int64) + x = local[:-1].reshape(-1, seq_len) + y = local[1:].reshape(-1, seq_len) + optimizer.zero_grad(set_to_none=True) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16): + loss = base_model(x, y) + loss.backward() + if world_size > 1: + for p in ttt_params: + if p.grad is not None: + dist.all_reduce(p.grad, op=dist.ReduceOp.AVG) + torch.nn.utils.clip_grad_norm_(ttt_params, 1.0) + optimizer.step() + + if log_fn and (ci % 10 == 0 or ci == num_chunks - 1): + elapsed = time.perf_counter() - t0 + rl = loss_sum.item() / max(token_count.item(), 1) + rbpb = rl / math.log(2.0) * (token_count.item() / max(byte_count.item(), 1)) if token_count.item() > 0 else 0.0 + log_fn(f" ttt_chunk [{ci+1}/{num_chunks}] bpb={rbpb:.6f} time={elapsed:.1f}s") + + if dist.is_available() and dist.is_initialized(): + dist.all_reduce(loss_sum, op=dist.ReduceOp.SUM) + dist.all_reduce(token_count, op=dist.ReduceOp.SUM) + dist.all_reduce(byte_count, op=dist.ReduceOp.SUM) + + val_loss = (loss_sum / token_count).item() + val_bpb = val_loss / math.log(2.0) * (token_count.item() / byte_count.item()) + + # Restore all params and return to eval mode + for p in base_model.parameters(): + p.requires_grad_(True) + base_model.eval() + + if log_fn: + log_fn(f"ttt_sliding:done val_loss={val_loss:.6f} val_bpb={val_bpb:.6f} " + f"elapsed={time.perf_counter() - t0:.1f}s") + return val_loss, val_bpb + + +# ----------------------------- +# INT6 MIXED QUANTIZATION (transplanted from working diagnostic scripts) +# ----------------------------- + +def _classify_param(name: str) -> str: + if "tok_emb" in name or "lm_head" in name: + return "embed" + if ".mlp." in name: + return "mlp" + if ".attn." in name or (".proj." in name and ".mlp." not in name): + return "attn" + return "other" + +def quantize_int6_per_row(t: Tensor, qmax: int = 31) -> tuple[Tensor, Tensor]: + t32 = t.float() + qmin = -qmax - 1 + pct = CastedLinear._quant_percentile + if t32.ndim == 2: + row_max = (torch.quantile(t32.abs(), pct, dim=1) if pct < 1.0 + else t32.abs().amax(dim=1)) + scale = (row_max / float(qmax)).clamp_min(1.0 / float(qmax)).to(torch.float16) + clipped = t32.clamp(-row_max[:, None], row_max[:, None]) + q = torch.clamp(torch.round(clipped / scale.float()[:, None]), qmin, qmax).to(torch.int8) + return q, scale + amax = t32.abs().max().item() + scale = torch.tensor(amax / float(qmax) if amax > 0 else 1.0, dtype=torch.float16) + q = torch.clamp(torch.round(t32 / scale.float()), qmin, qmax).to(torch.int8) + return q, scale + +def mixed_quantize_int6(state_dict: dict[str, Tensor], int6_cats: set[str], + int5_layers: set[int] | None = None): + num_layers_total = max( + (int(k.split(".")[1]) for k in state_dict if k.startswith("blocks.")), + default=0, + ) + 1 + late_k_layers = set(range(num_layers_total - 2, num_layers_total)) + if int5_layers is None: + int5_layers = set() + + result: dict[str, Tensor] = {} + meta: dict[str, object] = {} + for name, tensor in state_dict.items(): + t = tensor.detach().cpu().contiguous() + cat = _classify_param(name) + if not t.is_floating_point() or t.numel() <= 65536: + result[name] = t.to(torch.float16) if t.is_floating_point() else t + meta[name] = "passthrough" + continue + if any(p in name for p in CONTROL_TENSOR_NAME_PATTERNS): + result[name] = t.float() + meta[name] = "passthrough_ctrl" + continue + # Determine layer index for int5 fallback + layer_idx = -1 + if name.startswith("blocks."): + try: + layer_idx = int(name.split(".")[1]) + except (IndexError, ValueError): + pass + # tok_emb.weight falls through to int8 via "embed" category + if cat in int6_cats and t.ndim >= 1: + qmax = 15 if layer_idx in int5_layers else 31 + q, s = quantize_int6_per_row(t, qmax=qmax) + result[name + ".q"] = q + result[name + ".scale"] = s + meta[name] = {"type": "int5" if qmax == 15 else "int6"} + else: + q, s = quantize_float_tensor(t) + result[name + ".q"] = q + result[name + ".scale"] = s + meta[name] = {"type": "int8"} + return result, meta + +def dequantize_mixed_int6(result: dict[str, Tensor], meta: dict[str, object], + template_sd: dict[str, Tensor]) -> dict[str, Tensor]: + out: dict[str, Tensor] = {} + for name, orig in template_sd.items(): + info = meta.get(name) + if info is None: + continue + orig_dtype = orig.dtype + if info in ("passthrough", "passthrough_ctrl", "passthrough_fp16"): + t = result[name] + if t.dtype == torch.float16 and orig_dtype in (torch.float32, torch.bfloat16): + t = t.to(orig_dtype) + out[name] = t + continue + q, s = result[name + ".q"], result[name + ".scale"] + if s.ndim > 0: + out[name] = (q.float() * s.float().view(q.shape[0], *([1] * (q.ndim - 1)))).to(orig_dtype) + else: + out[name] = (q.float() * float(s.item())).to(orig_dtype) + return out + + +# ----------------------------- +# TRAINING +# ----------------------------- + +def main() -> None: + global zeropower_via_newtonschulz5 + + code = Path(__file__).read_text(encoding="utf-8") + args = Hyperparameters() + zeropower_via_newtonschulz5 = torch.compile(zeropower_via_newtonschulz5) + + # ----------------------------- + # DISTRIBUTED + CUDA SETUP + # ----------------------------- + + distributed = "RANK" in os.environ and "WORLD_SIZE" in os.environ and int(os.environ.get("WORLD_SIZE", "1")) > 1 + rank = int(os.environ.get("RANK", "0")) + world_size = int(os.environ.get("WORLD_SIZE", "1")) + local_rank = int(os.environ.get("LOCAL_RANK", "0")) + if world_size <= 0: + raise ValueError(f"WORLD_SIZE must be positive, got {world_size}") + if 8 % world_size != 0: + raise ValueError(f"WORLD_SIZE={world_size} must divide 8 so grad_accum_steps stays integral") + grad_accum_steps = 8 // world_size + grad_scale = 1.0 / grad_accum_steps + if not torch.cuda.is_available(): + raise RuntimeError("CUDA is required") + device = torch.device("cuda", local_rank) + torch.cuda.set_device(device) + if distributed: + dist.init_process_group(backend="nccl", device_id=device) + dist.barrier() + master_process = rank == 0 + + # Fast math knobs + torch.backends.cuda.matmul.allow_tf32 = True + torch.backends.cudnn.allow_tf32 = True + from torch.backends.cuda import enable_cudnn_sdp, enable_flash_sdp, enable_math_sdp, enable_mem_efficient_sdp + + if _USE_FA3: + enable_cudnn_sdp(False) + enable_flash_sdp(True) + enable_mem_efficient_sdp(False) + enable_math_sdp(False) + else: + enable_cudnn_sdp(False) + enable_flash_sdp(True) + enable_mem_efficient_sdp(True) + enable_math_sdp(True) + + logfile = None + if master_process: + os.makedirs("logs", exist_ok=True) + logfile = f"logs/{args.run_id}.txt" + print(logfile) + + def log0(msg: str, console: bool = True) -> None: + if not master_process: + return + if console: + print(msg) + if logfile is not None: + with open(logfile, "a", encoding="utf-8") as f: + print(msg, file=f) + + log0(code, console=False) + log0("=" * 100, console=False) + log0(f"Running Python {sys.version}", console=False) + log0(f"Running PyTorch {torch.__version__}", console=False) + log0( + subprocess.run(["nvidia-smi"], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True, check=False).stdout, + console=False, + ) + log0("=" * 100, console=False) + + # ----------------------------- + # TOKENIZER + VALIDATION METRIC SETUP + # ----------------------------- + + random.seed(args.seed) + np.random.seed(args.seed) + torch.manual_seed(args.seed) + torch.cuda.manual_seed_all(args.seed) + + if not args.tokenizer_path.endswith(".model"): + raise ValueError(f"Script only setup for SentencePiece .model file: {args.tokenizer_path}") + sp = spm.SentencePieceProcessor(model_file=args.tokenizer_path) + if int(sp.vocab_size()) != args.vocab_size: + raise ValueError( + f"VOCAB_SIZE={args.vocab_size} does not match tokenizer vocab_size={int(sp.vocab_size())}" + ) + dataset_dir = Path(args.data_path).resolve() + actual_train_files = len(list(dataset_dir.glob("fineweb_train_*.bin"))) + effective_eval_seq_len = args.eval_seq_len if args.eval_seq_len > 0 else args.train_seq_len + val_seq_len = max(args.train_seq_len, effective_eval_seq_len) + val_tokens = load_validation_tokens(args.val_files, val_seq_len) + base_bytes_lut, has_leading_space_lut, is_boundary_token_lut = build_sentencepiece_luts( + sp, args.vocab_size, device + ) + log0(f"val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path={args.tokenizer_path}") + log0(f"train_loader:dataset:{dataset_dir.name} train_shards:{actual_train_files}") + log0(f"val_loader:shards pattern={args.val_files} tokens:{val_tokens.numel() - 1}") + + # ----------------------------- + # MODEL + OPTIMIZER SETUP + # ----------------------------- + + CastedLinear._qat_enabled = args.qat_enabled + + base_model = GPT( + vocab_size=args.vocab_size, + num_layers=args.num_layers, + model_dim=args.model_dim, + num_heads=args.num_heads, + num_kv_heads=args.num_kv_heads, + mlp_mult=args.mlp_mult, + tie_embeddings=args.tie_embeddings, + tied_embed_init_std=args.tied_embed_init_std, + logit_softcap=args.logit_softcap, + rope_base=args.rope_base, + qk_gain_init=args.qk_gain_init, + mtp_num_heads=args.mtp_num_heads, + mtp_loss_weight=args.mtp_loss_weight, + bigram_vocab_size=args.bigram_vocab_size, + bigram_dim=args.bigram_dim, + xsa_last_n=args.xsa_last_n, + rope_dims=args.rope_dims, + ln_scale=args.ln_scale, + value_residual=args.value_residual, + gated_attention=args.gated_attention, + canon_last_n=args.canon_last_n, + canon_kernel=args.canon_kernel, + canon_delta_gate_init=args.canon_delta_gate_init, + flow_enabled=args.flow_enabled, + flow_latent_dim=args.flow_latent_dim, + flow_hidden_dim=args.flow_hidden_dim, + flow_init_scale=args.flow_init_scale, + native_flow_enabled=args.native_flow_enabled, + native_flow_hidden_dim=args.native_flow_hidden_dim, + native_flow_init_scale=args.native_flow_init_scale, + native_flow_loss_weight=args.native_flow_loss_weight, + e2e_ttt_enabled=args.e2e_ttt_enabled, + e2e_ttt_num_heads=args.e2e_ttt_num_heads, + e2e_ttt_mini_batch=args.e2e_ttt_mini_batch, + e2e_ttt_base_lr=args.e2e_ttt_base_lr, + ).to(device).bfloat16() + for module in base_model.modules(): + if isinstance(module, CastedLinear): + module.float() + restore_low_dim_params_to_fp32(base_model) + compiled_model = torch.compile(base_model, dynamic=False, fullgraph=True) + model: nn.Module = DDP(compiled_model, device_ids=[local_rank], broadcast_buffers=False, static_graph=True) if distributed else compiled_model + + # Optimizer split: + # - token embedding (Adam) uses EMBED_LR + # - untied lm_head (Adam) uses HEAD_LR + # - matrix params in transformer blocks use MATRIX_LR via Muon + # - vectors/scalars use SCALAR_LR via Adam + block_named_params = list(base_model.blocks.named_parameters()) + matrix_params = [ + p + for name, p in block_named_params + if p.ndim == 2 and not any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) + ] + if base_model.mtp_num_heads > 0: + matrix_params.extend([p for p in base_model.mtp_heads.parameters() if p.ndim == 2]) + scalar_params = [ + p + for name, p in block_named_params + if p.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) + ] + if base_model.skip_weights.numel() > 0: + scalar_params.append(base_model.skip_weights) + scalar_params.append(base_model.smear.gate) + if base_model.bigram is not None: + scalar_params.append(base_model.bigram.scale) + if base_model.flow_refiner is not None: + scalar_params.extend(list(base_model.flow_refiner.parameters())) + if base_model.native_flow is not None: + scalar_params.extend(list(base_model.native_flow.parameters())) + if base_model.ttt_refiner is not None: + scalar_params.extend(list(base_model.ttt_refiner.parameters())) + token_lr = args.tied_embed_lr if args.tie_embeddings else args.embed_lr + tok_params = [{"params": [base_model.tok_emb.weight], "lr": token_lr, "base_lr": token_lr}] + if base_model.bigram is not None: + tok_params.append({"params": [base_model.bigram.embed.weight], "lr": token_lr, "base_lr": token_lr}) + if base_model.bigram.proj is not None: + matrix_params.append(base_model.bigram.proj.weight) + optimizer_tok = torch.optim.AdamW( + tok_params, + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + weight_decay=args.adam_wd, + fused=True, + ) + optimizer_muon = Muon( + matrix_params, + lr=args.matrix_lr, + momentum=args.muon_momentum, + backend_steps=args.muon_backend_steps, + weight_decay=args.muon_wd, + ) + for group in optimizer_muon.param_groups: + group["base_lr"] = args.matrix_lr + optimizer_scalar = torch.optim.AdamW( + [{"params": scalar_params, "lr": args.scalar_lr, "base_lr": args.scalar_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + weight_decay=args.adam_wd, + fused=True, + ) + optimizers: list[torch.optim.Optimizer] = [optimizer_tok, optimizer_muon, optimizer_scalar] + if base_model.lm_head is not None: + optimizer_head = torch.optim.Adam( + [{"params": [base_model.lm_head.weight], "lr": args.head_lr, "base_lr": args.head_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + fused=True, + ) + optimizers.insert(1, optimizer_head) + + n_params = sum(p.numel() for p in base_model.parameters()) + mtp_params = sum(p.numel() for p in base_model.mtp_heads.parameters()) + log0(f"model_params:{n_params}") + log0(f"mtp_num_heads:{args.mtp_num_heads} mtp_loss_weight:{args.mtp_loss_weight} mtp_params:{mtp_params}") + log0(f"world_size:{world_size} grad_accum_steps:{grad_accum_steps}") + log0(f"sdp_backends:fa3={_USE_FA3} cudnn=False flash=True mem_efficient={not _USE_FA3} math={not _USE_FA3}") + log0(f"attention_mode:gqa num_heads:{args.num_heads} num_kv_heads:{args.num_kv_heads}") + log0( + f"tie_embeddings:{args.tie_embeddings} embed_lr:{token_lr} " + f"head_lr:{args.head_lr if base_model.lm_head is not None else 0.0} " + f"matrix_lr:{args.matrix_lr} scalar_lr:{args.scalar_lr}" + ) + log0( + f"train_batch_tokens:{args.train_batch_tokens} train_seq_len:{args.train_seq_len} " + f"iterations:{args.iterations} warmup_steps:{args.warmup_steps} " + f"max_wallclock_seconds:{args.max_wallclock_seconds:.3f}" + ) + log0(f"seed:{args.seed}") + + # ----------------------------- + # DATA LOADER & MODEL WARMUP + # ----------------------------- + + train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) + + def zero_grad_all() -> None: + for opt in optimizers: + opt.zero_grad(set_to_none=True) + + max_wallclock_ms = 1000.0 * args.max_wallclock_seconds if args.max_wallclock_seconds > 0 else None + + def lr_mul(step: int, elapsed_ms: float) -> float: + if args.warmdown_iters <= 0: + return 1.0 + if max_wallclock_ms is None: + warmdown_start = max(args.iterations - args.warmdown_iters, 0) + return max((args.iterations - step) / max(args.warmdown_iters, 1), 0.0) if warmdown_start <= step < args.iterations else 1.0 + step_ms = elapsed_ms / max(step, 1) + warmdown_ms = args.warmdown_iters * step_ms + remaining_ms = max(max_wallclock_ms - elapsed_ms, 0.0) + return remaining_ms / max(warmdown_ms, 1e-9) if remaining_ms <= warmdown_ms else 1.0 + + # Warmup primes the compiled forward/backward/optimizer paths, then we restore the + # initial weights/optimizer state so measured training starts from the true init. + if args.warmup_steps > 0: + initial_model_state = {name: tensor.detach().cpu().clone() for name, tensor in base_model.state_dict().items()} + initial_optimizer_states = [copy.deepcopy(opt.state_dict()) for opt in optimizers] + model.train() + for warmup_step in range(args.warmup_steps): + zero_grad_all() + for micro_step in range(grad_accum_steps): + if distributed: + model.require_backward_grad_sync = micro_step == grad_accum_steps - 1 + x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + warmup_loss = model(x, y) + (warmup_loss * grad_scale).backward() + for opt in optimizers: + opt.step() + zero_grad_all() + if args.warmup_steps <= 20 or (warmup_step + 1) % 10 == 0 or warmup_step + 1 == args.warmup_steps: + log0(f"warmup_step:{warmup_step + 1}/{args.warmup_steps}") + base_model.load_state_dict(initial_model_state, strict=True) + for opt, state in zip(optimizers, initial_optimizer_states, strict=True): + opt.load_state_dict(state) + zero_grad_all() + if distributed: + model.require_backward_grad_sync = True + train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) + + # ----------------------------- + # MAIN TRAINING LOOP + # ----------------------------- + + swa_state: dict[str, Tensor] | None = None + swa_count = 0 + + ema_state: dict[str, Tensor] | None = None + if args.ema_enabled: + ema_state = {name: t.detach().float().clone() for name, t in base_model.state_dict().items()} + + training_time_ms = 0.0 + stop_after_step: int | None = None + torch.cuda.synchronize() + t0 = time.perf_counter() + + eval_only_path = os.environ.get("EVAL_ONLY", "") + if eval_only_path: + log0(f"eval_only: loading {eval_only_path}, skipping training") + base_model.load_state_dict(torch.load(eval_only_path, map_location=device, weights_only=False), strict=False) + ema_state = None # prevent random EMA from overwriting loaded weights + swa_state = None + swa_count = 0 + args.iterations = 0 # skip training, go straight to eval + + step = 0 + while True: + last_step = step == args.iterations or (stop_after_step is not None and step >= stop_after_step) + + should_validate = last_step or (args.val_loss_every > 0 and step % args.val_loss_every == 0) + if should_validate: + torch.cuda.synchronize() + training_time_ms += 1000.0 * (time.perf_counter() - t0) + val_loss, val_bpb = eval_val( + args, + model, + rank, + world_size, + device, + grad_accum_steps, + val_tokens, + base_bytes_lut, + has_leading_space_lut, + is_boundary_token_lut, + ) + log0( + f"step:{step}/{args.iterations} val_loss:{val_loss:.4f} val_bpb:{val_bpb:.4f} " + f"train_time:{training_time_ms:.0f}ms step_avg:{training_time_ms / max(step, 1):.2f}ms" + ) + torch.cuda.synchronize() + t0 = time.perf_counter() + + if ( + args.checkpoint_every > 0 + and step > 0 + and step % args.checkpoint_every == 0 + and not last_step + and master_process + ): + ckpt_sd = {k: v for k, v in base_model.state_dict().items() if "mtp_heads" not in k} + ckpt_path = f"checkpoint_step{step}_{args.run_id}.pt" + torch.save(ckpt_sd, ckpt_path) + log0(f"checkpoint_saved: {ckpt_path} ({os.path.getsize(ckpt_path)} bytes)") + + if last_step: + if stop_after_step is not None and step < args.iterations: + log0( + f"stopping_early: wallclock_cap train_time:{training_time_ms:.0f}ms " + f"step:{step}/{args.iterations}" + ) + break + + elapsed_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) + scale = lr_mul(step, elapsed_ms) + qat_threshold = float(os.environ.get("QAT_THRESHOLD", "0.1")) + if args.late_qat and scale < qat_threshold and not CastedLinear._qat_enabled: + CastedLinear._qat_enabled = True + CastedLinear._soft_round = args.soft_round_qat + log0(f"late_qat:enabled step:{step} scale:{scale:.4f} soft_round:{args.soft_round_qat}") + if CastedLinear._qat_enabled and CastedLinear._soft_round: + qat_progress = max(0.0, 1.0 - (scale / qat_threshold)) + CastedLinear._soft_round_alpha = 1.0 + 15.0 * qat_progress # 1→16 + zero_grad_all() + train_loss = torch.zeros((), device=device) + for micro_step in range(grad_accum_steps): + if distributed: + model.require_backward_grad_sync = micro_step == grad_accum_steps - 1 + x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + loss = model(x, y) + train_loss += loss.detach() + (loss * grad_scale).backward() + train_loss /= grad_accum_steps + + frac = min(step / args.muon_momentum_warmup_steps, 1.0) if args.muon_momentum_warmup_steps > 0 else 1.0 + muon_momentum = (1 - frac) * args.muon_momentum_warmup_start + frac * args.muon_momentum + for group in optimizer_muon.param_groups: + group["momentum"] = muon_momentum + + for opt in optimizers: + for group in opt.param_groups: + group["lr"] = group["base_lr"] * scale + + if args.grad_clip_norm > 0: + torch.nn.utils.clip_grad_norm_(base_model.parameters(), args.grad_clip_norm) + for opt in optimizers: + opt.step() + zero_grad_all() + + step += 1 + approx_training_time_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) + + if ema_state is not None: + d = args.ema_decay + with torch.no_grad(): + for name, t in base_model.state_dict().items(): + ema_state[name].mul_(d).add_(t.detach().float(), alpha=1.0 - d) + + if args.swa_enabled and scale < 0.5 and step % args.swa_every == 0: + # Tight SWA: collect from EMA state if available, else from raw model + src = ema_state if ema_state is not None else {name: t.detach().float() for name, t in base_model.state_dict().items()} + if swa_state is None: + swa_state = {name: t.clone() for name, t in src.items()} + swa_count = 1 + log0(f"swa:start step:{step} tight={ema_state is not None}") + else: + for name in swa_state: + swa_state[name].add_(src[name]) + swa_count += 1 + + should_log_train = ( + args.train_log_every > 0 + and (step <= 10 or step % args.train_log_every == 0 or stop_after_step is not None) + ) + if should_log_train: + log0( + f"step:{step}/{args.iterations} train_loss:{train_loss.item():.4f} " + f"train_time:{approx_training_time_ms:.0f}ms step_avg:{approx_training_time_ms / step:.2f}ms" + ) + + # Needed to sync whether we've reached the wallclock cap. + reached_cap = max_wallclock_ms is not None and approx_training_time_ms >= max_wallclock_ms + if distributed and max_wallclock_ms is not None: + reached_cap_tensor = torch.tensor(int(reached_cap), device=device) + dist.all_reduce(reached_cap_tensor, op=dist.ReduceOp.MAX) + reached_cap = bool(reached_cap_tensor.item()) + if stop_after_step is None and reached_cap: + stop_after_step = step + + log0( + f"peak memory allocated: {torch.cuda.max_memory_allocated() // 1024 // 1024} MiB " + f"reserved: {torch.cuda.max_memory_reserved() // 1024 // 1024} MiB" + ) + + if args.swa_enabled and swa_state is not None and swa_count > 1: + log0(f"swa:applying Tight SWA averaged {swa_count} EMA checkpoints") + avg_state = {name: (t / swa_count).to(dtype=base_model.state_dict()[name].dtype) + for name, t in swa_state.items()} + del swa_state + if ema_state is not None: + del ema_state + base_model.load_state_dict(avg_state, strict=True) + elif ema_state is not None: + log0("ema:applying EMA weights") + avg_state = {name: t.to(dtype=base_model.state_dict()[name].dtype) + for name, t in ema_state.items()} + del ema_state + base_model.load_state_dict(avg_state, strict=True) + + # ----------------------------- + # SERIALIZATION + ROUNDTRIP VALIDATION + # ----------------------------- + + # Use RUN_ID for unique filenames when multiple jobs share CWD + _run_id = os.environ.get("RUN_ID", "") + _save_prefix = f"final_model_{_run_id}" if _run_id else "final_model" + _pt_path = f"{_save_prefix}.pt" + _ptz_path = f"{_save_prefix}.int6.ptz" + log0(f"save_paths: pt={_pt_path} ptz={_ptz_path}") + + full_state_dict = base_model.state_dict() + export_sd = {k: v for k, v in full_state_dict.items() if "mtp_heads" not in k} + excluded_mtp = sum(int(t.numel()) for k, t in full_state_dict.items() if "mtp_heads" in k) + if excluded_mtp > 0: + log0(f"export_excluding_mtp_params:{excluded_mtp}") + + if master_process: + torch.save(export_sd, _pt_path) + model_bytes = os.path.getsize(_pt_path) + code_bytes = len(code.encode("utf-8")) + log0(f"Serialized model: {model_bytes} bytes") + log0(f"Code size: {code_bytes} bytes") + + sd_cpu = {k: v.detach().cpu() for k, v in export_sd.items()} + code_bytes = len(code.encode("utf-8")) + artifact_limit = 16_000_000 - code_bytes + + # --- Auto-downgrade quantization: try int6 first, fall back to int5 middle layers --- + num_layers_total = max( + (int(k.split(".")[1]) for k in sd_cpu if k.startswith("blocks.")), + default=0, + ) + 1 + _zstd_levels = [int(os.environ.get("ZSTD_LEVEL", "16")), 1, 17, 2] + # Phase 1: pure int6 with multiple zstd levels + quant_result, quant_meta = mixed_quantize_int6(sd_cpu, {"mlp", "attn"}) + quant_buf = io.BytesIO() + torch.save({"w": quant_result, "m": quant_meta}, quant_buf) + quant_raw = quant_buf.getvalue() + quant_blob = None + chosen_level = _zstd_levels[0] + for lvl in _zstd_levels: + blob = zstandard.ZstdCompressor(level=lvl).compress(quant_raw) if _COMPRESSOR == "zstd" else zlib.compress(quant_raw, 9) + if master_process: + log0(f"quant_try int6 zstd-{lvl}: {len(blob)} bytes (limit {artifact_limit})") + if len(blob) <= artifact_limit: + quant_blob = blob + chosen_level = lvl + break + # Phase 2: progressive int5 fallback — one layer at a time from middle outward + if quant_blob is None: + mid = num_layers_total // 2 + # Expand outward from center: L5, L4, L6, L3, L7, L2, L8, ... + candidates = [] + for offset in range(num_layers_total): + for sign in [0, 1]: + layer = mid + offset if sign == 0 else mid - offset + if 0 <= layer < num_layers_total and layer not in candidates: + candidates.append(layer) + int5_layers: set[int] = set() + for layer in candidates: + int5_layers.add(layer) + if master_process: + log0(f"quant_fallback: int5 layers={sorted(int5_layers)}") + quant_result, quant_meta = mixed_quantize_int6(sd_cpu, {"mlp", "attn"}, int5_layers=int5_layers) + quant_buf = io.BytesIO() + torch.save({"w": quant_result, "m": quant_meta}, quant_buf) + quant_raw = quant_buf.getvalue() + for lvl in _zstd_levels: + blob = zstandard.ZstdCompressor(level=lvl).compress(quant_raw) if _COMPRESSOR == "zstd" else zlib.compress(quant_raw, 9) + if master_process: + log0(f"quant_try int5[{len(int5_layers)}L] zstd-{lvl}: {len(blob)} bytes (limit {artifact_limit})") + if len(blob) <= artifact_limit: + quant_blob = blob + chosen_level = lvl + break + if quant_blob is not None: + break + if quant_blob is None: + quant_blob = blob # Use last attempt even if over limit + if master_process: + log0(f"WARNING: artifact still over limit after all fallbacks") + if master_process: + with open(_ptz_path, "wb") as f: + f.write(quant_blob) + quant_file_bytes = len(quant_blob) + log0(f"Serialized model quant+{_COMPRESSOR}-{chosen_level}: {quant_file_bytes} bytes") + log0(f"Total submission size: {quant_file_bytes + code_bytes} bytes") + + # Roundtrip: decompress + dequantize into fresh model + eval + if distributed: + dist.barrier() + with open(_ptz_path, "rb") as f: + quant_blob_disk = f.read() + quant_state = torch.load( + io.BytesIO(zstandard.ZstdDecompressor().decompress(quant_blob_disk) if _COMPRESSOR == "zstd" else zlib.decompress(quant_blob_disk)), + map_location="cpu", + ) + deq_state = dequantize_mixed_int6(quant_state["w"], quant_state["m"], sd_cpu) + + eval_model = GPT( + vocab_size=args.vocab_size, num_layers=args.num_layers, model_dim=args.model_dim, + num_heads=args.num_heads, num_kv_heads=args.num_kv_heads, mlp_mult=args.mlp_mult, + tie_embeddings=args.tie_embeddings, tied_embed_init_std=args.tied_embed_init_std, + logit_softcap=args.logit_softcap, rope_base=args.rope_base, qk_gain_init=args.qk_gain_init, + mtp_num_heads=0, mtp_loss_weight=0.0, + bigram_vocab_size=args.bigram_vocab_size, bigram_dim=args.bigram_dim, + xsa_last_n=args.xsa_last_n, + rope_dims=args.rope_dims, + ln_scale=args.ln_scale, + value_residual=args.value_residual, + gated_attention=args.gated_attention, + canon_last_n=args.canon_last_n, + canon_kernel=args.canon_kernel, + canon_delta_gate_init=args.canon_delta_gate_init, + flow_enabled=args.flow_enabled, + flow_latent_dim=args.flow_latent_dim, + flow_hidden_dim=args.flow_hidden_dim, + flow_init_scale=args.flow_init_scale, + native_flow_enabled=args.native_flow_enabled, + native_flow_hidden_dim=args.native_flow_hidden_dim, + native_flow_init_scale=args.native_flow_init_scale, + native_flow_loss_weight=args.native_flow_loss_weight, + e2e_ttt_enabled=args.e2e_ttt_enabled, + e2e_ttt_num_heads=args.e2e_ttt_num_heads, + e2e_ttt_mini_batch=args.e2e_ttt_mini_batch, + e2e_ttt_base_lr=args.e2e_ttt_base_lr, + ).to(device).bfloat16() + for m in eval_model.modules(): + if isinstance(m, CastedLinear): + m.float() + restore_low_dim_params_to_fp32(eval_model) + eval_model.load_state_dict(deq_state, strict=True) + + # TTT: adapt model on validation data before eval + legal_ttt = bool(int(os.environ.get("LEGAL_TTT", "0"))) + if args.ttt_enabled and not legal_ttt: + # --- Invalid two-pass TTT (adapt then eval separately) --- + if distributed: + dist.barrier() + for block in eval_model.blocks: + block.attn.rotary._cos_cached = None + block.attn.rotary._sin_cached = None + block.attn.rotary._seq_len_cached = 0 + log0(f"ttt:start score-first optimizer={args.ttt_optimizer} lr={args.ttt_lr} " + f"epochs={args.ttt_epochs} freeze_blocks={args.ttt_freeze_blocks} " + f"chunk_tokens={args.ttt_chunk_tokens}") + t_ttt = time.perf_counter() + ttt_adapt(args, eval_model, device, val_tokens, + rank=rank, world_size=world_size, log_fn=log0) + log0(f"ttt:elapsed={time.perf_counter() - t_ttt:.1f}s") + if distributed: + dist.barrier() + + compiled_eval = torch.compile(eval_model, dynamic=False, fullgraph=True) + + # Standard non-overlapping eval (sanity check) + torch.cuda.synchronize() + t_qeval = time.perf_counter() + q_val_loss, q_val_bpb = eval_val( + args, compiled_eval, rank, world_size, device, grad_accum_steps, + val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, + eval_seq_len=effective_eval_seq_len, + ) + torch.cuda.synchronize() + log0( + f"final_int6_roundtrip val_loss:{q_val_loss:.4f} val_bpb:{q_val_bpb:.4f} " + f"eval_time:{1000.0 * (time.perf_counter() - t_qeval):.0f}ms" + ) + log0(f"final_int6_roundtrip_exact val_loss:{q_val_loss:.8f} val_bpb:{q_val_bpb:.8f}") + + # Sliding window eval (submission score) + sw_seq_len = effective_eval_seq_len + if args.eval_stride > 0 and args.eval_stride < sw_seq_len: + torch.cuda.synchronize() + t_slide = time.perf_counter() + if legal_ttt and args.ttt_enabled: + # Legal single-pass TTT: score → train interleaved per chunk + log0(f"legal_ttt:start stride={args.eval_stride} " + f"optimizer={args.ttt_optimizer} lr={args.ttt_lr} " + f"epochs={args.ttt_epochs} freeze_blocks={args.ttt_freeze_blocks}") + sw_val_loss, sw_val_bpb = eval_val_sliding_ttt( + args, eval_model, rank, world_size, device, + val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, + stride=args.eval_stride, + eval_seq_len=sw_seq_len, + log_fn=log0, + ) + else: + sw_val_loss, sw_val_bpb = eval_val_sliding( + args, eval_model, rank, world_size, device, + val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, + stride=args.eval_stride, + eval_seq_len=sw_seq_len, + ) + torch.cuda.synchronize() + log0( + f"final_int6_sliding_window val_loss:{sw_val_loss:.4f} val_bpb:{sw_val_bpb:.4f} " + f"stride:{args.eval_stride} eval_time:{1000.0 * (time.perf_counter() - t_slide):.0f}ms" + ) + log0(f"final_int6_sliding_window_exact val_loss:{sw_val_loss:.8f} val_bpb:{sw_val_bpb:.8f}") + + # Second sliding window eval at stride=64 for submission comparison + if args.eval_stride != 64 and 64 < sw_seq_len: + torch.cuda.synchronize() + t_slide64 = time.perf_counter() + sw64_val_loss, sw64_val_bpb = eval_val_sliding( + args, eval_model, rank, world_size, device, + val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, + stride=64, + eval_seq_len=sw_seq_len, + ) + torch.cuda.synchronize() + log0( + f"final_int6_sliding_window_s64 val_loss:{sw64_val_loss:.4f} val_bpb:{sw64_val_bpb:.4f} " + f"stride:64 eval_time:{1000.0 * (time.perf_counter() - t_slide64):.0f}ms" + ) + log0(f"final_int6_sliding_window_s64_exact val_loss:{sw64_val_loss:.8f} val_bpb:{sw64_val_bpb:.8f}") + + if distributed: + dist.destroy_process_group() + + +if __name__ == "__main__": + main() + + +==================================================================================================== +Running Python 3.12.7 | packaged by conda-forge | (main, Oct 4 2024, 16:05:46) [GCC 13.3.0] +Running PyTorch 2.10.0+cu128 +Sun Mar 29 20:03:52 2026 ++-----------------------------------------------------------------------------------------+ +| NVIDIA-SMI 570.172.08 Driver Version: 570.172.08 CUDA Version: 12.8 | +|-----------------------------------------+------------------------+----------------------+ +| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC | +| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. | +| | | MIG M. | +|=========================================+========================+======================| +| 0 NVIDIA A100-PCIE-40GB On | 00000000:17:00.0 Off | 0 | +| N/A 35C P0 49W / 250W | 423MiB / 40960MiB | 3% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ + ++-----------------------------------------------------------------------------------------+ +| Processes: | +| GPU GI CI PID Type Process name GPU Memory | +| ID ID Usage | +|=========================================================================================| +| 0 N/A N/A 401065 C ...ameter_golf/.venv/bin/python3 414MiB | ++-----------------------------------------------------------------------------------------+ + +==================================================================================================== +val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path=/hpfs/scratch/gpfs/mcclec07/code/parameter_golf/repo/data/tokenizers/fineweb_1024_bpe.model +train_loader:dataset:fineweb10B_sp1024 train_shards:80 +val_loader:shards pattern=/hpfs/scratch/gpfs/mcclec07/code/parameter_golf/repo/data/datasets/fineweb10B_sp1024/fineweb_val_*.bin tokens:62021632 +model_params:28319753 +mtp_num_heads:0 mtp_loss_weight:0.2 mtp_params:0 +world_size:1 grad_accum_steps:8 +sdp_backends:fa3=False cudnn=False flash=True mem_efficient=True math=True +attention_mode:gqa num_heads:8 num_kv_heads:4 +tie_embeddings:True embed_lr:0.035 head_lr:0.0 matrix_lr:0.025 scalar_lr:0.025 +train_batch_tokens:786432 train_seq_len:2048 iterations:7000 warmup_steps:20 max_wallclock_seconds:600.000 +seed:42 +warmup_step:1/20 +warmup_step:2/20 +warmup_step:3/20 +warmup_step:4/20 +warmup_step:5/20 +warmup_step:6/20 +warmup_step:7/20 +warmup_step:8/20 +warmup_step:9/20 +warmup_step:10/20 +warmup_step:11/20 +warmup_step:12/20 +warmup_step:13/20 +warmup_step:14/20 +warmup_step:15/20 +warmup_step:16/20 +warmup_step:17/20 +warmup_step:18/20 +warmup_step:19/20 +warmup_step:20/20 +eval_only: loading /hpfs/scratch/gpfs/mcclec07/code/parameter_golf/runs/e2ettt_flow_55363689/models/final_model_pr940_e2ettt_flow_55363689.pt, skipping training +step:0/0 val_loss:1.9237 val_bpb:1.1393 train_time:63ms step_avg:63.28ms +peak memory allocated: 27459 MiB reserved: 28022 MiB +save_paths: pt=final_model_eval_e2ettt_flow7k_legal_ttt_55375247.pt ptz=final_model_eval_e2ettt_flow7k_legal_ttt_55375247.int6.ptz +Serialized model: 108880403 bytes +Code size: 115032 bytes +quant_try int6 zstd-16: 16854717 bytes (limit 15884968) +quant_try int6 zstd-1: 16914613 bytes (limit 15884968) +quant_try int6 zstd-17: 16856988 bytes (limit 15884968) +quant_try int6 zstd-2: 16923144 bytes (limit 15884968) +quant_fallback: int5 layers=[5] +quant_try int5[1L] zstd-16: 16553148 bytes (limit 15884968) +quant_try int5[1L] zstd-1: 16617668 bytes (limit 15884968) +quant_try int5[1L] zstd-17: 16551154 bytes (limit 15884968) +quant_try int5[1L] zstd-2: 16675594 bytes (limit 15884968) +quant_fallback: int5 layers=[5, 6] +quant_try int5[2L] zstd-16: 16244171 bytes (limit 15884968) +quant_try int5[2L] zstd-1: 16320340 bytes (limit 15884968) +quant_try int5[2L] zstd-17: 16261223 bytes (limit 15884968) +quant_try int5[2L] zstd-2: 16427030 bytes (limit 15884968) +quant_fallback: int5 layers=[4, 5, 6] +quant_try int5[3L] zstd-16: 16007263 bytes (limit 15884968) +quant_try int5[3L] zstd-1: 16024461 bytes (limit 15884968) +quant_try int5[3L] zstd-17: 15970049 bytes (limit 15884968) +quant_try int5[3L] zstd-2: 16181460 bytes (limit 15884968) +quant_fallback: int5 layers=[4, 5, 6, 7] +quant_try int5[4L] zstd-16: 15643338 bytes (limit 15884968) +Serialized model quant+zstd-16: 15643338 bytes +Total submission size: 15758370 bytes +final_int6_roundtrip val_loss:1.9458 val_bpb:1.1524 eval_time:243645ms +final_int6_roundtrip_exact val_loss:1.94578104 val_bpb:1.15240113 +legal_ttt:start stride=64 optimizer=sgd lr=0.002 epochs=10 freeze_blocks=2 +ttt_sliding:start chunks=1893 chunk_tokens=32768 total_windows=969088 stride=64 ttt_lr=0.002 ttt_epochs=10 ttt_optimizer=sgd freeze_blocks=2 +ttt_sliding:params unfrozen=23588837 frozen=4730916 + ttt_chunk [1/1893] bpb=1.219061 time=7.4s + ttt_chunk [11/1893] bpb=1.120993 time=80.9s + ttt_chunk [21/1893] bpb=1.130134 time=154.5s + ttt_chunk [31/1893] bpb=1.135428 time=228.1s + ttt_chunk [41/1893] bpb=1.131994 time=301.7s + ttt_chunk [51/1893] bpb=1.133329 time=375.3s + ttt_chunk [61/1893] bpb=1.137232 time=448.8s + ttt_chunk [71/1893] bpb=1.135502 time=522.8s + ttt_chunk [81/1893] bpb=1.132059 time=597.6s + ttt_chunk [91/1893] bpb=1.130884 time=672.2s + ttt_chunk [101/1893] bpb=1.131887 time=746.1s + ttt_chunk [111/1893] bpb=1.132078 time=819.8s + ttt_chunk [121/1893] bpb=1.128451 time=893.6s + ttt_chunk [131/1893] bpb=1.127652 time=967.5s + ttt_chunk [141/1893] bpb=1.126534 time=1041.2s + ttt_chunk [151/1893] bpb=1.126669 time=1114.8s + ttt_chunk [161/1893] bpb=1.127536 time=1188.4s + ttt_chunk [171/1893] bpb=1.129615 time=1262.0s + ttt_chunk [181/1893] bpb=1.129613 time=1337.7s + ttt_chunk [191/1893] bpb=1.132008 time=1411.2s + ttt_chunk [201/1893] bpb=1.131532 time=1484.8s + ttt_chunk [211/1893] bpb=1.130573 time=1558.2s + ttt_chunk [221/1893] bpb=1.131473 time=1631.7s + ttt_chunk [231/1893] bpb=1.131180 time=1705.2s + ttt_chunk [241/1893] bpb=1.131385 time=1778.7s + ttt_chunk [251/1893] bpb=1.130903 time=1852.4s + ttt_chunk [261/1893] bpb=1.130156 time=1926.2s + ttt_chunk [271/1893] bpb=1.129245 time=1999.9s + ttt_chunk [281/1893] bpb=1.130732 time=2073.6s + ttt_chunk [291/1893] bpb=1.130322 time=2147.3s + ttt_chunk [301/1893] bpb=1.131145 time=2220.9s + ttt_chunk [311/1893] bpb=1.131184 time=2294.6s + ttt_chunk [321/1893] bpb=1.131877 time=2368.2s + ttt_chunk [331/1893] bpb=1.131323 time=2441.8s + ttt_chunk [341/1893] bpb=1.130879 time=2515.5s + ttt_chunk [351/1893] bpb=1.131565 time=2589.0s + ttt_chunk [361/1893] bpb=1.132309 time=2662.6s + ttt_chunk [371/1893] bpb=1.132205 time=2736.3s + ttt_chunk [381/1893] bpb=1.131957 time=2810.0s + ttt_chunk [391/1893] bpb=1.132640 time=2885.8s + ttt_chunk [401/1893] bpb=1.132159 time=2959.5s + ttt_chunk [411/1893] bpb=1.131170 time=3033.2s + ttt_chunk [421/1893] bpb=1.131239 time=3106.9s + ttt_chunk [431/1893] bpb=1.131658 time=3180.5s + ttt_chunk [441/1893] bpb=1.131042 time=3254.2s + ttt_chunk [451/1893] bpb=1.131166 time=3327.8s + ttt_chunk [461/1893] bpb=1.131031 time=3401.5s + ttt_chunk [471/1893] bpb=1.130577 time=3475.1s + ttt_chunk [481/1893] bpb=1.130385 time=3548.8s + ttt_chunk [491/1893] bpb=1.130544 time=3622.4s + ttt_chunk [501/1893] bpb=1.130284 time=3696.0s + ttt_chunk [511/1893] bpb=1.129779 time=3769.6s + ttt_chunk [521/1893] bpb=1.129314 time=3843.1s + ttt_chunk [531/1893] bpb=1.130018 time=3916.7s + ttt_chunk [541/1893] bpb=1.130099 time=3990.2s + ttt_chunk [551/1893] bpb=1.129554 time=4063.6s + ttt_chunk [561/1893] bpb=1.129400 time=4136.9s + ttt_chunk [571/1893] bpb=1.129109 time=4210.3s + ttt_chunk [581/1893] bpb=1.128706 time=4283.6s + ttt_chunk [591/1893] bpb=1.128130 time=4357.0s + ttt_chunk [601/1893] bpb=1.128118 time=4430.3s + ttt_chunk [611/1893] bpb=1.127777 time=4503.6s + ttt_chunk [621/1893] bpb=1.127617 time=4577.0s + ttt_chunk [631/1893] bpb=1.127359 time=4650.2s + ttt_chunk [641/1893] bpb=1.126912 time=4723.6s + ttt_chunk [651/1893] bpb=1.126447 time=4796.9s + ttt_chunk [661/1893] bpb=1.126330 time=4870.3s + ttt_chunk [671/1893] bpb=1.125836 time=4943.6s + ttt_chunk [681/1893] bpb=1.125266 time=5016.9s + ttt_chunk [691/1893] bpb=1.125340 time=5090.3s + ttt_chunk [701/1893] bpb=1.124500 time=5163.6s + ttt_chunk [711/1893] bpb=1.124502 time=5236.9s + ttt_chunk [721/1893] bpb=1.124411 time=5310.2s + ttt_chunk [731/1893] bpb=1.124652 time=5383.5s + ttt_chunk [741/1893] bpb=1.124528 time=5456.8s + ttt_chunk [751/1893] bpb=1.124227 time=5530.2s + ttt_chunk [761/1893] bpb=1.124358 time=5603.5s + ttt_chunk [771/1893] bpb=1.124185 time=5676.8s + ttt_chunk [781/1893] bpb=1.124354 time=5750.1s + ttt_chunk [791/1893] bpb=1.124211 time=5823.5s + ttt_chunk [801/1893] bpb=1.124147 time=5896.8s + ttt_chunk [811/1893] bpb=1.124151 time=5970.1s + ttt_chunk [821/1893] bpb=1.124052 time=6043.4s + ttt_chunk [831/1893] bpb=1.123777 time=6116.7s + ttt_chunk [841/1893] bpb=1.123545 time=6190.1s + ttt_chunk [851/1893] bpb=1.123609 time=6263.4s + ttt_chunk [861/1893] bpb=1.123678 time=6336.7s + ttt_chunk [871/1893] bpb=1.123882 time=6410.0s + ttt_chunk [881/1893] bpb=1.123884 time=6483.0s + ttt_chunk [891/1893] bpb=1.123365 time=6555.8s + ttt_chunk [901/1893] bpb=1.123387 time=6628.5s + ttt_chunk [911/1893] bpb=1.123230 time=6701.3s + ttt_chunk [921/1893] bpb=1.123363 time=6774.0s + ttt_chunk [931/1893] bpb=1.123317 time=6846.8s + ttt_chunk [941/1893] bpb=1.123522 time=6919.5s + ttt_chunk [951/1893] bpb=1.123821 time=6992.3s + ttt_chunk [961/1893] bpb=1.124109 time=7065.0s + ttt_chunk [971/1893] bpb=1.124460 time=7137.7s + ttt_chunk [981/1893] bpb=1.124663 time=7210.3s + ttt_chunk [991/1893] bpb=1.124562 time=7283.1s + ttt_chunk [1001/1893] bpb=1.124875 time=7355.8s + ttt_chunk [1011/1893] bpb=1.125008 time=7430.4s + ttt_chunk [1021/1893] bpb=1.125291 time=7503.1s + ttt_chunk [1031/1893] bpb=1.125664 time=7575.8s + ttt_chunk [1041/1893] bpb=1.126175 time=7648.5s + ttt_chunk [1051/1893] bpb=1.126029 time=7721.1s + ttt_chunk [1061/1893] bpb=1.126122 time=7793.8s + ttt_chunk [1071/1893] bpb=1.126273 time=7866.4s + ttt_chunk [1081/1893] bpb=1.126313 time=7939.1s + ttt_chunk [1091/1893] bpb=1.126566 time=8011.8s + ttt_chunk [1101/1893] bpb=1.126699 time=8084.5s + ttt_chunk [1111/1893] bpb=1.126441 time=8157.1s + ttt_chunk [1121/1893] bpb=1.126208 time=8229.8s + ttt_chunk [1131/1893] bpb=1.126093 time=8302.4s + ttt_chunk [1141/1893] bpb=1.125849 time=8375.0s + ttt_chunk [1151/1893] bpb=1.125863 time=8447.6s + ttt_chunk [1161/1893] bpb=1.125649 time=8520.3s + ttt_chunk [1171/1893] bpb=1.125474 time=8592.9s + ttt_chunk [1181/1893] bpb=1.125243 time=8665.6s + ttt_chunk [1191/1893] bpb=1.125389 time=8738.2s + ttt_chunk [1201/1893] bpb=1.125588 time=8810.9s + ttt_chunk [1211/1893] bpb=1.125185 time=8883.6s + ttt_chunk [1221/1893] bpb=1.125510 time=8956.3s + ttt_chunk [1231/1893] bpb=1.125434 time=9029.0s + ttt_chunk [1241/1893] bpb=1.125130 time=9101.6s + ttt_chunk [1251/1893] bpb=1.124594 time=9174.3s + ttt_chunk [1261/1893] bpb=1.124331 time=9247.0s + ttt_chunk [1271/1893] bpb=1.124082 time=9319.6s diff --git a/records/track_non_record_16mb/2026-03-31_11L_NativeFlowMatcher_LegalTTT/supplementary/eval_e2ettt_flow7k_legal_ttt_complete.log b/records/track_non_record_16mb/2026-03-31_11L_NativeFlowMatcher_LegalTTT/supplementary/eval_e2ettt_flow7k_legal_ttt_complete.log new file mode 100644 index 0000000000..5722d09312 --- /dev/null +++ b/records/track_non_record_16mb/2026-03-31_11L_NativeFlowMatcher_LegalTTT/supplementary/eval_e2ettt_flow7k_legal_ttt_complete.log @@ -0,0 +1,2846 @@ +logs/eval_e2ettt_flow7k_legal_ttt_55398555.txt +val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path=/hpfs/scratch/gpfs/mcclec07/code/parameter_golf/repo/data/tokenizers/fineweb_1024_bpe.model +train_loader:dataset:fineweb10B_sp1024 train_shards:80 +val_loader:shards pattern=/hpfs/scratch/gpfs/mcclec07/code/parameter_golf/repo/data/datasets/fineweb10B_sp1024/fineweb_val_*.bin tokens:62021632 +model_params:28319753 +mtp_num_heads:0 mtp_loss_weight:0.2 mtp_params:0 +world_size:1 grad_accum_steps:8 +sdp_backends:fa3=False cudnn=False flash=True mem_efficient=True math=True +attention_mode:gqa num_heads:8 num_kv_heads:4 +tie_embeddings:True embed_lr:0.035 head_lr:0.0 matrix_lr:0.025 scalar_lr:0.025 +train_batch_tokens:786432 train_seq_len:2048 iterations:7000 warmup_steps:20 max_wallclock_seconds:600.000 +seed:42 +warmup_step:1/20 +warmup_step:2/20 +warmup_step:3/20 +warmup_step:4/20 +warmup_step:5/20 +warmup_step:6/20 +warmup_step:7/20 +warmup_step:8/20 +warmup_step:9/20 +warmup_step:10/20 +warmup_step:11/20 +warmup_step:12/20 +warmup_step:13/20 +warmup_step:14/20 +warmup_step:15/20 +warmup_step:16/20 +warmup_step:17/20 +warmup_step:18/20 +warmup_step:19/20 +warmup_step:20/20 +eval_only: loading /hpfs/scratch/gpfs/mcclec07/code/parameter_golf/runs/e2ettt_flow_55363689/models/final_model_pr940_e2ettt_flow_55363689.pt, skipping training +step:0/0 val_loss:1.9237 val_bpb:1.1393 train_time:60ms step_avg:59.90ms +peak memory allocated: 27459 MiB reserved: 28022 MiB +save_paths: pt=final_model_eval_e2ettt_flow7k_legal_ttt_55398555.pt ptz=final_model_eval_e2ettt_flow7k_legal_ttt_55398555.int6.ptz +Serialized model: 108880403 bytes +Code size: 115032 bytes +quant_try int6 zstd-16: 16854717 bytes (limit 15884968) +quant_try int6 zstd-1: 16914613 bytes (limit 15884968) +quant_try int6 zstd-17: 16856988 bytes (limit 15884968) +quant_try int6 zstd-2: 16923144 bytes (limit 15884968) +quant_fallback: int5 layers=[5] +quant_try int5[1L] zstd-16: 16553148 bytes (limit 15884968) +quant_try int5[1L] zstd-1: 16617668 bytes (limit 15884968) +quant_try int5[1L] zstd-17: 16551154 bytes (limit 15884968) +quant_try int5[1L] zstd-2: 16675594 bytes (limit 15884968) +quant_fallback: int5 layers=[5, 6] +quant_try int5[2L] zstd-16: 16244171 bytes (limit 15884968) +quant_try int5[2L] zstd-1: 16320340 bytes (limit 15884968) +quant_try int5[2L] zstd-17: 16261223 bytes (limit 15884968) +quant_try int5[2L] zstd-2: 16427030 bytes (limit 15884968) +quant_fallback: int5 layers=[4, 5, 6] +quant_try int5[3L] zstd-16: 16007263 bytes (limit 15884968) +quant_try int5[3L] zstd-1: 16024461 bytes (limit 15884968) +quant_try int5[3L] zstd-17: 15970049 bytes (limit 15884968) +quant_try int5[3L] zstd-2: 16181460 bytes (limit 15884968) +quant_fallback: int5 layers=[4, 5, 6, 7] +quant_try int5[4L] zstd-16: 15643338 bytes (limit 15884968) +Serialized model quant+zstd-16: 15643338 bytes +Total submission size: 15758370 bytes +final_int6_roundtrip val_loss:1.9458 val_bpb:1.1524 eval_time:243753ms +final_int6_roundtrip_exact val_loss:1.94578102 val_bpb:1.15240112 +legal_ttt:start stride=64 optimizer=sgd lr=0.002 epochs=10 freeze_blocks=2 +ttt_sliding:start chunks=1893 chunk_tokens=32768 total_windows=969088 stride=64 ttt_lr=0.002 ttt_epochs=10 ttt_optimizer=sgd freeze_blocks=2 +ttt_sliding:params unfrozen=23588837 frozen=4730916 + ttt_chunk [1/1893] bpb=1.219061 time=7.4s + ttt_chunk [11/1893] bpb=1.121041 time=81.1s + ttt_chunk [21/1893] bpb=1.130171 time=154.6s + ttt_chunk [31/1893] bpb=1.135459 time=228.3s + ttt_chunk [41/1893] bpb=1.132009 time=301.9s + ttt_chunk [51/1893] bpb=1.133333 time=375.5s + ttt_chunk [61/1893] bpb=1.137245 time=449.1s + ttt_chunk [71/1893] bpb=1.135520 time=522.6s + ttt_chunk [81/1893] bpb=1.132070 time=596.2s + ttt_chunk [91/1893] bpb=1.130891 time=669.8s + ttt_chunk [101/1893] bpb=1.131886 time=743.4s + ttt_chunk [111/1893] bpb=1.132075 time=817.1s + ttt_chunk [121/1893] bpb=1.128446 time=890.7s + ttt_chunk [131/1893] bpb=1.127652 time=964.3s + ttt_chunk [141/1893] bpb=1.126530 time=1037.8s + ttt_chunk [151/1893] bpb=1.126665 time=1111.3s + ttt_chunk [161/1893] bpb=1.127536 time=1184.8s + ttt_chunk [171/1893] bpb=1.129614 time=1258.3s + ttt_chunk [181/1893] bpb=1.129614 time=1331.8s + ttt_chunk [191/1893] bpb=1.132011 time=1405.3s + ttt_chunk [201/1893] bpb=1.131538 time=1478.9s + ttt_chunk [211/1893] bpb=1.130578 time=1552.5s + ttt_chunk [221/1893] bpb=1.131476 time=1626.0s + ttt_chunk [231/1893] bpb=1.131182 time=1699.5s + ttt_chunk [241/1893] bpb=1.131387 time=1773.1s + ttt_chunk [251/1893] bpb=1.130904 time=1846.7s + ttt_chunk [261/1893] bpb=1.130158 time=1920.2s + ttt_chunk [271/1893] bpb=1.129247 time=1993.8s + ttt_chunk [281/1893] bpb=1.130734 time=2067.3s + ttt_chunk [291/1893] bpb=1.130323 time=2143.1s + ttt_chunk [301/1893] bpb=1.131146 time=2216.6s + ttt_chunk [311/1893] bpb=1.131186 time=2290.2s + ttt_chunk [321/1893] bpb=1.131881 time=2363.7s + ttt_chunk [331/1893] bpb=1.131325 time=2437.1s + ttt_chunk [341/1893] bpb=1.130881 time=2510.6s + ttt_chunk [351/1893] bpb=1.131568 time=2584.1s + ttt_chunk [361/1893] bpb=1.132313 time=2657.6s + ttt_chunk [371/1893] bpb=1.132206 time=2731.0s + ttt_chunk [381/1893] bpb=1.131959 time=2804.5s + ttt_chunk [391/1893] bpb=1.132643 time=2878.2s + ttt_chunk [401/1893] bpb=1.132161 time=2951.7s + ttt_chunk [411/1893] bpb=1.131172 time=3025.3s + ttt_chunk [421/1893] bpb=1.131241 time=3098.9s + ttt_chunk [431/1893] bpb=1.131660 time=3172.5s + ttt_chunk [441/1893] bpb=1.131044 time=3246.1s + ttt_chunk [451/1893] bpb=1.131169 time=3319.7s + ttt_chunk [461/1893] bpb=1.131035 time=3393.2s + ttt_chunk [471/1893] bpb=1.130581 time=3466.8s + ttt_chunk [481/1893] bpb=1.130389 time=3540.4s + ttt_chunk [491/1893] bpb=1.130548 time=3613.9s + ttt_chunk [501/1893] bpb=1.130289 time=3687.5s + ttt_chunk [511/1893] bpb=1.129785 time=3763.2s + ttt_chunk [521/1893] bpb=1.129319 time=3836.9s + ttt_chunk [531/1893] bpb=1.130023 time=3910.6s + ttt_chunk [541/1893] bpb=1.130105 time=3984.2s + ttt_chunk [551/1893] bpb=1.129558 time=4057.7s + ttt_chunk [561/1893] bpb=1.129405 time=4131.3s + ttt_chunk [571/1893] bpb=1.129114 time=4204.9s + ttt_chunk [581/1893] bpb=1.128711 time=4278.7s + ttt_chunk [591/1893] bpb=1.128134 time=4352.3s + ttt_chunk [601/1893] bpb=1.128124 time=4425.9s + ttt_chunk [611/1893] bpb=1.127783 time=4499.5s + ttt_chunk [621/1893] bpb=1.127623 time=4573.2s + ttt_chunk [631/1893] bpb=1.127364 time=4646.7s + ttt_chunk [641/1893] bpb=1.126916 time=4720.3s + ttt_chunk [651/1893] bpb=1.126452 time=4793.9s + ttt_chunk [661/1893] bpb=1.126336 time=4867.5s + ttt_chunk [671/1893] bpb=1.125841 time=4941.1s + ttt_chunk [681/1893] bpb=1.125271 time=5014.6s + ttt_chunk [691/1893] bpb=1.125343 time=5088.3s + ttt_chunk [701/1893] bpb=1.124503 time=5161.9s + ttt_chunk [711/1893] bpb=1.124505 time=5235.4s + ttt_chunk [721/1893] bpb=1.124414 time=5309.1s + ttt_chunk [731/1893] bpb=1.124656 time=5382.8s + ttt_chunk [741/1893] bpb=1.124531 time=5456.4s + ttt_chunk [751/1893] bpb=1.124232 time=5530.0s + ttt_chunk [761/1893] bpb=1.124362 time=5603.6s + ttt_chunk [771/1893] bpb=1.124191 time=5677.4s + ttt_chunk [781/1893] bpb=1.124358 time=5751.1s + ttt_chunk [791/1893] bpb=1.124216 time=5824.7s + ttt_chunk [801/1893] bpb=1.124150 time=5898.2s + ttt_chunk [811/1893] bpb=1.124155 time=5971.6s + ttt_chunk [821/1893] bpb=1.124055 time=6045.2s + ttt_chunk [831/1893] bpb=1.123781 time=6118.7s + ttt_chunk [841/1893] bpb=1.123550 time=6192.2s + ttt_chunk [851/1893] bpb=1.123615 time=6265.7s + ttt_chunk [861/1893] bpb=1.123683 time=6339.1s + ttt_chunk [871/1893] bpb=1.123887 time=6412.7s + ttt_chunk [881/1893] bpb=1.123889 time=6486.1s + ttt_chunk [891/1893] bpb=1.123371 time=6559.7s + ttt_chunk [901/1893] bpb=1.123393 time=6633.2s + ttt_chunk [911/1893] bpb=1.123236 time=6706.7s + ttt_chunk [921/1893] bpb=1.123369 time=6780.2s + ttt_chunk [931/1893] bpb=1.123323 time=6853.7s + ttt_chunk [941/1893] bpb=1.123528 time=6927.2s + ttt_chunk [951/1893] bpb=1.123827 time=7000.7s + ttt_chunk [961/1893] bpb=1.124116 time=7074.1s + ttt_chunk [971/1893] bpb=1.124467 time=7147.7s + ttt_chunk [981/1893] bpb=1.124670 time=7221.3s + ttt_chunk [991/1893] bpb=1.124569 time=7294.8s + ttt_chunk [1001/1893] bpb=1.124882 time=7368.2s + ttt_chunk [1011/1893] bpb=1.125013 time=7441.7s + ttt_chunk [1021/1893] bpb=1.125297 time=7515.1s + ttt_chunk [1031/1893] bpb=1.125670 time=7588.6s + ttt_chunk [1041/1893] bpb=1.126180 time=7662.1s + ttt_chunk [1051/1893] bpb=1.126036 time=7735.5s + ttt_chunk [1061/1893] bpb=1.126128 time=7808.9s + ttt_chunk [1071/1893] bpb=1.126279 time=7882.4s + ttt_chunk [1081/1893] bpb=1.126319 time=7955.8s + ttt_chunk [1091/1893] bpb=1.126572 time=8029.5s + ttt_chunk [1101/1893] bpb=1.126705 time=8103.0s + ttt_chunk [1111/1893] bpb=1.126448 time=8177.0s + ttt_chunk [1121/1893] bpb=1.126214 time=8252.8s + ttt_chunk [1131/1893] bpb=1.126099 time=8326.2s + ttt_chunk [1141/1893] bpb=1.125855 time=8399.7s + ttt_chunk [1151/1893] bpb=1.125869 time=8473.1s + ttt_chunk [1161/1893] bpb=1.125655 time=8546.7s + ttt_chunk [1171/1893] bpb=1.125480 time=8620.5s + ttt_chunk [1181/1893] bpb=1.125249 time=8694.0s + ttt_chunk [1191/1893] bpb=1.125396 time=8767.6s + ttt_chunk [1201/1893] bpb=1.125594 time=8841.1s + ttt_chunk [1211/1893] bpb=1.125191 time=8914.7s + ttt_chunk [1221/1893] bpb=1.125516 time=8988.3s + ttt_chunk [1231/1893] bpb=1.125440 time=9061.8s + ttt_chunk [1241/1893] bpb=1.125135 time=9135.4s + ttt_chunk [1251/1893] bpb=1.124600 time=9209.0s + ttt_chunk [1261/1893] bpb=1.124336 time=9282.7s + ttt_chunk [1271/1893] bpb=1.124088 time=9356.4s + ttt_chunk [1281/1893] bpb=1.123772 time=9430.0s + ttt_chunk [1291/1893] bpb=1.123521 time=9503.6s + ttt_chunk [1301/1893] bpb=1.123471 time=9577.1s + ttt_chunk [1311/1893] bpb=1.123188 time=9650.7s + ttt_chunk [1321/1893] bpb=1.122893 time=9724.3s + ttt_chunk [1331/1893] bpb=1.122653 time=9797.9s + ttt_chunk [1341/1893] bpb=1.122520 time=9871.5s + ttt_chunk [1351/1893] bpb=1.122369 time=9945.1s + ttt_chunk [1361/1893] bpb=1.122500 time=10018.8s + ttt_chunk [1371/1893] bpb=1.122711 time=10092.4s + ttt_chunk [1381/1893] bpb=1.122921 time=10166.0s + ttt_chunk [1391/1893] bpb=1.122713 time=10239.6s + ttt_chunk [1401/1893] bpb=1.122753 time=10313.1s + ttt_chunk [1411/1893] bpb=1.122866 time=10386.7s + ttt_chunk [1421/1893] bpb=1.122861 time=10460.3s + ttt_chunk [1431/1893] bpb=1.122838 time=10533.8s + ttt_chunk [1441/1893] bpb=1.123318 time=10609.6s + ttt_chunk [1451/1893] bpb=1.123189 time=10683.1s + ttt_chunk [1461/1893] bpb=1.123121 time=10756.6s + ttt_chunk [1471/1893] bpb=1.123725 time=10830.2s + ttt_chunk [1481/1893] bpb=1.123601 time=10903.7s + ttt_chunk [1491/1893] bpb=1.123968 time=10977.2s + ttt_chunk [1501/1893] bpb=1.123946 time=11050.6s + ttt_chunk [1511/1893] bpb=1.123898 time=11124.1s + ttt_chunk [1521/1893] bpb=1.124014 time=11197.5s + ttt_chunk [1531/1893] bpb=1.124227 time=11271.0s + ttt_chunk [1541/1893] bpb=1.124295 time=11344.4s + ttt_chunk [1551/1893] bpb=1.124539 time=11418.0s + ttt_chunk [1561/1893] bpb=1.124622 time=11491.4s + ttt_chunk [1571/1893] bpb=1.124763 time=11564.7s + ttt_chunk [1581/1893] bpb=1.124918 time=11638.2s + ttt_chunk [1591/1893] bpb=1.124975 time=11711.5s + ttt_chunk [1601/1893] bpb=1.125091 time=11784.7s + ttt_chunk [1611/1893] bpb=1.125350 time=11857.8s + ttt_chunk [1621/1893] bpb=1.125216 time=11930.9s + ttt_chunk [1631/1893] bpb=1.125256 time=12004.1s + ttt_chunk [1641/1893] bpb=1.125275 time=12077.3s + ttt_chunk [1651/1893] bpb=1.125325 time=12152.4s + ttt_chunk [1661/1893] bpb=1.125470 time=12225.4s + ttt_chunk [1671/1893] bpb=1.125654 time=12298.4s + ttt_chunk [1681/1893] bpb=1.125744 time=12371.5s + ttt_chunk [1691/1893] bpb=1.125846 time=12444.4s + ttt_chunk [1701/1893] bpb=1.125941 time=12517.5s + ttt_chunk [1711/1893] bpb=1.125922 time=12590.5s + ttt_chunk [1721/1893] bpb=1.125758 time=12663.5s + ttt_chunk [1731/1893] bpb=1.125852 time=12736.4s + ttt_chunk [1741/1893] bpb=1.125590 time=12809.6s + ttt_chunk [1751/1893] bpb=1.125467 time=12882.6s + ttt_chunk [1761/1893] bpb=1.125505 time=12955.5s + ttt_chunk [1771/1893] bpb=1.125448 time=13028.4s + ttt_chunk [1781/1893] bpb=1.125347 time=13101.3s + ttt_chunk [1791/1893] bpb=1.125006 time=13174.2s + ttt_chunk [1801/1893] bpb=1.124982 time=13247.1s + ttt_chunk [1811/1893] bpb=1.124828 time=13320.0s + ttt_chunk [1821/1893] bpb=1.124885 time=13392.9s + ttt_chunk [1831/1893] bpb=1.124734 time=13465.8s + ttt_chunk [1841/1893] bpb=1.124742 time=13538.7s + ttt_chunk [1851/1893] bpb=1.124572 time=13611.7s + ttt_chunk [1861/1893] bpb=1.124487 time=13684.7s + ttt_chunk [1871/1893] bpb=1.124422 time=13757.7s + ttt_chunk [1881/1893] bpb=1.124172 time=13830.7s + ttt_chunk [1891/1893] bpb=1.124152 time=13903.7s + ttt_chunk [1893/1893] bpb=1.124183 time=13913.3s +ttt_sliding:done val_loss=1.898130 val_bpb=1.124183 elapsed=13913.3s +final_int6_sliding_window val_loss:1.8981 val_bpb:1.1242 stride:64 eval_time:13913796ms +final_int6_sliding_window_exact val_loss:1.89813008 val_bpb:1.12418252 + -> tuple[float, float]: + seq_len = eval_seq_len or args.train_seq_len + local_batch_tokens = args.val_batch_size // (world_size * grad_accum_steps) + if local_batch_tokens < seq_len: + raise ValueError( + "VAL_BATCH_SIZE must provide at least one sequence per rank; " + f"got VAL_BATCH_SIZE={args.val_batch_size}, WORLD_SIZE={world_size}, " + f"GRAD_ACCUM_STEPS={grad_accum_steps}, seq_len={seq_len}" + ) + local_batch_seqs = local_batch_tokens // seq_len + total_seqs = (val_tokens.numel() - 1) // seq_len + seq_start = (total_seqs * rank) // world_size + seq_end = (total_seqs * (rank + 1)) // world_size + val_loss_sum = torch.zeros((), device=device, dtype=torch.float64) + val_token_count = torch.zeros((), device=device, dtype=torch.float64) + val_byte_count = torch.zeros((), device=device, dtype=torch.float64) + + model.eval() + with torch.inference_mode(): + for batch_seq_start in range(seq_start, seq_end, local_batch_seqs): + batch_seq_end = min(batch_seq_start + local_batch_seqs, seq_end) + raw_start = batch_seq_start * seq_len + raw_end = batch_seq_end * seq_len + 1 + local = val_tokens[raw_start:raw_end].to(device=device, dtype=torch.int64, non_blocking=True) + x = local[:-1].reshape(-1, seq_len) + y = local[1:].reshape(-1, seq_len) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + batch_loss = model(x, y).detach() + batch_token_count = float(y.numel()) + val_loss_sum += batch_loss.to(torch.float64) * batch_token_count + val_token_count += batch_token_count + prev_ids = x.reshape(-1) + tgt_ids = y.reshape(-1) + token_bytes = base_bytes_lut[tgt_ids].to(dtype=torch.int16) + token_bytes += (has_leading_space_lut[tgt_ids] & ~is_boundary_token_lut[prev_ids]).to(dtype=torch.int16) + val_byte_count += token_bytes.to(torch.float64).sum() + + if dist.is_available() and dist.is_initialized(): + dist.all_reduce(val_loss_sum, op=dist.ReduceOp.SUM) + dist.all_reduce(val_token_count, op=dist.ReduceOp.SUM) + dist.all_reduce(val_byte_count, op=dist.ReduceOp.SUM) + + val_loss = val_loss_sum / val_token_count + bits_per_token = val_loss.item() / math.log(2.0) + tokens_per_byte = val_token_count.item() / val_byte_count.item() + model.train() + return float(val_loss.item()), float(bits_per_token * tokens_per_byte) + +# ----------------------------- +# POST-TRAINING QUANTIZATION +# ----------------------------- +# +# It's silly to export our model, which is trained in bf16 and fp32, at that same precision. +# Instead, we get approximately the same model (with a small hit) by quantizing the model to int8 & zlib compressing. +# We can then decompress the model and run in higher precision for evaluation, after closing in under the size limit. + +CONTROL_TENSOR_NAME_PATTERNS = tuple( + pattern + for pattern in os.environ.get( + "CONTROL_TENSOR_NAME_PATTERNS", + "attn_scale,attn_scales,mlp_scale,mlp_scales,resid_mix,resid_mixes,q_gain,skip_weight,skip_weights,smear,vr_lambda,attn_gate,canon_a,canon_c,delta_gate", + ).split(",") + if pattern +) +INT8_KEEP_FLOAT_FP32_NAME_PATTERNS = tuple( + pattern + for pattern in os.environ.get( + "INT8_KEEP_FLOAT_FP32_NAME_PATTERNS", + ",".join(CONTROL_TENSOR_NAME_PATTERNS), + ).split(",") + if pattern +) +INT8_KEEP_FLOAT_MAX_NUMEL = 65_536 +INT8_KEEP_FLOAT_STORE_DTYPE = torch.float16 +INT8_PER_ROW_SCALE_DTYPE = torch.float16 +INT8_CLIP_PERCENTILE = 99.99984 +INT8_CLIP_Q = INT8_CLIP_PERCENTILE / 100.0 + +def tensor_nbytes(t: Tensor) -> int: + return int(t.numel()) * int(t.element_size()) + +def keep_float_tensor(name: str, t: Tensor, passthrough_orig_dtypes: dict[str, str]) -> Tensor: + if any(pattern in name for pattern in INT8_KEEP_FLOAT_FP32_NAME_PATTERNS): + return t.float().contiguous() + if t.dtype in {torch.float32, torch.bfloat16}: + passthrough_orig_dtypes[name] = str(t.dtype).removeprefix("torch.") + return t.to(dtype=INT8_KEEP_FLOAT_STORE_DTYPE).contiguous() + return t + +def quantize_float_tensor(t: Tensor, qmax: int = 127) -> tuple[Tensor, Tensor]: + """Quantize to [-qmax, qmax] range. Default int8 (qmax=127), int6 (qmax=31), int5 (qmax=15).""" + t32 = t.float() + qmin = -qmax + if t32.ndim == 2: + clip_abs = ( + torch.quantile(t32.abs(), INT8_CLIP_Q, dim=1) + if t32.numel() + else torch.empty((t32.shape[0],), dtype=torch.float32) + ) + clipped = torch.maximum(torch.minimum(t32, clip_abs[:, None]), -clip_abs[:, None]) + scale = (clip_abs / float(qmax)).clamp_min(1.0 / float(qmax)) + q = torch.clamp(torch.round(clipped / scale[:, None]), qmin, qmax).to(torch.int8).contiguous() + return q, scale.to(dtype=INT8_PER_ROW_SCALE_DTYPE).contiguous() + + clip_abs = float(torch.quantile(t32.abs().flatten(), INT8_CLIP_Q).item()) if t32.numel() else 0.0 + scale = torch.tensor(clip_abs / float(qmax) if clip_abs > 0 else 1.0, dtype=torch.float32) + q = torch.clamp(torch.round(torch.clamp(t32, -clip_abs, clip_abs) / scale), qmin, qmax).to(torch.int8).contiguous() + return q, scale + +def quantize_state_dict_int8(state_dict: dict[str, Tensor]): + # Single supported clean-script export format: + # - per-row int8 for 2D float tensors + # - per-tensor int8 for other float tensors + # - exact passthrough for non-floats + # - passthrough for small float tensors, stored as fp16 to save bytes + quantized: dict[str, Tensor] = {} + scales: dict[str, Tensor] = {} + dtypes: dict[str, str] = {} + passthrough: dict[str, Tensor] = {} + passthrough_orig_dtypes: dict[str, str] = {} + qmeta: dict[str, dict[str, object]] = {} + stats = dict.fromkeys( + ("param_count", "num_tensors", "num_float_tensors", "num_nonfloat_tensors", "baseline_tensor_bytes", "int8_payload_bytes"), + 0, + ) + + for name, tensor in state_dict.items(): + t = tensor.detach().to("cpu").contiguous() + stats["param_count"] += int(t.numel()) + stats["num_tensors"] += 1 + stats["baseline_tensor_bytes"] += tensor_nbytes(t) + + if not t.is_floating_point(): + stats["num_nonfloat_tensors"] += 1 + passthrough[name] = t + stats["int8_payload_bytes"] += tensor_nbytes(t) + continue + + # Small float tensors are cheap enough to keep directly. We still downcast + # fp32/bf16 passthrough tensors to fp16 so metadata does not dominate size. + if t.numel() <= INT8_KEEP_FLOAT_MAX_NUMEL: + kept = keep_float_tensor(name, t, passthrough_orig_dtypes) + passthrough[name] = kept + stats["int8_payload_bytes"] += tensor_nbytes(kept) + continue + + stats["num_float_tensors"] += 1 + # Mixed quantization: int6 for MLP layers 3-7 to save artifact space + int6_mlp_layers = os.environ.get("INT6_MLP_LAYERS", "") + qmax = 127 # default int8 + if int6_mlp_layers: + for li in int6_mlp_layers.split(","): + if li.strip() and f"blocks.{li.strip()}.mlp" in name and t.ndim == 2: + qmax = 31 # int6 + break + q, s = quantize_float_tensor(t, qmax=qmax) + if s.ndim > 0: + qmeta[name] = {"scheme": "per_row", "axis": 0} + quantized[name] = q + scales[name] = s + dtypes[name] = str(t.dtype).removeprefix("torch.") + stats["int8_payload_bytes"] += tensor_nbytes(q) + tensor_nbytes(s) + + obj: dict[str, object] = { + "__quant_format__": "int8_clean_per_row_v1", + "quantized": quantized, + "scales": scales, + "dtypes": dtypes, + "passthrough": passthrough, + } + if qmeta: + obj["qmeta"] = qmeta + if passthrough_orig_dtypes: + obj["passthrough_orig_dtypes"] = passthrough_orig_dtypes + return obj, stats + +def dequantize_state_dict_int8(obj: dict[str, object]) -> dict[str, Tensor]: + out: dict[str, Tensor] = {} + qmeta = obj.get("qmeta", {}) + passthrough_orig_dtypes = obj.get("passthrough_orig_dtypes", {}) + for name, q in obj["quantized"].items(): + dtype = getattr(torch, obj["dtypes"][name]) + s = obj["scales"][name] + if qmeta.get(name, {}).get("scheme") == "per_row" or s.ndim > 0: + s = s.to(dtype=torch.float32) + # Broadcast the saved row scale back across trailing dimensions. + out[name] = (q.float() * s.view(q.shape[0], *([1] * (q.ndim - 1)))).to(dtype=dtype).contiguous() + else: + scale = float(s.item()) + out[name] = (q.float() * scale).to(dtype=dtype).contiguous() + for name, t in obj["passthrough"].items(): + # Restore small tensors, undoing the temporary fp16 storage cast if needed. + out_t = t.detach().to("cpu").contiguous() + orig_dtype = passthrough_orig_dtypes.get(name) + if isinstance(orig_dtype, str): + out_t = out_t.to(dtype=getattr(torch, orig_dtype)).contiguous() + out[name] = out_t + return out + + +# ----------------------------- +# DATA LOADING +# ----------------------------- + +def load_data_shard(file: Path) -> Tensor: + header_bytes = 256 * np.dtype(" None: + self.file_idx = (self.file_idx + 1) % len(self.files) + self.tokens = load_data_shard(self.files[self.file_idx]) + self.pos = 0 + + def take(self, n: int) -> Tensor: + chunks: list[Tensor] = [] + remaining = n + while remaining > 0: + avail = self.tokens.numel() - self.pos + if avail <= 0: + self._advance_file() + continue + k = min(remaining, avail) + chunks.append(self.tokens[self.pos : self.pos + k]) + self.pos += k + remaining -= k + return chunks[0] if len(chunks) == 1 else torch.cat(chunks) + + +class DistributedTokenLoader: + # Each call consumes a contiguous chunk from the shared token stream, then slices out + # one disjoint span per rank. The extra "+1" token lets us build (x, y) by shifting. + def __init__(self, pattern: str, rank: int, world_size: int, device: torch.device): + self.rank = rank + self.world_size = world_size + self.device = device + self.stream = TokenStream(pattern) + + def next_batch(self, global_tokens: int, seq_len: int, grad_accum_steps: int) -> tuple[Tensor, Tensor]: + local_tokens = global_tokens // (self.world_size * grad_accum_steps) + per_rank_span = local_tokens + 1 + chunk = self.stream.take(per_rank_span * self.world_size) + start = self.rank * per_rank_span + local = chunk[start : start + per_rank_span].to(dtype=torch.int64) + x = local[:-1].reshape(-1, seq_len) + y = local[1:].reshape(-1, seq_len) + return x.to(self.device, non_blocking=True), y.to(self.device, non_blocking=True) + +# ----------------------------- +# TRANSFORMER MODULES +# ----------------------------- + +class RMSNorm(nn.Module): + def __init__(self, eps: float | None = None): + super().__init__() + self.eps = eps + + def forward(self, x: Tensor) -> Tensor: + return F.rms_norm(x, (x.size(-1),), eps=self.eps) + + +class CastedLinear(nn.Linear): + _qat_enabled: bool = False + _soft_round: bool = False + _soft_round_alpha: float = 1.0 + _quant_percentile: float = float(os.environ.get("QUANT_PERCENTILE", "1.0")) + + def forward(self, x: Tensor) -> Tensor: + w = self.weight.to(x.dtype) + if CastedLinear._qat_enabled and self.training and w.ndim == 2: + w32 = self.weight.float() + pct = CastedLinear._quant_percentile + row_max = (torch.quantile(w32.abs(), pct, dim=1) if pct < 1.0 + else w32.abs().amax(dim=1)).detach() + scale = (row_max / 31.0).clamp_min(1.0 / 31.0) + r = w32 / scale[:, None] + if CastedLinear._soft_round: + alpha = CastedLinear._soft_round_alpha + r_frac = r - r.detach().floor() - 0.5 + norm = torch.tanh(torch.tensor(alpha * 0.5, device=r.device, dtype=r.dtype)) + r_soft = r.detach().floor() + 0.5 + torch.tanh(alpha * r_frac) / (2.0 * norm) + w_q = (torch.clamp(r_soft, -32, 31) * scale[:, None]).to(x.dtype) + w = w_q # soft-round is differentiable, no STE needed + else: + with torch.no_grad(): + w_q = (torch.clamp(torch.round(r), -32, 31) * scale[:, None]).to(x.dtype) + w = w + (w_q - w).detach() # STE + bias = self.bias.to(x.dtype) if self.bias is not None else None + return F.linear(x, w, bias) + + +def restore_low_dim_params_to_fp32(module: nn.Module) -> None: + # Keep small/control parameters in fp32 even when the model body runs in bf16. + with torch.no_grad(): + for name, param in module.named_parameters(): + if (param.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS)) and param.dtype != torch.float32: + param.data = param.data.float() + + +class Rotary(nn.Module): + # NTK-aware RoPE: auto-scales base frequency when seq_len exceeds train_seq_len. + def __init__(self, dim: int, base: float = 10000.0, train_seq_len: int = 1024, rope_dims: int = 0): + super().__init__() + self.rope_dims = rope_dims if rope_dims > 0 else dim + self.dim = dim + self.base = base + self.train_seq_len = train_seq_len + rd = self.rope_dims + inv_freq = 1.0 / (base ** (torch.arange(0, rd, 2, dtype=torch.float32) / rd)) + self.register_buffer("inv_freq", inv_freq, persistent=False) + self._seq_len_cached = 0 + self._cos_cached: Tensor | None = None + self._sin_cached: Tensor | None = None + + def forward(self, seq_len: int, device: torch.device, dtype: torch.dtype) -> tuple[Tensor, Tensor]: + if ( + self._cos_cached is None + or self._sin_cached is None + or self._seq_len_cached != seq_len + or self._cos_cached.device != device + ): + rd = self.rope_dims + if seq_len > self.train_seq_len: + scale = seq_len / self.train_seq_len + new_base = self.base * (scale ** (rd / (rd - 2))) + inv_freq = 1.0 / (new_base ** (torch.arange(0, rd, 2, dtype=torch.float32, device=device) / rd)) + else: + inv_freq = self.inv_freq.to(device) + t = torch.arange(seq_len, device=device, dtype=inv_freq.dtype) + freqs = torch.outer(t, inv_freq) + self._cos_cached = freqs.cos()[None, :, None, :] + self._sin_cached = freqs.sin()[None, :, None, :] + self._seq_len_cached = seq_len + return self._cos_cached.to(dtype=dtype), self._sin_cached.to(dtype=dtype) + + +def apply_rotary_emb(x: Tensor, cos: Tensor, sin: Tensor) -> Tensor: + rd = cos.size(-1) * 2 + if rd < x.size(-1): + x_rope, x_pass = x[..., :rd], x[..., rd:] + half = rd // 2 + x1, x2 = x_rope[..., :half], x_rope[..., half:] + x_rot = torch.cat((x1 * cos + x2 * sin, x1 * (-sin) + x2 * cos), dim=-1) + return torch.cat((x_rot, x_pass), dim=-1) + half = x.size(-1) // 2 + x1, x2 = x[..., :half], x[..., half:] + return torch.cat((x1 * cos + x2 * sin, x1 * (-sin) + x2 * cos), dim=-1) + + +class CausalSelfAttention(nn.Module): + def __init__( + self, + dim: int, + num_heads: int, + num_kv_heads: int, + rope_base: float, + qk_gain_init: float, + rope_dims: int = 0, + value_residual: bool = False, + gated_attention: bool = False, + ): + super().__init__() + if dim % num_heads != 0: + raise ValueError("model_dim must be divisible by num_heads") + if num_heads % num_kv_heads != 0: + raise ValueError("num_heads must be divisible by num_kv_heads") + self.num_heads = num_heads + self.num_kv_heads = num_kv_heads + self.head_dim = dim // num_heads + if self.head_dim % 2 != 0: + raise ValueError("head_dim must be even for RoPE") + kv_dim = self.num_kv_heads * self.head_dim + self.c_q = CastedLinear(dim, dim, bias=False) + self.c_k = CastedLinear(dim, kv_dim, bias=False) + self.c_v = CastedLinear(dim, kv_dim, bias=False) + self.proj = CastedLinear(dim, dim, bias=False) + self.proj._zero_init = True + self.q_gain = nn.Parameter(torch.full((num_heads,), qk_gain_init, dtype=torch.float32)) + self.rope_dims = rope_dims + self.rotary = Rotary(self.head_dim, base=rope_base, train_seq_len=1024, rope_dims=rope_dims) + self.use_xsa = False + self.value_residual = value_residual + if value_residual: + self.vr_lambda = nn.Parameter(torch.tensor([0.5, 0.5], dtype=torch.float32)) + self.gated_attention = gated_attention + if gated_attention: + self.attn_gate = nn.Linear(dim, num_heads, bias=True) + nn.init.zeros_(self.attn_gate.weight) + nn.init.constant_(self.attn_gate.bias, 4.0) + + def _xsa_efficient(self, y: Tensor, v: Tensor) -> Tensor: + """Subtract self-value projection via GQA-aware reshape (no repeat_interleave).""" + B, T, H, D = y.shape + Hkv = v.size(-2) + group = H // Hkv + y_g = y.reshape(B, T, Hkv, group, D) + vn = F.normalize(v, dim=-1).unsqueeze(-2) + proj = (y_g * vn).sum(dim=-1, keepdim=True) * vn + return (y_g - proj).reshape(B, T, H, D) + + def forward(self, x: Tensor, v0: Tensor | None = None) -> tuple[Tensor, Tensor | None]: + bsz, seqlen, dim = x.shape + q = self.c_q(x).reshape(bsz, seqlen, self.num_heads, self.head_dim) + k = self.c_k(x).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim) + v = self.c_v(x).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim) + raw_v = v if self.value_residual else None + if self.value_residual and v0 is not None: + lam = self.vr_lambda.to(dtype=v.dtype) + v = lam[0] * v0 + lam[1] * v + q = F.rms_norm(q, (q.size(-1),)) + k = F.rms_norm(k, (k.size(-1),)) + cos, sin = self.rotary(seqlen, x.device, q.dtype) + q = apply_rotary_emb(q, cos, sin) + k = apply_rotary_emb(k, cos, sin) + q = q * self.q_gain.to(dtype=q.dtype)[None, None, :, None] + fa_dtype = torch.bfloat16 + if _USE_FA3: + y = flash_attn_3_func(q.to(fa_dtype), k.to(fa_dtype), v.to(fa_dtype), causal=True) + else: + # SDPA fallback: (B, T, H, D) -> (B, H, T, D), expand KV for GQA + q_t = q.to(fa_dtype).transpose(1, 2) + k_t = k.to(fa_dtype).transpose(1, 2) + v_t = v.to(fa_dtype).transpose(1, 2) + if self.num_kv_heads != self.num_heads: + rep = self.num_heads // self.num_kv_heads + k_t = k_t.repeat_interleave(rep, dim=1) + v_t = v_t.repeat_interleave(rep, dim=1) + y = F.scaled_dot_product_attention(q_t, k_t, v_t, is_causal=True) + y = y.transpose(1, 2) # (B, H, T, D) -> (B, T, H, D) + if self.use_xsa: + y = self._xsa_efficient(y, v) + if self.gated_attention: + gate = torch.sigmoid(self.attn_gate(x)) # (B, T, num_heads) + y = y * gate.unsqueeze(-1) # (B, T, H, 1) broadcast to (B, T, H, D) + y = y.reshape(bsz, seqlen, dim) + return self.proj(y), raw_v + + +class SmearGate(nn.Module): + def __init__(self, dim: int): + super().__init__() + self.gate = nn.Parameter(torch.zeros(dim, dtype=torch.float32)) + + def forward(self, x: Tensor) -> Tensor: + g = torch.sigmoid(self.gate.to(dtype=x.dtype))[None, None, :] + x_prev = torch.cat([torch.zeros_like(x[:, :1]), x[:, :-1]], dim=1) + return (1 - g) * x + g * x_prev + + +class BigramHashEmbedding(nn.Module): + def __init__(self, bigram_vocab_size: int, bigram_dim: int, model_dim: int): + super().__init__() + self.bigram_vocab_size = bigram_vocab_size + self.embed = nn.Embedding(bigram_vocab_size, bigram_dim) + nn.init.zeros_(self.embed.weight) + self.proj = CastedLinear(bigram_dim, model_dim, bias=False) if bigram_dim != model_dim else None + if self.proj is not None: + nn.init.zeros_(self.proj.weight) + self.scale = nn.Parameter(torch.tensor(0.05, dtype=torch.float32)) + + def bigram_hash(self, tokens: Tensor) -> Tensor: + t = tokens.to(torch.int32) + mod = self.bigram_vocab_size - 1 + out = torch.empty_like(t) + out[..., 0] = mod + out[..., 1:] = torch.bitwise_xor(36313 * t[..., 1:], 27191 * t[..., :-1]) % mod + return out.long() + + def forward(self, token_ids: Tensor) -> Tensor: + h = self.embed(self.bigram_hash(token_ids)) + if self.proj is not None: + h = self.proj(h) + return h * self.scale.to(dtype=h.dtype) + + +class MLP(nn.Module): + def __init__(self, dim: int, mlp_mult: int): + super().__init__() + hidden = int(mlp_mult * dim) + self.fc = CastedLinear(dim, hidden, bias=False) + self.proj = CastedLinear(hidden, dim, bias=False) + self.proj._zero_init = True + self.use_leaky = bool(int(os.environ.get("LEAKY_RELU", "1"))) + self.leaky_slope = float(os.environ.get("LEAKY_SLOPE", "0.5")) + + def forward(self, x: Tensor) -> Tensor: + x = F.leaky_relu(self.fc(x), self.leaky_slope) if self.use_leaky else torch.relu(self.fc(x)) + return self.proj(x.square()) + + +class CanonAC(nn.Module): + """Canon Autoregressive Convolution with DeltaGate. Manual shift+mul (no Conv1d).""" + def __init__(self, dim: int, kernel: int = 4, delta_gate_init: float = -4.0): + super().__init__() + self.kernel = kernel + self.weight = nn.Parameter(torch.zeros(kernel, dim)) + self.delta_gate_logit = nn.Parameter(torch.tensor(delta_gate_init)) + + def forward(self, x: Tensor) -> Tensor: + B, T, D = x.shape + K = self.kernel + w = self.weight.to(x.dtype) + x_pad = F.pad(x, (0, 0, K - 1, 0)) + y = w[0] * x_pad[:, K - 1:, :] + for k in range(1, K): + y = y + w[k] * x_pad[:, K - 1 - k : T + K - 1 - k, :] + gate = torch.sigmoid(self.delta_gate_logit.to(x.dtype)) + return x + gate * y + + +class Block(nn.Module): + def __init__( + self, + dim: int, + num_heads: int, + num_kv_heads: int, + mlp_mult: int, + rope_base: float, + qk_gain_init: float, + rope_dims: int = 0, + layer_idx: int = 0, + ln_scale: bool = False, + value_residual: bool = False, + gated_attention: bool = False, + canon_kernel: int = 0, + canon_delta_gate_init: float = -4.0, + ): + super().__init__() + self.attn_norm = RMSNorm() + self.mlp_norm = RMSNorm() + self.attn = CausalSelfAttention(dim, num_heads, num_kv_heads, rope_base, qk_gain_init, + rope_dims=rope_dims, value_residual=value_residual, + gated_attention=gated_attention) + self.mlp = MLP(dim, mlp_mult) + self.attn_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + self.mlp_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + self.resid_mix = nn.Parameter(torch.stack((torch.ones(dim), torch.zeros(dim))).float()) + self.ln_scale_factor = 1.0 / math.sqrt(layer_idx + 1) if ln_scale else 1.0 + self.canon_a = CanonAC(dim, canon_kernel, canon_delta_gate_init) if canon_kernel > 0 else None + self.canon_c = CanonAC(dim, canon_kernel, canon_delta_gate_init) if canon_kernel > 0 else None + + def forward(self, x: Tensor, x0: Tensor, v0: Tensor | None = None) -> tuple[Tensor, Tensor | None]: + mix = self.resid_mix.to(dtype=x.dtype) + x = mix[0][None, None, :] * x + mix[1][None, None, :] * x0 + s = self.ln_scale_factor + attn_in = self.attn_norm(x) * s + if self.canon_a is not None: + attn_in = self.canon_a(attn_in) + attn_out, raw_v = self.attn(attn_in, v0=v0) + x = x + self.attn_scale.to(dtype=x.dtype)[None, None, :] * attn_out + mlp_in = self.mlp_norm(x) * s + if self.canon_c is not None: + mlp_in = self.canon_c(mlp_in) + x = x + self.mlp_scale.to(dtype=x.dtype)[None, None, :] * self.mlp(mlp_in) + return x, raw_v + + +class FlowRefiner(nn.Module): + """1-step flow matching refiner in low-dim latent space. + + Projects hidden states to a low-dim latent, applies a learned velocity + field (1-step Euler), projects back. Output is additive adjustment to + hidden states. Initialized near-zero so the refiner starts as identity. + """ + def __init__(self, model_dim: int, latent_dim: int = 64, hidden_dim: int = 256, init_scale: float = 0.01): + super().__init__() + self.down_proj = nn.Linear(model_dim, latent_dim, bias=False) + self.velocity_net = nn.Sequential( + nn.Linear(latent_dim, hidden_dim, bias=True), + nn.GELU(), + nn.Linear(hidden_dim, latent_dim, bias=True), + ) + self.up_proj = nn.Linear(latent_dim, model_dim, bias=False) + self.gate = nn.Parameter(torch.tensor(-5.0)) # sigmoid(-5) ≈ 0.007 + self._init_scale = init_scale + self._init_weights() + self.velocity_net[2]._zero_init = True + + def _init_weights(self): + nn.init.normal_(self.down_proj.weight, std=self._init_scale) + nn.init.normal_(self.velocity_net[0].weight, std=self._init_scale) + nn.init.zeros_(self.velocity_net[0].bias) + nn.init.zeros_(self.velocity_net[2].weight) + nn.init.zeros_(self.velocity_net[2].bias) + nn.init.normal_(self.up_proj.weight, std=self._init_scale) + + def forward(self, x: Tensor) -> Tensor: + z = self.down_proj(x) + v = self.velocity_net(z) + z_refined = z + v + delta = self.up_proj(z_refined) + return torch.sigmoid(self.gate) * delta + + +class NativeFlowMatcher(nn.Module): + """Conditional Flow Matching refiner for hidden states. + + During training: computes auxiliary CFM loss on interpolated states. + During inference: applies a single Euler step correction at t=1. + """ + def __init__(self, model_dim: int, hidden_dim: int = 256, init_scale: float = 0.01): + super().__init__() + self.model_dim = model_dim + # Sinusoidal time embedding → projected to hidden_dim + self.time_proj = nn.Sequential( + nn.Linear(model_dim, hidden_dim, bias=True), + nn.GELU(), + ) + # Velocity network: x → hidden → x (with time conditioning via addition) + self.v_in = nn.Linear(model_dim, hidden_dim, bias=True) + self.v_act = nn.GELU() + self.v_out = nn.Linear(hidden_dim, model_dim, bias=False) + self.gate = nn.Parameter(torch.tensor(-5.0)) # sigmoid(-5) ≈ 0.007 + self._init_scale = init_scale + self._init_weights() + self.v_out._zero_init = True + + def _init_weights(self): + nn.init.normal_(self.v_in.weight, std=self._init_scale) + nn.init.zeros_(self.v_in.bias) + nn.init.normal_(self.time_proj[0].weight, std=self._init_scale) + nn.init.zeros_(self.time_proj[0].bias) + nn.init.zeros_(self.v_out.weight) + + @staticmethod + def _sinusoidal_time_emb(t: Tensor, dim: int) -> Tensor: + """Sinusoidal positional embedding for scalar time t. t: (...,) -> (..., dim).""" + half_dim = dim // 2 + emb = math.log(10000.0) / max(half_dim - 1, 1) + emb = torch.exp(torch.arange(half_dim, device=t.device, dtype=t.dtype) * -emb) + emb = t.unsqueeze(-1) * emb # (..., half_dim) + emb = torch.cat([emb.sin(), emb.cos()], dim=-1) # (..., dim) + if dim % 2 == 1: + emb = F.pad(emb, (0, 1)) + return emb + + def _velocity(self, x: Tensor, t_emb: Tensor) -> Tensor: + """Compute velocity v(x, t). x: (..., model_dim), t_emb: (..., hidden_dim).""" + h = self.v_in(x) # (..., hidden_dim) + h = h + t_emb # time conditioning via addition + h = self.v_act(h) # GELU + return self.v_out(h) # (..., model_dim) + + def forward(self, x: Tensor) -> tuple[Tensor, Tensor]: + """Returns (correction, cfm_loss). + + During training: cfm_loss is the flow matching objective (> 0). + During eval: cfm_loss is zero. + correction: gated velocity at t=1, to be added to x. + """ + B_seq = x.shape[:-1] # works for any leading dimensions + D = self.model_dim + + # Inference path: velocity at t=1 (clean input, no noise) + t_ones = torch.ones(*B_seq, device=x.device, dtype=x.dtype) + t_emb_ones = self.time_proj(self._sinusoidal_time_emb(t_ones, D)) + v_at_1 = self._velocity(x, t_emb_ones) + correction = torch.sigmoid(self.gate) * v_at_1 + + # Training path: auxiliary CFM loss + if self.training: + # Sample random t ~ U(0,1) per position + t = torch.rand(*B_seq, device=x.device, dtype=x.dtype) + z = torch.randn_like(x) # noise + # OT interpolant: x_t = (1-t)*z + t*x + t_expanded = t.unsqueeze(-1) # (..., 1) + x_t = (1.0 - t_expanded) * z + t_expanded * x.detach() + # Target velocity = x - z (OT conditional velocity field) + target_v = x.detach() - z + # Predict velocity at interpolated point + t_emb = self.time_proj(self._sinusoidal_time_emb(t, D)) + pred_v = self._velocity(x_t, t_emb) + cfm_loss = F.mse_loss(pred_v, target_v) + else: + cfm_loss = x.new_zeros(()) + + return correction, cfm_loss + + +class TTTLinearRefiner(nn.Module): + """E2E TTT-Linear refiner (Sun et al., 2024, arXiv:2407.04620). + + Applied to final hidden states [B, L, D] before lm_head. The hidden + state is a per-head linear model W updated via mini-batch gradient + descent on a learned self-supervised reconstruction task. Uses the + dual form for GPU efficiency. Trained end-to-end: the outer loop + optimizes theta_K, theta_V, theta_Q projections and the inner-loop + initialization W_0, making the model learn WHAT to compress. + """ + + def __init__(self, model_dim: int, num_heads: int = 8, + mini_batch_size: int = 16, ttt_base_lr: float = 1.0): + super().__init__() + self.model_dim = model_dim + self.num_heads = num_heads + self.head_dim = model_dim // num_heads + self.mini_batch_size = mini_batch_size + self.ttt_base_lr = ttt_base_lr + self.eta = ttt_base_lr / self.head_dim + + # Learnable projections (outer-loop) + self.theta_K = nn.Linear(model_dim, model_dim, bias=False) + self.theta_V = nn.Linear(model_dim, model_dim, bias=False) + self.theta_Q = nn.Linear(model_dim, model_dim, bias=False) + + # Inner-loop model: W_0, b_0 + self.W1 = nn.Parameter(torch.normal(0, 0.02, + size=(num_heads, self.head_dim, self.head_dim))) + self.b1 = nn.Parameter(torch.zeros(num_heads, 1, self.head_dim)) + + # Inner-loop LayerNorm (outer-loop trainable) + self.ttt_ln_w = nn.Parameter(torch.ones(num_heads, self.head_dim)) + self.ttt_ln_b = nn.Parameter(torch.zeros(num_heads, self.head_dim)) + + # Output projection + gate (starts near zero for residual safety) + self.o_proj = nn.Linear(model_dim, model_dim, bias=False) + self.post_norm = nn.LayerNorm(model_dim, eps=1e-6) + self.gate = nn.Parameter(torch.tensor(-5.0)) + nn.init.zeros_(self.o_proj.weight) + self.o_proj._zero_init = True + + @staticmethod + def _ln_fwd(x: Tensor, gamma: Tensor, beta: Tensor, eps: float = 1e-6) -> Tensor: + mu = x.mean(dim=-1, keepdim=True) + var = x.var(dim=-1, keepdim=True, unbiased=False) + x_hat = (x - mu) / torch.sqrt(var + eps) + return gamma * x_hat + beta + + @staticmethod + def _ln_fused_l2_bwd(x: Tensor, target: Tensor, gamma: Tensor, beta: Tensor, + eps: float = 1e-6) -> Tensor: + D = x.shape[-1] + mu = x.mean(dim=-1, keepdim=True) + var = x.var(dim=-1, keepdim=True, unbiased=False) + std = torch.sqrt(var + eps) + x_hat = (x - mu) / std + y = gamma * x_hat + beta + grad_output = y - target + grad_x_hat = grad_output * gamma + z = (1.0 / D) * ( + D * grad_x_hat + - grad_x_hat.sum(dim=-1, keepdim=True) + - x_hat * (grad_x_hat * x_hat).sum(dim=-1, keepdim=True) + ) / std + return z + + def forward(self, x: Tensor) -> Tensor: + """x: [B, L, D]. Returns [B, L, D] additive refinement (gated).""" + B, L, D = x.shape + NH, HD, b = self.num_heads, self.head_dim, self.mini_batch_size + eta = self.eta + + num_mb = L // b + usable_L = num_mb * b + x_proc = x[:, :usable_L] if usable_L < L else x + + # Project: [B, L, D] -> [B, NH, L, HD] + XK = self.theta_K(x_proc).reshape(B, usable_L, NH, HD).permute(0, 2, 1, 3) + XV = self.theta_V(x_proc).reshape(B, usable_L, NH, HD).permute(0, 2, 1, 3) + XQ = self.theta_Q(x_proc).reshape(B, usable_L, NH, HD).permute(0, 2, 1, 3) + + # Reshape into mini-batches: [B, NH, num_mb, b, HD] + XK = XK.reshape(B, NH, num_mb, b, HD) + XV = XV.reshape(B, NH, num_mb, b, HD) + XQ = XQ.reshape(B, NH, num_mb, b, HD) + + ln_w = self.ttt_ln_w.reshape(1, NH, 1, HD) + ln_b = self.ttt_ln_b.reshape(1, NH, 1, HD) + + # Causal mask for dual form + causal = torch.tril(torch.ones(b, b, device=x.device, dtype=x.dtype)) + + # Initialize inner-loop model + W = self.W1.unsqueeze(0).expand(B, -1, -1, -1).clone() + bias = self.b1.unsqueeze(0).expand(B, -1, -1, -1).clone() + + all_out = [] + for m in range(num_mb): + xk = XK[:, :, m] + xv = XV[:, :, m] + xq = XQ[:, :, m] + + Z1 = xk @ W + bias + target = xv - xk + + grad = self._ln_fused_l2_bwd(Z1, target, ln_w, ln_b) + + Attn = causal.unsqueeze(0).unsqueeze(0) * (xq @ xk.transpose(-2, -1)) + cumgrad = (causal.unsqueeze(0).unsqueeze(0) @ grad) + b_bar = bias - eta * cumgrad + Z_bar = xq @ W - eta * Attn @ grad + b_bar + + W = W - eta * xk.transpose(-2, -1) @ grad + bias = bias - eta * grad.sum(dim=-2, keepdim=True) + + Z_bar = self._ln_fwd(Z_bar, ln_w, ln_b) + out_mb = xq + Z_bar + all_out.append(out_mb) + + z = torch.cat(all_out, dim=2).permute(0, 2, 1, 3).reshape(B, usable_L, D) + z = self.post_norm(z) + z = self.o_proj(z) + result = torch.sigmoid(self.gate) * z + + if usable_L < L: + pad = torch.zeros(B, L - usable_L, D, device=x.device, dtype=x.dtype) + result = torch.cat([result, pad], dim=1) + + return result + + +class GPT(nn.Module): + def __init__( + self, + vocab_size: int, + num_layers: int, + model_dim: int, + num_heads: int, + num_kv_heads: int, + mlp_mult: int, + tie_embeddings: bool, + tied_embed_init_std: float, + logit_softcap: float, + rope_base: float, + qk_gain_init: float, + mtp_num_heads: int = 0, + mtp_loss_weight: float = 0.1, + bigram_vocab_size: int = 0, + bigram_dim: int = 128, + xsa_last_n: int = 0, + rope_dims: int = 0, + ln_scale: bool = False, + value_residual: bool = False, + gated_attention: bool = False, + canon_last_n: int = 0, + canon_kernel: int = 4, + canon_delta_gate_init: float = -4.0, + flow_enabled: bool = False, + flow_latent_dim: int = 64, + flow_hidden_dim: int = 256, + flow_init_scale: float = 0.01, + native_flow_enabled: bool = False, + native_flow_hidden_dim: int = 256, + native_flow_init_scale: float = 0.01, + native_flow_loss_weight: float = 0.1, + e2e_ttt_enabled: bool = False, + e2e_ttt_num_heads: int = 8, + e2e_ttt_mini_batch: int = 16, + e2e_ttt_base_lr: float = 1.0, + ): + super().__init__() + if logit_softcap <= 0.0: + raise ValueError(f"logit_softcap must be positive, got {logit_softcap}") + self.tie_embeddings = tie_embeddings + self.tied_embed_init_std = tied_embed_init_std + self.logit_softcap = logit_softcap + self.mtp_num_heads = mtp_num_heads + self.mtp_loss_weight = mtp_loss_weight + self.tok_emb = nn.Embedding(vocab_size, model_dim) + self.bigram = BigramHashEmbedding(bigram_vocab_size, bigram_dim, model_dim) if bigram_vocab_size > 0 else None + self.smear = SmearGate(model_dim) + self.num_encoder_layers = num_layers // 2 + self.num_decoder_layers = num_layers - self.num_encoder_layers + self.num_skip_weights = min(self.num_encoder_layers, self.num_decoder_layers) + self.skip_weights = nn.Parameter(torch.ones(self.num_skip_weights, model_dim, dtype=torch.float32)) + canon_start = num_layers - canon_last_n if canon_last_n > 0 else num_layers + self.blocks = nn.ModuleList( + [ + Block( + model_dim, + num_heads, + num_kv_heads, + mlp_mult, + rope_base, + qk_gain_init, + rope_dims=rope_dims, + layer_idx=i, + ln_scale=ln_scale, + value_residual=value_residual, + gated_attention=gated_attention, + canon_kernel=canon_kernel if i >= canon_start else 0, + canon_delta_gate_init=canon_delta_gate_init, + ) + for i in range(num_layers) + ] + ) + self.final_norm = RMSNorm() + self.lm_head = None if tie_embeddings else CastedLinear(model_dim, vocab_size, bias=False) + if self.lm_head is not None: + self.lm_head._zero_init = True + self.flow_refiner = FlowRefiner(model_dim, flow_latent_dim, flow_hidden_dim, flow_init_scale) if flow_enabled else None + self.native_flow = NativeFlowMatcher(model_dim, native_flow_hidden_dim, native_flow_init_scale) if native_flow_enabled else None + self.native_flow_loss_weight = native_flow_loss_weight + self.ttt_refiner = TTTLinearRefiner(model_dim, e2e_ttt_num_heads, e2e_ttt_mini_batch, e2e_ttt_base_lr) if e2e_ttt_enabled else None + self.mtp_heads = nn.ModuleList( + [CastedLinear(model_dim, vocab_size, bias=False) for _ in range(mtp_num_heads)] + ) + for head in self.mtp_heads: + head._zero_init = True + if xsa_last_n > 0: + for i in range(max(0, num_layers - xsa_last_n), num_layers): + self.blocks[i].attn.use_xsa = True + self._init_weights() + + def _init_weights(self) -> None: + if self.tie_embeddings: + nn.init.normal_(self.tok_emb.weight, mean=0.0, std=self.tied_embed_init_std) + num_layers = len(self.blocks) + for name, module in self.named_modules(): + if isinstance(module, nn.Linear): + if getattr(module, "_zero_init", False): + nn.init.zeros_(module.weight) + elif module.weight.ndim == 2 and module.weight.shape[0] >= 64 and module.weight.shape[1] >= 64: + nn.init.orthogonal_(module.weight, gain=1.0) + if ".proj." in name or name.endswith(".proj"): + with torch.no_grad(): + module.weight.mul_(1.0 / math.sqrt(2 * num_layers)) + + def forward(self, input_ids: Tensor, target_ids: Tensor) -> Tensor: + x = self.tok_emb(input_ids) + if self.bigram is not None: + x = x + self.bigram(input_ids) + x = F.rms_norm(x, (x.size(-1),)) + x = self.smear(x) + x0 = x + v0 = None + skips: list[Tensor] = [] + + for i in range(self.num_encoder_layers): + x, raw_v = self.blocks[i](x, x0, v0=v0) + if v0 is None and raw_v is not None: + v0 = raw_v + skips.append(x) + for i in range(self.num_decoder_layers): + if skips: + x = x + self.skip_weights[i].to(dtype=x.dtype)[None, None, :] * skips.pop() + x, _ = self.blocks[self.num_encoder_layers + i](x, x0, v0=v0) + + x = self.final_norm(x) + if self.ttt_refiner is not None: + x = x + self.ttt_refiner(x) + if self.flow_refiner is not None: + x = x + self.flow_refiner(x) + nflow_loss = x.new_zeros(()) + if self.native_flow is not None: + correction, nflow_loss = self.native_flow(x) + x = x + correction + x_flat = x.reshape(-1, x.size(-1)) + targets = target_ids.reshape(-1) + if self.tie_embeddings: + logits_proj = F.linear(x_flat, self.tok_emb.weight) + else: + if self.lm_head is None: + raise RuntimeError("lm_head is required when tie_embeddings=False") + logits_proj = self.lm_head(x_flat) + logits = self.logit_softcap * torch.tanh(logits_proj / self.logit_softcap) + main_loss = F.cross_entropy(logits.float(), targets, reduction="mean") + + if self.training and self.mtp_num_heads > 0 and self.mtp_loss_weight > 0.0: + _, seqlen, dim = x.shape + mtp_loss_sum = x.new_zeros(()) + mtp_loss_count = 0 + for k, mtp_head in enumerate(self.mtp_heads): + valid_t = seqlen - (k + 1) + if valid_t <= 0: + continue + mtp_hidden = x[:, :valid_t, :].reshape(-1, dim) + mtp_targets = target_ids[:, k + 1 :].reshape(-1) + mtp_logits_proj = mtp_head(mtp_hidden) + mtp_logits = self.logit_softcap * torch.tanh(mtp_logits_proj / self.logit_softcap) + mtp_loss_sum = mtp_loss_sum + F.cross_entropy(mtp_logits.float(), mtp_targets, reduction="mean") + mtp_loss_count += 1 + if mtp_loss_count > 0: + main_loss = main_loss + self.mtp_loss_weight * (mtp_loss_sum / mtp_loss_count) + + if self.native_flow is not None and self.training: + main_loss = main_loss + self.native_flow_loss_weight * nflow_loss + + return main_loss + + def forward_logits(self, input_ids: Tensor) -> Tensor: + """Return logits (bsz, seq_len, vocab) without computing loss.""" + x = self.tok_emb(input_ids) + if self.bigram is not None: + x = x + self.bigram(input_ids) + x = F.rms_norm(x, (x.size(-1),)) + x = self.smear(x) + x0 = x + v0 = None + skips: list[Tensor] = [] + for i in range(self.num_encoder_layers): + x, raw_v = self.blocks[i](x, x0, v0=v0) + if v0 is None and raw_v is not None: + v0 = raw_v + skips.append(x) + for i in range(self.num_decoder_layers): + if skips: + x = x + self.skip_weights[i].to(dtype=x.dtype)[None, None, :] * skips.pop() + x, _ = self.blocks[self.num_encoder_layers + i](x, x0, v0=v0) + x = self.final_norm(x) + if self.ttt_refiner is not None: + x = x + self.ttt_refiner(x) + if self.flow_refiner is not None: + x = x + self.flow_refiner(x) + if self.native_flow is not None: + correction, _ = self.native_flow(x) + x = x + correction + if self.tie_embeddings: + logits_proj = F.linear(x, self.tok_emb.weight) + else: + logits_proj = self.lm_head(x) + return self.logit_softcap * torch.tanh(logits_proj / self.logit_softcap) + + +# ----------------------------- +# SLIDING WINDOW EVALUATION +# ----------------------------- + +def eval_val_sliding( + args: Hyperparameters, + base_model: nn.Module, + rank: int, + world_size: int, + device: torch.device, + val_tokens: Tensor, + base_bytes_lut: Tensor, + has_leading_space_lut: Tensor, + is_boundary_token_lut: Tensor, + stride: int, + batch_seqs: int = 32, + eval_seq_len: int | None = None, +) -> tuple[float, float]: + """Sliding window evaluation: each token scored with maximum context. + Optionally uses entropy-gated 5-gram cache (NGRAM_CACHE=1).""" + seq_len = eval_seq_len or args.train_seq_len + total_tokens = val_tokens.numel() - 1 + + window_starts = [ws for ws in range(0, total_tokens, stride) + if min(ws + seq_len, total_tokens) - ws >= 1] + total_windows = len(window_starts) + + my_s = (total_windows * rank) // world_size + my_e = (total_windows * (rank + 1)) // world_size + my_windows = window_starts[my_s:my_e] + + loss_sum = torch.zeros((), device=device, dtype=torch.float64) + token_count = torch.zeros((), device=device, dtype=torch.float64) + byte_count = torch.zeros((), device=device, dtype=torch.float64) + + # N-gram eval cache with multi-order backoff + entropy-adaptive alpha (PR #702 inspired) + _ngram_default = "1" if world_size > 1 else "0" + use_ngram = bool(int(os.environ.get("NGRAM_CACHE", _ngram_default))) + ngram_alpha = float(os.environ.get("NGRAM_ALPHA", "0.40")) + ngram_min_count = int(os.environ.get("NGRAM_MIN_COUNT", "2")) + ngram_order = int(os.environ.get("NGRAM_ORDER", "7")) + ngram_min_order = int(os.environ.get("NGRAM_MIN_ORDER", "2")) + ngram_buckets = int(os.environ.get("NGRAM_BUCKETS", "4194304")) + ngram_entropy = bool(int(os.environ.get("NGRAM_ENTROPY", "1"))) + ngram_ent_base = float(os.environ.get("NGRAM_ENT_BASE", "0.05")) + ngram_ent_range = float(os.environ.get("NGRAM_ENT_RANGE", "0.55")) + ngram_ent_scale = float(os.environ.get("NGRAM_ENT_SCALE", "2.0")) + ngram_ent_thresh = float(os.environ.get("NGRAM_ENT_THRESH", "4.0")) + if use_ngram: + val_np = val_tokens.cpu().numpy() + _n_orders = ngram_order - ngram_min_order + 1 + ctx_tables = [np.zeros((ngram_buckets,), dtype=np.uint32) for _ in range(_n_orders)] + full_tables = [np.zeros((ngram_buckets,), dtype=np.uint32) for _ in range(_n_orders)] + ng_mask = np.uint64(ngram_buckets - 1) + ng_primes = np.array( + [np.uint64(36313), np.uint64(27191), np.uint64(51647), np.uint64(81929), + np.uint64(131071), np.uint64(175447), np.uint64(209591)], + dtype=np.uint64, + ) + print(f"ngram_cache:enabled orders={ngram_min_order}-{ngram_order} backoff " + f"entropy={ngram_entropy} alpha={ngram_alpha} " + f"ent_base={ngram_ent_base} ent_range={ngram_ent_range} " + f"min_count={ngram_min_count} buckets={ngram_buckets}", flush=True) + + base_model.eval() + compiled_logits = torch.compile(base_model.forward_logits, dynamic=False, fullgraph=True) + + with torch.inference_mode(): + for bi in range(0, len(my_windows), batch_seqs): + batch_ws = my_windows[bi:bi + batch_seqs] + bsz = len(batch_ws) + + x_batch = torch.zeros(bsz, seq_len, dtype=torch.int64, device=device) + y_batch = torch.zeros(bsz, seq_len, dtype=torch.int64, device=device) + wlens: list[int] = [] + + for i, ws in enumerate(batch_ws): + end = min(ws + seq_len, total_tokens) + wlen = end - ws + wlens.append(wlen) + chunk = val_tokens[ws:end + 1].to(dtype=torch.int64, device=device) + x_batch[i, :wlen] = chunk[:-1] + y_batch[i, :wlen] = chunk[1:] + + with torch.autocast(device_type="cuda", dtype=torch.bfloat16): + logits = compiled_logits(x_batch) + + nll = F.cross_entropy( + logits.reshape(-1, logits.size(-1)).float(), + y_batch.reshape(-1), + reduction="none", + ).reshape(bsz, seq_len) + + for i, ws in enumerate(batch_ws): + wlen = wlens[i] + s = 0 if ws == 0 else max(wlen - stride, 0) + seg_len = wlen - s + if seg_len <= 0: + continue + + scored_nll = nll[i, s:wlen].to(torch.float64) + + if use_ngram: + seg_nll_np = scored_nll.cpu().numpy() + seg_model_p = np.exp(-seg_nll_np) + n_seg = len(seg_nll_np) + global_j = np.arange(ws + s + 1, ws + wlen + 1, dtype=np.int64) + + # Entropy-adaptive alpha: compute from model logits (GPU) + if ngram_entropy: + with torch.no_grad(): + lp = F.log_softmax(logits[i, s:wlen].float(), dim=-1) + seg_ent = -(lp.exp() * lp).sum(dim=-1).cpu().numpy() + alpha_per_tok = ngram_ent_base + ngram_ent_range / ( + 1.0 + np.exp(-ngram_ent_scale * (seg_ent - ngram_ent_thresh))) + + # Precompute hashes for all orders + order_data = [] # (v_idx, ctx_key, full_key) per order + for oi in range(_n_orders): + ctx_w = ngram_min_order + oi - 1 + valid = global_j >= ctx_w + if not valid.any(): + order_data.append(None) + continue + v_idx = np.nonzero(valid)[0] + jv = global_j[v_idx] + ctx_hash = np.zeros(len(jv), dtype=np.uint64) + for k in range(ctx_w): + tok = val_np[jv - (ctx_w - k)].astype(np.uint64) + ctx_hash ^= tok * ng_primes[k % len(ng_primes)] + ctx_key = (ctx_hash & ng_mask).astype(np.int64) + tgt_np = val_np[jv].astype(np.uint64) + full_key = ((ctx_hash ^ (tgt_np * ng_primes[ctx_w % len(ng_primes)])) & ng_mask).astype(np.int64) + order_data.append((v_idx, ctx_key, full_key)) + + # Multi-order backoff: highest order first, fill unmatched with lower orders + best_p_ng = np.full(n_seg, -1.0) + for oi in range(_n_orders - 1, -1, -1): + if order_data[oi] is None: + continue + v_idx, ctx_key, full_key = order_data[oi] + ctx_counts = ctx_tables[oi][ctx_key].astype(np.float64) + full_counts = full_tables[oi][full_key].astype(np.float64) + has_match = ctx_counts >= float(ngram_min_count) + needs_fill = has_match & (best_p_ng[v_idx] < 0) + if needs_fill.any(): + fill_idx = v_idx[needs_fill] + p = np.minimum(full_counts[needs_fill], ctx_counts[needs_fill]) / np.maximum(ctx_counts[needs_fill], 1.0) + best_p_ng[fill_idx] = np.clip(p, 0.0, 1.0) + + # Mix model probability with n-gram + has_match = best_p_ng >= 0 + if has_match.any(): + if ngram_entropy: + alpha = alpha_per_tok[has_match] + else: + alpha = ngram_alpha + seg_model_p[has_match] = (1.0 - alpha) * seg_model_p[has_match] + alpha * best_p_ng[has_match] + seg_nll_np = -np.log(np.clip(seg_model_p, 1e-12, 1.0)) + + # Score-first: update ALL order tables AFTER scoring + for oi in range(_n_orders): + if order_data[oi] is None: + continue + v_idx, ctx_key, full_key = order_data[oi] + np.add.at(ctx_tables[oi], ctx_key, 1) + np.add.at(full_tables[oi], full_key, 1) + + scored_nll = torch.from_numpy(seg_nll_np).to(dtype=torch.float64, device=device) + + loss_sum += scored_nll.sum() + token_count += float(seg_len) + tgt = y_batch[i, s:wlen] + prev = x_batch[i, s:wlen] + tb = base_bytes_lut[tgt].to(torch.float64) + tb += (has_leading_space_lut[tgt] & ~is_boundary_token_lut[prev]).to(torch.float64) + byte_count += tb.sum() + + if dist.is_available() and dist.is_initialized(): + dist.all_reduce(loss_sum, op=dist.ReduceOp.SUM) + dist.all_reduce(token_count, op=dist.ReduceOp.SUM) + dist.all_reduce(byte_count, op=dist.ReduceOp.SUM) + + val_loss = (loss_sum / token_count).item() + bits_per_token = val_loss / math.log(2.0) + tokens_per_byte = token_count.item() / byte_count.item() + base_model.train() + return val_loss, bits_per_token * tokens_per_byte + + +# ----------------------------- +# TEST-TIME TRAINING (TTT) +# ----------------------------- + +def _clear_rotary_caches(model: nn.Module) -> None: + """Clear cached RoPE tensors to avoid 'Inference tensors cannot be saved for backward'.""" + for m in model.modules(): + if isinstance(m, Rotary): + m._cos_cached = None + m._sin_cached = None + m._seq_len_cached = 0 + + +def ttt_adapt(args: Hyperparameters, base_model: nn.Module, device: torch.device, + val_tokens: Tensor, rank: int = 0, world_size: int = 1, + log_fn=None) -> None: + """Score-first TTT: process val data in chunks, score each chunk first + (inference_mode), then train on scored tokens. Compliant with Issue #677.""" + seq_len = args.train_seq_len + total_tokens = val_tokens.numel() - 1 + chunk_tokens = args.ttt_chunk_tokens + batch_seqs = args.ttt_batch_seqs + + # Freeze early blocks + if args.ttt_freeze_blocks > 0: + for i, block in enumerate(base_model.blocks): + if i < args.ttt_freeze_blocks: + for p in block.parameters(): + p.requires_grad_(False) + + ttt_params = [p for p in base_model.parameters() if p.requires_grad] + if args.ttt_optimizer == "adamw": + optimizer = torch.optim.AdamW(ttt_params, lr=args.ttt_lr, weight_decay=0.0) + else: + optimizer = torch.optim.SGD(ttt_params, lr=args.ttt_lr, momentum=args.ttt_momentum) + + t0 = time.perf_counter() + chunk_idx = 0 + + for chunk_start in range(0, total_tokens - seq_len, chunk_tokens): + chunk_end = min(chunk_start + chunk_tokens, total_tokens) + chunk_len = chunk_end - chunk_start + n_seqs = chunk_len // seq_len + if n_seqs == 0: + break + + my_start = (n_seqs * rank) // world_size + my_end = (n_seqs * (rank + 1)) // world_size + if my_end <= my_start: + continue + + # Phase 1: Score chunk under inference_mode (forward only) + base_model.eval() + with torch.inference_mode(): + for si in range(my_start, my_end, batch_seqs): + se = min(si + batch_seqs, my_end) + raw_s = chunk_start + si * seq_len + raw_e = chunk_start + se * seq_len + 1 + local = val_tokens[raw_s:raw_e].to(device=device, dtype=torch.int64) + x = local[:-1].reshape(-1, seq_len) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16): + base_model.forward_logits(x) + + # Phase 2: Train on scored tokens (K epochs) + base_model.train() + for epoch in range(args.ttt_epochs): + for si in range(my_start, my_end, batch_seqs): + se = min(si + batch_seqs, my_end) + raw_s = chunk_start + si * seq_len + raw_e = chunk_start + se * seq_len + 1 + local = val_tokens[raw_s:raw_e].to(device=device, dtype=torch.int64) + x = local[:-1].reshape(-1, seq_len) + y = local[1:].reshape(-1, seq_len) + + optimizer.zero_grad(set_to_none=True) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16): + loss = base_model(x, y) + loss.backward() + + if world_size > 1: + for p in ttt_params: + if p.grad is not None: + dist.all_reduce(p.grad, op=dist.ReduceOp.AVG) + + torch.nn.utils.clip_grad_norm_(ttt_params, 1.0) + optimizer.step() + + chunk_idx += 1 + if log_fn and chunk_idx % 20 == 0: + log_fn(f"ttt:chunk={chunk_idx} elapsed={time.perf_counter()-t0:.1f}s") + + # Restore all params + for p in base_model.parameters(): + p.requires_grad_(True) + + if log_fn: + log_fn(f"ttt:done chunks={chunk_idx} elapsed={time.perf_counter()-t0:.1f}s") + + +def eval_val_sliding_ttt( + args: Hyperparameters, + base_model: nn.Module, + rank: int, + world_size: int, + device: torch.device, + val_tokens: Tensor, + base_bytes_lut: Tensor, + has_leading_space_lut: Tensor, + is_boundary_token_lut: Tensor, + stride: int, + batch_seqs: int = 32, + eval_seq_len: int | None = None, + log_fn=None, +) -> tuple[float, float]: + """Legal single-pass TTT: score each chunk with sliding windows, then train on it. + Tokens are always scored BEFORE any training on their chunk, so the evaluation + is never contaminated by future information.""" + seq_len = eval_seq_len or args.train_seq_len + total_tokens = val_tokens.numel() - 1 + ttt_chunk = args.ttt_chunk_tokens + + # Build window starts (same logic as eval_val_sliding) + window_starts = [ws for ws in range(0, total_tokens, stride) + if min(ws + seq_len, total_tokens) - ws >= stride or ws == 0] + + # Map each window to the chunk that contains its first scored token + num_chunks = (total_tokens + ttt_chunk - 1) // ttt_chunk + chunk_windows: list[list[int]] = [[] for _ in range(num_chunks)] + for ws in window_starts: + end = min(ws + seq_len, total_tokens) + wlen = end - ws + s = 0 if ws == 0 else max(wlen - stride, 0) + scored_start = ws + s + ci = min(scored_start // ttt_chunk, num_chunks - 1) + chunk_windows[ci].append(ws) + + if log_fn: + log_fn(f"ttt_sliding:start chunks={num_chunks} chunk_tokens={ttt_chunk} " + f"total_windows={len(window_starts)} stride={stride} " + f"ttt_lr={args.ttt_lr} ttt_epochs={args.ttt_epochs} " + f"ttt_optimizer={args.ttt_optimizer} freeze_blocks={args.ttt_freeze_blocks}") + + loss_sum = torch.zeros((), device=device, dtype=torch.float64) + token_count = torch.zeros((), device=device, dtype=torch.float64) + byte_count = torch.zeros((), device=device, dtype=torch.float64) + + # Freeze first N blocks + n_blocks = len(base_model.blocks) + frozen_block_ids = set(range(min(args.ttt_freeze_blocks, n_blocks))) + + ttt_params = [] + for name, p in base_model.named_parameters(): + freeze = False + for bi in frozen_block_ids: + if f"blocks.{bi}." in name: + freeze = True + break + if freeze: + p.requires_grad_(False) + else: + p.requires_grad_(True) + ttt_params.append(p) + + if log_fn: + log_fn(f"ttt_sliding:params unfrozen={sum(p.numel() for p in ttt_params)} " + f"frozen={sum(p.numel() for p in base_model.parameters() if not p.requires_grad)}") + + if args.ttt_optimizer == "adamw": + optimizer = torch.optim.AdamW(ttt_params, lr=args.ttt_lr, weight_decay=0.0) + else: + optimizer = torch.optim.SGD(ttt_params, lr=args.ttt_lr, momentum=args.ttt_momentum) + + t0 = time.perf_counter() + + for ci in range(num_chunks): + windows = chunk_windows[ci] + if not windows: + continue + + # --- Phase 1: SCORE this chunk's windows --- + my_s = (len(windows) * rank) // world_size + my_e = (len(windows) * (rank + 1)) // world_size + my_windows = windows[my_s:my_e] + + base_model.eval() + with torch.inference_mode(): + for bi in range(0, len(my_windows), batch_seqs): + batch_ws = my_windows[bi:bi + batch_seqs] + bsz = len(batch_ws) + x_batch = torch.zeros(bsz, seq_len, dtype=torch.int64, device=device) + y_batch = torch.zeros(bsz, seq_len, dtype=torch.int64, device=device) + wlens: list[int] = [] + for i, ws in enumerate(batch_ws): + end = min(ws + seq_len, total_tokens) + wlen = end - ws + wlens.append(wlen) + chunk_tok = val_tokens[ws:end + 1].to(dtype=torch.int64, device=device) + x_batch[i, :wlen] = chunk_tok[:-1] + y_batch[i, :wlen] = chunk_tok[1:] + with torch.autocast(device_type="cuda", dtype=torch.bfloat16): + logits = base_model.forward_logits(x_batch) + nll = F.cross_entropy( + logits.reshape(-1, logits.size(-1)).float(), + y_batch.reshape(-1), reduction="none", + ).reshape(bsz, seq_len) + for i, ws in enumerate(batch_ws): + wlen = wlens[i] + s = 0 if ws == 0 else max(wlen - stride, 0) + scored_nll = nll[i, s:wlen].to(torch.float64) + loss_sum += scored_nll.sum() + token_count += float(wlen - s) + tgt, prev = y_batch[i, s:wlen], x_batch[i, s:wlen] + tb = base_bytes_lut[tgt].to(torch.float64) + tb += (has_leading_space_lut[tgt] & ~is_boundary_token_lut[prev]).to(torch.float64) + byte_count += tb.sum() + + # --- Phase 2: TRAIN on this chunk's tokens (already scored = legal) --- + _clear_rotary_caches(base_model) + is_last_chunk = (ci == num_chunks - 1) + if not is_last_chunk and args.ttt_epochs > 0: + base_model.train() + chunk_start = ci * ttt_chunk + chunk_end = min((ci + 1) * ttt_chunk, total_tokens) + chunk_seqs = (chunk_end - chunk_start) // seq_len + if chunk_seqs > 0: + # Cosine LR schedule with 5% warmup + warmup_chunks = max(num_chunks // 20, 1) + if ci < warmup_chunks: + lr_scale = (ci + 1) / warmup_chunks + else: + progress = (ci - warmup_chunks) / max(num_chunks - 1 - warmup_chunks, 1) + lr_scale = 0.5 * (1.0 + math.cos(math.pi * progress)) + cos_lr = args.ttt_lr * lr_scale + for pg in optimizer.param_groups: + pg['lr'] = cos_lr + + my_seq_s = (chunk_seqs * rank) // world_size + my_seq_e = (chunk_seqs * (rank + 1)) // world_size + my_chunk_seqs = my_seq_e - my_seq_s + + for _ep in range(args.ttt_epochs): + for bs in range(0, my_chunk_seqs, args.ttt_batch_seqs): + be = min(bs + args.ttt_batch_seqs, my_chunk_seqs) + actual_bs = my_seq_s + bs + actual_be = my_seq_s + be + start_tok = chunk_start + actual_bs * seq_len + end_tok = chunk_start + actual_be * seq_len + 1 + if end_tok > val_tokens.numel(): + continue + local = val_tokens[start_tok:end_tok].to(device=device, dtype=torch.int64) + x = local[:-1].reshape(-1, seq_len) + y = local[1:].reshape(-1, seq_len) + optimizer.zero_grad(set_to_none=True) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16): + loss = base_model(x, y) + loss.backward() + if world_size > 1: + for p in ttt_params: + if p.grad is not None: + dist.all_reduce(p.grad, op=dist.ReduceOp.AVG) + torch.nn.utils.clip_grad_norm_(ttt_params, 1.0) + optimizer.step() + + if log_fn and (ci % 10 == 0 or ci == num_chunks - 1): + elapsed = time.perf_counter() - t0 + rl = loss_sum.item() / max(token_count.item(), 1) + rbpb = rl / math.log(2.0) * (token_count.item() / max(byte_count.item(), 1)) if token_count.item() > 0 else 0.0 + log_fn(f" ttt_chunk [{ci+1}/{num_chunks}] bpb={rbpb:.6f} time={elapsed:.1f}s") + + if dist.is_available() and dist.is_initialized(): + dist.all_reduce(loss_sum, op=dist.ReduceOp.SUM) + dist.all_reduce(token_count, op=dist.ReduceOp.SUM) + dist.all_reduce(byte_count, op=dist.ReduceOp.SUM) + + val_loss = (loss_sum / token_count).item() + val_bpb = val_loss / math.log(2.0) * (token_count.item() / byte_count.item()) + + # Restore all params and return to eval mode + for p in base_model.parameters(): + p.requires_grad_(True) + base_model.eval() + + if log_fn: + log_fn(f"ttt_sliding:done val_loss={val_loss:.6f} val_bpb={val_bpb:.6f} " + f"elapsed={time.perf_counter() - t0:.1f}s") + return val_loss, val_bpb + + +# ----------------------------- +# INT6 MIXED QUANTIZATION (transplanted from working diagnostic scripts) +# ----------------------------- + +def _classify_param(name: str) -> str: + if "tok_emb" in name or "lm_head" in name: + return "embed" + if ".mlp." in name: + return "mlp" + if ".attn." in name or (".proj." in name and ".mlp." not in name): + return "attn" + return "other" + +def quantize_int6_per_row(t: Tensor, qmax: int = 31) -> tuple[Tensor, Tensor]: + t32 = t.float() + qmin = -qmax - 1 + pct = CastedLinear._quant_percentile + if t32.ndim == 2: + row_max = (torch.quantile(t32.abs(), pct, dim=1) if pct < 1.0 + else t32.abs().amax(dim=1)) + scale = (row_max / float(qmax)).clamp_min(1.0 / float(qmax)).to(torch.float16) + clipped = t32.clamp(-row_max[:, None], row_max[:, None]) + q = torch.clamp(torch.round(clipped / scale.float()[:, None]), qmin, qmax).to(torch.int8) + return q, scale + amax = t32.abs().max().item() + scale = torch.tensor(amax / float(qmax) if amax > 0 else 1.0, dtype=torch.float16) + q = torch.clamp(torch.round(t32 / scale.float()), qmin, qmax).to(torch.int8) + return q, scale + +def mixed_quantize_int6(state_dict: dict[str, Tensor], int6_cats: set[str], + int5_layers: set[int] | None = None): + num_layers_total = max( + (int(k.split(".")[1]) for k in state_dict if k.startswith("blocks.")), + default=0, + ) + 1 + late_k_layers = set(range(num_layers_total - 2, num_layers_total)) + if int5_layers is None: + int5_layers = set() + + result: dict[str, Tensor] = {} + meta: dict[str, object] = {} + for name, tensor in state_dict.items(): + t = tensor.detach().cpu().contiguous() + cat = _classify_param(name) + if not t.is_floating_point() or t.numel() <= 65536: + result[name] = t.to(torch.float16) if t.is_floating_point() else t + meta[name] = "passthrough" + continue + if any(p in name for p in CONTROL_TENSOR_NAME_PATTERNS): + result[name] = t.float() + meta[name] = "passthrough_ctrl" + continue + # Determine layer index for int5 fallback + layer_idx = -1 + if name.startswith("blocks."): + try: + layer_idx = int(name.split(".")[1]) + except (IndexError, ValueError): + pass + # tok_emb.weight falls through to int8 via "embed" category + if cat in int6_cats and t.ndim >= 1: + qmax = 15 if layer_idx in int5_layers else 31 + q, s = quantize_int6_per_row(t, qmax=qmax) + result[name + ".q"] = q + result[name + ".scale"] = s + meta[name] = {"type": "int5" if qmax == 15 else "int6"} + else: + q, s = quantize_float_tensor(t) + result[name + ".q"] = q + result[name + ".scale"] = s + meta[name] = {"type": "int8"} + return result, meta + +def dequantize_mixed_int6(result: dict[str, Tensor], meta: dict[str, object], + template_sd: dict[str, Tensor]) -> dict[str, Tensor]: + out: dict[str, Tensor] = {} + for name, orig in template_sd.items(): + info = meta.get(name) + if info is None: + continue + orig_dtype = orig.dtype + if info in ("passthrough", "passthrough_ctrl", "passthrough_fp16"): + t = result[name] + if t.dtype == torch.float16 and orig_dtype in (torch.float32, torch.bfloat16): + t = t.to(orig_dtype) + out[name] = t + continue + q, s = result[name + ".q"], result[name + ".scale"] + if s.ndim > 0: + out[name] = (q.float() * s.float().view(q.shape[0], *([1] * (q.ndim - 1)))).to(orig_dtype) + else: + out[name] = (q.float() * float(s.item())).to(orig_dtype) + return out + + +# ----------------------------- +# TRAINING +# ----------------------------- + +def main() -> None: + global zeropower_via_newtonschulz5 + + code = Path(__file__).read_text(encoding="utf-8") + args = Hyperparameters() + zeropower_via_newtonschulz5 = torch.compile(zeropower_via_newtonschulz5) + + # ----------------------------- + # DISTRIBUTED + CUDA SETUP + # ----------------------------- + + distributed = "RANK" in os.environ and "WORLD_SIZE" in os.environ and int(os.environ.get("WORLD_SIZE", "1")) > 1 + rank = int(os.environ.get("RANK", "0")) + world_size = int(os.environ.get("WORLD_SIZE", "1")) + local_rank = int(os.environ.get("LOCAL_RANK", "0")) + if world_size <= 0: + raise ValueError(f"WORLD_SIZE must be positive, got {world_size}") + if 8 % world_size != 0: + raise ValueError(f"WORLD_SIZE={world_size} must divide 8 so grad_accum_steps stays integral") + grad_accum_steps = 8 // world_size + grad_scale = 1.0 / grad_accum_steps + if not torch.cuda.is_available(): + raise RuntimeError("CUDA is required") + device = torch.device("cuda", local_rank) + torch.cuda.set_device(device) + if distributed: + dist.init_process_group(backend="nccl", device_id=device) + dist.barrier() + master_process = rank == 0 + + # Fast math knobs + torch.backends.cuda.matmul.allow_tf32 = True + torch.backends.cudnn.allow_tf32 = True + from torch.backends.cuda import enable_cudnn_sdp, enable_flash_sdp, enable_math_sdp, enable_mem_efficient_sdp + + if _USE_FA3: + enable_cudnn_sdp(False) + enable_flash_sdp(True) + enable_mem_efficient_sdp(False) + enable_math_sdp(False) + else: + enable_cudnn_sdp(False) + enable_flash_sdp(True) + enable_mem_efficient_sdp(True) + enable_math_sdp(True) + + logfile = None + if master_process: + os.makedirs("logs", exist_ok=True) + logfile = f"logs/{args.run_id}.txt" + print(logfile) + + def log0(msg: str, console: bool = True) -> None: + if not master_process: + return + if console: + print(msg) + if logfile is not None: + with open(logfile, "a", encoding="utf-8") as f: + print(msg, file=f) + + log0(code, console=False) + log0("=" * 100, console=False) + log0(f"Running Python {sys.version}", console=False) + log0(f"Running PyTorch {torch.__version__}", console=False) + log0( + subprocess.run(["nvidia-smi"], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True, check=False).stdout, + console=False, + ) + log0("=" * 100, console=False) + + # ----------------------------- + # TOKENIZER + VALIDATION METRIC SETUP + # ----------------------------- + + random.seed(args.seed) + np.random.seed(args.seed) + torch.manual_seed(args.seed) + torch.cuda.manual_seed_all(args.seed) + + if not args.tokenizer_path.endswith(".model"): + raise ValueError(f"Script only setup for SentencePiece .model file: {args.tokenizer_path}") + sp = spm.SentencePieceProcessor(model_file=args.tokenizer_path) + if int(sp.vocab_size()) != args.vocab_size: + raise ValueError( + f"VOCAB_SIZE={args.vocab_size} does not match tokenizer vocab_size={int(sp.vocab_size())}" + ) + dataset_dir = Path(args.data_path).resolve() + actual_train_files = len(list(dataset_dir.glob("fineweb_train_*.bin"))) + effective_eval_seq_len = args.eval_seq_len if args.eval_seq_len > 0 else args.train_seq_len + val_seq_len = max(args.train_seq_len, effective_eval_seq_len) + val_tokens = load_validation_tokens(args.val_files, val_seq_len) + base_bytes_lut, has_leading_space_lut, is_boundary_token_lut = build_sentencepiece_luts( + sp, args.vocab_size, device + ) + log0(f"val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path={args.tokenizer_path}") + log0(f"train_loader:dataset:{dataset_dir.name} train_shards:{actual_train_files}") + log0(f"val_loader:shards pattern={args.val_files} tokens:{val_tokens.numel() - 1}") + + # ----------------------------- + # MODEL + OPTIMIZER SETUP + # ----------------------------- + + CastedLinear._qat_enabled = args.qat_enabled + + base_model = GPT( + vocab_size=args.vocab_size, + num_layers=args.num_layers, + model_dim=args.model_dim, + num_heads=args.num_heads, + num_kv_heads=args.num_kv_heads, + mlp_mult=args.mlp_mult, + tie_embeddings=args.tie_embeddings, + tied_embed_init_std=args.tied_embed_init_std, + logit_softcap=args.logit_softcap, + rope_base=args.rope_base, + qk_gain_init=args.qk_gain_init, + mtp_num_heads=args.mtp_num_heads, + mtp_loss_weight=args.mtp_loss_weight, + bigram_vocab_size=args.bigram_vocab_size, + bigram_dim=args.bigram_dim, + xsa_last_n=args.xsa_last_n, + rope_dims=args.rope_dims, + ln_scale=args.ln_scale, + value_residual=args.value_residual, + gated_attention=args.gated_attention, + canon_last_n=args.canon_last_n, + canon_kernel=args.canon_kernel, + canon_delta_gate_init=args.canon_delta_gate_init, + flow_enabled=args.flow_enabled, + flow_latent_dim=args.flow_latent_dim, + flow_hidden_dim=args.flow_hidden_dim, + flow_init_scale=args.flow_init_scale, + native_flow_enabled=args.native_flow_enabled, + native_flow_hidden_dim=args.native_flow_hidden_dim, + native_flow_init_scale=args.native_flow_init_scale, + native_flow_loss_weight=args.native_flow_loss_weight, + e2e_ttt_enabled=args.e2e_ttt_enabled, + e2e_ttt_num_heads=args.e2e_ttt_num_heads, + e2e_ttt_mini_batch=args.e2e_ttt_mini_batch, + e2e_ttt_base_lr=args.e2e_ttt_base_lr, + ).to(device).bfloat16() + for module in base_model.modules(): + if isinstance(module, CastedLinear): + module.float() + restore_low_dim_params_to_fp32(base_model) + compiled_model = torch.compile(base_model, dynamic=False, fullgraph=True) + model: nn.Module = DDP(compiled_model, device_ids=[local_rank], broadcast_buffers=False, static_graph=True) if distributed else compiled_model + + # Optimizer split: + # - token embedding (Adam) uses EMBED_LR + # - untied lm_head (Adam) uses HEAD_LR + # - matrix params in transformer blocks use MATRIX_LR via Muon + # - vectors/scalars use SCALAR_LR via Adam + block_named_params = list(base_model.blocks.named_parameters()) + matrix_params = [ + p + for name, p in block_named_params + if p.ndim == 2 and not any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) + ] + if base_model.mtp_num_heads > 0: + matrix_params.extend([p for p in base_model.mtp_heads.parameters() if p.ndim == 2]) + scalar_params = [ + p + for name, p in block_named_params + if p.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) + ] + if base_model.skip_weights.numel() > 0: + scalar_params.append(base_model.skip_weights) + scalar_params.append(base_model.smear.gate) + if base_model.bigram is not None: + scalar_params.append(base_model.bigram.scale) + if base_model.flow_refiner is not None: + scalar_params.extend(list(base_model.flow_refiner.parameters())) + if base_model.native_flow is not None: + scalar_params.extend(list(base_model.native_flow.parameters())) + if base_model.ttt_refiner is not None: + scalar_params.extend(list(base_model.ttt_refiner.parameters())) + token_lr = args.tied_embed_lr if args.tie_embeddings else args.embed_lr + tok_params = [{"params": [base_model.tok_emb.weight], "lr": token_lr, "base_lr": token_lr}] + if base_model.bigram is not None: + tok_params.append({"params": [base_model.bigram.embed.weight], "lr": token_lr, "base_lr": token_lr}) + if base_model.bigram.proj is not None: + matrix_params.append(base_model.bigram.proj.weight) + optimizer_tok = torch.optim.AdamW( + tok_params, + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + weight_decay=args.adam_wd, + fused=True, + ) + optimizer_muon = Muon( + matrix_params, + lr=args.matrix_lr, + momentum=args.muon_momentum, + backend_steps=args.muon_backend_steps, + weight_decay=args.muon_wd, + ) + for group in optimizer_muon.param_groups: + group["base_lr"] = args.matrix_lr + optimizer_scalar = torch.optim.AdamW( + [{"params": scalar_params, "lr": args.scalar_lr, "base_lr": args.scalar_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + weight_decay=args.adam_wd, + fused=True, + ) + optimizers: list[torch.optim.Optimizer] = [optimizer_tok, optimizer_muon, optimizer_scalar] + if base_model.lm_head is not None: + optimizer_head = torch.optim.Adam( + [{"params": [base_model.lm_head.weight], "lr": args.head_lr, "base_lr": args.head_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + fused=True, + ) + optimizers.insert(1, optimizer_head) + + n_params = sum(p.numel() for p in base_model.parameters()) + mtp_params = sum(p.numel() for p in base_model.mtp_heads.parameters()) + log0(f"model_params:{n_params}") + log0(f"mtp_num_heads:{args.mtp_num_heads} mtp_loss_weight:{args.mtp_loss_weight} mtp_params:{mtp_params}") + log0(f"world_size:{world_size} grad_accum_steps:{grad_accum_steps}") + log0(f"sdp_backends:fa3={_USE_FA3} cudnn=False flash=True mem_efficient={not _USE_FA3} math={not _USE_FA3}") + log0(f"attention_mode:gqa num_heads:{args.num_heads} num_kv_heads:{args.num_kv_heads}") + log0( + f"tie_embeddings:{args.tie_embeddings} embed_lr:{token_lr} " + f"head_lr:{args.head_lr if base_model.lm_head is not None else 0.0} " + f"matrix_lr:{args.matrix_lr} scalar_lr:{args.scalar_lr}" + ) + log0( + f"train_batch_tokens:{args.train_batch_tokens} train_seq_len:{args.train_seq_len} " + f"iterations:{args.iterations} warmup_steps:{args.warmup_steps} " + f"max_wallclock_seconds:{args.max_wallclock_seconds:.3f}" + ) + log0(f"seed:{args.seed}") + + # ----------------------------- + # DATA LOADER & MODEL WARMUP + # ----------------------------- + + train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) + + def zero_grad_all() -> None: + for opt in optimizers: + opt.zero_grad(set_to_none=True) + + max_wallclock_ms = 1000.0 * args.max_wallclock_seconds if args.max_wallclock_seconds > 0 else None + + def lr_mul(step: int, elapsed_ms: float) -> float: + if args.warmdown_iters <= 0: + return 1.0 + if max_wallclock_ms is None: + warmdown_start = max(args.iterations - args.warmdown_iters, 0) + return max((args.iterations - step) / max(args.warmdown_iters, 1), 0.0) if warmdown_start <= step < args.iterations else 1.0 + step_ms = elapsed_ms / max(step, 1) + warmdown_ms = args.warmdown_iters * step_ms + remaining_ms = max(max_wallclock_ms - elapsed_ms, 0.0) + return remaining_ms / max(warmdown_ms, 1e-9) if remaining_ms <= warmdown_ms else 1.0 + + # Warmup primes the compiled forward/backward/optimizer paths, then we restore the + # initial weights/optimizer state so measured training starts from the true init. + if args.warmup_steps > 0: + initial_model_state = {name: tensor.detach().cpu().clone() for name, tensor in base_model.state_dict().items()} + initial_optimizer_states = [copy.deepcopy(opt.state_dict()) for opt in optimizers] + model.train() + for warmup_step in range(args.warmup_steps): + zero_grad_all() + for micro_step in range(grad_accum_steps): + if distributed: + model.require_backward_grad_sync = micro_step == grad_accum_steps - 1 + x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + warmup_loss = model(x, y) + (warmup_loss * grad_scale).backward() + for opt in optimizers: + opt.step() + zero_grad_all() + if args.warmup_steps <= 20 or (warmup_step + 1) % 10 == 0 or warmup_step + 1 == args.warmup_steps: + log0(f"warmup_step:{warmup_step + 1}/{args.warmup_steps}") + base_model.load_state_dict(initial_model_state, strict=True) + for opt, state in zip(optimizers, initial_optimizer_states, strict=True): + opt.load_state_dict(state) + zero_grad_all() + if distributed: + model.require_backward_grad_sync = True + train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) + + # ----------------------------- + # MAIN TRAINING LOOP + # ----------------------------- + + swa_state: dict[str, Tensor] | None = None + swa_count = 0 + + ema_state: dict[str, Tensor] | None = None + if args.ema_enabled: + ema_state = {name: t.detach().float().clone() for name, t in base_model.state_dict().items()} + + training_time_ms = 0.0 + stop_after_step: int | None = None + torch.cuda.synchronize() + t0 = time.perf_counter() + + eval_only_path = os.environ.get("EVAL_ONLY", "") + if eval_only_path: + log0(f"eval_only: loading {eval_only_path}, skipping training") + base_model.load_state_dict(torch.load(eval_only_path, map_location=device, weights_only=False), strict=False) + ema_state = None # prevent random EMA from overwriting loaded weights + swa_state = None + swa_count = 0 + args.iterations = 0 # skip training, go straight to eval + + step = 0 + while True: + last_step = step == args.iterations or (stop_after_step is not None and step >= stop_after_step) + + should_validate = last_step or (args.val_loss_every > 0 and step % args.val_loss_every == 0) + if should_validate: + torch.cuda.synchronize() + training_time_ms += 1000.0 * (time.perf_counter() - t0) + val_loss, val_bpb = eval_val( + args, + model, + rank, + world_size, + device, + grad_accum_steps, + val_tokens, + base_bytes_lut, + has_leading_space_lut, + is_boundary_token_lut, + ) + log0( + f"step:{step}/{args.iterations} val_loss:{val_loss:.4f} val_bpb:{val_bpb:.4f} " + f"train_time:{training_time_ms:.0f}ms step_avg:{training_time_ms / max(step, 1):.2f}ms" + ) + torch.cuda.synchronize() + t0 = time.perf_counter() + + if ( + args.checkpoint_every > 0 + and step > 0 + and step % args.checkpoint_every == 0 + and not last_step + and master_process + ): + ckpt_sd = {k: v for k, v in base_model.state_dict().items() if "mtp_heads" not in k} + ckpt_path = f"checkpoint_step{step}_{args.run_id}.pt" + torch.save(ckpt_sd, ckpt_path) + log0(f"checkpoint_saved: {ckpt_path} ({os.path.getsize(ckpt_path)} bytes)") + + if last_step: + if stop_after_step is not None and step < args.iterations: + log0( + f"stopping_early: wallclock_cap train_time:{training_time_ms:.0f}ms " + f"step:{step}/{args.iterations}" + ) + break + + elapsed_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) + scale = lr_mul(step, elapsed_ms) + qat_threshold = float(os.environ.get("QAT_THRESHOLD", "0.1")) + if args.late_qat and scale < qat_threshold and not CastedLinear._qat_enabled: + CastedLinear._qat_enabled = True + CastedLinear._soft_round = args.soft_round_qat + log0(f"late_qat:enabled step:{step} scale:{scale:.4f} soft_round:{args.soft_round_qat}") + if CastedLinear._qat_enabled and CastedLinear._soft_round: + qat_progress = max(0.0, 1.0 - (scale / qat_threshold)) + CastedLinear._soft_round_alpha = 1.0 + 15.0 * qat_progress # 1→16 + zero_grad_all() + train_loss = torch.zeros((), device=device) + for micro_step in range(grad_accum_steps): + if distributed: + model.require_backward_grad_sync = micro_step == grad_accum_steps - 1 + x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + loss = model(x, y) + train_loss += loss.detach() + (loss * grad_scale).backward() + train_loss /= grad_accum_steps + + frac = min(step / args.muon_momentum_warmup_steps, 1.0) if args.muon_momentum_warmup_steps > 0 else 1.0 + muon_momentum = (1 - frac) * args.muon_momentum_warmup_start + frac * args.muon_momentum + for group in optimizer_muon.param_groups: + group["momentum"] = muon_momentum + + for opt in optimizers: + for group in opt.param_groups: + group["lr"] = group["base_lr"] * scale + + if args.grad_clip_norm > 0: + torch.nn.utils.clip_grad_norm_(base_model.parameters(), args.grad_clip_norm) + for opt in optimizers: + opt.step() + zero_grad_all() + + step += 1 + approx_training_time_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) + + if ema_state is not None: + d = args.ema_decay + with torch.no_grad(): + for name, t in base_model.state_dict().items(): + ema_state[name].mul_(d).add_(t.detach().float(), alpha=1.0 - d) + + if args.swa_enabled and scale < 0.5 and step % args.swa_every == 0: + # Tight SWA: collect from EMA state if available, else from raw model + src = ema_state if ema_state is not None else {name: t.detach().float() for name, t in base_model.state_dict().items()} + if swa_state is None: + swa_state = {name: t.clone() for name, t in src.items()} + swa_count = 1 + log0(f"swa:start step:{step} tight={ema_state is not None}") + else: + for name in swa_state: + swa_state[name].add_(src[name]) + swa_count += 1 + + should_log_train = ( + args.train_log_every > 0 + and (step <= 10 or step % args.train_log_every == 0 or stop_after_step is not None) + ) + if should_log_train: + log0( + f"step:{step}/{args.iterations} train_loss:{train_loss.item():.4f} " + f"train_time:{approx_training_time_ms:.0f}ms step_avg:{approx_training_time_ms / step:.2f}ms" + ) + + # Needed to sync whether we've reached the wallclock cap. + reached_cap = max_wallclock_ms is not None and approx_training_time_ms >= max_wallclock_ms + if distributed and max_wallclock_ms is not None: + reached_cap_tensor = torch.tensor(int(reached_cap), device=device) + dist.all_reduce(reached_cap_tensor, op=dist.ReduceOp.MAX) + reached_cap = bool(reached_cap_tensor.item()) + if stop_after_step is None and reached_cap: + stop_after_step = step + + log0( + f"peak memory allocated: {torch.cuda.max_memory_allocated() // 1024 // 1024} MiB " + f"reserved: {torch.cuda.max_memory_reserved() // 1024 // 1024} MiB" + ) + + if args.swa_enabled and swa_state is not None and swa_count > 1: + log0(f"swa:applying Tight SWA averaged {swa_count} EMA checkpoints") + avg_state = {name: (t / swa_count).to(dtype=base_model.state_dict()[name].dtype) + for name, t in swa_state.items()} + del swa_state + if ema_state is not None: + del ema_state + base_model.load_state_dict(avg_state, strict=True) + elif ema_state is not None: + log0("ema:applying EMA weights") + avg_state = {name: t.to(dtype=base_model.state_dict()[name].dtype) + for name, t in ema_state.items()} + del ema_state + base_model.load_state_dict(avg_state, strict=True) + + # ----------------------------- + # SERIALIZATION + ROUNDTRIP VALIDATION + # ----------------------------- + + # Use RUN_ID for unique filenames when multiple jobs share CWD + _run_id = os.environ.get("RUN_ID", "") + _save_prefix = f"final_model_{_run_id}" if _run_id else "final_model" + _pt_path = f"{_save_prefix}.pt" + _ptz_path = f"{_save_prefix}.int6.ptz" + log0(f"save_paths: pt={_pt_path} ptz={_ptz_path}") + + full_state_dict = base_model.state_dict() + export_sd = {k: v for k, v in full_state_dict.items() if "mtp_heads" not in k} + excluded_mtp = sum(int(t.numel()) for k, t in full_state_dict.items() if "mtp_heads" in k) + if excluded_mtp > 0: + log0(f"export_excluding_mtp_params:{excluded_mtp}") + + if master_process: + torch.save(export_sd, _pt_path) + model_bytes = os.path.getsize(_pt_path) + code_bytes = len(code.encode("utf-8")) + log0(f"Serialized model: {model_bytes} bytes") + log0(f"Code size: {code_bytes} bytes") + + sd_cpu = {k: v.detach().cpu() for k, v in export_sd.items()} + code_bytes = len(code.encode("utf-8")) + artifact_limit = 16_000_000 - code_bytes + + # --- Auto-downgrade quantization: try int6 first, fall back to int5 middle layers --- + num_layers_total = max( + (int(k.split(".")[1]) for k in sd_cpu if k.startswith("blocks.")), + default=0, + ) + 1 + _zstd_levels = [int(os.environ.get("ZSTD_LEVEL", "16")), 1, 17, 2] + # Phase 1: pure int6 with multiple zstd levels + quant_result, quant_meta = mixed_quantize_int6(sd_cpu, {"mlp", "attn"}) + quant_buf = io.BytesIO() + torch.save({"w": quant_result, "m": quant_meta}, quant_buf) + quant_raw = quant_buf.getvalue() + quant_blob = None + chosen_level = _zstd_levels[0] + for lvl in _zstd_levels: + blob = zstandard.ZstdCompressor(level=lvl).compress(quant_raw) if _COMPRESSOR == "zstd" else zlib.compress(quant_raw, 9) + if master_process: + log0(f"quant_try int6 zstd-{lvl}: {len(blob)} bytes (limit {artifact_limit})") + if len(blob) <= artifact_limit: + quant_blob = blob + chosen_level = lvl + break + # Phase 2: progressive int5 fallback — one layer at a time from middle outward + if quant_blob is None: + mid = num_layers_total // 2 + # Expand outward from center: L5, L4, L6, L3, L7, L2, L8, ... + candidates = [] + for offset in range(num_layers_total): + for sign in [0, 1]: + layer = mid + offset if sign == 0 else mid - offset + if 0 <= layer < num_layers_total and layer not in candidates: + candidates.append(layer) + int5_layers: set[int] = set() + for layer in candidates: + int5_layers.add(layer) + if master_process: + log0(f"quant_fallback: int5 layers={sorted(int5_layers)}") + quant_result, quant_meta = mixed_quantize_int6(sd_cpu, {"mlp", "attn"}, int5_layers=int5_layers) + quant_buf = io.BytesIO() + torch.save({"w": quant_result, "m": quant_meta}, quant_buf) + quant_raw = quant_buf.getvalue() + for lvl in _zstd_levels: + blob = zstandard.ZstdCompressor(level=lvl).compress(quant_raw) if _COMPRESSOR == "zstd" else zlib.compress(quant_raw, 9) + if master_process: + log0(f"quant_try int5[{len(int5_layers)}L] zstd-{lvl}: {len(blob)} bytes (limit {artifact_limit})") + if len(blob) <= artifact_limit: + quant_blob = blob + chosen_level = lvl + break + if quant_blob is not None: + break + if quant_blob is None: + quant_blob = blob # Use last attempt even if over limit + if master_process: + log0(f"WARNING: artifact still over limit after all fallbacks") + if master_process: + with open(_ptz_path, "wb") as f: + f.write(quant_blob) + quant_file_bytes = len(quant_blob) + log0(f"Serialized model quant+{_COMPRESSOR}-{chosen_level}: {quant_file_bytes} bytes") + log0(f"Total submission size: {quant_file_bytes + code_bytes} bytes") + + # Roundtrip: decompress + dequantize into fresh model + eval + if distributed: + dist.barrier() + with open(_ptz_path, "rb") as f: + quant_blob_disk = f.read() + quant_state = torch.load( + io.BytesIO(zstandard.ZstdDecompressor().decompress(quant_blob_disk) if _COMPRESSOR == "zstd" else zlib.decompress(quant_blob_disk)), + map_location="cpu", + ) + deq_state = dequantize_mixed_int6(quant_state["w"], quant_state["m"], sd_cpu) + + eval_model = GPT( + vocab_size=args.vocab_size, num_layers=args.num_layers, model_dim=args.model_dim, + num_heads=args.num_heads, num_kv_heads=args.num_kv_heads, mlp_mult=args.mlp_mult, + tie_embeddings=args.tie_embeddings, tied_embed_init_std=args.tied_embed_init_std, + logit_softcap=args.logit_softcap, rope_base=args.rope_base, qk_gain_init=args.qk_gain_init, + mtp_num_heads=0, mtp_loss_weight=0.0, + bigram_vocab_size=args.bigram_vocab_size, bigram_dim=args.bigram_dim, + xsa_last_n=args.xsa_last_n, + rope_dims=args.rope_dims, + ln_scale=args.ln_scale, + value_residual=args.value_residual, + gated_attention=args.gated_attention, + canon_last_n=args.canon_last_n, + canon_kernel=args.canon_kernel, + canon_delta_gate_init=args.canon_delta_gate_init, + flow_enabled=args.flow_enabled, + flow_latent_dim=args.flow_latent_dim, + flow_hidden_dim=args.flow_hidden_dim, + flow_init_scale=args.flow_init_scale, + native_flow_enabled=args.native_flow_enabled, + native_flow_hidden_dim=args.native_flow_hidden_dim, + native_flow_init_scale=args.native_flow_init_scale, + native_flow_loss_weight=args.native_flow_loss_weight, + e2e_ttt_enabled=args.e2e_ttt_enabled, + e2e_ttt_num_heads=args.e2e_ttt_num_heads, + e2e_ttt_mini_batch=args.e2e_ttt_mini_batch, + e2e_ttt_base_lr=args.e2e_ttt_base_lr, + ).to(device).bfloat16() + for m in eval_model.modules(): + if isinstance(m, CastedLinear): + m.float() + restore_low_dim_params_to_fp32(eval_model) + eval_model.load_state_dict(deq_state, strict=True) + + # TTT: adapt model on validation data before eval + legal_ttt = bool(int(os.environ.get("LEGAL_TTT", "0"))) + if args.ttt_enabled and not legal_ttt: + # --- Invalid two-pass TTT (adapt then eval separately) --- + if distributed: + dist.barrier() + for block in eval_model.blocks: + block.attn.rotary._cos_cached = None + block.attn.rotary._sin_cached = None + block.attn.rotary._seq_len_cached = 0 + log0(f"ttt:start score-first optimizer={args.ttt_optimizer} lr={args.ttt_lr} " + f"epochs={args.ttt_epochs} freeze_blocks={args.ttt_freeze_blocks} " + f"chunk_tokens={args.ttt_chunk_tokens}") + t_ttt = time.perf_counter() + ttt_adapt(args, eval_model, device, val_tokens, + rank=rank, world_size=world_size, log_fn=log0) + log0(f"ttt:elapsed={time.perf_counter() - t_ttt:.1f}s") + if distributed: + dist.barrier() + + compiled_eval = torch.compile(eval_model, dynamic=False, fullgraph=True) + + # Standard non-overlapping eval (sanity check) + torch.cuda.synchronize() + t_qeval = time.perf_counter() + q_val_loss, q_val_bpb = eval_val( + args, compiled_eval, rank, world_size, device, grad_accum_steps, + val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, + eval_seq_len=effective_eval_seq_len, + ) + torch.cuda.synchronize() + log0( + f"final_int6_roundtrip val_loss:{q_val_loss:.4f} val_bpb:{q_val_bpb:.4f} " + f"eval_time:{1000.0 * (time.perf_counter() - t_qeval):.0f}ms" + ) + log0(f"final_int6_roundtrip_exact val_loss:{q_val_loss:.8f} val_bpb:{q_val_bpb:.8f}") + + # Sliding window eval (submission score) + sw_seq_len = effective_eval_seq_len + if args.eval_stride > 0 and args.eval_stride < sw_seq_len: + torch.cuda.synchronize() + t_slide = time.perf_counter() + if legal_ttt and args.ttt_enabled: + # Legal single-pass TTT: score → train interleaved per chunk + log0(f"legal_ttt:start stride={args.eval_stride} " + f"optimizer={args.ttt_optimizer} lr={args.ttt_lr} " + f"epochs={args.ttt_epochs} freeze_blocks={args.ttt_freeze_blocks}") + sw_val_loss, sw_val_bpb = eval_val_sliding_ttt( + args, eval_model, rank, world_size, device, + val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, + stride=args.eval_stride, + eval_seq_len=sw_seq_len, + log_fn=log0, + ) + else: + sw_val_loss, sw_val_bpb = eval_val_sliding( + args, eval_model, rank, world_size, device, + val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, + stride=args.eval_stride, + eval_seq_len=sw_seq_len, + ) + torch.cuda.synchronize() + log0( + f"final_int6_sliding_window val_loss:{sw_val_loss:.4f} val_bpb:{sw_val_bpb:.4f} " + f"stride:{args.eval_stride} eval_time:{1000.0 * (time.perf_counter() - t_slide):.0f}ms" + ) + log0(f"final_int6_sliding_window_exact val_loss:{sw_val_loss:.8f} val_bpb:{sw_val_bpb:.8f}") + + # Second sliding window eval at stride=64 for submission comparison + if args.eval_stride != 64 and 64 < sw_seq_len: + torch.cuda.synchronize() + t_slide64 = time.perf_counter() + sw64_val_loss, sw64_val_bpb = eval_val_sliding( + args, eval_model, rank, world_size, device, + val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, + stride=64, + eval_seq_len=sw_seq_len, + ) + torch.cuda.synchronize() + log0( + f"final_int6_sliding_window_s64 val_loss:{sw64_val_loss:.4f} val_bpb:{sw64_val_bpb:.4f} " + f"stride:64 eval_time:{1000.0 * (time.perf_counter() - t_slide64):.0f}ms" + ) + log0(f"final_int6_sliding_window_s64_exact val_loss:{sw64_val_loss:.8f} val_bpb:{sw64_val_bpb:.8f}") + + if distributed: + dist.destroy_process_group() + + +if __name__ == "__main__": + main() + + +==================================================================================================== +Running Python 3.12.7 | packaged by conda-forge | (main, Oct 4 2024, 16:05:46) [GCC 13.3.0] +Running PyTorch 2.10.0+cu128 +Tue Mar 31 14:20:14 2026 ++-----------------------------------------------------------------------------------------+ +| NVIDIA-SMI 570.172.08 Driver Version: 570.172.08 CUDA Version: 12.8 | +|-----------------------------------------+------------------------+----------------------+ +| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC | +| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. | +| | | MIG M. | +|=========================================+========================+======================| +| 0 NVIDIA A100-PCIE-40GB On | 00000000:E3:00.0 Off | 0 | +| N/A 31C P0 46W / 250W | 423MiB / 40960MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ + ++-----------------------------------------------------------------------------------------+ +| Processes: | +| GPU GI CI PID Type Process name GPU Memory | +| ID ID Usage | +|=========================================================================================| +| 0 N/A N/A 251475 C ...ameter_golf/.venv/bin/python3 414MiB | ++-----------------------------------------------------------------------------------------+ + +==================================================================================================== +val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path=/hpfs/scratch/gpfs/mcclec07/code/parameter_golf/repo/data/tokenizers/fineweb_1024_bpe.model +train_loader:dataset:fineweb10B_sp1024 train_shards:80 +val_loader:shards pattern=/hpfs/scratch/gpfs/mcclec07/code/parameter_golf/repo/data/datasets/fineweb10B_sp1024/fineweb_val_*.bin tokens:62021632 +model_params:28319753 +mtp_num_heads:0 mtp_loss_weight:0.2 mtp_params:0 +world_size:1 grad_accum_steps:8 +sdp_backends:fa3=False cudnn=False flash=True mem_efficient=True math=True +attention_mode:gqa num_heads:8 num_kv_heads:4 +tie_embeddings:True embed_lr:0.035 head_lr:0.0 matrix_lr:0.025 scalar_lr:0.025 +train_batch_tokens:786432 train_seq_len:2048 iterations:7000 warmup_steps:20 max_wallclock_seconds:600.000 +seed:42 +warmup_step:1/20 +warmup_step:2/20 +warmup_step:3/20 +warmup_step:4/20 +warmup_step:5/20 +warmup_step:6/20 +warmup_step:7/20 +warmup_step:8/20 +warmup_step:9/20 +warmup_step:10/20 +warmup_step:11/20 +warmup_step:12/20 +warmup_step:13/20 +warmup_step:14/20 +warmup_step:15/20 +warmup_step:16/20 +warmup_step:17/20 +warmup_step:18/20 +warmup_step:19/20 +warmup_step:20/20 +eval_only: loading /hpfs/scratch/gpfs/mcclec07/code/parameter_golf/runs/e2ettt_flow_55363689/models/final_model_pr940_e2ettt_flow_55363689.pt, skipping training +step:0/0 val_loss:1.9237 val_bpb:1.1393 train_time:60ms step_avg:59.90ms +peak memory allocated: 27459 MiB reserved: 28022 MiB +save_paths: pt=final_model_eval_e2ettt_flow7k_legal_ttt_55398555.pt ptz=final_model_eval_e2ettt_flow7k_legal_ttt_55398555.int6.ptz +Serialized model: 108880403 bytes +Code size: 115032 bytes +quant_try int6 zstd-16: 16854717 bytes (limit 15884968) +quant_try int6 zstd-1: 16914613 bytes (limit 15884968) +quant_try int6 zstd-17: 16856988 bytes (limit 15884968) +quant_try int6 zstd-2: 16923144 bytes (limit 15884968) +quant_fallback: int5 layers=[5] +quant_try int5[1L] zstd-16: 16553148 bytes (limit 15884968) +quant_try int5[1L] zstd-1: 16617668 bytes (limit 15884968) +quant_try int5[1L] zstd-17: 16551154 bytes (limit 15884968) +quant_try int5[1L] zstd-2: 16675594 bytes (limit 15884968) +quant_fallback: int5 layers=[5, 6] +quant_try int5[2L] zstd-16: 16244171 bytes (limit 15884968) +quant_try int5[2L] zstd-1: 16320340 bytes (limit 15884968) +quant_try int5[2L] zstd-17: 16261223 bytes (limit 15884968) +quant_try int5[2L] zstd-2: 16427030 bytes (limit 15884968) +quant_fallback: int5 layers=[4, 5, 6] +quant_try int5[3L] zstd-16: 16007263 bytes (limit 15884968) +quant_try int5[3L] zstd-1: 16024461 bytes (limit 15884968) +quant_try int5[3L] zstd-17: 15970049 bytes (limit 15884968) +quant_try int5[3L] zstd-2: 16181460 bytes (limit 15884968) +quant_fallback: int5 layers=[4, 5, 6, 7] +quant_try int5[4L] zstd-16: 15643338 bytes (limit 15884968) +Serialized model quant+zstd-16: 15643338 bytes +Total submission size: 15758370 bytes +final_int6_roundtrip val_loss:1.9458 val_bpb:1.1524 eval_time:243753ms +final_int6_roundtrip_exact val_loss:1.94578102 val_bpb:1.15240112 +legal_ttt:start stride=64 optimizer=sgd lr=0.002 epochs=10 freeze_blocks=2 +ttt_sliding:start chunks=1893 chunk_tokens=32768 total_windows=969088 stride=64 ttt_lr=0.002 ttt_epochs=10 ttt_optimizer=sgd freeze_blocks=2 +ttt_sliding:params unfrozen=23588837 frozen=4730916 + ttt_chunk [1/1893] bpb=1.219061 time=7.4s + ttt_chunk [11/1893] bpb=1.121041 time=81.1s + ttt_chunk [21/1893] bpb=1.130171 time=154.6s + ttt_chunk [31/1893] bpb=1.135459 time=228.3s + ttt_chunk [41/1893] bpb=1.132009 time=301.9s + ttt_chunk [51/1893] bpb=1.133333 time=375.5s + ttt_chunk [61/1893] bpb=1.137245 time=449.1s + ttt_chunk [71/1893] bpb=1.135520 time=522.6s + ttt_chunk [81/1893] bpb=1.132070 time=596.2s + ttt_chunk [91/1893] bpb=1.130891 time=669.8s + ttt_chunk [101/1893] bpb=1.131886 time=743.4s + ttt_chunk [111/1893] bpb=1.132075 time=817.1s + ttt_chunk [121/1893] bpb=1.128446 time=890.7s + ttt_chunk [131/1893] bpb=1.127652 time=964.3s + ttt_chunk [141/1893] bpb=1.126530 time=1037.8s + ttt_chunk [151/1893] bpb=1.126665 time=1111.3s + ttt_chunk [161/1893] bpb=1.127536 time=1184.8s + ttt_chunk [171/1893] bpb=1.129614 time=1258.3s + ttt_chunk [181/1893] bpb=1.129614 time=1331.8s + ttt_chunk [191/1893] bpb=1.132011 time=1405.3s + ttt_chunk [201/1893] bpb=1.131538 time=1478.9s + ttt_chunk [211/1893] bpb=1.130578 time=1552.5s + ttt_chunk [221/1893] bpb=1.131476 time=1626.0s + ttt_chunk [231/1893] bpb=1.131182 time=1699.5s + ttt_chunk [241/1893] bpb=1.131387 time=1773.1s + ttt_chunk [251/1893] bpb=1.130904 time=1846.7s + ttt_chunk [261/1893] bpb=1.130158 time=1920.2s + ttt_chunk [271/1893] bpb=1.129247 time=1993.8s + ttt_chunk [281/1893] bpb=1.130734 time=2067.3s + ttt_chunk [291/1893] bpb=1.130323 time=2143.1s + ttt_chunk [301/1893] bpb=1.131146 time=2216.6s + ttt_chunk [311/1893] bpb=1.131186 time=2290.2s + ttt_chunk [321/1893] bpb=1.131881 time=2363.7s + ttt_chunk [331/1893] bpb=1.131325 time=2437.1s + ttt_chunk [341/1893] bpb=1.130881 time=2510.6s + ttt_chunk [351/1893] bpb=1.131568 time=2584.1s + ttt_chunk [361/1893] bpb=1.132313 time=2657.6s + ttt_chunk [371/1893] bpb=1.132206 time=2731.0s + ttt_chunk [381/1893] bpb=1.131959 time=2804.5s + ttt_chunk [391/1893] bpb=1.132643 time=2878.2s + ttt_chunk [401/1893] bpb=1.132161 time=2951.7s + ttt_chunk [411/1893] bpb=1.131172 time=3025.3s + ttt_chunk [421/1893] bpb=1.131241 time=3098.9s + ttt_chunk [431/1893] bpb=1.131660 time=3172.5s + ttt_chunk [441/1893] bpb=1.131044 time=3246.1s + ttt_chunk [451/1893] bpb=1.131169 time=3319.7s + ttt_chunk [461/1893] bpb=1.131035 time=3393.2s + ttt_chunk [471/1893] bpb=1.130581 time=3466.8s + ttt_chunk [481/1893] bpb=1.130389 time=3540.4s + ttt_chunk [491/1893] bpb=1.130548 time=3613.9s + ttt_chunk [501/1893] bpb=1.130289 time=3687.5s + ttt_chunk [511/1893] bpb=1.129785 time=3763.2s + ttt_chunk [521/1893] bpb=1.129319 time=3836.9s + ttt_chunk [531/1893] bpb=1.130023 time=3910.6s + ttt_chunk [541/1893] bpb=1.130105 time=3984.2s + ttt_chunk [551/1893] bpb=1.129558 time=4057.7s + ttt_chunk [561/1893] bpb=1.129405 time=4131.3s + ttt_chunk [571/1893] bpb=1.129114 time=4204.9s + ttt_chunk [581/1893] bpb=1.128711 time=4278.7s + ttt_chunk [591/1893] bpb=1.128134 time=4352.3s + ttt_chunk [601/1893] bpb=1.128124 time=4425.9s + ttt_chunk [611/1893] bpb=1.127783 time=4499.5s + ttt_chunk [621/1893] bpb=1.127623 time=4573.2s + ttt_chunk [631/1893] bpb=1.127364 time=4646.7s + ttt_chunk [641/1893] bpb=1.126916 time=4720.3s + ttt_chunk [651/1893] bpb=1.126452 time=4793.9s + ttt_chunk [661/1893] bpb=1.126336 time=4867.5s + ttt_chunk [671/1893] bpb=1.125841 time=4941.1s + ttt_chunk [681/1893] bpb=1.125271 time=5014.6s + ttt_chunk [691/1893] bpb=1.125343 time=5088.3s + ttt_chunk [701/1893] bpb=1.124503 time=5161.9s + ttt_chunk [711/1893] bpb=1.124505 time=5235.4s + ttt_chunk [721/1893] bpb=1.124414 time=5309.1s + ttt_chunk [731/1893] bpb=1.124656 time=5382.8s + ttt_chunk [741/1893] bpb=1.124531 time=5456.4s + ttt_chunk [751/1893] bpb=1.124232 time=5530.0s + ttt_chunk [761/1893] bpb=1.124362 time=5603.6s + ttt_chunk [771/1893] bpb=1.124191 time=5677.4s + ttt_chunk [781/1893] bpb=1.124358 time=5751.1s + ttt_chunk [791/1893] bpb=1.124216 time=5824.7s + ttt_chunk [801/1893] bpb=1.124150 time=5898.2s + ttt_chunk [811/1893] bpb=1.124155 time=5971.6s + ttt_chunk [821/1893] bpb=1.124055 time=6045.2s + ttt_chunk [831/1893] bpb=1.123781 time=6118.7s + ttt_chunk [841/1893] bpb=1.123550 time=6192.2s + ttt_chunk [851/1893] bpb=1.123615 time=6265.7s + ttt_chunk [861/1893] bpb=1.123683 time=6339.1s + ttt_chunk [871/1893] bpb=1.123887 time=6412.7s + ttt_chunk [881/1893] bpb=1.123889 time=6486.1s + ttt_chunk [891/1893] bpb=1.123371 time=6559.7s + ttt_chunk [901/1893] bpb=1.123393 time=6633.2s + ttt_chunk [911/1893] bpb=1.123236 time=6706.7s + ttt_chunk [921/1893] bpb=1.123369 time=6780.2s + ttt_chunk [931/1893] bpb=1.123323 time=6853.7s + ttt_chunk [941/1893] bpb=1.123528 time=6927.2s + ttt_chunk [951/1893] bpb=1.123827 time=7000.7s + ttt_chunk [961/1893] bpb=1.124116 time=7074.1s + ttt_chunk [971/1893] bpb=1.124467 time=7147.7s + ttt_chunk [981/1893] bpb=1.124670 time=7221.3s + ttt_chunk [991/1893] bpb=1.124569 time=7294.8s + ttt_chunk [1001/1893] bpb=1.124882 time=7368.2s + ttt_chunk [1011/1893] bpb=1.125013 time=7441.7s + ttt_chunk [1021/1893] bpb=1.125297 time=7515.1s + ttt_chunk [1031/1893] bpb=1.125670 time=7588.6s + ttt_chunk [1041/1893] bpb=1.126180 time=7662.1s + ttt_chunk [1051/1893] bpb=1.126036 time=7735.5s + ttt_chunk [1061/1893] bpb=1.126128 time=7808.9s + ttt_chunk [1071/1893] bpb=1.126279 time=7882.4s + ttt_chunk [1081/1893] bpb=1.126319 time=7955.8s + ttt_chunk [1091/1893] bpb=1.126572 time=8029.5s + ttt_chunk [1101/1893] bpb=1.126705 time=8103.0s + ttt_chunk [1111/1893] bpb=1.126448 time=8177.0s + ttt_chunk [1121/1893] bpb=1.126214 time=8252.8s + ttt_chunk [1131/1893] bpb=1.126099 time=8326.2s + ttt_chunk [1141/1893] bpb=1.125855 time=8399.7s + ttt_chunk [1151/1893] bpb=1.125869 time=8473.1s + ttt_chunk [1161/1893] bpb=1.125655 time=8546.7s + ttt_chunk [1171/1893] bpb=1.125480 time=8620.5s + ttt_chunk [1181/1893] bpb=1.125249 time=8694.0s + ttt_chunk [1191/1893] bpb=1.125396 time=8767.6s + ttt_chunk [1201/1893] bpb=1.125594 time=8841.1s + ttt_chunk [1211/1893] bpb=1.125191 time=8914.7s + ttt_chunk [1221/1893] bpb=1.125516 time=8988.3s + ttt_chunk [1231/1893] bpb=1.125440 time=9061.8s + ttt_chunk [1241/1893] bpb=1.125135 time=9135.4s + ttt_chunk [1251/1893] bpb=1.124600 time=9209.0s + ttt_chunk [1261/1893] bpb=1.124336 time=9282.7s + ttt_chunk [1271/1893] bpb=1.124088 time=9356.4s + ttt_chunk [1281/1893] bpb=1.123772 time=9430.0s + ttt_chunk [1291/1893] bpb=1.123521 time=9503.6s + ttt_chunk [1301/1893] bpb=1.123471 time=9577.1s + ttt_chunk [1311/1893] bpb=1.123188 time=9650.7s + ttt_chunk [1321/1893] bpb=1.122893 time=9724.3s + ttt_chunk [1331/1893] bpb=1.122653 time=9797.9s + ttt_chunk [1341/1893] bpb=1.122520 time=9871.5s + ttt_chunk [1351/1893] bpb=1.122369 time=9945.1s + ttt_chunk [1361/1893] bpb=1.122500 time=10018.8s + ttt_chunk [1371/1893] bpb=1.122711 time=10092.4s + ttt_chunk [1381/1893] bpb=1.122921 time=10166.0s + ttt_chunk [1391/1893] bpb=1.122713 time=10239.6s + ttt_chunk [1401/1893] bpb=1.122753 time=10313.1s + ttt_chunk [1411/1893] bpb=1.122866 time=10386.7s + ttt_chunk [1421/1893] bpb=1.122861 time=10460.3s + ttt_chunk [1431/1893] bpb=1.122838 time=10533.8s + ttt_chunk [1441/1893] bpb=1.123318 time=10609.6s + ttt_chunk [1451/1893] bpb=1.123189 time=10683.1s + ttt_chunk [1461/1893] bpb=1.123121 time=10756.6s + ttt_chunk [1471/1893] bpb=1.123725 time=10830.2s + ttt_chunk [1481/1893] bpb=1.123601 time=10903.7s + ttt_chunk [1491/1893] bpb=1.123968 time=10977.2s + ttt_chunk [1501/1893] bpb=1.123946 time=11050.6s + ttt_chunk [1511/1893] bpb=1.123898 time=11124.1s + ttt_chunk [1521/1893] bpb=1.124014 time=11197.5s + ttt_chunk [1531/1893] bpb=1.124227 time=11271.0s + ttt_chunk [1541/1893] bpb=1.124295 time=11344.4s + ttt_chunk [1551/1893] bpb=1.124539 time=11418.0s + ttt_chunk [1561/1893] bpb=1.124622 time=11491.4s + ttt_chunk [1571/1893] bpb=1.124763 time=11564.7s + ttt_chunk [1581/1893] bpb=1.124918 time=11638.2s + ttt_chunk [1591/1893] bpb=1.124975 time=11711.5s + ttt_chunk [1601/1893] bpb=1.125091 time=11784.7s + ttt_chunk [1611/1893] bpb=1.125350 time=11857.8s + ttt_chunk [1621/1893] bpb=1.125216 time=11930.9s + ttt_chunk [1631/1893] bpb=1.125256 time=12004.1s + ttt_chunk [1641/1893] bpb=1.125275 time=12077.3s + ttt_chunk [1651/1893] bpb=1.125325 time=12152.4s + ttt_chunk [1661/1893] bpb=1.125470 time=12225.4s + ttt_chunk [1671/1893] bpb=1.125654 time=12298.4s + ttt_chunk [1681/1893] bpb=1.125744 time=12371.5s + ttt_chunk [1691/1893] bpb=1.125846 time=12444.4s + ttt_chunk [1701/1893] bpb=1.125941 time=12517.5s + ttt_chunk [1711/1893] bpb=1.125922 time=12590.5s + ttt_chunk [1721/1893] bpb=1.125758 time=12663.5s + ttt_chunk [1731/1893] bpb=1.125852 time=12736.4s + ttt_chunk [1741/1893] bpb=1.125590 time=12809.6s + ttt_chunk [1751/1893] bpb=1.125467 time=12882.6s + ttt_chunk [1761/1893] bpb=1.125505 time=12955.5s + ttt_chunk [1771/1893] bpb=1.125448 time=13028.4s + ttt_chunk [1781/1893] bpb=1.125347 time=13101.3s + ttt_chunk [1791/1893] bpb=1.125006 time=13174.2s + ttt_chunk [1801/1893] bpb=1.124982 time=13247.1s + ttt_chunk [1811/1893] bpb=1.124828 time=13320.0s + ttt_chunk [1821/1893] bpb=1.124885 time=13392.9s + ttt_chunk [1831/1893] bpb=1.124734 time=13465.8s + ttt_chunk [1841/1893] bpb=1.124742 time=13538.7s + ttt_chunk [1851/1893] bpb=1.124572 time=13611.7s + ttt_chunk [1861/1893] bpb=1.124487 time=13684.7s + ttt_chunk [1871/1893] bpb=1.124422 time=13757.7s + ttt_chunk [1881/1893] bpb=1.124172 time=13830.7s + ttt_chunk [1891/1893] bpb=1.124152 time=13903.7s + ttt_chunk [1893/1893] bpb=1.124183 time=13913.3s +ttt_sliding:done val_loss=1.898130 val_bpb=1.124183 elapsed=13913.3s +final_int6_sliding_window val_loss:1.8981 val_bpb:1.1242 stride:64 eval_time:13913796ms +final_int6_sliding_window_exact val_loss:1.89813008 val_bpb:1.12418252 diff --git a/records/track_non_record_16mb/2026-03-31_11L_NativeFlowMatcher_LegalTTT/supplementary/eval_nflow7k_legal_ttt.log b/records/track_non_record_16mb/2026-03-31_11L_NativeFlowMatcher_LegalTTT/supplementary/eval_nflow7k_legal_ttt.log new file mode 100644 index 0000000000..99bd3203cf --- /dev/null +++ b/records/track_non_record_16mb/2026-03-31_11L_NativeFlowMatcher_LegalTTT/supplementary/eval_nflow7k_legal_ttt.log @@ -0,0 +1,2847 @@ +logs/eval_nflow7k_legal_ttt_55375245.txt +val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path=/hpfs/scratch/gpfs/mcclec07/code/parameter_golf/repo/data/tokenizers/fineweb_1024_bpe.model +train_loader:dataset:fineweb10B_sp1024 train_shards:80 +val_loader:shards pattern=/hpfs/scratch/gpfs/mcclec07/code/parameter_golf/repo/data/datasets/fineweb10B_sp1024/fineweb_val_*.bin tokens:62021632 +model_params:27530952 +mtp_num_heads:0 mtp_loss_weight:0.2 mtp_params:0 +world_size:1 grad_accum_steps:8 +sdp_backends:fa3=False cudnn=False flash=True mem_efficient=True math=True +attention_mode:gqa num_heads:8 num_kv_heads:4 +tie_embeddings:True embed_lr:0.035 head_lr:0.0 matrix_lr:0.025 scalar_lr:0.025 +train_batch_tokens:786432 train_seq_len:2048 iterations:7000 warmup_steps:20 max_wallclock_seconds:600.000 +seed:42 +warmup_step:1/20 +warmup_step:2/20 +warmup_step:3/20 +warmup_step:4/20 +warmup_step:5/20 +warmup_step:6/20 +warmup_step:7/20 +warmup_step:8/20 +warmup_step:9/20 +warmup_step:10/20 +warmup_step:11/20 +warmup_step:12/20 +warmup_step:13/20 +warmup_step:14/20 +warmup_step:15/20 +warmup_step:16/20 +warmup_step:17/20 +warmup_step:18/20 +warmup_step:19/20 +warmup_step:20/20 +eval_only: loading /hpfs/scratch/gpfs/mcclec07/code/parameter_golf/runs/nflow_55342820/models/final_model_pr940_nflow_55342820.pt, skipping training +step:0/0 val_loss:1.9197 val_bpb:1.1370 train_time:54ms step_avg:54.40ms +peak memory allocated: 25725 MiB reserved: 25972 MiB +save_paths: pt=final_model_eval_nflow7k_legal_ttt_55375245.pt ptz=final_model_eval_nflow7k_legal_ttt_55375245.int6.ptz +Serialized model: 107295853 bytes +Code size: 115032 bytes +quant_try int6 zstd-16: 16242245 bytes (limit 15884968) +quant_try int6 zstd-1: 16298729 bytes (limit 15884968) +quant_try int6 zstd-17: 16250339 bytes (limit 15884968) +quant_try int6 zstd-2: 16301974 bytes (limit 15884968) +quant_fallback: int5 layers=[5] +quant_try int5[1L] zstd-16: 15927696 bytes (limit 15884968) +quant_try int5[1L] zstd-1: 16001920 bytes (limit 15884968) +quant_try int5[1L] zstd-17: 16297960 bytes (limit 15884968) +quant_try int5[1L] zstd-2: 16053283 bytes (limit 15884968) +quant_fallback: int5 layers=[5, 6] +quant_try int5[2L] zstd-16: 15630744 bytes (limit 15884968) +Serialized model quant+zstd-16: 15630744 bytes +Total submission size: 15745776 bytes +final_int6_roundtrip val_loss:1.9363 val_bpb:1.1468 eval_time:67588ms +final_int6_roundtrip_exact val_loss:1.93630746 val_bpb:1.14679034 +legal_ttt:start stride=64 optimizer=sgd lr=0.002 epochs=10 freeze_blocks=2 +ttt_sliding:start chunks=1893 chunk_tokens=32768 total_windows=969088 stride=64 ttt_lr=0.002 ttt_epochs=10 ttt_optimizer=sgd freeze_blocks=2 +ttt_sliding:params unfrozen=22800036 frozen=4730916 + ttt_chunk [1/1893] bpb=1.212107 time=3.8s + ttt_chunk [11/1893] bpb=1.115790 time=41.7s + ttt_chunk [21/1893] bpb=1.125357 time=79.6s + ttt_chunk [31/1893] bpb=1.130512 time=117.6s + ttt_chunk [41/1893] bpb=1.127394 time=155.6s + ttt_chunk [51/1893] bpb=1.128903 time=193.6s + ttt_chunk [61/1893] bpb=1.132827 time=231.7s + ttt_chunk [71/1893] bpb=1.131231 time=269.7s + ttt_chunk [81/1893] bpb=1.127681 time=307.8s + ttt_chunk [91/1893] bpb=1.126507 time=345.8s + ttt_chunk [101/1893] bpb=1.127414 time=383.9s + ttt_chunk [111/1893] bpb=1.127682 time=421.9s + ttt_chunk [121/1893] bpb=1.124073 time=459.9s + ttt_chunk [131/1893] bpb=1.123226 time=498.0s + ttt_chunk [141/1893] bpb=1.122055 time=536.0s + ttt_chunk [151/1893] bpb=1.122195 time=574.0s + ttt_chunk [161/1893] bpb=1.123067 time=612.0s + ttt_chunk [171/1893] bpb=1.125089 time=649.9s + ttt_chunk [181/1893] bpb=1.125155 time=687.9s + ttt_chunk [191/1893] bpb=1.127505 time=725.9s + ttt_chunk [201/1893] bpb=1.126992 time=763.9s + ttt_chunk [211/1893] bpb=1.126053 time=801.8s + ttt_chunk [221/1893] bpb=1.126909 time=839.8s + ttt_chunk [231/1893] bpb=1.126593 time=877.8s + ttt_chunk [241/1893] bpb=1.126801 time=915.7s + ttt_chunk [251/1893] bpb=1.126322 time=953.7s + ttt_chunk [261/1893] bpb=1.125593 time=991.7s + ttt_chunk [271/1893] bpb=1.124720 time=1029.6s + ttt_chunk [281/1893] bpb=1.126223 time=1067.6s + ttt_chunk [291/1893] bpb=1.125781 time=1105.6s + ttt_chunk [301/1893] bpb=1.126625 time=1143.5s + ttt_chunk [311/1893] bpb=1.126647 time=1181.5s + ttt_chunk [321/1893] bpb=1.127307 time=1219.5s + ttt_chunk [331/1893] bpb=1.126728 time=1257.5s + ttt_chunk [341/1893] bpb=1.126294 time=1295.4s + ttt_chunk [351/1893] bpb=1.126974 time=1333.4s + ttt_chunk [361/1893] bpb=1.127742 time=1371.4s + ttt_chunk [371/1893] bpb=1.127643 time=1409.4s + ttt_chunk [381/1893] bpb=1.127394 time=1447.3s + ttt_chunk [391/1893] bpb=1.128075 time=1485.3s + ttt_chunk [401/1893] bpb=1.127618 time=1523.3s + ttt_chunk [411/1893] bpb=1.126636 time=1561.3s + ttt_chunk [421/1893] bpb=1.126693 time=1599.3s + ttt_chunk [431/1893] bpb=1.127118 time=1637.2s + ttt_chunk [441/1893] bpb=1.126511 time=1675.2s + ttt_chunk [451/1893] bpb=1.126643 time=1713.2s + ttt_chunk [461/1893] bpb=1.126522 time=1751.2s + ttt_chunk [471/1893] bpb=1.126096 time=1789.2s + ttt_chunk [481/1893] bpb=1.125905 time=1827.3s + ttt_chunk [491/1893] bpb=1.126069 time=1865.3s + ttt_chunk [501/1893] bpb=1.125822 time=1903.3s + ttt_chunk [511/1893] bpb=1.125318 time=1941.4s + ttt_chunk [521/1893] bpb=1.124852 time=1979.4s + ttt_chunk [531/1893] bpb=1.125548 time=2017.4s + ttt_chunk [541/1893] bpb=1.125645 time=2055.4s + ttt_chunk [551/1893] bpb=1.125098 time=2093.4s + ttt_chunk [561/1893] bpb=1.124953 time=2131.4s + ttt_chunk [571/1893] bpb=1.124686 time=2169.5s + ttt_chunk [581/1893] bpb=1.124306 time=2207.5s + ttt_chunk [591/1893] bpb=1.123730 time=2245.5s + ttt_chunk [601/1893] bpb=1.123722 time=2283.5s + ttt_chunk [611/1893] bpb=1.123387 time=2321.6s + ttt_chunk [621/1893] bpb=1.123232 time=2359.6s + ttt_chunk [631/1893] bpb=1.122966 time=2397.6s + ttt_chunk [641/1893] bpb=1.122515 time=2435.7s + ttt_chunk [651/1893] bpb=1.122053 time=2473.7s + ttt_chunk [661/1893] bpb=1.121945 time=2511.7s + ttt_chunk [671/1893] bpb=1.121464 time=2549.7s + ttt_chunk [681/1893] bpb=1.120886 time=2587.7s + ttt_chunk [691/1893] bpb=1.120968 time=2625.7s + ttt_chunk [701/1893] bpb=1.120141 time=2663.7s + ttt_chunk [711/1893] bpb=1.120150 time=2701.6s + ttt_chunk [721/1893] bpb=1.120057 time=2739.6s + ttt_chunk [731/1893] bpb=1.120288 time=2777.6s + ttt_chunk [741/1893] bpb=1.120158 time=2815.6s + ttt_chunk [751/1893] bpb=1.119858 time=2853.6s + ttt_chunk [761/1893] bpb=1.119991 time=2891.6s + ttt_chunk [771/1893] bpb=1.119821 time=2929.6s + ttt_chunk [781/1893] bpb=1.119996 time=2967.5s + ttt_chunk [791/1893] bpb=1.119860 time=3005.5s + ttt_chunk [801/1893] bpb=1.119803 time=3043.5s + ttt_chunk [811/1893] bpb=1.119818 time=3081.5s + ttt_chunk [821/1893] bpb=1.119709 time=3119.5s + ttt_chunk [831/1893] bpb=1.119425 time=3157.5s + ttt_chunk [841/1893] bpb=1.119191 time=3195.5s + ttt_chunk [851/1893] bpb=1.119263 time=3233.5s + ttt_chunk [861/1893] bpb=1.119325 time=3271.6s + ttt_chunk [871/1893] bpb=1.119532 time=3309.6s + ttt_chunk [881/1893] bpb=1.119529 time=3347.6s + ttt_chunk [891/1893] bpb=1.119002 time=3385.6s + ttt_chunk [901/1893] bpb=1.119023 time=3423.6s + ttt_chunk [911/1893] bpb=1.118886 time=3461.6s + ttt_chunk [921/1893] bpb=1.119020 time=3499.6s + ttt_chunk [931/1893] bpb=1.118987 time=3537.6s + ttt_chunk [941/1893] bpb=1.119191 time=3575.6s + ttt_chunk [951/1893] bpb=1.119481 time=3613.6s + ttt_chunk [961/1893] bpb=1.119771 time=3651.5s + ttt_chunk [971/1893] bpb=1.120129 time=3689.5s + ttt_chunk [981/1893] bpb=1.120331 time=3727.5s + ttt_chunk [991/1893] bpb=1.120230 time=3765.5s + ttt_chunk [1001/1893] bpb=1.120542 time=3803.5s + ttt_chunk [1011/1893] bpb=1.120676 time=3841.5s + ttt_chunk [1021/1893] bpb=1.120959 time=3879.5s + ttt_chunk [1031/1893] bpb=1.121332 time=3917.4s + ttt_chunk [1041/1893] bpb=1.121840 time=3955.4s + ttt_chunk [1051/1893] bpb=1.121693 time=3993.4s + ttt_chunk [1061/1893] bpb=1.121785 time=4031.4s + ttt_chunk [1071/1893] bpb=1.121934 time=4069.4s + ttt_chunk [1081/1893] bpb=1.121980 time=4107.4s + ttt_chunk [1091/1893] bpb=1.122239 time=4145.3s + ttt_chunk [1101/1893] bpb=1.122380 time=4183.3s + ttt_chunk [1111/1893] bpb=1.122116 time=4221.3s + ttt_chunk [1121/1893] bpb=1.121879 time=4259.3s + ttt_chunk [1131/1893] bpb=1.121775 time=4297.3s + ttt_chunk [1141/1893] bpb=1.121530 time=4335.2s + ttt_chunk [1151/1893] bpb=1.121538 time=4373.2s + ttt_chunk [1161/1893] bpb=1.121313 time=4411.2s + ttt_chunk [1171/1893] bpb=1.121138 time=4449.2s + ttt_chunk [1181/1893] bpb=1.120908 time=4487.2s + ttt_chunk [1191/1893] bpb=1.121048 time=4525.1s + ttt_chunk [1201/1893] bpb=1.121252 time=4563.1s + ttt_chunk [1211/1893] bpb=1.120843 time=4601.1s + ttt_chunk [1221/1893] bpb=1.121177 time=4639.1s + ttt_chunk [1231/1893] bpb=1.121102 time=4677.1s + ttt_chunk [1241/1893] bpb=1.120801 time=4715.0s + ttt_chunk [1251/1893] bpb=1.120267 time=4753.1s + ttt_chunk [1261/1893] bpb=1.119999 time=4791.1s + ttt_chunk [1271/1893] bpb=1.119752 time=4829.1s + ttt_chunk [1281/1893] bpb=1.119439 time=4867.1s + ttt_chunk [1291/1893] bpb=1.119197 time=4905.1s + ttt_chunk [1301/1893] bpb=1.119156 time=4943.1s + ttt_chunk [1311/1893] bpb=1.118872 time=4981.1s + ttt_chunk [1321/1893] bpb=1.118578 time=5019.2s + ttt_chunk [1331/1893] bpb=1.118339 time=5057.2s + ttt_chunk [1341/1893] bpb=1.118212 time=5095.2s + ttt_chunk [1351/1893] bpb=1.118063 time=5133.2s + ttt_chunk [1361/1893] bpb=1.118196 time=5171.2s + ttt_chunk [1371/1893] bpb=1.118409 time=5209.2s + ttt_chunk [1381/1893] bpb=1.118608 time=5247.2s + ttt_chunk [1391/1893] bpb=1.118407 time=5285.2s + ttt_chunk [1401/1893] bpb=1.118446 time=5323.2s + ttt_chunk [1411/1893] bpb=1.118557 time=5361.2s + ttt_chunk [1421/1893] bpb=1.118548 time=5399.2s + ttt_chunk [1431/1893] bpb=1.118528 time=5437.2s + ttt_chunk [1441/1893] bpb=1.119007 time=5475.2s + ttt_chunk [1451/1893] bpb=1.118878 time=5513.1s + ttt_chunk [1461/1893] bpb=1.118811 time=5551.2s + ttt_chunk [1471/1893] bpb=1.119408 time=5589.1s + ttt_chunk [1481/1893] bpb=1.119277 time=5627.1s + ttt_chunk [1491/1893] bpb=1.119642 time=5665.1s + ttt_chunk [1501/1893] bpb=1.119625 time=5703.1s + ttt_chunk [1511/1893] bpb=1.119578 time=5741.1s + ttt_chunk [1521/1893] bpb=1.119692 time=5779.1s + ttt_chunk [1531/1893] bpb=1.119911 time=5817.1s + ttt_chunk [1541/1893] bpb=1.119986 time=5855.0s + ttt_chunk [1551/1893] bpb=1.120227 time=5893.0s + ttt_chunk [1561/1893] bpb=1.120314 time=5931.0s + ttt_chunk [1571/1893] bpb=1.120455 time=5969.0s + ttt_chunk [1581/1893] bpb=1.120609 time=6006.9s + ttt_chunk [1591/1893] bpb=1.120669 time=6044.9s + ttt_chunk [1601/1893] bpb=1.120794 time=6082.9s + ttt_chunk [1611/1893] bpb=1.121048 time=6120.9s + ttt_chunk [1621/1893] bpb=1.120917 time=6158.9s + ttt_chunk [1631/1893] bpb=1.120960 time=6196.8s + ttt_chunk [1641/1893] bpb=1.120983 time=6234.8s + ttt_chunk [1651/1893] bpb=1.121033 time=6272.8s + ttt_chunk [1661/1893] bpb=1.121175 time=6310.8s + ttt_chunk [1671/1893] bpb=1.121358 time=6348.8s + ttt_chunk [1681/1893] bpb=1.121447 time=6386.8s + ttt_chunk [1691/1893] bpb=1.121552 time=6424.8s + ttt_chunk [1701/1893] bpb=1.121657 time=6462.8s + ttt_chunk [1711/1893] bpb=1.121636 time=6500.8s + ttt_chunk [1721/1893] bpb=1.121475 time=6538.7s + ttt_chunk [1731/1893] bpb=1.121572 time=6576.7s + ttt_chunk [1741/1893] bpb=1.121311 time=6614.7s + ttt_chunk [1751/1893] bpb=1.121184 time=6652.7s + ttt_chunk [1761/1893] bpb=1.121225 time=6690.7s + ttt_chunk [1771/1893] bpb=1.121172 time=6728.6s + ttt_chunk [1781/1893] bpb=1.121072 time=6766.6s + ttt_chunk [1791/1893] bpb=1.120730 time=6804.6s + ttt_chunk [1801/1893] bpb=1.120711 time=6842.6s + ttt_chunk [1811/1893] bpb=1.120557 time=6880.5s + ttt_chunk [1821/1893] bpb=1.120615 time=6918.5s + ttt_chunk [1831/1893] bpb=1.120464 time=6956.5s + ttt_chunk [1841/1893] bpb=1.120475 time=6994.4s + ttt_chunk [1851/1893] bpb=1.120303 time=7032.4s + ttt_chunk [1861/1893] bpb=1.120214 time=7070.4s + ttt_chunk [1871/1893] bpb=1.120145 time=7108.4s + ttt_chunk [1881/1893] bpb=1.119902 time=7146.3s + ttt_chunk [1891/1893] bpb=1.119878 time=7184.3s + ttt_chunk [1893/1893] bpb=1.119907 time=7189.7s +ttt_sliding:done val_loss=1.890910 val_bpb=1.119907 elapsed=7189.8s +final_int6_sliding_window val_loss:1.8909 val_bpb:1.1199 stride:64 eval_time:7190260ms +final_int6_sliding_window_exact val_loss:1.89091021 val_bpb:1.11990650 +ttern: {pattern}") + # The export pipeline writes the fixed first-50k-doc validation set to fineweb_val_*. + tokens = torch.cat([load_data_shard(file) for file in files]).contiguous() + usable = ((tokens.numel() - 1) // seq_len) * seq_len + if usable <= 0: + raise ValueError(f"Validation split is too short for TRAIN_SEQ_LEN={seq_len}") + return tokens[: usable + 1] + + +def eval_val( + args: Hyperparameters, + model: nn.Module, + rank: int, + world_size: int, + device: torch.device, + grad_accum_steps: int, + val_tokens: Tensor, + base_bytes_lut: Tensor, + has_leading_space_lut: Tensor, + is_boundary_token_lut: Tensor, + eval_seq_len: int | None = None, +) -> tuple[float, float]: + seq_len = eval_seq_len or args.train_seq_len + local_batch_tokens = args.val_batch_size // (world_size * grad_accum_steps) + if local_batch_tokens < seq_len: + raise ValueError( + "VAL_BATCH_SIZE must provide at least one sequence per rank; " + f"got VAL_BATCH_SIZE={args.val_batch_size}, WORLD_SIZE={world_size}, " + f"GRAD_ACCUM_STEPS={grad_accum_steps}, seq_len={seq_len}" + ) + local_batch_seqs = local_batch_tokens // seq_len + total_seqs = (val_tokens.numel() - 1) // seq_len + seq_start = (total_seqs * rank) // world_size + seq_end = (total_seqs * (rank + 1)) // world_size + val_loss_sum = torch.zeros((), device=device, dtype=torch.float64) + val_token_count = torch.zeros((), device=device, dtype=torch.float64) + val_byte_count = torch.zeros((), device=device, dtype=torch.float64) + + model.eval() + with torch.inference_mode(): + for batch_seq_start in range(seq_start, seq_end, local_batch_seqs): + batch_seq_end = min(batch_seq_start + local_batch_seqs, seq_end) + raw_start = batch_seq_start * seq_len + raw_end = batch_seq_end * seq_len + 1 + local = val_tokens[raw_start:raw_end].to(device=device, dtype=torch.int64, non_blocking=True) + x = local[:-1].reshape(-1, seq_len) + y = local[1:].reshape(-1, seq_len) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + batch_loss = model(x, y).detach() + batch_token_count = float(y.numel()) + val_loss_sum += batch_loss.to(torch.float64) * batch_token_count + val_token_count += batch_token_count + prev_ids = x.reshape(-1) + tgt_ids = y.reshape(-1) + token_bytes = base_bytes_lut[tgt_ids].to(dtype=torch.int16) + token_bytes += (has_leading_space_lut[tgt_ids] & ~is_boundary_token_lut[prev_ids]).to(dtype=torch.int16) + val_byte_count += token_bytes.to(torch.float64).sum() + + if dist.is_available() and dist.is_initialized(): + dist.all_reduce(val_loss_sum, op=dist.ReduceOp.SUM) + dist.all_reduce(val_token_count, op=dist.ReduceOp.SUM) + dist.all_reduce(val_byte_count, op=dist.ReduceOp.SUM) + + val_loss = val_loss_sum / val_token_count + bits_per_token = val_loss.item() / math.log(2.0) + tokens_per_byte = val_token_count.item() / val_byte_count.item() + model.train() + return float(val_loss.item()), float(bits_per_token * tokens_per_byte) + +# ----------------------------- +# POST-TRAINING QUANTIZATION +# ----------------------------- +# +# It's silly to export our model, which is trained in bf16 and fp32, at that same precision. +# Instead, we get approximately the same model (with a small hit) by quantizing the model to int8 & zlib compressing. +# We can then decompress the model and run in higher precision for evaluation, after closing in under the size limit. + +CONTROL_TENSOR_NAME_PATTERNS = tuple( + pattern + for pattern in os.environ.get( + "CONTROL_TENSOR_NAME_PATTERNS", + "attn_scale,attn_scales,mlp_scale,mlp_scales,resid_mix,resid_mixes,q_gain,skip_weight,skip_weights,smear,vr_lambda,attn_gate,canon_a,canon_c,delta_gate", + ).split(",") + if pattern +) +INT8_KEEP_FLOAT_FP32_NAME_PATTERNS = tuple( + pattern + for pattern in os.environ.get( + "INT8_KEEP_FLOAT_FP32_NAME_PATTERNS", + ",".join(CONTROL_TENSOR_NAME_PATTERNS), + ).split(",") + if pattern +) +INT8_KEEP_FLOAT_MAX_NUMEL = 65_536 +INT8_KEEP_FLOAT_STORE_DTYPE = torch.float16 +INT8_PER_ROW_SCALE_DTYPE = torch.float16 +INT8_CLIP_PERCENTILE = 99.99984 +INT8_CLIP_Q = INT8_CLIP_PERCENTILE / 100.0 + +def tensor_nbytes(t: Tensor) -> int: + return int(t.numel()) * int(t.element_size()) + +def keep_float_tensor(name: str, t: Tensor, passthrough_orig_dtypes: dict[str, str]) -> Tensor: + if any(pattern in name for pattern in INT8_KEEP_FLOAT_FP32_NAME_PATTERNS): + return t.float().contiguous() + if t.dtype in {torch.float32, torch.bfloat16}: + passthrough_orig_dtypes[name] = str(t.dtype).removeprefix("torch.") + return t.to(dtype=INT8_KEEP_FLOAT_STORE_DTYPE).contiguous() + return t + +def quantize_float_tensor(t: Tensor, qmax: int = 127) -> tuple[Tensor, Tensor]: + """Quantize to [-qmax, qmax] range. Default int8 (qmax=127), int6 (qmax=31), int5 (qmax=15).""" + t32 = t.float() + qmin = -qmax + if t32.ndim == 2: + clip_abs = ( + torch.quantile(t32.abs(), INT8_CLIP_Q, dim=1) + if t32.numel() + else torch.empty((t32.shape[0],), dtype=torch.float32) + ) + clipped = torch.maximum(torch.minimum(t32, clip_abs[:, None]), -clip_abs[:, None]) + scale = (clip_abs / float(qmax)).clamp_min(1.0 / float(qmax)) + q = torch.clamp(torch.round(clipped / scale[:, None]), qmin, qmax).to(torch.int8).contiguous() + return q, scale.to(dtype=INT8_PER_ROW_SCALE_DTYPE).contiguous() + + clip_abs = float(torch.quantile(t32.abs().flatten(), INT8_CLIP_Q).item()) if t32.numel() else 0.0 + scale = torch.tensor(clip_abs / float(qmax) if clip_abs > 0 else 1.0, dtype=torch.float32) + q = torch.clamp(torch.round(torch.clamp(t32, -clip_abs, clip_abs) / scale), qmin, qmax).to(torch.int8).contiguous() + return q, scale + +def quantize_state_dict_int8(state_dict: dict[str, Tensor]): + # Single supported clean-script export format: + # - per-row int8 for 2D float tensors + # - per-tensor int8 for other float tensors + # - exact passthrough for non-floats + # - passthrough for small float tensors, stored as fp16 to save bytes + quantized: dict[str, Tensor] = {} + scales: dict[str, Tensor] = {} + dtypes: dict[str, str] = {} + passthrough: dict[str, Tensor] = {} + passthrough_orig_dtypes: dict[str, str] = {} + qmeta: dict[str, dict[str, object]] = {} + stats = dict.fromkeys( + ("param_count", "num_tensors", "num_float_tensors", "num_nonfloat_tensors", "baseline_tensor_bytes", "int8_payload_bytes"), + 0, + ) + + for name, tensor in state_dict.items(): + t = tensor.detach().to("cpu").contiguous() + stats["param_count"] += int(t.numel()) + stats["num_tensors"] += 1 + stats["baseline_tensor_bytes"] += tensor_nbytes(t) + + if not t.is_floating_point(): + stats["num_nonfloat_tensors"] += 1 + passthrough[name] = t + stats["int8_payload_bytes"] += tensor_nbytes(t) + continue + + # Small float tensors are cheap enough to keep directly. We still downcast + # fp32/bf16 passthrough tensors to fp16 so metadata does not dominate size. + if t.numel() <= INT8_KEEP_FLOAT_MAX_NUMEL: + kept = keep_float_tensor(name, t, passthrough_orig_dtypes) + passthrough[name] = kept + stats["int8_payload_bytes"] += tensor_nbytes(kept) + continue + + stats["num_float_tensors"] += 1 + # Mixed quantization: int6 for MLP layers 3-7 to save artifact space + int6_mlp_layers = os.environ.get("INT6_MLP_LAYERS", "") + qmax = 127 # default int8 + if int6_mlp_layers: + for li in int6_mlp_layers.split(","): + if li.strip() and f"blocks.{li.strip()}.mlp" in name and t.ndim == 2: + qmax = 31 # int6 + break + q, s = quantize_float_tensor(t, qmax=qmax) + if s.ndim > 0: + qmeta[name] = {"scheme": "per_row", "axis": 0} + quantized[name] = q + scales[name] = s + dtypes[name] = str(t.dtype).removeprefix("torch.") + stats["int8_payload_bytes"] += tensor_nbytes(q) + tensor_nbytes(s) + + obj: dict[str, object] = { + "__quant_format__": "int8_clean_per_row_v1", + "quantized": quantized, + "scales": scales, + "dtypes": dtypes, + "passthrough": passthrough, + } + if qmeta: + obj["qmeta"] = qmeta + if passthrough_orig_dtypes: + obj["passthrough_orig_dtypes"] = passthrough_orig_dtypes + return obj, stats + +def dequantize_state_dict_int8(obj: dict[str, object]) -> dict[str, Tensor]: + out: dict[str, Tensor] = {} + qmeta = obj.get("qmeta", {}) + passthrough_orig_dtypes = obj.get("passthrough_orig_dtypes", {}) + for name, q in obj["quantized"].items(): + dtype = getattr(torch, obj["dtypes"][name]) + s = obj["scales"][name] + if qmeta.get(name, {}).get("scheme") == "per_row" or s.ndim > 0: + s = s.to(dtype=torch.float32) + # Broadcast the saved row scale back across trailing dimensions. + out[name] = (q.float() * s.view(q.shape[0], *([1] * (q.ndim - 1)))).to(dtype=dtype).contiguous() + else: + scale = float(s.item()) + out[name] = (q.float() * scale).to(dtype=dtype).contiguous() + for name, t in obj["passthrough"].items(): + # Restore small tensors, undoing the temporary fp16 storage cast if needed. + out_t = t.detach().to("cpu").contiguous() + orig_dtype = passthrough_orig_dtypes.get(name) + if isinstance(orig_dtype, str): + out_t = out_t.to(dtype=getattr(torch, orig_dtype)).contiguous() + out[name] = out_t + return out + + +# ----------------------------- +# DATA LOADING +# ----------------------------- + +def load_data_shard(file: Path) -> Tensor: + header_bytes = 256 * np.dtype(" None: + self.file_idx = (self.file_idx + 1) % len(self.files) + self.tokens = load_data_shard(self.files[self.file_idx]) + self.pos = 0 + + def take(self, n: int) -> Tensor: + chunks: list[Tensor] = [] + remaining = n + while remaining > 0: + avail = self.tokens.numel() - self.pos + if avail <= 0: + self._advance_file() + continue + k = min(remaining, avail) + chunks.append(self.tokens[self.pos : self.pos + k]) + self.pos += k + remaining -= k + return chunks[0] if len(chunks) == 1 else torch.cat(chunks) + + +class DistributedTokenLoader: + # Each call consumes a contiguous chunk from the shared token stream, then slices out + # one disjoint span per rank. The extra "+1" token lets us build (x, y) by shifting. + def __init__(self, pattern: str, rank: int, world_size: int, device: torch.device): + self.rank = rank + self.world_size = world_size + self.device = device + self.stream = TokenStream(pattern) + + def next_batch(self, global_tokens: int, seq_len: int, grad_accum_steps: int) -> tuple[Tensor, Tensor]: + local_tokens = global_tokens // (self.world_size * grad_accum_steps) + per_rank_span = local_tokens + 1 + chunk = self.stream.take(per_rank_span * self.world_size) + start = self.rank * per_rank_span + local = chunk[start : start + per_rank_span].to(dtype=torch.int64) + x = local[:-1].reshape(-1, seq_len) + y = local[1:].reshape(-1, seq_len) + return x.to(self.device, non_blocking=True), y.to(self.device, non_blocking=True) + +# ----------------------------- +# TRANSFORMER MODULES +# ----------------------------- + +class RMSNorm(nn.Module): + def __init__(self, eps: float | None = None): + super().__init__() + self.eps = eps + + def forward(self, x: Tensor) -> Tensor: + return F.rms_norm(x, (x.size(-1),), eps=self.eps) + + +class CastedLinear(nn.Linear): + _qat_enabled: bool = False + _soft_round: bool = False + _soft_round_alpha: float = 1.0 + _quant_percentile: float = float(os.environ.get("QUANT_PERCENTILE", "1.0")) + + def forward(self, x: Tensor) -> Tensor: + w = self.weight.to(x.dtype) + if CastedLinear._qat_enabled and self.training and w.ndim == 2: + w32 = self.weight.float() + pct = CastedLinear._quant_percentile + row_max = (torch.quantile(w32.abs(), pct, dim=1) if pct < 1.0 + else w32.abs().amax(dim=1)).detach() + scale = (row_max / 31.0).clamp_min(1.0 / 31.0) + r = w32 / scale[:, None] + if CastedLinear._soft_round: + alpha = CastedLinear._soft_round_alpha + r_frac = r - r.detach().floor() - 0.5 + norm = torch.tanh(torch.tensor(alpha * 0.5, device=r.device, dtype=r.dtype)) + r_soft = r.detach().floor() + 0.5 + torch.tanh(alpha * r_frac) / (2.0 * norm) + w_q = (torch.clamp(r_soft, -32, 31) * scale[:, None]).to(x.dtype) + w = w_q # soft-round is differentiable, no STE needed + else: + with torch.no_grad(): + w_q = (torch.clamp(torch.round(r), -32, 31) * scale[:, None]).to(x.dtype) + w = w + (w_q - w).detach() # STE + bias = self.bias.to(x.dtype) if self.bias is not None else None + return F.linear(x, w, bias) + + +def restore_low_dim_params_to_fp32(module: nn.Module) -> None: + # Keep small/control parameters in fp32 even when the model body runs in bf16. + with torch.no_grad(): + for name, param in module.named_parameters(): + if (param.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS)) and param.dtype != torch.float32: + param.data = param.data.float() + + +class Rotary(nn.Module): + # NTK-aware RoPE: auto-scales base frequency when seq_len exceeds train_seq_len. + def __init__(self, dim: int, base: float = 10000.0, train_seq_len: int = 1024, rope_dims: int = 0): + super().__init__() + self.rope_dims = rope_dims if rope_dims > 0 else dim + self.dim = dim + self.base = base + self.train_seq_len = train_seq_len + rd = self.rope_dims + inv_freq = 1.0 / (base ** (torch.arange(0, rd, 2, dtype=torch.float32) / rd)) + self.register_buffer("inv_freq", inv_freq, persistent=False) + self._seq_len_cached = 0 + self._cos_cached: Tensor | None = None + self._sin_cached: Tensor | None = None + + def forward(self, seq_len: int, device: torch.device, dtype: torch.dtype) -> tuple[Tensor, Tensor]: + if ( + self._cos_cached is None + or self._sin_cached is None + or self._seq_len_cached != seq_len + or self._cos_cached.device != device + ): + rd = self.rope_dims + if seq_len > self.train_seq_len: + scale = seq_len / self.train_seq_len + new_base = self.base * (scale ** (rd / (rd - 2))) + inv_freq = 1.0 / (new_base ** (torch.arange(0, rd, 2, dtype=torch.float32, device=device) / rd)) + else: + inv_freq = self.inv_freq.to(device) + t = torch.arange(seq_len, device=device, dtype=inv_freq.dtype) + freqs = torch.outer(t, inv_freq) + self._cos_cached = freqs.cos()[None, :, None, :] + self._sin_cached = freqs.sin()[None, :, None, :] + self._seq_len_cached = seq_len + return self._cos_cached.to(dtype=dtype), self._sin_cached.to(dtype=dtype) + + +def apply_rotary_emb(x: Tensor, cos: Tensor, sin: Tensor) -> Tensor: + rd = cos.size(-1) * 2 + if rd < x.size(-1): + x_rope, x_pass = x[..., :rd], x[..., rd:] + half = rd // 2 + x1, x2 = x_rope[..., :half], x_rope[..., half:] + x_rot = torch.cat((x1 * cos + x2 * sin, x1 * (-sin) + x2 * cos), dim=-1) + return torch.cat((x_rot, x_pass), dim=-1) + half = x.size(-1) // 2 + x1, x2 = x[..., :half], x[..., half:] + return torch.cat((x1 * cos + x2 * sin, x1 * (-sin) + x2 * cos), dim=-1) + + +class CausalSelfAttention(nn.Module): + def __init__( + self, + dim: int, + num_heads: int, + num_kv_heads: int, + rope_base: float, + qk_gain_init: float, + rope_dims: int = 0, + value_residual: bool = False, + gated_attention: bool = False, + ): + super().__init__() + if dim % num_heads != 0: + raise ValueError("model_dim must be divisible by num_heads") + if num_heads % num_kv_heads != 0: + raise ValueError("num_heads must be divisible by num_kv_heads") + self.num_heads = num_heads + self.num_kv_heads = num_kv_heads + self.head_dim = dim // num_heads + if self.head_dim % 2 != 0: + raise ValueError("head_dim must be even for RoPE") + kv_dim = self.num_kv_heads * self.head_dim + self.c_q = CastedLinear(dim, dim, bias=False) + self.c_k = CastedLinear(dim, kv_dim, bias=False) + self.c_v = CastedLinear(dim, kv_dim, bias=False) + self.proj = CastedLinear(dim, dim, bias=False) + self.proj._zero_init = True + self.q_gain = nn.Parameter(torch.full((num_heads,), qk_gain_init, dtype=torch.float32)) + self.rope_dims = rope_dims + self.rotary = Rotary(self.head_dim, base=rope_base, train_seq_len=1024, rope_dims=rope_dims) + self.use_xsa = False + self.value_residual = value_residual + if value_residual: + self.vr_lambda = nn.Parameter(torch.tensor([0.5, 0.5], dtype=torch.float32)) + self.gated_attention = gated_attention + if gated_attention: + self.attn_gate = nn.Linear(dim, num_heads, bias=True) + nn.init.zeros_(self.attn_gate.weight) + nn.init.constant_(self.attn_gate.bias, 4.0) + + def _xsa_efficient(self, y: Tensor, v: Tensor) -> Tensor: + """Subtract self-value projection via GQA-aware reshape (no repeat_interleave).""" + B, T, H, D = y.shape + Hkv = v.size(-2) + group = H // Hkv + y_g = y.reshape(B, T, Hkv, group, D) + vn = F.normalize(v, dim=-1).unsqueeze(-2) + proj = (y_g * vn).sum(dim=-1, keepdim=True) * vn + return (y_g - proj).reshape(B, T, H, D) + + def forward(self, x: Tensor, v0: Tensor | None = None) -> tuple[Tensor, Tensor | None]: + bsz, seqlen, dim = x.shape + q = self.c_q(x).reshape(bsz, seqlen, self.num_heads, self.head_dim) + k = self.c_k(x).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim) + v = self.c_v(x).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim) + raw_v = v if self.value_residual else None + if self.value_residual and v0 is not None: + lam = self.vr_lambda.to(dtype=v.dtype) + v = lam[0] * v0 + lam[1] * v + q = F.rms_norm(q, (q.size(-1),)) + k = F.rms_norm(k, (k.size(-1),)) + cos, sin = self.rotary(seqlen, x.device, q.dtype) + q = apply_rotary_emb(q, cos, sin) + k = apply_rotary_emb(k, cos, sin) + q = q * self.q_gain.to(dtype=q.dtype)[None, None, :, None] + fa_dtype = torch.bfloat16 + if _USE_FA3: + y = flash_attn_3_func(q.to(fa_dtype), k.to(fa_dtype), v.to(fa_dtype), causal=True) + else: + # SDPA fallback: (B, T, H, D) -> (B, H, T, D), expand KV for GQA + q_t = q.to(fa_dtype).transpose(1, 2) + k_t = k.to(fa_dtype).transpose(1, 2) + v_t = v.to(fa_dtype).transpose(1, 2) + if self.num_kv_heads != self.num_heads: + rep = self.num_heads // self.num_kv_heads + k_t = k_t.repeat_interleave(rep, dim=1) + v_t = v_t.repeat_interleave(rep, dim=1) + y = F.scaled_dot_product_attention(q_t, k_t, v_t, is_causal=True) + y = y.transpose(1, 2) # (B, H, T, D) -> (B, T, H, D) + if self.use_xsa: + y = self._xsa_efficient(y, v) + if self.gated_attention: + gate = torch.sigmoid(self.attn_gate(x)) # (B, T, num_heads) + y = y * gate.unsqueeze(-1) # (B, T, H, 1) broadcast to (B, T, H, D) + y = y.reshape(bsz, seqlen, dim) + return self.proj(y), raw_v + + +class SmearGate(nn.Module): + def __init__(self, dim: int): + super().__init__() + self.gate = nn.Parameter(torch.zeros(dim, dtype=torch.float32)) + + def forward(self, x: Tensor) -> Tensor: + g = torch.sigmoid(self.gate.to(dtype=x.dtype))[None, None, :] + x_prev = torch.cat([torch.zeros_like(x[:, :1]), x[:, :-1]], dim=1) + return (1 - g) * x + g * x_prev + + +class BigramHashEmbedding(nn.Module): + def __init__(self, bigram_vocab_size: int, bigram_dim: int, model_dim: int): + super().__init__() + self.bigram_vocab_size = bigram_vocab_size + self.embed = nn.Embedding(bigram_vocab_size, bigram_dim) + nn.init.zeros_(self.embed.weight) + self.proj = CastedLinear(bigram_dim, model_dim, bias=False) if bigram_dim != model_dim else None + if self.proj is not None: + nn.init.zeros_(self.proj.weight) + self.scale = nn.Parameter(torch.tensor(0.05, dtype=torch.float32)) + + def bigram_hash(self, tokens: Tensor) -> Tensor: + t = tokens.to(torch.int32) + mod = self.bigram_vocab_size - 1 + out = torch.empty_like(t) + out[..., 0] = mod + out[..., 1:] = torch.bitwise_xor(36313 * t[..., 1:], 27191 * t[..., :-1]) % mod + return out.long() + + def forward(self, token_ids: Tensor) -> Tensor: + h = self.embed(self.bigram_hash(token_ids)) + if self.proj is not None: + h = self.proj(h) + return h * self.scale.to(dtype=h.dtype) + + +class MLP(nn.Module): + def __init__(self, dim: int, mlp_mult: int): + super().__init__() + hidden = int(mlp_mult * dim) + self.fc = CastedLinear(dim, hidden, bias=False) + self.proj = CastedLinear(hidden, dim, bias=False) + self.proj._zero_init = True + self.use_leaky = bool(int(os.environ.get("LEAKY_RELU", "1"))) + self.leaky_slope = float(os.environ.get("LEAKY_SLOPE", "0.5")) + + def forward(self, x: Tensor) -> Tensor: + x = F.leaky_relu(self.fc(x), self.leaky_slope) if self.use_leaky else torch.relu(self.fc(x)) + return self.proj(x.square()) + + +class CanonAC(nn.Module): + """Canon Autoregressive Convolution with DeltaGate. Manual shift+mul (no Conv1d).""" + def __init__(self, dim: int, kernel: int = 4, delta_gate_init: float = -4.0): + super().__init__() + self.kernel = kernel + self.weight = nn.Parameter(torch.zeros(kernel, dim)) + self.delta_gate_logit = nn.Parameter(torch.tensor(delta_gate_init)) + + def forward(self, x: Tensor) -> Tensor: + B, T, D = x.shape + K = self.kernel + w = self.weight.to(x.dtype) + x_pad = F.pad(x, (0, 0, K - 1, 0)) + y = w[0] * x_pad[:, K - 1:, :] + for k in range(1, K): + y = y + w[k] * x_pad[:, K - 1 - k : T + K - 1 - k, :] + gate = torch.sigmoid(self.delta_gate_logit.to(x.dtype)) + return x + gate * y + + +class Block(nn.Module): + def __init__( + self, + dim: int, + num_heads: int, + num_kv_heads: int, + mlp_mult: int, + rope_base: float, + qk_gain_init: float, + rope_dims: int = 0, + layer_idx: int = 0, + ln_scale: bool = False, + value_residual: bool = False, + gated_attention: bool = False, + canon_kernel: int = 0, + canon_delta_gate_init: float = -4.0, + ): + super().__init__() + self.attn_norm = RMSNorm() + self.mlp_norm = RMSNorm() + self.attn = CausalSelfAttention(dim, num_heads, num_kv_heads, rope_base, qk_gain_init, + rope_dims=rope_dims, value_residual=value_residual, + gated_attention=gated_attention) + self.mlp = MLP(dim, mlp_mult) + self.attn_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + self.mlp_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + self.resid_mix = nn.Parameter(torch.stack((torch.ones(dim), torch.zeros(dim))).float()) + self.ln_scale_factor = 1.0 / math.sqrt(layer_idx + 1) if ln_scale else 1.0 + self.canon_a = CanonAC(dim, canon_kernel, canon_delta_gate_init) if canon_kernel > 0 else None + self.canon_c = CanonAC(dim, canon_kernel, canon_delta_gate_init) if canon_kernel > 0 else None + + def forward(self, x: Tensor, x0: Tensor, v0: Tensor | None = None) -> tuple[Tensor, Tensor | None]: + mix = self.resid_mix.to(dtype=x.dtype) + x = mix[0][None, None, :] * x + mix[1][None, None, :] * x0 + s = self.ln_scale_factor + attn_in = self.attn_norm(x) * s + if self.canon_a is not None: + attn_in = self.canon_a(attn_in) + attn_out, raw_v = self.attn(attn_in, v0=v0) + x = x + self.attn_scale.to(dtype=x.dtype)[None, None, :] * attn_out + mlp_in = self.mlp_norm(x) * s + if self.canon_c is not None: + mlp_in = self.canon_c(mlp_in) + x = x + self.mlp_scale.to(dtype=x.dtype)[None, None, :] * self.mlp(mlp_in) + return x, raw_v + + +class FlowRefiner(nn.Module): + """1-step flow matching refiner in low-dim latent space. + + Projects hidden states to a low-dim latent, applies a learned velocity + field (1-step Euler), projects back. Output is additive adjustment to + hidden states. Initialized near-zero so the refiner starts as identity. + """ + def __init__(self, model_dim: int, latent_dim: int = 64, hidden_dim: int = 256, init_scale: float = 0.01): + super().__init__() + self.down_proj = nn.Linear(model_dim, latent_dim, bias=False) + self.velocity_net = nn.Sequential( + nn.Linear(latent_dim, hidden_dim, bias=True), + nn.GELU(), + nn.Linear(hidden_dim, latent_dim, bias=True), + ) + self.up_proj = nn.Linear(latent_dim, model_dim, bias=False) + self.gate = nn.Parameter(torch.tensor(-5.0)) # sigmoid(-5) ≈ 0.007 + self._init_scale = init_scale + self._init_weights() + self.velocity_net[2]._zero_init = True + + def _init_weights(self): + nn.init.normal_(self.down_proj.weight, std=self._init_scale) + nn.init.normal_(self.velocity_net[0].weight, std=self._init_scale) + nn.init.zeros_(self.velocity_net[0].bias) + nn.init.zeros_(self.velocity_net[2].weight) + nn.init.zeros_(self.velocity_net[2].bias) + nn.init.normal_(self.up_proj.weight, std=self._init_scale) + + def forward(self, x: Tensor) -> Tensor: + z = self.down_proj(x) + v = self.velocity_net(z) + z_refined = z + v + delta = self.up_proj(z_refined) + return torch.sigmoid(self.gate) * delta + + +class NativeFlowMatcher(nn.Module): + """Conditional Flow Matching refiner for hidden states. + + During training: computes auxiliary CFM loss on interpolated states. + During inference: applies a single Euler step correction at t=1. + """ + def __init__(self, model_dim: int, hidden_dim: int = 256, init_scale: float = 0.01): + super().__init__() + self.model_dim = model_dim + # Sinusoidal time embedding → projected to hidden_dim + self.time_proj = nn.Sequential( + nn.Linear(model_dim, hidden_dim, bias=True), + nn.GELU(), + ) + # Velocity network: x → hidden → x (with time conditioning via addition) + self.v_in = nn.Linear(model_dim, hidden_dim, bias=True) + self.v_act = nn.GELU() + self.v_out = nn.Linear(hidden_dim, model_dim, bias=False) + self.gate = nn.Parameter(torch.tensor(-5.0)) # sigmoid(-5) ≈ 0.007 + self._init_scale = init_scale + self._init_weights() + self.v_out._zero_init = True + + def _init_weights(self): + nn.init.normal_(self.v_in.weight, std=self._init_scale) + nn.init.zeros_(self.v_in.bias) + nn.init.normal_(self.time_proj[0].weight, std=self._init_scale) + nn.init.zeros_(self.time_proj[0].bias) + nn.init.zeros_(self.v_out.weight) + + @staticmethod + def _sinusoidal_time_emb(t: Tensor, dim: int) -> Tensor: + """Sinusoidal positional embedding for scalar time t. t: (...,) -> (..., dim).""" + half_dim = dim // 2 + emb = math.log(10000.0) / max(half_dim - 1, 1) + emb = torch.exp(torch.arange(half_dim, device=t.device, dtype=t.dtype) * -emb) + emb = t.unsqueeze(-1) * emb # (..., half_dim) + emb = torch.cat([emb.sin(), emb.cos()], dim=-1) # (..., dim) + if dim % 2 == 1: + emb = F.pad(emb, (0, 1)) + return emb + + def _velocity(self, x: Tensor, t_emb: Tensor) -> Tensor: + """Compute velocity v(x, t). x: (..., model_dim), t_emb: (..., hidden_dim).""" + h = self.v_in(x) # (..., hidden_dim) + h = h + t_emb # time conditioning via addition + h = self.v_act(h) # GELU + return self.v_out(h) # (..., model_dim) + + def forward(self, x: Tensor) -> tuple[Tensor, Tensor]: + """Returns (correction, cfm_loss). + + During training: cfm_loss is the flow matching objective (> 0). + During eval: cfm_loss is zero. + correction: gated velocity at t=1, to be added to x. + """ + B_seq = x.shape[:-1] # works for any leading dimensions + D = self.model_dim + + # Inference path: velocity at t=1 (clean input, no noise) + t_ones = torch.ones(*B_seq, device=x.device, dtype=x.dtype) + t_emb_ones = self.time_proj(self._sinusoidal_time_emb(t_ones, D)) + v_at_1 = self._velocity(x, t_emb_ones) + correction = torch.sigmoid(self.gate) * v_at_1 + + # Training path: auxiliary CFM loss + if self.training: + # Sample random t ~ U(0,1) per position + t = torch.rand(*B_seq, device=x.device, dtype=x.dtype) + z = torch.randn_like(x) # noise + # OT interpolant: x_t = (1-t)*z + t*x + t_expanded = t.unsqueeze(-1) # (..., 1) + x_t = (1.0 - t_expanded) * z + t_expanded * x.detach() + # Target velocity = x - z (OT conditional velocity field) + target_v = x.detach() - z + # Predict velocity at interpolated point + t_emb = self.time_proj(self._sinusoidal_time_emb(t, D)) + pred_v = self._velocity(x_t, t_emb) + cfm_loss = F.mse_loss(pred_v, target_v) + else: + cfm_loss = x.new_zeros(()) + + return correction, cfm_loss + + +class TTTLinearRefiner(nn.Module): + """E2E TTT-Linear refiner (Sun et al., 2024, arXiv:2407.04620). + + Applied to final hidden states [B, L, D] before lm_head. The hidden + state is a per-head linear model W updated via mini-batch gradient + descent on a learned self-supervised reconstruction task. Uses the + dual form for GPU efficiency. Trained end-to-end: the outer loop + optimizes theta_K, theta_V, theta_Q projections and the inner-loop + initialization W_0, making the model learn WHAT to compress. + """ + + def __init__(self, model_dim: int, num_heads: int = 8, + mini_batch_size: int = 16, ttt_base_lr: float = 1.0): + super().__init__() + self.model_dim = model_dim + self.num_heads = num_heads + self.head_dim = model_dim // num_heads + self.mini_batch_size = mini_batch_size + self.ttt_base_lr = ttt_base_lr + self.eta = ttt_base_lr / self.head_dim + + # Learnable projections (outer-loop) + self.theta_K = nn.Linear(model_dim, model_dim, bias=False) + self.theta_V = nn.Linear(model_dim, model_dim, bias=False) + self.theta_Q = nn.Linear(model_dim, model_dim, bias=False) + + # Inner-loop model: W_0, b_0 + self.W1 = nn.Parameter(torch.normal(0, 0.02, + size=(num_heads, self.head_dim, self.head_dim))) + self.b1 = nn.Parameter(torch.zeros(num_heads, 1, self.head_dim)) + + # Inner-loop LayerNorm (outer-loop trainable) + self.ttt_ln_w = nn.Parameter(torch.ones(num_heads, self.head_dim)) + self.ttt_ln_b = nn.Parameter(torch.zeros(num_heads, self.head_dim)) + + # Output projection + gate (starts near zero for residual safety) + self.o_proj = nn.Linear(model_dim, model_dim, bias=False) + self.post_norm = nn.LayerNorm(model_dim, eps=1e-6) + self.gate = nn.Parameter(torch.tensor(-5.0)) + nn.init.zeros_(self.o_proj.weight) + self.o_proj._zero_init = True + + @staticmethod + def _ln_fwd(x: Tensor, gamma: Tensor, beta: Tensor, eps: float = 1e-6) -> Tensor: + mu = x.mean(dim=-1, keepdim=True) + var = x.var(dim=-1, keepdim=True, unbiased=False) + x_hat = (x - mu) / torch.sqrt(var + eps) + return gamma * x_hat + beta + + @staticmethod + def _ln_fused_l2_bwd(x: Tensor, target: Tensor, gamma: Tensor, beta: Tensor, + eps: float = 1e-6) -> Tensor: + D = x.shape[-1] + mu = x.mean(dim=-1, keepdim=True) + var = x.var(dim=-1, keepdim=True, unbiased=False) + std = torch.sqrt(var + eps) + x_hat = (x - mu) / std + y = gamma * x_hat + beta + grad_output = y - target + grad_x_hat = grad_output * gamma + z = (1.0 / D) * ( + D * grad_x_hat + - grad_x_hat.sum(dim=-1, keepdim=True) + - x_hat * (grad_x_hat * x_hat).sum(dim=-1, keepdim=True) + ) / std + return z + + def forward(self, x: Tensor) -> Tensor: + """x: [B, L, D]. Returns [B, L, D] additive refinement (gated).""" + B, L, D = x.shape + NH, HD, b = self.num_heads, self.head_dim, self.mini_batch_size + eta = self.eta + + num_mb = L // b + usable_L = num_mb * b + x_proc = x[:, :usable_L] if usable_L < L else x + + # Project: [B, L, D] -> [B, NH, L, HD] + XK = self.theta_K(x_proc).reshape(B, usable_L, NH, HD).permute(0, 2, 1, 3) + XV = self.theta_V(x_proc).reshape(B, usable_L, NH, HD).permute(0, 2, 1, 3) + XQ = self.theta_Q(x_proc).reshape(B, usable_L, NH, HD).permute(0, 2, 1, 3) + + # Reshape into mini-batches: [B, NH, num_mb, b, HD] + XK = XK.reshape(B, NH, num_mb, b, HD) + XV = XV.reshape(B, NH, num_mb, b, HD) + XQ = XQ.reshape(B, NH, num_mb, b, HD) + + ln_w = self.ttt_ln_w.reshape(1, NH, 1, HD) + ln_b = self.ttt_ln_b.reshape(1, NH, 1, HD) + + # Causal mask for dual form + causal = torch.tril(torch.ones(b, b, device=x.device, dtype=x.dtype)) + + # Initialize inner-loop model + W = self.W1.unsqueeze(0).expand(B, -1, -1, -1).clone() + bias = self.b1.unsqueeze(0).expand(B, -1, -1, -1).clone() + + all_out = [] + for m in range(num_mb): + xk = XK[:, :, m] + xv = XV[:, :, m] + xq = XQ[:, :, m] + + Z1 = xk @ W + bias + target = xv - xk + + grad = self._ln_fused_l2_bwd(Z1, target, ln_w, ln_b) + + Attn = causal.unsqueeze(0).unsqueeze(0) * (xq @ xk.transpose(-2, -1)) + cumgrad = (causal.unsqueeze(0).unsqueeze(0) @ grad) + b_bar = bias - eta * cumgrad + Z_bar = xq @ W - eta * Attn @ grad + b_bar + + W = W - eta * xk.transpose(-2, -1) @ grad + bias = bias - eta * grad.sum(dim=-2, keepdim=True) + + Z_bar = self._ln_fwd(Z_bar, ln_w, ln_b) + out_mb = xq + Z_bar + all_out.append(out_mb) + + z = torch.cat(all_out, dim=2).permute(0, 2, 1, 3).reshape(B, usable_L, D) + z = self.post_norm(z) + z = self.o_proj(z) + result = torch.sigmoid(self.gate) * z + + if usable_L < L: + pad = torch.zeros(B, L - usable_L, D, device=x.device, dtype=x.dtype) + result = torch.cat([result, pad], dim=1) + + return result + + +class GPT(nn.Module): + def __init__( + self, + vocab_size: int, + num_layers: int, + model_dim: int, + num_heads: int, + num_kv_heads: int, + mlp_mult: int, + tie_embeddings: bool, + tied_embed_init_std: float, + logit_softcap: float, + rope_base: float, + qk_gain_init: float, + mtp_num_heads: int = 0, + mtp_loss_weight: float = 0.1, + bigram_vocab_size: int = 0, + bigram_dim: int = 128, + xsa_last_n: int = 0, + rope_dims: int = 0, + ln_scale: bool = False, + value_residual: bool = False, + gated_attention: bool = False, + canon_last_n: int = 0, + canon_kernel: int = 4, + canon_delta_gate_init: float = -4.0, + flow_enabled: bool = False, + flow_latent_dim: int = 64, + flow_hidden_dim: int = 256, + flow_init_scale: float = 0.01, + native_flow_enabled: bool = False, + native_flow_hidden_dim: int = 256, + native_flow_init_scale: float = 0.01, + native_flow_loss_weight: float = 0.1, + e2e_ttt_enabled: bool = False, + e2e_ttt_num_heads: int = 8, + e2e_ttt_mini_batch: int = 16, + e2e_ttt_base_lr: float = 1.0, + ): + super().__init__() + if logit_softcap <= 0.0: + raise ValueError(f"logit_softcap must be positive, got {logit_softcap}") + self.tie_embeddings = tie_embeddings + self.tied_embed_init_std = tied_embed_init_std + self.logit_softcap = logit_softcap + self.mtp_num_heads = mtp_num_heads + self.mtp_loss_weight = mtp_loss_weight + self.tok_emb = nn.Embedding(vocab_size, model_dim) + self.bigram = BigramHashEmbedding(bigram_vocab_size, bigram_dim, model_dim) if bigram_vocab_size > 0 else None + self.smear = SmearGate(model_dim) + self.num_encoder_layers = num_layers // 2 + self.num_decoder_layers = num_layers - self.num_encoder_layers + self.num_skip_weights = min(self.num_encoder_layers, self.num_decoder_layers) + self.skip_weights = nn.Parameter(torch.ones(self.num_skip_weights, model_dim, dtype=torch.float32)) + canon_start = num_layers - canon_last_n if canon_last_n > 0 else num_layers + self.blocks = nn.ModuleList( + [ + Block( + model_dim, + num_heads, + num_kv_heads, + mlp_mult, + rope_base, + qk_gain_init, + rope_dims=rope_dims, + layer_idx=i, + ln_scale=ln_scale, + value_residual=value_residual, + gated_attention=gated_attention, + canon_kernel=canon_kernel if i >= canon_start else 0, + canon_delta_gate_init=canon_delta_gate_init, + ) + for i in range(num_layers) + ] + ) + self.final_norm = RMSNorm() + self.lm_head = None if tie_embeddings else CastedLinear(model_dim, vocab_size, bias=False) + if self.lm_head is not None: + self.lm_head._zero_init = True + self.flow_refiner = FlowRefiner(model_dim, flow_latent_dim, flow_hidden_dim, flow_init_scale) if flow_enabled else None + self.native_flow = NativeFlowMatcher(model_dim, native_flow_hidden_dim, native_flow_init_scale) if native_flow_enabled else None + self.native_flow_loss_weight = native_flow_loss_weight + self.ttt_refiner = TTTLinearRefiner(model_dim, e2e_ttt_num_heads, e2e_ttt_mini_batch, e2e_ttt_base_lr) if e2e_ttt_enabled else None + self.mtp_heads = nn.ModuleList( + [CastedLinear(model_dim, vocab_size, bias=False) for _ in range(mtp_num_heads)] + ) + for head in self.mtp_heads: + head._zero_init = True + if xsa_last_n > 0: + for i in range(max(0, num_layers - xsa_last_n), num_layers): + self.blocks[i].attn.use_xsa = True + self._init_weights() + + def _init_weights(self) -> None: + if self.tie_embeddings: + nn.init.normal_(self.tok_emb.weight, mean=0.0, std=self.tied_embed_init_std) + num_layers = len(self.blocks) + for name, module in self.named_modules(): + if isinstance(module, nn.Linear): + if getattr(module, "_zero_init", False): + nn.init.zeros_(module.weight) + elif module.weight.ndim == 2 and module.weight.shape[0] >= 64 and module.weight.shape[1] >= 64: + nn.init.orthogonal_(module.weight, gain=1.0) + if ".proj." in name or name.endswith(".proj"): + with torch.no_grad(): + module.weight.mul_(1.0 / math.sqrt(2 * num_layers)) + + def forward(self, input_ids: Tensor, target_ids: Tensor) -> Tensor: + x = self.tok_emb(input_ids) + if self.bigram is not None: + x = x + self.bigram(input_ids) + x = F.rms_norm(x, (x.size(-1),)) + x = self.smear(x) + x0 = x + v0 = None + skips: list[Tensor] = [] + + for i in range(self.num_encoder_layers): + x, raw_v = self.blocks[i](x, x0, v0=v0) + if v0 is None and raw_v is not None: + v0 = raw_v + skips.append(x) + for i in range(self.num_decoder_layers): + if skips: + x = x + self.skip_weights[i].to(dtype=x.dtype)[None, None, :] * skips.pop() + x, _ = self.blocks[self.num_encoder_layers + i](x, x0, v0=v0) + + x = self.final_norm(x) + if self.ttt_refiner is not None: + x = x + self.ttt_refiner(x) + if self.flow_refiner is not None: + x = x + self.flow_refiner(x) + nflow_loss = x.new_zeros(()) + if self.native_flow is not None: + correction, nflow_loss = self.native_flow(x) + x = x + correction + x_flat = x.reshape(-1, x.size(-1)) + targets = target_ids.reshape(-1) + if self.tie_embeddings: + logits_proj = F.linear(x_flat, self.tok_emb.weight) + else: + if self.lm_head is None: + raise RuntimeError("lm_head is required when tie_embeddings=False") + logits_proj = self.lm_head(x_flat) + logits = self.logit_softcap * torch.tanh(logits_proj / self.logit_softcap) + main_loss = F.cross_entropy(logits.float(), targets, reduction="mean") + + if self.training and self.mtp_num_heads > 0 and self.mtp_loss_weight > 0.0: + _, seqlen, dim = x.shape + mtp_loss_sum = x.new_zeros(()) + mtp_loss_count = 0 + for k, mtp_head in enumerate(self.mtp_heads): + valid_t = seqlen - (k + 1) + if valid_t <= 0: + continue + mtp_hidden = x[:, :valid_t, :].reshape(-1, dim) + mtp_targets = target_ids[:, k + 1 :].reshape(-1) + mtp_logits_proj = mtp_head(mtp_hidden) + mtp_logits = self.logit_softcap * torch.tanh(mtp_logits_proj / self.logit_softcap) + mtp_loss_sum = mtp_loss_sum + F.cross_entropy(mtp_logits.float(), mtp_targets, reduction="mean") + mtp_loss_count += 1 + if mtp_loss_count > 0: + main_loss = main_loss + self.mtp_loss_weight * (mtp_loss_sum / mtp_loss_count) + + if self.native_flow is not None and self.training: + main_loss = main_loss + self.native_flow_loss_weight * nflow_loss + + return main_loss + + def forward_logits(self, input_ids: Tensor) -> Tensor: + """Return logits (bsz, seq_len, vocab) without computing loss.""" + x = self.tok_emb(input_ids) + if self.bigram is not None: + x = x + self.bigram(input_ids) + x = F.rms_norm(x, (x.size(-1),)) + x = self.smear(x) + x0 = x + v0 = None + skips: list[Tensor] = [] + for i in range(self.num_encoder_layers): + x, raw_v = self.blocks[i](x, x0, v0=v0) + if v0 is None and raw_v is not None: + v0 = raw_v + skips.append(x) + for i in range(self.num_decoder_layers): + if skips: + x = x + self.skip_weights[i].to(dtype=x.dtype)[None, None, :] * skips.pop() + x, _ = self.blocks[self.num_encoder_layers + i](x, x0, v0=v0) + x = self.final_norm(x) + if self.ttt_refiner is not None: + x = x + self.ttt_refiner(x) + if self.flow_refiner is not None: + x = x + self.flow_refiner(x) + if self.native_flow is not None: + correction, _ = self.native_flow(x) + x = x + correction + if self.tie_embeddings: + logits_proj = F.linear(x, self.tok_emb.weight) + else: + logits_proj = self.lm_head(x) + return self.logit_softcap * torch.tanh(logits_proj / self.logit_softcap) + + +# ----------------------------- +# SLIDING WINDOW EVALUATION +# ----------------------------- + +def eval_val_sliding( + args: Hyperparameters, + base_model: nn.Module, + rank: int, + world_size: int, + device: torch.device, + val_tokens: Tensor, + base_bytes_lut: Tensor, + has_leading_space_lut: Tensor, + is_boundary_token_lut: Tensor, + stride: int, + batch_seqs: int = 32, + eval_seq_len: int | None = None, +) -> tuple[float, float]: + """Sliding window evaluation: each token scored with maximum context. + Optionally uses entropy-gated 5-gram cache (NGRAM_CACHE=1).""" + seq_len = eval_seq_len or args.train_seq_len + total_tokens = val_tokens.numel() - 1 + + window_starts = [ws for ws in range(0, total_tokens, stride) + if min(ws + seq_len, total_tokens) - ws >= 1] + total_windows = len(window_starts) + + my_s = (total_windows * rank) // world_size + my_e = (total_windows * (rank + 1)) // world_size + my_windows = window_starts[my_s:my_e] + + loss_sum = torch.zeros((), device=device, dtype=torch.float64) + token_count = torch.zeros((), device=device, dtype=torch.float64) + byte_count = torch.zeros((), device=device, dtype=torch.float64) + + # N-gram eval cache with multi-order backoff + entropy-adaptive alpha (PR #702 inspired) + _ngram_default = "1" if world_size > 1 else "0" + use_ngram = bool(int(os.environ.get("NGRAM_CACHE", _ngram_default))) + ngram_alpha = float(os.environ.get("NGRAM_ALPHA", "0.40")) + ngram_min_count = int(os.environ.get("NGRAM_MIN_COUNT", "2")) + ngram_order = int(os.environ.get("NGRAM_ORDER", "7")) + ngram_min_order = int(os.environ.get("NGRAM_MIN_ORDER", "2")) + ngram_buckets = int(os.environ.get("NGRAM_BUCKETS", "4194304")) + ngram_entropy = bool(int(os.environ.get("NGRAM_ENTROPY", "1"))) + ngram_ent_base = float(os.environ.get("NGRAM_ENT_BASE", "0.05")) + ngram_ent_range = float(os.environ.get("NGRAM_ENT_RANGE", "0.55")) + ngram_ent_scale = float(os.environ.get("NGRAM_ENT_SCALE", "2.0")) + ngram_ent_thresh = float(os.environ.get("NGRAM_ENT_THRESH", "4.0")) + if use_ngram: + val_np = val_tokens.cpu().numpy() + _n_orders = ngram_order - ngram_min_order + 1 + ctx_tables = [np.zeros((ngram_buckets,), dtype=np.uint32) for _ in range(_n_orders)] + full_tables = [np.zeros((ngram_buckets,), dtype=np.uint32) for _ in range(_n_orders)] + ng_mask = np.uint64(ngram_buckets - 1) + ng_primes = np.array( + [np.uint64(36313), np.uint64(27191), np.uint64(51647), np.uint64(81929), + np.uint64(131071), np.uint64(175447), np.uint64(209591)], + dtype=np.uint64, + ) + print(f"ngram_cache:enabled orders={ngram_min_order}-{ngram_order} backoff " + f"entropy={ngram_entropy} alpha={ngram_alpha} " + f"ent_base={ngram_ent_base} ent_range={ngram_ent_range} " + f"min_count={ngram_min_count} buckets={ngram_buckets}", flush=True) + + base_model.eval() + compiled_logits = torch.compile(base_model.forward_logits, dynamic=False, fullgraph=True) + + with torch.inference_mode(): + for bi in range(0, len(my_windows), batch_seqs): + batch_ws = my_windows[bi:bi + batch_seqs] + bsz = len(batch_ws) + + x_batch = torch.zeros(bsz, seq_len, dtype=torch.int64, device=device) + y_batch = torch.zeros(bsz, seq_len, dtype=torch.int64, device=device) + wlens: list[int] = [] + + for i, ws in enumerate(batch_ws): + end = min(ws + seq_len, total_tokens) + wlen = end - ws + wlens.append(wlen) + chunk = val_tokens[ws:end + 1].to(dtype=torch.int64, device=device) + x_batch[i, :wlen] = chunk[:-1] + y_batch[i, :wlen] = chunk[1:] + + with torch.autocast(device_type="cuda", dtype=torch.bfloat16): + logits = compiled_logits(x_batch) + + nll = F.cross_entropy( + logits.reshape(-1, logits.size(-1)).float(), + y_batch.reshape(-1), + reduction="none", + ).reshape(bsz, seq_len) + + for i, ws in enumerate(batch_ws): + wlen = wlens[i] + s = 0 if ws == 0 else max(wlen - stride, 0) + seg_len = wlen - s + if seg_len <= 0: + continue + + scored_nll = nll[i, s:wlen].to(torch.float64) + + if use_ngram: + seg_nll_np = scored_nll.cpu().numpy() + seg_model_p = np.exp(-seg_nll_np) + n_seg = len(seg_nll_np) + global_j = np.arange(ws + s + 1, ws + wlen + 1, dtype=np.int64) + + # Entropy-adaptive alpha: compute from model logits (GPU) + if ngram_entropy: + with torch.no_grad(): + lp = F.log_softmax(logits[i, s:wlen].float(), dim=-1) + seg_ent = -(lp.exp() * lp).sum(dim=-1).cpu().numpy() + alpha_per_tok = ngram_ent_base + ngram_ent_range / ( + 1.0 + np.exp(-ngram_ent_scale * (seg_ent - ngram_ent_thresh))) + + # Precompute hashes for all orders + order_data = [] # (v_idx, ctx_key, full_key) per order + for oi in range(_n_orders): + ctx_w = ngram_min_order + oi - 1 + valid = global_j >= ctx_w + if not valid.any(): + order_data.append(None) + continue + v_idx = np.nonzero(valid)[0] + jv = global_j[v_idx] + ctx_hash = np.zeros(len(jv), dtype=np.uint64) + for k in range(ctx_w): + tok = val_np[jv - (ctx_w - k)].astype(np.uint64) + ctx_hash ^= tok * ng_primes[k % len(ng_primes)] + ctx_key = (ctx_hash & ng_mask).astype(np.int64) + tgt_np = val_np[jv].astype(np.uint64) + full_key = ((ctx_hash ^ (tgt_np * ng_primes[ctx_w % len(ng_primes)])) & ng_mask).astype(np.int64) + order_data.append((v_idx, ctx_key, full_key)) + + # Multi-order backoff: highest order first, fill unmatched with lower orders + best_p_ng = np.full(n_seg, -1.0) + for oi in range(_n_orders - 1, -1, -1): + if order_data[oi] is None: + continue + v_idx, ctx_key, full_key = order_data[oi] + ctx_counts = ctx_tables[oi][ctx_key].astype(np.float64) + full_counts = full_tables[oi][full_key].astype(np.float64) + has_match = ctx_counts >= float(ngram_min_count) + needs_fill = has_match & (best_p_ng[v_idx] < 0) + if needs_fill.any(): + fill_idx = v_idx[needs_fill] + p = np.minimum(full_counts[needs_fill], ctx_counts[needs_fill]) / np.maximum(ctx_counts[needs_fill], 1.0) + best_p_ng[fill_idx] = np.clip(p, 0.0, 1.0) + + # Mix model probability with n-gram + has_match = best_p_ng >= 0 + if has_match.any(): + if ngram_entropy: + alpha = alpha_per_tok[has_match] + else: + alpha = ngram_alpha + seg_model_p[has_match] = (1.0 - alpha) * seg_model_p[has_match] + alpha * best_p_ng[has_match] + seg_nll_np = -np.log(np.clip(seg_model_p, 1e-12, 1.0)) + + # Score-first: update ALL order tables AFTER scoring + for oi in range(_n_orders): + if order_data[oi] is None: + continue + v_idx, ctx_key, full_key = order_data[oi] + np.add.at(ctx_tables[oi], ctx_key, 1) + np.add.at(full_tables[oi], full_key, 1) + + scored_nll = torch.from_numpy(seg_nll_np).to(dtype=torch.float64, device=device) + + loss_sum += scored_nll.sum() + token_count += float(seg_len) + tgt = y_batch[i, s:wlen] + prev = x_batch[i, s:wlen] + tb = base_bytes_lut[tgt].to(torch.float64) + tb += (has_leading_space_lut[tgt] & ~is_boundary_token_lut[prev]).to(torch.float64) + byte_count += tb.sum() + + if dist.is_available() and dist.is_initialized(): + dist.all_reduce(loss_sum, op=dist.ReduceOp.SUM) + dist.all_reduce(token_count, op=dist.ReduceOp.SUM) + dist.all_reduce(byte_count, op=dist.ReduceOp.SUM) + + val_loss = (loss_sum / token_count).item() + bits_per_token = val_loss / math.log(2.0) + tokens_per_byte = token_count.item() / byte_count.item() + base_model.train() + return val_loss, bits_per_token * tokens_per_byte + + +# ----------------------------- +# TEST-TIME TRAINING (TTT) +# ----------------------------- + +def _clear_rotary_caches(model: nn.Module) -> None: + """Clear cached RoPE tensors to avoid 'Inference tensors cannot be saved for backward'.""" + for m in model.modules(): + if isinstance(m, Rotary): + m._cos_cached = None + m._sin_cached = None + m._seq_len_cached = 0 + + +def ttt_adapt(args: Hyperparameters, base_model: nn.Module, device: torch.device, + val_tokens: Tensor, rank: int = 0, world_size: int = 1, + log_fn=None) -> None: + """Score-first TTT: process val data in chunks, score each chunk first + (inference_mode), then train on scored tokens. Compliant with Issue #677.""" + seq_len = args.train_seq_len + total_tokens = val_tokens.numel() - 1 + chunk_tokens = args.ttt_chunk_tokens + batch_seqs = args.ttt_batch_seqs + + # Freeze early blocks + if args.ttt_freeze_blocks > 0: + for i, block in enumerate(base_model.blocks): + if i < args.ttt_freeze_blocks: + for p in block.parameters(): + p.requires_grad_(False) + + ttt_params = [p for p in base_model.parameters() if p.requires_grad] + if args.ttt_optimizer == "adamw": + optimizer = torch.optim.AdamW(ttt_params, lr=args.ttt_lr, weight_decay=0.0) + else: + optimizer = torch.optim.SGD(ttt_params, lr=args.ttt_lr, momentum=args.ttt_momentum) + + t0 = time.perf_counter() + chunk_idx = 0 + + for chunk_start in range(0, total_tokens - seq_len, chunk_tokens): + chunk_end = min(chunk_start + chunk_tokens, total_tokens) + chunk_len = chunk_end - chunk_start + n_seqs = chunk_len // seq_len + if n_seqs == 0: + break + + my_start = (n_seqs * rank) // world_size + my_end = (n_seqs * (rank + 1)) // world_size + if my_end <= my_start: + continue + + # Phase 1: Score chunk under inference_mode (forward only) + base_model.eval() + with torch.inference_mode(): + for si in range(my_start, my_end, batch_seqs): + se = min(si + batch_seqs, my_end) + raw_s = chunk_start + si * seq_len + raw_e = chunk_start + se * seq_len + 1 + local = val_tokens[raw_s:raw_e].to(device=device, dtype=torch.int64) + x = local[:-1].reshape(-1, seq_len) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16): + base_model.forward_logits(x) + + # Phase 2: Train on scored tokens (K epochs) + base_model.train() + for epoch in range(args.ttt_epochs): + for si in range(my_start, my_end, batch_seqs): + se = min(si + batch_seqs, my_end) + raw_s = chunk_start + si * seq_len + raw_e = chunk_start + se * seq_len + 1 + local = val_tokens[raw_s:raw_e].to(device=device, dtype=torch.int64) + x = local[:-1].reshape(-1, seq_len) + y = local[1:].reshape(-1, seq_len) + + optimizer.zero_grad(set_to_none=True) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16): + loss = base_model(x, y) + loss.backward() + + if world_size > 1: + for p in ttt_params: + if p.grad is not None: + dist.all_reduce(p.grad, op=dist.ReduceOp.AVG) + + torch.nn.utils.clip_grad_norm_(ttt_params, 1.0) + optimizer.step() + + chunk_idx += 1 + if log_fn and chunk_idx % 20 == 0: + log_fn(f"ttt:chunk={chunk_idx} elapsed={time.perf_counter()-t0:.1f}s") + + # Restore all params + for p in base_model.parameters(): + p.requires_grad_(True) + + if log_fn: + log_fn(f"ttt:done chunks={chunk_idx} elapsed={time.perf_counter()-t0:.1f}s") + + +def eval_val_sliding_ttt( + args: Hyperparameters, + base_model: nn.Module, + rank: int, + world_size: int, + device: torch.device, + val_tokens: Tensor, + base_bytes_lut: Tensor, + has_leading_space_lut: Tensor, + is_boundary_token_lut: Tensor, + stride: int, + batch_seqs: int = 32, + eval_seq_len: int | None = None, + log_fn=None, +) -> tuple[float, float]: + """Legal single-pass TTT: score each chunk with sliding windows, then train on it. + Tokens are always scored BEFORE any training on their chunk, so the evaluation + is never contaminated by future information.""" + seq_len = eval_seq_len or args.train_seq_len + total_tokens = val_tokens.numel() - 1 + ttt_chunk = args.ttt_chunk_tokens + + # Build window starts (same logic as eval_val_sliding) + window_starts = [ws for ws in range(0, total_tokens, stride) + if min(ws + seq_len, total_tokens) - ws >= stride or ws == 0] + + # Map each window to the chunk that contains its first scored token + num_chunks = (total_tokens + ttt_chunk - 1) // ttt_chunk + chunk_windows: list[list[int]] = [[] for _ in range(num_chunks)] + for ws in window_starts: + end = min(ws + seq_len, total_tokens) + wlen = end - ws + s = 0 if ws == 0 else max(wlen - stride, 0) + scored_start = ws + s + ci = min(scored_start // ttt_chunk, num_chunks - 1) + chunk_windows[ci].append(ws) + + if log_fn: + log_fn(f"ttt_sliding:start chunks={num_chunks} chunk_tokens={ttt_chunk} " + f"total_windows={len(window_starts)} stride={stride} " + f"ttt_lr={args.ttt_lr} ttt_epochs={args.ttt_epochs} " + f"ttt_optimizer={args.ttt_optimizer} freeze_blocks={args.ttt_freeze_blocks}") + + loss_sum = torch.zeros((), device=device, dtype=torch.float64) + token_count = torch.zeros((), device=device, dtype=torch.float64) + byte_count = torch.zeros((), device=device, dtype=torch.float64) + + # Freeze first N blocks + n_blocks = len(base_model.blocks) + frozen_block_ids = set(range(min(args.ttt_freeze_blocks, n_blocks))) + + ttt_params = [] + for name, p in base_model.named_parameters(): + freeze = False + for bi in frozen_block_ids: + if f"blocks.{bi}." in name: + freeze = True + break + if freeze: + p.requires_grad_(False) + else: + p.requires_grad_(True) + ttt_params.append(p) + + if log_fn: + log_fn(f"ttt_sliding:params unfrozen={sum(p.numel() for p in ttt_params)} " + f"frozen={sum(p.numel() for p in base_model.parameters() if not p.requires_grad)}") + + if args.ttt_optimizer == "adamw": + optimizer = torch.optim.AdamW(ttt_params, lr=args.ttt_lr, weight_decay=0.0) + else: + optimizer = torch.optim.SGD(ttt_params, lr=args.ttt_lr, momentum=args.ttt_momentum) + + t0 = time.perf_counter() + + for ci in range(num_chunks): + windows = chunk_windows[ci] + if not windows: + continue + + # --- Phase 1: SCORE this chunk's windows --- + my_s = (len(windows) * rank) // world_size + my_e = (len(windows) * (rank + 1)) // world_size + my_windows = windows[my_s:my_e] + + base_model.eval() + with torch.inference_mode(): + for bi in range(0, len(my_windows), batch_seqs): + batch_ws = my_windows[bi:bi + batch_seqs] + bsz = len(batch_ws) + x_batch = torch.zeros(bsz, seq_len, dtype=torch.int64, device=device) + y_batch = torch.zeros(bsz, seq_len, dtype=torch.int64, device=device) + wlens: list[int] = [] + for i, ws in enumerate(batch_ws): + end = min(ws + seq_len, total_tokens) + wlen = end - ws + wlens.append(wlen) + chunk_tok = val_tokens[ws:end + 1].to(dtype=torch.int64, device=device) + x_batch[i, :wlen] = chunk_tok[:-1] + y_batch[i, :wlen] = chunk_tok[1:] + with torch.autocast(device_type="cuda", dtype=torch.bfloat16): + logits = base_model.forward_logits(x_batch) + nll = F.cross_entropy( + logits.reshape(-1, logits.size(-1)).float(), + y_batch.reshape(-1), reduction="none", + ).reshape(bsz, seq_len) + for i, ws in enumerate(batch_ws): + wlen = wlens[i] + s = 0 if ws == 0 else max(wlen - stride, 0) + scored_nll = nll[i, s:wlen].to(torch.float64) + loss_sum += scored_nll.sum() + token_count += float(wlen - s) + tgt, prev = y_batch[i, s:wlen], x_batch[i, s:wlen] + tb = base_bytes_lut[tgt].to(torch.float64) + tb += (has_leading_space_lut[tgt] & ~is_boundary_token_lut[prev]).to(torch.float64) + byte_count += tb.sum() + + # --- Phase 2: TRAIN on this chunk's tokens (already scored = legal) --- + _clear_rotary_caches(base_model) + is_last_chunk = (ci == num_chunks - 1) + if not is_last_chunk and args.ttt_epochs > 0: + base_model.train() + chunk_start = ci * ttt_chunk + chunk_end = min((ci + 1) * ttt_chunk, total_tokens) + chunk_seqs = (chunk_end - chunk_start) // seq_len + if chunk_seqs > 0: + # Cosine LR schedule with 5% warmup + warmup_chunks = max(num_chunks // 20, 1) + if ci < warmup_chunks: + lr_scale = (ci + 1) / warmup_chunks + else: + progress = (ci - warmup_chunks) / max(num_chunks - 1 - warmup_chunks, 1) + lr_scale = 0.5 * (1.0 + math.cos(math.pi * progress)) + cos_lr = args.ttt_lr * lr_scale + for pg in optimizer.param_groups: + pg['lr'] = cos_lr + + my_seq_s = (chunk_seqs * rank) // world_size + my_seq_e = (chunk_seqs * (rank + 1)) // world_size + my_chunk_seqs = my_seq_e - my_seq_s + + for _ep in range(args.ttt_epochs): + for bs in range(0, my_chunk_seqs, args.ttt_batch_seqs): + be = min(bs + args.ttt_batch_seqs, my_chunk_seqs) + actual_bs = my_seq_s + bs + actual_be = my_seq_s + be + start_tok = chunk_start + actual_bs * seq_len + end_tok = chunk_start + actual_be * seq_len + 1 + if end_tok > val_tokens.numel(): + continue + local = val_tokens[start_tok:end_tok].to(device=device, dtype=torch.int64) + x = local[:-1].reshape(-1, seq_len) + y = local[1:].reshape(-1, seq_len) + optimizer.zero_grad(set_to_none=True) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16): + loss = base_model(x, y) + loss.backward() + if world_size > 1: + for p in ttt_params: + if p.grad is not None: + dist.all_reduce(p.grad, op=dist.ReduceOp.AVG) + torch.nn.utils.clip_grad_norm_(ttt_params, 1.0) + optimizer.step() + + if log_fn and (ci % 10 == 0 or ci == num_chunks - 1): + elapsed = time.perf_counter() - t0 + rl = loss_sum.item() / max(token_count.item(), 1) + rbpb = rl / math.log(2.0) * (token_count.item() / max(byte_count.item(), 1)) if token_count.item() > 0 else 0.0 + log_fn(f" ttt_chunk [{ci+1}/{num_chunks}] bpb={rbpb:.6f} time={elapsed:.1f}s") + + if dist.is_available() and dist.is_initialized(): + dist.all_reduce(loss_sum, op=dist.ReduceOp.SUM) + dist.all_reduce(token_count, op=dist.ReduceOp.SUM) + dist.all_reduce(byte_count, op=dist.ReduceOp.SUM) + + val_loss = (loss_sum / token_count).item() + val_bpb = val_loss / math.log(2.0) * (token_count.item() / byte_count.item()) + + # Restore all params and return to eval mode + for p in base_model.parameters(): + p.requires_grad_(True) + base_model.eval() + + if log_fn: + log_fn(f"ttt_sliding:done val_loss={val_loss:.6f} val_bpb={val_bpb:.6f} " + f"elapsed={time.perf_counter() - t0:.1f}s") + return val_loss, val_bpb + + +# ----------------------------- +# INT6 MIXED QUANTIZATION (transplanted from working diagnostic scripts) +# ----------------------------- + +def _classify_param(name: str) -> str: + if "tok_emb" in name or "lm_head" in name: + return "embed" + if ".mlp." in name: + return "mlp" + if ".attn." in name or (".proj." in name and ".mlp." not in name): + return "attn" + return "other" + +def quantize_int6_per_row(t: Tensor, qmax: int = 31) -> tuple[Tensor, Tensor]: + t32 = t.float() + qmin = -qmax - 1 + pct = CastedLinear._quant_percentile + if t32.ndim == 2: + row_max = (torch.quantile(t32.abs(), pct, dim=1) if pct < 1.0 + else t32.abs().amax(dim=1)) + scale = (row_max / float(qmax)).clamp_min(1.0 / float(qmax)).to(torch.float16) + clipped = t32.clamp(-row_max[:, None], row_max[:, None]) + q = torch.clamp(torch.round(clipped / scale.float()[:, None]), qmin, qmax).to(torch.int8) + return q, scale + amax = t32.abs().max().item() + scale = torch.tensor(amax / float(qmax) if amax > 0 else 1.0, dtype=torch.float16) + q = torch.clamp(torch.round(t32 / scale.float()), qmin, qmax).to(torch.int8) + return q, scale + +def mixed_quantize_int6(state_dict: dict[str, Tensor], int6_cats: set[str], + int5_layers: set[int] | None = None): + num_layers_total = max( + (int(k.split(".")[1]) for k in state_dict if k.startswith("blocks.")), + default=0, + ) + 1 + late_k_layers = set(range(num_layers_total - 2, num_layers_total)) + if int5_layers is None: + int5_layers = set() + + result: dict[str, Tensor] = {} + meta: dict[str, object] = {} + for name, tensor in state_dict.items(): + t = tensor.detach().cpu().contiguous() + cat = _classify_param(name) + if not t.is_floating_point() or t.numel() <= 65536: + result[name] = t.to(torch.float16) if t.is_floating_point() else t + meta[name] = "passthrough" + continue + if any(p in name for p in CONTROL_TENSOR_NAME_PATTERNS): + result[name] = t.float() + meta[name] = "passthrough_ctrl" + continue + # Determine layer index for int5 fallback + layer_idx = -1 + if name.startswith("blocks."): + try: + layer_idx = int(name.split(".")[1]) + except (IndexError, ValueError): + pass + # tok_emb.weight falls through to int8 via "embed" category + if cat in int6_cats and t.ndim >= 1: + qmax = 15 if layer_idx in int5_layers else 31 + q, s = quantize_int6_per_row(t, qmax=qmax) + result[name + ".q"] = q + result[name + ".scale"] = s + meta[name] = {"type": "int5" if qmax == 15 else "int6"} + else: + q, s = quantize_float_tensor(t) + result[name + ".q"] = q + result[name + ".scale"] = s + meta[name] = {"type": "int8"} + return result, meta + +def dequantize_mixed_int6(result: dict[str, Tensor], meta: dict[str, object], + template_sd: dict[str, Tensor]) -> dict[str, Tensor]: + out: dict[str, Tensor] = {} + for name, orig in template_sd.items(): + info = meta.get(name) + if info is None: + continue + orig_dtype = orig.dtype + if info in ("passthrough", "passthrough_ctrl", "passthrough_fp16"): + t = result[name] + if t.dtype == torch.float16 and orig_dtype in (torch.float32, torch.bfloat16): + t = t.to(orig_dtype) + out[name] = t + continue + q, s = result[name + ".q"], result[name + ".scale"] + if s.ndim > 0: + out[name] = (q.float() * s.float().view(q.shape[0], *([1] * (q.ndim - 1)))).to(orig_dtype) + else: + out[name] = (q.float() * float(s.item())).to(orig_dtype) + return out + + +# ----------------------------- +# TRAINING +# ----------------------------- + +def main() -> None: + global zeropower_via_newtonschulz5 + + code = Path(__file__).read_text(encoding="utf-8") + args = Hyperparameters() + zeropower_via_newtonschulz5 = torch.compile(zeropower_via_newtonschulz5) + + # ----------------------------- + # DISTRIBUTED + CUDA SETUP + # ----------------------------- + + distributed = "RANK" in os.environ and "WORLD_SIZE" in os.environ and int(os.environ.get("WORLD_SIZE", "1")) > 1 + rank = int(os.environ.get("RANK", "0")) + world_size = int(os.environ.get("WORLD_SIZE", "1")) + local_rank = int(os.environ.get("LOCAL_RANK", "0")) + if world_size <= 0: + raise ValueError(f"WORLD_SIZE must be positive, got {world_size}") + if 8 % world_size != 0: + raise ValueError(f"WORLD_SIZE={world_size} must divide 8 so grad_accum_steps stays integral") + grad_accum_steps = 8 // world_size + grad_scale = 1.0 / grad_accum_steps + if not torch.cuda.is_available(): + raise RuntimeError("CUDA is required") + device = torch.device("cuda", local_rank) + torch.cuda.set_device(device) + if distributed: + dist.init_process_group(backend="nccl", device_id=device) + dist.barrier() + master_process = rank == 0 + + # Fast math knobs + torch.backends.cuda.matmul.allow_tf32 = True + torch.backends.cudnn.allow_tf32 = True + from torch.backends.cuda import enable_cudnn_sdp, enable_flash_sdp, enable_math_sdp, enable_mem_efficient_sdp + + if _USE_FA3: + enable_cudnn_sdp(False) + enable_flash_sdp(True) + enable_mem_efficient_sdp(False) + enable_math_sdp(False) + else: + enable_cudnn_sdp(False) + enable_flash_sdp(True) + enable_mem_efficient_sdp(True) + enable_math_sdp(True) + + logfile = None + if master_process: + os.makedirs("logs", exist_ok=True) + logfile = f"logs/{args.run_id}.txt" + print(logfile) + + def log0(msg: str, console: bool = True) -> None: + if not master_process: + return + if console: + print(msg) + if logfile is not None: + with open(logfile, "a", encoding="utf-8") as f: + print(msg, file=f) + + log0(code, console=False) + log0("=" * 100, console=False) + log0(f"Running Python {sys.version}", console=False) + log0(f"Running PyTorch {torch.__version__}", console=False) + log0( + subprocess.run(["nvidia-smi"], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True, check=False).stdout, + console=False, + ) + log0("=" * 100, console=False) + + # ----------------------------- + # TOKENIZER + VALIDATION METRIC SETUP + # ----------------------------- + + random.seed(args.seed) + np.random.seed(args.seed) + torch.manual_seed(args.seed) + torch.cuda.manual_seed_all(args.seed) + + if not args.tokenizer_path.endswith(".model"): + raise ValueError(f"Script only setup for SentencePiece .model file: {args.tokenizer_path}") + sp = spm.SentencePieceProcessor(model_file=args.tokenizer_path) + if int(sp.vocab_size()) != args.vocab_size: + raise ValueError( + f"VOCAB_SIZE={args.vocab_size} does not match tokenizer vocab_size={int(sp.vocab_size())}" + ) + dataset_dir = Path(args.data_path).resolve() + actual_train_files = len(list(dataset_dir.glob("fineweb_train_*.bin"))) + effective_eval_seq_len = args.eval_seq_len if args.eval_seq_len > 0 else args.train_seq_len + val_seq_len = max(args.train_seq_len, effective_eval_seq_len) + val_tokens = load_validation_tokens(args.val_files, val_seq_len) + base_bytes_lut, has_leading_space_lut, is_boundary_token_lut = build_sentencepiece_luts( + sp, args.vocab_size, device + ) + log0(f"val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path={args.tokenizer_path}") + log0(f"train_loader:dataset:{dataset_dir.name} train_shards:{actual_train_files}") + log0(f"val_loader:shards pattern={args.val_files} tokens:{val_tokens.numel() - 1}") + + # ----------------------------- + # MODEL + OPTIMIZER SETUP + # ----------------------------- + + CastedLinear._qat_enabled = args.qat_enabled + + base_model = GPT( + vocab_size=args.vocab_size, + num_layers=args.num_layers, + model_dim=args.model_dim, + num_heads=args.num_heads, + num_kv_heads=args.num_kv_heads, + mlp_mult=args.mlp_mult, + tie_embeddings=args.tie_embeddings, + tied_embed_init_std=args.tied_embed_init_std, + logit_softcap=args.logit_softcap, + rope_base=args.rope_base, + qk_gain_init=args.qk_gain_init, + mtp_num_heads=args.mtp_num_heads, + mtp_loss_weight=args.mtp_loss_weight, + bigram_vocab_size=args.bigram_vocab_size, + bigram_dim=args.bigram_dim, + xsa_last_n=args.xsa_last_n, + rope_dims=args.rope_dims, + ln_scale=args.ln_scale, + value_residual=args.value_residual, + gated_attention=args.gated_attention, + canon_last_n=args.canon_last_n, + canon_kernel=args.canon_kernel, + canon_delta_gate_init=args.canon_delta_gate_init, + flow_enabled=args.flow_enabled, + flow_latent_dim=args.flow_latent_dim, + flow_hidden_dim=args.flow_hidden_dim, + flow_init_scale=args.flow_init_scale, + native_flow_enabled=args.native_flow_enabled, + native_flow_hidden_dim=args.native_flow_hidden_dim, + native_flow_init_scale=args.native_flow_init_scale, + native_flow_loss_weight=args.native_flow_loss_weight, + e2e_ttt_enabled=args.e2e_ttt_enabled, + e2e_ttt_num_heads=args.e2e_ttt_num_heads, + e2e_ttt_mini_batch=args.e2e_ttt_mini_batch, + e2e_ttt_base_lr=args.e2e_ttt_base_lr, + ).to(device).bfloat16() + for module in base_model.modules(): + if isinstance(module, CastedLinear): + module.float() + restore_low_dim_params_to_fp32(base_model) + compiled_model = torch.compile(base_model, dynamic=False, fullgraph=True) + model: nn.Module = DDP(compiled_model, device_ids=[local_rank], broadcast_buffers=False, static_graph=True) if distributed else compiled_model + + # Optimizer split: + # - token embedding (Adam) uses EMBED_LR + # - untied lm_head (Adam) uses HEAD_LR + # - matrix params in transformer blocks use MATRIX_LR via Muon + # - vectors/scalars use SCALAR_LR via Adam + block_named_params = list(base_model.blocks.named_parameters()) + matrix_params = [ + p + for name, p in block_named_params + if p.ndim == 2 and not any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) + ] + if base_model.mtp_num_heads > 0: + matrix_params.extend([p for p in base_model.mtp_heads.parameters() if p.ndim == 2]) + scalar_params = [ + p + for name, p in block_named_params + if p.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) + ] + if base_model.skip_weights.numel() > 0: + scalar_params.append(base_model.skip_weights) + scalar_params.append(base_model.smear.gate) + if base_model.bigram is not None: + scalar_params.append(base_model.bigram.scale) + if base_model.flow_refiner is not None: + scalar_params.extend(list(base_model.flow_refiner.parameters())) + if base_model.native_flow is not None: + scalar_params.extend(list(base_model.native_flow.parameters())) + if base_model.ttt_refiner is not None: + scalar_params.extend(list(base_model.ttt_refiner.parameters())) + token_lr = args.tied_embed_lr if args.tie_embeddings else args.embed_lr + tok_params = [{"params": [base_model.tok_emb.weight], "lr": token_lr, "base_lr": token_lr}] + if base_model.bigram is not None: + tok_params.append({"params": [base_model.bigram.embed.weight], "lr": token_lr, "base_lr": token_lr}) + if base_model.bigram.proj is not None: + matrix_params.append(base_model.bigram.proj.weight) + optimizer_tok = torch.optim.AdamW( + tok_params, + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + weight_decay=args.adam_wd, + fused=True, + ) + optimizer_muon = Muon( + matrix_params, + lr=args.matrix_lr, + momentum=args.muon_momentum, + backend_steps=args.muon_backend_steps, + weight_decay=args.muon_wd, + ) + for group in optimizer_muon.param_groups: + group["base_lr"] = args.matrix_lr + optimizer_scalar = torch.optim.AdamW( + [{"params": scalar_params, "lr": args.scalar_lr, "base_lr": args.scalar_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + weight_decay=args.adam_wd, + fused=True, + ) + optimizers: list[torch.optim.Optimizer] = [optimizer_tok, optimizer_muon, optimizer_scalar] + if base_model.lm_head is not None: + optimizer_head = torch.optim.Adam( + [{"params": [base_model.lm_head.weight], "lr": args.head_lr, "base_lr": args.head_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + fused=True, + ) + optimizers.insert(1, optimizer_head) + + n_params = sum(p.numel() for p in base_model.parameters()) + mtp_params = sum(p.numel() for p in base_model.mtp_heads.parameters()) + log0(f"model_params:{n_params}") + log0(f"mtp_num_heads:{args.mtp_num_heads} mtp_loss_weight:{args.mtp_loss_weight} mtp_params:{mtp_params}") + log0(f"world_size:{world_size} grad_accum_steps:{grad_accum_steps}") + log0(f"sdp_backends:fa3={_USE_FA3} cudnn=False flash=True mem_efficient={not _USE_FA3} math={not _USE_FA3}") + log0(f"attention_mode:gqa num_heads:{args.num_heads} num_kv_heads:{args.num_kv_heads}") + log0( + f"tie_embeddings:{args.tie_embeddings} embed_lr:{token_lr} " + f"head_lr:{args.head_lr if base_model.lm_head is not None else 0.0} " + f"matrix_lr:{args.matrix_lr} scalar_lr:{args.scalar_lr}" + ) + log0( + f"train_batch_tokens:{args.train_batch_tokens} train_seq_len:{args.train_seq_len} " + f"iterations:{args.iterations} warmup_steps:{args.warmup_steps} " + f"max_wallclock_seconds:{args.max_wallclock_seconds:.3f}" + ) + log0(f"seed:{args.seed}") + + # ----------------------------- + # DATA LOADER & MODEL WARMUP + # ----------------------------- + + train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) + + def zero_grad_all() -> None: + for opt in optimizers: + opt.zero_grad(set_to_none=True) + + max_wallclock_ms = 1000.0 * args.max_wallclock_seconds if args.max_wallclock_seconds > 0 else None + + def lr_mul(step: int, elapsed_ms: float) -> float: + if args.warmdown_iters <= 0: + return 1.0 + if max_wallclock_ms is None: + warmdown_start = max(args.iterations - args.warmdown_iters, 0) + return max((args.iterations - step) / max(args.warmdown_iters, 1), 0.0) if warmdown_start <= step < args.iterations else 1.0 + step_ms = elapsed_ms / max(step, 1) + warmdown_ms = args.warmdown_iters * step_ms + remaining_ms = max(max_wallclock_ms - elapsed_ms, 0.0) + return remaining_ms / max(warmdown_ms, 1e-9) if remaining_ms <= warmdown_ms else 1.0 + + # Warmup primes the compiled forward/backward/optimizer paths, then we restore the + # initial weights/optimizer state so measured training starts from the true init. + if args.warmup_steps > 0: + initial_model_state = {name: tensor.detach().cpu().clone() for name, tensor in base_model.state_dict().items()} + initial_optimizer_states = [copy.deepcopy(opt.state_dict()) for opt in optimizers] + model.train() + for warmup_step in range(args.warmup_steps): + zero_grad_all() + for micro_step in range(grad_accum_steps): + if distributed: + model.require_backward_grad_sync = micro_step == grad_accum_steps - 1 + x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + warmup_loss = model(x, y) + (warmup_loss * grad_scale).backward() + for opt in optimizers: + opt.step() + zero_grad_all() + if args.warmup_steps <= 20 or (warmup_step + 1) % 10 == 0 or warmup_step + 1 == args.warmup_steps: + log0(f"warmup_step:{warmup_step + 1}/{args.warmup_steps}") + base_model.load_state_dict(initial_model_state, strict=True) + for opt, state in zip(optimizers, initial_optimizer_states, strict=True): + opt.load_state_dict(state) + zero_grad_all() + if distributed: + model.require_backward_grad_sync = True + train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) + + # ----------------------------- + # MAIN TRAINING LOOP + # ----------------------------- + + swa_state: dict[str, Tensor] | None = None + swa_count = 0 + + ema_state: dict[str, Tensor] | None = None + if args.ema_enabled: + ema_state = {name: t.detach().float().clone() for name, t in base_model.state_dict().items()} + + training_time_ms = 0.0 + stop_after_step: int | None = None + torch.cuda.synchronize() + t0 = time.perf_counter() + + eval_only_path = os.environ.get("EVAL_ONLY", "") + if eval_only_path: + log0(f"eval_only: loading {eval_only_path}, skipping training") + base_model.load_state_dict(torch.load(eval_only_path, map_location=device, weights_only=False), strict=False) + ema_state = None # prevent random EMA from overwriting loaded weights + swa_state = None + swa_count = 0 + args.iterations = 0 # skip training, go straight to eval + + step = 0 + while True: + last_step = step == args.iterations or (stop_after_step is not None and step >= stop_after_step) + + should_validate = last_step or (args.val_loss_every > 0 and step % args.val_loss_every == 0) + if should_validate: + torch.cuda.synchronize() + training_time_ms += 1000.0 * (time.perf_counter() - t0) + val_loss, val_bpb = eval_val( + args, + model, + rank, + world_size, + device, + grad_accum_steps, + val_tokens, + base_bytes_lut, + has_leading_space_lut, + is_boundary_token_lut, + ) + log0( + f"step:{step}/{args.iterations} val_loss:{val_loss:.4f} val_bpb:{val_bpb:.4f} " + f"train_time:{training_time_ms:.0f}ms step_avg:{training_time_ms / max(step, 1):.2f}ms" + ) + torch.cuda.synchronize() + t0 = time.perf_counter() + + if ( + args.checkpoint_every > 0 + and step > 0 + and step % args.checkpoint_every == 0 + and not last_step + and master_process + ): + ckpt_sd = {k: v for k, v in base_model.state_dict().items() if "mtp_heads" not in k} + ckpt_path = f"checkpoint_step{step}_{args.run_id}.pt" + torch.save(ckpt_sd, ckpt_path) + log0(f"checkpoint_saved: {ckpt_path} ({os.path.getsize(ckpt_path)} bytes)") + + if last_step: + if stop_after_step is not None and step < args.iterations: + log0( + f"stopping_early: wallclock_cap train_time:{training_time_ms:.0f}ms " + f"step:{step}/{args.iterations}" + ) + break + + elapsed_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) + scale = lr_mul(step, elapsed_ms) + qat_threshold = float(os.environ.get("QAT_THRESHOLD", "0.1")) + if args.late_qat and scale < qat_threshold and not CastedLinear._qat_enabled: + CastedLinear._qat_enabled = True + CastedLinear._soft_round = args.soft_round_qat + log0(f"late_qat:enabled step:{step} scale:{scale:.4f} soft_round:{args.soft_round_qat}") + if CastedLinear._qat_enabled and CastedLinear._soft_round: + qat_progress = max(0.0, 1.0 - (scale / qat_threshold)) + CastedLinear._soft_round_alpha = 1.0 + 15.0 * qat_progress # 1→16 + zero_grad_all() + train_loss = torch.zeros((), device=device) + for micro_step in range(grad_accum_steps): + if distributed: + model.require_backward_grad_sync = micro_step == grad_accum_steps - 1 + x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + loss = model(x, y) + train_loss += loss.detach() + (loss * grad_scale).backward() + train_loss /= grad_accum_steps + + frac = min(step / args.muon_momentum_warmup_steps, 1.0) if args.muon_momentum_warmup_steps > 0 else 1.0 + muon_momentum = (1 - frac) * args.muon_momentum_warmup_start + frac * args.muon_momentum + for group in optimizer_muon.param_groups: + group["momentum"] = muon_momentum + + for opt in optimizers: + for group in opt.param_groups: + group["lr"] = group["base_lr"] * scale + + if args.grad_clip_norm > 0: + torch.nn.utils.clip_grad_norm_(base_model.parameters(), args.grad_clip_norm) + for opt in optimizers: + opt.step() + zero_grad_all() + + step += 1 + approx_training_time_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) + + if ema_state is not None: + d = args.ema_decay + with torch.no_grad(): + for name, t in base_model.state_dict().items(): + ema_state[name].mul_(d).add_(t.detach().float(), alpha=1.0 - d) + + if args.swa_enabled and scale < 0.5 and step % args.swa_every == 0: + # Tight SWA: collect from EMA state if available, else from raw model + src = ema_state if ema_state is not None else {name: t.detach().float() for name, t in base_model.state_dict().items()} + if swa_state is None: + swa_state = {name: t.clone() for name, t in src.items()} + swa_count = 1 + log0(f"swa:start step:{step} tight={ema_state is not None}") + else: + for name in swa_state: + swa_state[name].add_(src[name]) + swa_count += 1 + + should_log_train = ( + args.train_log_every > 0 + and (step <= 10 or step % args.train_log_every == 0 or stop_after_step is not None) + ) + if should_log_train: + log0( + f"step:{step}/{args.iterations} train_loss:{train_loss.item():.4f} " + f"train_time:{approx_training_time_ms:.0f}ms step_avg:{approx_training_time_ms / step:.2f}ms" + ) + + # Needed to sync whether we've reached the wallclock cap. + reached_cap = max_wallclock_ms is not None and approx_training_time_ms >= max_wallclock_ms + if distributed and max_wallclock_ms is not None: + reached_cap_tensor = torch.tensor(int(reached_cap), device=device) + dist.all_reduce(reached_cap_tensor, op=dist.ReduceOp.MAX) + reached_cap = bool(reached_cap_tensor.item()) + if stop_after_step is None and reached_cap: + stop_after_step = step + + log0( + f"peak memory allocated: {torch.cuda.max_memory_allocated() // 1024 // 1024} MiB " + f"reserved: {torch.cuda.max_memory_reserved() // 1024 // 1024} MiB" + ) + + if args.swa_enabled and swa_state is not None and swa_count > 1: + log0(f"swa:applying Tight SWA averaged {swa_count} EMA checkpoints") + avg_state = {name: (t / swa_count).to(dtype=base_model.state_dict()[name].dtype) + for name, t in swa_state.items()} + del swa_state + if ema_state is not None: + del ema_state + base_model.load_state_dict(avg_state, strict=True) + elif ema_state is not None: + log0("ema:applying EMA weights") + avg_state = {name: t.to(dtype=base_model.state_dict()[name].dtype) + for name, t in ema_state.items()} + del ema_state + base_model.load_state_dict(avg_state, strict=True) + + # ----------------------------- + # SERIALIZATION + ROUNDTRIP VALIDATION + # ----------------------------- + + # Use RUN_ID for unique filenames when multiple jobs share CWD + _run_id = os.environ.get("RUN_ID", "") + _save_prefix = f"final_model_{_run_id}" if _run_id else "final_model" + _pt_path = f"{_save_prefix}.pt" + _ptz_path = f"{_save_prefix}.int6.ptz" + log0(f"save_paths: pt={_pt_path} ptz={_ptz_path}") + + full_state_dict = base_model.state_dict() + export_sd = {k: v for k, v in full_state_dict.items() if "mtp_heads" not in k} + excluded_mtp = sum(int(t.numel()) for k, t in full_state_dict.items() if "mtp_heads" in k) + if excluded_mtp > 0: + log0(f"export_excluding_mtp_params:{excluded_mtp}") + + if master_process: + torch.save(export_sd, _pt_path) + model_bytes = os.path.getsize(_pt_path) + code_bytes = len(code.encode("utf-8")) + log0(f"Serialized model: {model_bytes} bytes") + log0(f"Code size: {code_bytes} bytes") + + sd_cpu = {k: v.detach().cpu() for k, v in export_sd.items()} + code_bytes = len(code.encode("utf-8")) + artifact_limit = 16_000_000 - code_bytes + + # --- Auto-downgrade quantization: try int6 first, fall back to int5 middle layers --- + num_layers_total = max( + (int(k.split(".")[1]) for k in sd_cpu if k.startswith("blocks.")), + default=0, + ) + 1 + _zstd_levels = [int(os.environ.get("ZSTD_LEVEL", "16")), 1, 17, 2] + # Phase 1: pure int6 with multiple zstd levels + quant_result, quant_meta = mixed_quantize_int6(sd_cpu, {"mlp", "attn"}) + quant_buf = io.BytesIO() + torch.save({"w": quant_result, "m": quant_meta}, quant_buf) + quant_raw = quant_buf.getvalue() + quant_blob = None + chosen_level = _zstd_levels[0] + for lvl in _zstd_levels: + blob = zstandard.ZstdCompressor(level=lvl).compress(quant_raw) if _COMPRESSOR == "zstd" else zlib.compress(quant_raw, 9) + if master_process: + log0(f"quant_try int6 zstd-{lvl}: {len(blob)} bytes (limit {artifact_limit})") + if len(blob) <= artifact_limit: + quant_blob = blob + chosen_level = lvl + break + # Phase 2: progressive int5 fallback — one layer at a time from middle outward + if quant_blob is None: + mid = num_layers_total // 2 + # Expand outward from center: L5, L4, L6, L3, L7, L2, L8, ... + candidates = [] + for offset in range(num_layers_total): + for sign in [0, 1]: + layer = mid + offset if sign == 0 else mid - offset + if 0 <= layer < num_layers_total and layer not in candidates: + candidates.append(layer) + int5_layers: set[int] = set() + for layer in candidates: + int5_layers.add(layer) + if master_process: + log0(f"quant_fallback: int5 layers={sorted(int5_layers)}") + quant_result, quant_meta = mixed_quantize_int6(sd_cpu, {"mlp", "attn"}, int5_layers=int5_layers) + quant_buf = io.BytesIO() + torch.save({"w": quant_result, "m": quant_meta}, quant_buf) + quant_raw = quant_buf.getvalue() + for lvl in _zstd_levels: + blob = zstandard.ZstdCompressor(level=lvl).compress(quant_raw) if _COMPRESSOR == "zstd" else zlib.compress(quant_raw, 9) + if master_process: + log0(f"quant_try int5[{len(int5_layers)}L] zstd-{lvl}: {len(blob)} bytes (limit {artifact_limit})") + if len(blob) <= artifact_limit: + quant_blob = blob + chosen_level = lvl + break + if quant_blob is not None: + break + if quant_blob is None: + quant_blob = blob # Use last attempt even if over limit + if master_process: + log0(f"WARNING: artifact still over limit after all fallbacks") + if master_process: + with open(_ptz_path, "wb") as f: + f.write(quant_blob) + quant_file_bytes = len(quant_blob) + log0(f"Serialized model quant+{_COMPRESSOR}-{chosen_level}: {quant_file_bytes} bytes") + log0(f"Total submission size: {quant_file_bytes + code_bytes} bytes") + + # Roundtrip: decompress + dequantize into fresh model + eval + if distributed: + dist.barrier() + with open(_ptz_path, "rb") as f: + quant_blob_disk = f.read() + quant_state = torch.load( + io.BytesIO(zstandard.ZstdDecompressor().decompress(quant_blob_disk) if _COMPRESSOR == "zstd" else zlib.decompress(quant_blob_disk)), + map_location="cpu", + ) + deq_state = dequantize_mixed_int6(quant_state["w"], quant_state["m"], sd_cpu) + + eval_model = GPT( + vocab_size=args.vocab_size, num_layers=args.num_layers, model_dim=args.model_dim, + num_heads=args.num_heads, num_kv_heads=args.num_kv_heads, mlp_mult=args.mlp_mult, + tie_embeddings=args.tie_embeddings, tied_embed_init_std=args.tied_embed_init_std, + logit_softcap=args.logit_softcap, rope_base=args.rope_base, qk_gain_init=args.qk_gain_init, + mtp_num_heads=0, mtp_loss_weight=0.0, + bigram_vocab_size=args.bigram_vocab_size, bigram_dim=args.bigram_dim, + xsa_last_n=args.xsa_last_n, + rope_dims=args.rope_dims, + ln_scale=args.ln_scale, + value_residual=args.value_residual, + gated_attention=args.gated_attention, + canon_last_n=args.canon_last_n, + canon_kernel=args.canon_kernel, + canon_delta_gate_init=args.canon_delta_gate_init, + flow_enabled=args.flow_enabled, + flow_latent_dim=args.flow_latent_dim, + flow_hidden_dim=args.flow_hidden_dim, + flow_init_scale=args.flow_init_scale, + native_flow_enabled=args.native_flow_enabled, + native_flow_hidden_dim=args.native_flow_hidden_dim, + native_flow_init_scale=args.native_flow_init_scale, + native_flow_loss_weight=args.native_flow_loss_weight, + e2e_ttt_enabled=args.e2e_ttt_enabled, + e2e_ttt_num_heads=args.e2e_ttt_num_heads, + e2e_ttt_mini_batch=args.e2e_ttt_mini_batch, + e2e_ttt_base_lr=args.e2e_ttt_base_lr, + ).to(device).bfloat16() + for m in eval_model.modules(): + if isinstance(m, CastedLinear): + m.float() + restore_low_dim_params_to_fp32(eval_model) + eval_model.load_state_dict(deq_state, strict=True) + + # TTT: adapt model on validation data before eval + legal_ttt = bool(int(os.environ.get("LEGAL_TTT", "0"))) + if args.ttt_enabled and not legal_ttt: + # --- Invalid two-pass TTT (adapt then eval separately) --- + if distributed: + dist.barrier() + for block in eval_model.blocks: + block.attn.rotary._cos_cached = None + block.attn.rotary._sin_cached = None + block.attn.rotary._seq_len_cached = 0 + log0(f"ttt:start score-first optimizer={args.ttt_optimizer} lr={args.ttt_lr} " + f"epochs={args.ttt_epochs} freeze_blocks={args.ttt_freeze_blocks} " + f"chunk_tokens={args.ttt_chunk_tokens}") + t_ttt = time.perf_counter() + ttt_adapt(args, eval_model, device, val_tokens, + rank=rank, world_size=world_size, log_fn=log0) + log0(f"ttt:elapsed={time.perf_counter() - t_ttt:.1f}s") + if distributed: + dist.barrier() + + compiled_eval = torch.compile(eval_model, dynamic=False, fullgraph=True) + + # Standard non-overlapping eval (sanity check) + torch.cuda.synchronize() + t_qeval = time.perf_counter() + q_val_loss, q_val_bpb = eval_val( + args, compiled_eval, rank, world_size, device, grad_accum_steps, + val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, + eval_seq_len=effective_eval_seq_len, + ) + torch.cuda.synchronize() + log0( + f"final_int6_roundtrip val_loss:{q_val_loss:.4f} val_bpb:{q_val_bpb:.4f} " + f"eval_time:{1000.0 * (time.perf_counter() - t_qeval):.0f}ms" + ) + log0(f"final_int6_roundtrip_exact val_loss:{q_val_loss:.8f} val_bpb:{q_val_bpb:.8f}") + + # Sliding window eval (submission score) + sw_seq_len = effective_eval_seq_len + if args.eval_stride > 0 and args.eval_stride < sw_seq_len: + torch.cuda.synchronize() + t_slide = time.perf_counter() + if legal_ttt and args.ttt_enabled: + # Legal single-pass TTT: score → train interleaved per chunk + log0(f"legal_ttt:start stride={args.eval_stride} " + f"optimizer={args.ttt_optimizer} lr={args.ttt_lr} " + f"epochs={args.ttt_epochs} freeze_blocks={args.ttt_freeze_blocks}") + sw_val_loss, sw_val_bpb = eval_val_sliding_ttt( + args, eval_model, rank, world_size, device, + val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, + stride=args.eval_stride, + eval_seq_len=sw_seq_len, + log_fn=log0, + ) + else: + sw_val_loss, sw_val_bpb = eval_val_sliding( + args, eval_model, rank, world_size, device, + val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, + stride=args.eval_stride, + eval_seq_len=sw_seq_len, + ) + torch.cuda.synchronize() + log0( + f"final_int6_sliding_window val_loss:{sw_val_loss:.4f} val_bpb:{sw_val_bpb:.4f} " + f"stride:{args.eval_stride} eval_time:{1000.0 * (time.perf_counter() - t_slide):.0f}ms" + ) + log0(f"final_int6_sliding_window_exact val_loss:{sw_val_loss:.8f} val_bpb:{sw_val_bpb:.8f}") + + # Second sliding window eval at stride=64 for submission comparison + if args.eval_stride != 64 and 64 < sw_seq_len: + torch.cuda.synchronize() + t_slide64 = time.perf_counter() + sw64_val_loss, sw64_val_bpb = eval_val_sliding( + args, eval_model, rank, world_size, device, + val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, + stride=64, + eval_seq_len=sw_seq_len, + ) + torch.cuda.synchronize() + log0( + f"final_int6_sliding_window_s64 val_loss:{sw64_val_loss:.4f} val_bpb:{sw64_val_bpb:.4f} " + f"stride:64 eval_time:{1000.0 * (time.perf_counter() - t_slide64):.0f}ms" + ) + log0(f"final_int6_sliding_window_s64_exact val_loss:{sw64_val_loss:.8f} val_bpb:{sw64_val_bpb:.8f}") + + if distributed: + dist.destroy_process_group() + + +if __name__ == "__main__": + main() + + +==================================================================================================== +Running Python 3.12.7 | packaged by conda-forge | (main, Oct 4 2024, 16:05:46) [GCC 13.3.0] +Running PyTorch 2.10.0+cu128 +Sun Mar 29 20:03:41 2026 ++-----------------------------------------------------------------------------------------+ +| NVIDIA-SMI 570.172.08 Driver Version: 570.172.08 CUDA Version: 12.8 | +|-----------------------------------------+------------------------+----------------------+ +| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC | +| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. | +| | | MIG M. | +|=========================================+========================+======================| +| 0 NVIDIA A100-PCIE-40GB On | 00000000:17:00.0 Off | 0 | +| N/A 34C P0 47W / 250W | 423MiB / 40960MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ + ++-----------------------------------------------------------------------------------------+ +| Processes: | +| GPU GI CI PID Type Process name GPU Memory | +| ID ID Usage | +|=========================================================================================| +| 0 N/A N/A 3356593 C ...ameter_golf/.venv/bin/python3 414MiB | ++-----------------------------------------------------------------------------------------+ + +==================================================================================================== +val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path=/hpfs/scratch/gpfs/mcclec07/code/parameter_golf/repo/data/tokenizers/fineweb_1024_bpe.model +train_loader:dataset:fineweb10B_sp1024 train_shards:80 +val_loader:shards pattern=/hpfs/scratch/gpfs/mcclec07/code/parameter_golf/repo/data/datasets/fineweb10B_sp1024/fineweb_val_*.bin tokens:62021632 +model_params:27530952 +mtp_num_heads:0 mtp_loss_weight:0.2 mtp_params:0 +world_size:1 grad_accum_steps:8 +sdp_backends:fa3=False cudnn=False flash=True mem_efficient=True math=True +attention_mode:gqa num_heads:8 num_kv_heads:4 +tie_embeddings:True embed_lr:0.035 head_lr:0.0 matrix_lr:0.025 scalar_lr:0.025 +train_batch_tokens:786432 train_seq_len:2048 iterations:7000 warmup_steps:20 max_wallclock_seconds:600.000 +seed:42 +warmup_step:1/20 +warmup_step:2/20 +warmup_step:3/20 +warmup_step:4/20 +warmup_step:5/20 +warmup_step:6/20 +warmup_step:7/20 +warmup_step:8/20 +warmup_step:9/20 +warmup_step:10/20 +warmup_step:11/20 +warmup_step:12/20 +warmup_step:13/20 +warmup_step:14/20 +warmup_step:15/20 +warmup_step:16/20 +warmup_step:17/20 +warmup_step:18/20 +warmup_step:19/20 +warmup_step:20/20 +eval_only: loading /hpfs/scratch/gpfs/mcclec07/code/parameter_golf/runs/nflow_55342820/models/final_model_pr940_nflow_55342820.pt, skipping training +step:0/0 val_loss:1.9197 val_bpb:1.1370 train_time:54ms step_avg:54.40ms +peak memory allocated: 25725 MiB reserved: 25972 MiB +save_paths: pt=final_model_eval_nflow7k_legal_ttt_55375245.pt ptz=final_model_eval_nflow7k_legal_ttt_55375245.int6.ptz +Serialized model: 107295853 bytes +Code size: 115032 bytes +quant_try int6 zstd-16: 16242245 bytes (limit 15884968) +quant_try int6 zstd-1: 16298729 bytes (limit 15884968) +quant_try int6 zstd-17: 16250339 bytes (limit 15884968) +quant_try int6 zstd-2: 16301974 bytes (limit 15884968) +quant_fallback: int5 layers=[5] +quant_try int5[1L] zstd-16: 15927696 bytes (limit 15884968) +quant_try int5[1L] zstd-1: 16001920 bytes (limit 15884968) +quant_try int5[1L] zstd-17: 16297960 bytes (limit 15884968) +quant_try int5[1L] zstd-2: 16053283 bytes (limit 15884968) +quant_fallback: int5 layers=[5, 6] +quant_try int5[2L] zstd-16: 15630744 bytes (limit 15884968) +Serialized model quant+zstd-16: 15630744 bytes +Total submission size: 15745776 bytes +final_int6_roundtrip val_loss:1.9363 val_bpb:1.1468 eval_time:67588ms +final_int6_roundtrip_exact val_loss:1.93630746 val_bpb:1.14679034 +legal_ttt:start stride=64 optimizer=sgd lr=0.002 epochs=10 freeze_blocks=2 +ttt_sliding:start chunks=1893 chunk_tokens=32768 total_windows=969088 stride=64 ttt_lr=0.002 ttt_epochs=10 ttt_optimizer=sgd freeze_blocks=2 +ttt_sliding:params unfrozen=22800036 frozen=4730916 + ttt_chunk [1/1893] bpb=1.212107 time=3.8s + ttt_chunk [11/1893] bpb=1.115790 time=41.7s + ttt_chunk [21/1893] bpb=1.125357 time=79.6s + ttt_chunk [31/1893] bpb=1.130512 time=117.6s + ttt_chunk [41/1893] bpb=1.127394 time=155.6s + ttt_chunk [51/1893] bpb=1.128903 time=193.6s + ttt_chunk [61/1893] bpb=1.132827 time=231.7s + ttt_chunk [71/1893] bpb=1.131231 time=269.7s + ttt_chunk [81/1893] bpb=1.127681 time=307.8s + ttt_chunk [91/1893] bpb=1.126507 time=345.8s + ttt_chunk [101/1893] bpb=1.127414 time=383.9s + ttt_chunk [111/1893] bpb=1.127682 time=421.9s + ttt_chunk [121/1893] bpb=1.124073 time=459.9s + ttt_chunk [131/1893] bpb=1.123226 time=498.0s + ttt_chunk [141/1893] bpb=1.122055 time=536.0s + ttt_chunk [151/1893] bpb=1.122195 time=574.0s + ttt_chunk [161/1893] bpb=1.123067 time=612.0s + ttt_chunk [171/1893] bpb=1.125089 time=649.9s + ttt_chunk [181/1893] bpb=1.125155 time=687.9s + ttt_chunk [191/1893] bpb=1.127505 time=725.9s + ttt_chunk [201/1893] bpb=1.126992 time=763.9s + ttt_chunk [211/1893] bpb=1.126053 time=801.8s + ttt_chunk [221/1893] bpb=1.126909 time=839.8s + ttt_chunk [231/1893] bpb=1.126593 time=877.8s + ttt_chunk [241/1893] bpb=1.126801 time=915.7s + ttt_chunk [251/1893] bpb=1.126322 time=953.7s + ttt_chunk [261/1893] bpb=1.125593 time=991.7s + ttt_chunk [271/1893] bpb=1.124720 time=1029.6s + ttt_chunk [281/1893] bpb=1.126223 time=1067.6s + ttt_chunk [291/1893] bpb=1.125781 time=1105.6s + ttt_chunk [301/1893] bpb=1.126625 time=1143.5s + ttt_chunk [311/1893] bpb=1.126647 time=1181.5s + ttt_chunk [321/1893] bpb=1.127307 time=1219.5s + ttt_chunk [331/1893] bpb=1.126728 time=1257.5s + ttt_chunk [341/1893] bpb=1.126294 time=1295.4s + ttt_chunk [351/1893] bpb=1.126974 time=1333.4s + ttt_chunk [361/1893] bpb=1.127742 time=1371.4s + ttt_chunk [371/1893] bpb=1.127643 time=1409.4s + ttt_chunk [381/1893] bpb=1.127394 time=1447.3s + ttt_chunk [391/1893] bpb=1.128075 time=1485.3s + ttt_chunk [401/1893] bpb=1.127618 time=1523.3s + ttt_chunk [411/1893] bpb=1.126636 time=1561.3s + ttt_chunk [421/1893] bpb=1.126693 time=1599.3s + ttt_chunk [431/1893] bpb=1.127118 time=1637.2s + ttt_chunk [441/1893] bpb=1.126511 time=1675.2s + ttt_chunk [451/1893] bpb=1.126643 time=1713.2s + ttt_chunk [461/1893] bpb=1.126522 time=1751.2s + ttt_chunk [471/1893] bpb=1.126096 time=1789.2s + ttt_chunk [481/1893] bpb=1.125905 time=1827.3s + ttt_chunk [491/1893] bpb=1.126069 time=1865.3s + ttt_chunk [501/1893] bpb=1.125822 time=1903.3s + ttt_chunk [511/1893] bpb=1.125318 time=1941.4s + ttt_chunk [521/1893] bpb=1.124852 time=1979.4s + ttt_chunk [531/1893] bpb=1.125548 time=2017.4s + ttt_chunk [541/1893] bpb=1.125645 time=2055.4s + ttt_chunk [551/1893] bpb=1.125098 time=2093.4s + ttt_chunk [561/1893] bpb=1.124953 time=2131.4s + ttt_chunk [571/1893] bpb=1.124686 time=2169.5s + ttt_chunk [581/1893] bpb=1.124306 time=2207.5s + ttt_chunk [591/1893] bpb=1.123730 time=2245.5s + ttt_chunk [601/1893] bpb=1.123722 time=2283.5s + ttt_chunk [611/1893] bpb=1.123387 time=2321.6s + ttt_chunk [621/1893] bpb=1.123232 time=2359.6s + ttt_chunk [631/1893] bpb=1.122966 time=2397.6s + ttt_chunk [641/1893] bpb=1.122515 time=2435.7s + ttt_chunk [651/1893] bpb=1.122053 time=2473.7s + ttt_chunk [661/1893] bpb=1.121945 time=2511.7s + ttt_chunk [671/1893] bpb=1.121464 time=2549.7s + ttt_chunk [681/1893] bpb=1.120886 time=2587.7s + ttt_chunk [691/1893] bpb=1.120968 time=2625.7s + ttt_chunk [701/1893] bpb=1.120141 time=2663.7s + ttt_chunk [711/1893] bpb=1.120150 time=2701.6s + ttt_chunk [721/1893] bpb=1.120057 time=2739.6s + ttt_chunk [731/1893] bpb=1.120288 time=2777.6s + ttt_chunk [741/1893] bpb=1.120158 time=2815.6s + ttt_chunk [751/1893] bpb=1.119858 time=2853.6s + ttt_chunk [761/1893] bpb=1.119991 time=2891.6s + ttt_chunk [771/1893] bpb=1.119821 time=2929.6s + ttt_chunk [781/1893] bpb=1.119996 time=2967.5s + ttt_chunk [791/1893] bpb=1.119860 time=3005.5s + ttt_chunk [801/1893] bpb=1.119803 time=3043.5s + ttt_chunk [811/1893] bpb=1.119818 time=3081.5s + ttt_chunk [821/1893] bpb=1.119709 time=3119.5s + ttt_chunk [831/1893] bpb=1.119425 time=3157.5s + ttt_chunk [841/1893] bpb=1.119191 time=3195.5s + ttt_chunk [851/1893] bpb=1.119263 time=3233.5s + ttt_chunk [861/1893] bpb=1.119325 time=3271.6s + ttt_chunk [871/1893] bpb=1.119532 time=3309.6s + ttt_chunk [881/1893] bpb=1.119529 time=3347.6s + ttt_chunk [891/1893] bpb=1.119002 time=3385.6s + ttt_chunk [901/1893] bpb=1.119023 time=3423.6s + ttt_chunk [911/1893] bpb=1.118886 time=3461.6s + ttt_chunk [921/1893] bpb=1.119020 time=3499.6s + ttt_chunk [931/1893] bpb=1.118987 time=3537.6s + ttt_chunk [941/1893] bpb=1.119191 time=3575.6s + ttt_chunk [951/1893] bpb=1.119481 time=3613.6s + ttt_chunk [961/1893] bpb=1.119771 time=3651.5s + ttt_chunk [971/1893] bpb=1.120129 time=3689.5s + ttt_chunk [981/1893] bpb=1.120331 time=3727.5s + ttt_chunk [991/1893] bpb=1.120230 time=3765.5s + ttt_chunk [1001/1893] bpb=1.120542 time=3803.5s + ttt_chunk [1011/1893] bpb=1.120676 time=3841.5s + ttt_chunk [1021/1893] bpb=1.120959 time=3879.5s + ttt_chunk [1031/1893] bpb=1.121332 time=3917.4s + ttt_chunk [1041/1893] bpb=1.121840 time=3955.4s + ttt_chunk [1051/1893] bpb=1.121693 time=3993.4s + ttt_chunk [1061/1893] bpb=1.121785 time=4031.4s + ttt_chunk [1071/1893] bpb=1.121934 time=4069.4s + ttt_chunk [1081/1893] bpb=1.121980 time=4107.4s + ttt_chunk [1091/1893] bpb=1.122239 time=4145.3s + ttt_chunk [1101/1893] bpb=1.122380 time=4183.3s + ttt_chunk [1111/1893] bpb=1.122116 time=4221.3s + ttt_chunk [1121/1893] bpb=1.121879 time=4259.3s + ttt_chunk [1131/1893] bpb=1.121775 time=4297.3s + ttt_chunk [1141/1893] bpb=1.121530 time=4335.2s + ttt_chunk [1151/1893] bpb=1.121538 time=4373.2s + ttt_chunk [1161/1893] bpb=1.121313 time=4411.2s + ttt_chunk [1171/1893] bpb=1.121138 time=4449.2s + ttt_chunk [1181/1893] bpb=1.120908 time=4487.2s + ttt_chunk [1191/1893] bpb=1.121048 time=4525.1s + ttt_chunk [1201/1893] bpb=1.121252 time=4563.1s + ttt_chunk [1211/1893] bpb=1.120843 time=4601.1s + ttt_chunk [1221/1893] bpb=1.121177 time=4639.1s + ttt_chunk [1231/1893] bpb=1.121102 time=4677.1s + ttt_chunk [1241/1893] bpb=1.120801 time=4715.0s + ttt_chunk [1251/1893] bpb=1.120267 time=4753.1s + ttt_chunk [1261/1893] bpb=1.119999 time=4791.1s + ttt_chunk [1271/1893] bpb=1.119752 time=4829.1s + ttt_chunk [1281/1893] bpb=1.119439 time=4867.1s + ttt_chunk [1291/1893] bpb=1.119197 time=4905.1s + ttt_chunk [1301/1893] bpb=1.119156 time=4943.1s + ttt_chunk [1311/1893] bpb=1.118872 time=4981.1s + ttt_chunk [1321/1893] bpb=1.118578 time=5019.2s + ttt_chunk [1331/1893] bpb=1.118339 time=5057.2s + ttt_chunk [1341/1893] bpb=1.118212 time=5095.2s + ttt_chunk [1351/1893] bpb=1.118063 time=5133.2s + ttt_chunk [1361/1893] bpb=1.118196 time=5171.2s + ttt_chunk [1371/1893] bpb=1.118409 time=5209.2s + ttt_chunk [1381/1893] bpb=1.118608 time=5247.2s + ttt_chunk [1391/1893] bpb=1.118407 time=5285.2s + ttt_chunk [1401/1893] bpb=1.118446 time=5323.2s + ttt_chunk [1411/1893] bpb=1.118557 time=5361.2s + ttt_chunk [1421/1893] bpb=1.118548 time=5399.2s + ttt_chunk [1431/1893] bpb=1.118528 time=5437.2s + ttt_chunk [1441/1893] bpb=1.119007 time=5475.2s + ttt_chunk [1451/1893] bpb=1.118878 time=5513.1s + ttt_chunk [1461/1893] bpb=1.118811 time=5551.2s + ttt_chunk [1471/1893] bpb=1.119408 time=5589.1s + ttt_chunk [1481/1893] bpb=1.119277 time=5627.1s + ttt_chunk [1491/1893] bpb=1.119642 time=5665.1s + ttt_chunk [1501/1893] bpb=1.119625 time=5703.1s + ttt_chunk [1511/1893] bpb=1.119578 time=5741.1s + ttt_chunk [1521/1893] bpb=1.119692 time=5779.1s + ttt_chunk [1531/1893] bpb=1.119911 time=5817.1s + ttt_chunk [1541/1893] bpb=1.119986 time=5855.0s + ttt_chunk [1551/1893] bpb=1.120227 time=5893.0s + ttt_chunk [1561/1893] bpb=1.120314 time=5931.0s + ttt_chunk [1571/1893] bpb=1.120455 time=5969.0s + ttt_chunk [1581/1893] bpb=1.120609 time=6006.9s + ttt_chunk [1591/1893] bpb=1.120669 time=6044.9s + ttt_chunk [1601/1893] bpb=1.120794 time=6082.9s + ttt_chunk [1611/1893] bpb=1.121048 time=6120.9s + ttt_chunk [1621/1893] bpb=1.120917 time=6158.9s + ttt_chunk [1631/1893] bpb=1.120960 time=6196.8s + ttt_chunk [1641/1893] bpb=1.120983 time=6234.8s + ttt_chunk [1651/1893] bpb=1.121033 time=6272.8s + ttt_chunk [1661/1893] bpb=1.121175 time=6310.8s + ttt_chunk [1671/1893] bpb=1.121358 time=6348.8s + ttt_chunk [1681/1893] bpb=1.121447 time=6386.8s + ttt_chunk [1691/1893] bpb=1.121552 time=6424.8s + ttt_chunk [1701/1893] bpb=1.121657 time=6462.8s + ttt_chunk [1711/1893] bpb=1.121636 time=6500.8s + ttt_chunk [1721/1893] bpb=1.121475 time=6538.7s + ttt_chunk [1731/1893] bpb=1.121572 time=6576.7s + ttt_chunk [1741/1893] bpb=1.121311 time=6614.7s + ttt_chunk [1751/1893] bpb=1.121184 time=6652.7s + ttt_chunk [1761/1893] bpb=1.121225 time=6690.7s + ttt_chunk [1771/1893] bpb=1.121172 time=6728.6s + ttt_chunk [1781/1893] bpb=1.121072 time=6766.6s + ttt_chunk [1791/1893] bpb=1.120730 time=6804.6s + ttt_chunk [1801/1893] bpb=1.120711 time=6842.6s + ttt_chunk [1811/1893] bpb=1.120557 time=6880.5s + ttt_chunk [1821/1893] bpb=1.120615 time=6918.5s + ttt_chunk [1831/1893] bpb=1.120464 time=6956.5s + ttt_chunk [1841/1893] bpb=1.120475 time=6994.4s + ttt_chunk [1851/1893] bpb=1.120303 time=7032.4s + ttt_chunk [1861/1893] bpb=1.120214 time=7070.4s + ttt_chunk [1871/1893] bpb=1.120145 time=7108.4s + ttt_chunk [1881/1893] bpb=1.119902 time=7146.3s + ttt_chunk [1891/1893] bpb=1.119878 time=7184.3s + ttt_chunk [1893/1893] bpb=1.119907 time=7189.7s +ttt_sliding:done val_loss=1.890910 val_bpb=1.119907 elapsed=7189.8s +final_int6_sliding_window val_loss:1.8909 val_bpb:1.1199 stride:64 eval_time:7190260ms +final_int6_sliding_window_exact val_loss:1.89091021 val_bpb:1.11990650 diff --git a/records/track_non_record_16mb/2026-03-31_11L_NativeFlowMatcher_LegalTTT/supplementary/eval_nflow7k_nottt.log b/records/track_non_record_16mb/2026-03-31_11L_NativeFlowMatcher_LegalTTT/supplementary/eval_nflow7k_nottt.log new file mode 100644 index 0000000000..905cbcfb4e --- /dev/null +++ b/records/track_non_record_16mb/2026-03-31_11L_NativeFlowMatcher_LegalTTT/supplementary/eval_nflow7k_nottt.log @@ -0,0 +1,2661 @@ +logs/eval_nflow7k_nottt_55375246.txt +val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path=/hpfs/scratch/gpfs/mcclec07/code/parameter_golf/repo/data/tokenizers/fineweb_1024_bpe.model +train_loader:dataset:fineweb10B_sp1024 train_shards:80 +val_loader:shards pattern=/hpfs/scratch/gpfs/mcclec07/code/parameter_golf/repo/data/datasets/fineweb10B_sp1024/fineweb_val_*.bin tokens:62021632 +model_params:27530952 +mtp_num_heads:0 mtp_loss_weight:0.2 mtp_params:0 +world_size:1 grad_accum_steps:8 +sdp_backends:fa3=False cudnn=False flash=True mem_efficient=True math=True +attention_mode:gqa num_heads:8 num_kv_heads:4 +tie_embeddings:True embed_lr:0.035 head_lr:0.0 matrix_lr:0.025 scalar_lr:0.025 +train_batch_tokens:786432 train_seq_len:2048 iterations:7000 warmup_steps:20 max_wallclock_seconds:600.000 +seed:42 +warmup_step:1/20 +warmup_step:2/20 +warmup_step:3/20 +warmup_step:4/20 +warmup_step:5/20 +warmup_step:6/20 +warmup_step:7/20 +warmup_step:8/20 +warmup_step:9/20 +warmup_step:10/20 +warmup_step:11/20 +warmup_step:12/20 +warmup_step:13/20 +warmup_step:14/20 +warmup_step:15/20 +warmup_step:16/20 +warmup_step:17/20 +warmup_step:18/20 +warmup_step:19/20 +warmup_step:20/20 +eval_only: loading /hpfs/scratch/gpfs/mcclec07/code/parameter_golf/runs/nflow_55342820/models/final_model_pr940_nflow_55342820.pt, skipping training +step:0/0 val_loss:1.9197 val_bpb:1.1370 train_time:67ms step_avg:67.20ms +peak memory allocated: 25725 MiB reserved: 25972 MiB +save_paths: pt=final_model_eval_nflow7k_nottt_55375246.pt ptz=final_model_eval_nflow7k_nottt_55375246.int6.ptz +Serialized model: 107295209 bytes +Code size: 115032 bytes +quant_try int6 zstd-16: 16242245 bytes (limit 15884968) +quant_try int6 zstd-1: 16298729 bytes (limit 15884968) +quant_try int6 zstd-17: 16250339 bytes (limit 15884968) +quant_try int6 zstd-2: 16301974 bytes (limit 15884968) +quant_fallback: int5 layers=[5] +quant_try int5[1L] zstd-16: 15927696 bytes (limit 15884968) +quant_try int5[1L] zstd-1: 16001920 bytes (limit 15884968) +quant_try int5[1L] zstd-17: 16297960 bytes (limit 15884968) +quant_try int5[1L] zstd-2: 16053283 bytes (limit 15884968) +quant_fallback: int5 layers=[5, 6] +quant_try int5[2L] zstd-16: 15630744 bytes (limit 15884968) +Serialized model quant+zstd-16: 15630744 bytes +Total submission size: 15745776 bytes +final_int6_roundtrip val_loss:1.9363 val_bpb:1.1468 eval_time:67735ms +final_int6_roundtrip_exact val_loss:1.93630739 val_bpb:1.14679030 +final_int6_sliding_window val_loss:1.8963 val_bpb:1.1231 stride:64 eval_time:1702654ms +final_int6_sliding_window_exact val_loss:1.89632895 val_bpb:1.12311579 + max_wallclock_seconds = float(os.environ.get("MAX_WALLCLOCK_SECONDS", 600.0)) + qk_gain_init = float(os.environ.get("QK_GAIN_INIT", 1.5)) + + # Model shape. + vocab_size = int(os.environ.get("VOCAB_SIZE", 1024)) + num_layers = int(os.environ.get("NUM_LAYERS", 11)) + num_kv_heads = int(os.environ.get("NUM_KV_HEADS", 4)) + model_dim = int(os.environ.get("MODEL_DIM", 512)) + num_heads = int(os.environ.get("NUM_HEADS", 8)) + mlp_mult = float(os.environ.get("MLP_MULT", 3.0)) + tie_embeddings = bool(int(os.environ.get("TIE_EMBEDDINGS", "1"))) + rope_base = float(os.environ.get("ROPE_BASE", 10000.0)) + logit_softcap = float(os.environ.get("LOGIT_SOFTCAP", 30.0)) + + # Optimizer hyperparameters. + embed_lr = float(os.environ.get("EMBED_LR", 0.6)) + head_lr = float(os.environ.get("HEAD_LR", 0.008)) + tied_embed_lr = float(os.environ.get("TIED_EMBED_LR", 0.035)) + tied_embed_init_std = float(os.environ.get("TIED_EMBED_INIT_STD", 0.005)) + matrix_lr = float(os.environ.get("MATRIX_LR", 0.025)) + scalar_lr = float(os.environ.get("SCALAR_LR", 0.025)) + muon_momentum = float(os.environ.get("MUON_MOMENTUM", 0.99)) + muon_backend_steps = int(os.environ.get("MUON_BACKEND_STEPS", 5)) + muon_momentum_warmup_start = float(os.environ.get("MUON_MOMENTUM_WARMUP_START", 0.92)) + muon_momentum_warmup_steps = int(os.environ.get("MUON_MOMENTUM_WARMUP_STEPS", 1500)) + beta1 = float(os.environ.get("BETA1", 0.9)) + beta2 = float(os.environ.get("BETA2", 0.95)) + adam_eps = float(os.environ.get("ADAM_EPS", 1e-8)) + grad_clip_norm = float(os.environ.get("GRAD_CLIP_NORM", 0.3)) + eval_stride = int(os.environ.get("EVAL_STRIDE", 64)) + mtp_num_heads = int(os.environ.get("MTP_NUM_HEADS", 0)) + mtp_loss_weight = float(os.environ.get("MTP_LOSS_WEIGHT", 0.2)) + muon_beta2 = float(os.environ.get("MUON_BETA2", 0.95)) + swa_enabled = bool(int(os.environ.get("SWA_ENABLED", "0"))) + swa_every = int(os.environ.get("SWA_EVERY", 200)) + muon_wd = float(os.environ.get("MUON_WD", 0.04)) + adam_wd = float(os.environ.get("ADAM_WD", 0.04)) + qat_enabled = bool(int(os.environ.get("QAT_ENABLED", "0"))) + xsa_last_n = int(os.environ.get("XSA_LAST_N", 11)) + ema_enabled = bool(int(os.environ.get("EMA_ENABLED", "1"))) + ema_decay = float(os.environ.get("EMA_DECAY", 0.997)) + rope_dims = int(os.environ.get("ROPE_DIMS", 16)) + ln_scale = bool(int(os.environ.get("LN_SCALE", "1"))) + late_qat = bool(int(os.environ.get("LATE_QAT", "0"))) + soft_round_qat = bool(int(os.environ.get("SOFT_ROUND_QAT", "0"))) + value_residual = bool(int(os.environ.get("VALUE_RESIDUAL", "1"))) + gated_attention = bool(int(os.environ.get("GATED_ATTENTION", "1"))) + canon_last_n = int(os.environ.get("CANON_LAST_N", 0)) + canon_kernel = int(os.environ.get("CANON_KERNEL", 4)) + canon_delta_gate_init = float(os.environ.get("CANON_DELTA_GATE_INIT", -4.0)) + bigram_vocab_size = int(os.environ.get("BIGRAM_VOCAB_SIZE", 4096)) + bigram_dim = int(os.environ.get("BIGRAM_DIM", 128)) + + # TTT (Test-Time Training) — score-first, backward-looking + ttt_enabled = bool(int(os.environ.get("TTT_ENABLED", "0"))) + ttt_optimizer = os.environ.get("TTT_OPTIMIZER", "adamw") # "sgd" or "adamw" + ttt_lr = float(os.environ.get("TTT_LR", 0.0001)) + ttt_epochs = int(os.environ.get("TTT_EPOCHS", 4)) + ttt_momentum = float(os.environ.get("TTT_MOMENTUM", 0.9)) + ttt_batch_seqs = int(os.environ.get("TTT_BATCH_SEQS", 32)) + ttt_freeze_blocks = int(os.environ.get("TTT_FREEZE_BLOCKS", 2)) + ttt_chunk_tokens = int(os.environ.get("TTT_CHUNK_TOKENS", 131072)) + + # FlowRefiner — additive latent-space refinement + flow_enabled = bool(int(os.environ.get("FLOW_ENABLED", "0"))) + flow_latent_dim = int(os.environ.get("FLOW_LATENT_DIM", "64")) + flow_hidden_dim = int(os.environ.get("FLOW_HIDDEN_DIM", "256")) + flow_init_scale = float(os.environ.get("FLOW_INIT_SCALE", "0.01")) + + # NativeFlowMatcher — time-conditioned CFM on hidden states + native_flow_enabled = bool(int(os.environ.get("NATIVE_FLOW_ENABLED", "0"))) + native_flow_hidden_dim = int(os.environ.get("NATIVE_FLOW_HIDDEN_DIM", "256")) + native_flow_init_scale = float(os.environ.get("NATIVE_FLOW_INIT_SCALE", "0.01")) + native_flow_loss_weight = float(os.environ.get("NATIVE_FLOW_LOSS_WEIGHT", "0.1")) + + # E2E TTT-Linear refiner (Sun et al., 2024) + e2e_ttt_enabled = bool(int(os.environ.get("E2E_TTT_ENABLED", "0"))) + e2e_ttt_num_heads = int(os.environ.get("E2E_TTT_NUM_HEADS", "8")) + e2e_ttt_mini_batch = int(os.environ.get("E2E_TTT_MINI_BATCH", "16")) + e2e_ttt_base_lr = float(os.environ.get("E2E_TTT_BASE_LR", "1.0")) + +# ----------------------------- +# MUON OPTIMIZER +# ----------------------------- +# +# As borrowed from modded-nanogpt +# Background on Muon: https://kellerjordan.github.io/posts/muon/ + +def zeropower_via_newtonschulz5(G: Tensor, steps: int = 10, eps: float = 1e-7) -> Tensor: + a, b, c = (3.4445, -4.7750, 2.0315) + X = G.bfloat16() + X /= X.norm() + eps + transposed = G.size(0) > G.size(1) + if transposed: + X = X.T + for _ in range(steps): + A = X @ X.T + B = b * A + c * A @ A + X = a * X + B @ X + return X.T if transposed else X + + +class Muon(torch.optim.Optimizer): + def __init__(self, params, lr: float, momentum: float, backend_steps: int, + nesterov: bool = True, weight_decay: float = 0.0): + super().__init__( + params, + dict(lr=lr, momentum=momentum, backend_steps=backend_steps, + nesterov=nesterov, weight_decay=weight_decay), + ) + + @torch.no_grad() + def step(self, closure=None): + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + + distributed = dist.is_available() and dist.is_initialized() + world_size = dist.get_world_size() if distributed else 1 + rank = dist.get_rank() if distributed else 0 + + for group in self.param_groups: + params = group["params"] + if not params: + continue + lr = group["lr"] + momentum = group["momentum"] + backend_steps = group["backend_steps"] + nesterov = group["nesterov"] + + total_params = sum(int(p.numel()) for p in params) + updates_flat = torch.zeros(total_params, device=params[0].device, dtype=torch.bfloat16) + + curr = 0 + for i, p in enumerate(params): + if i % world_size == rank and p.grad is not None: + g = p.grad + state = self.state[p] + if "momentum_buffer" not in state: + state["momentum_buffer"] = torch.zeros_like(g) + buf = state["momentum_buffer"] + buf.mul_(momentum).add_(g) + if nesterov: + g = g.add(buf, alpha=momentum) + g = zeropower_via_newtonschulz5(g, steps=backend_steps) + g *= max(1, g.size(0) / g.size(1)) ** 0.5 + updates_flat[curr : curr + p.numel()] = g.reshape(-1) + curr += p.numel() + + if distributed: + dist.all_reduce(updates_flat, op=dist.ReduceOp.SUM) + + wd = group.get("weight_decay", 0.0) + curr = 0 + for p in params: + if wd > 0.0: + p.data.mul_(1.0 - lr * wd) + g = updates_flat[curr : curr + p.numel()].view_as(p).to(dtype=p.dtype) + p.add_(g, alpha=-lr) + curr += p.numel() + + return loss + + +# ----------------------------- +# TOKENIZER-AGNOSTIC EVALUATION SETUP +# ----------------------------- +# +# It's common for small models have a large fraction of their parameters be embeddings, since the 2 * d_model * d_vocab vectors can be gigantic. +# Instead of locking the tokenizer, we let you bring your own and calculate our validation metrics on the average compression of the validation set. +# We calculate BPB (bits-per-byte) instead of validation loss, so we need methods to count the number of bits per token in the tokenizer. +# Note: Submissions that edit the tokenizer will be examined more carefully, since screwing this up might unjustly improve your score. + +def build_sentencepiece_luts( + sp: spm.SentencePieceProcessor, vocab_size: int, device: torch.device +) -> tuple[Tensor, Tensor, Tensor]: + sp_vocab_size = int(sp.vocab_size()) + table_size = max(sp_vocab_size, vocab_size) + base_bytes_np = np.zeros((table_size,), dtype=np.int16) + has_leading_space_np = np.zeros((table_size,), dtype=np.bool_) + is_boundary_token_np = np.ones((table_size,), dtype=np.bool_) + for token_id in range(sp_vocab_size): + if sp.is_control(token_id) or sp.is_unknown(token_id) or sp.is_unused(token_id): + continue + is_boundary_token_np[token_id] = False + if sp.is_byte(token_id): + base_bytes_np[token_id] = 1 + continue + piece = sp.id_to_piece(token_id) + if piece.startswith("▁"): + has_leading_space_np[token_id] = True + piece = piece[1:] + base_bytes_np[token_id] = len(piece.encode("utf-8")) + return ( + torch.tensor(base_bytes_np, dtype=torch.int16, device=device), + torch.tensor(has_leading_space_np, dtype=torch.bool, device=device), + torch.tensor(is_boundary_token_np, dtype=torch.bool, device=device), + ) + + +def load_validation_tokens(pattern: str, seq_len: int) -> Tensor: + files = [Path(p) for p in sorted(glob.glob(pattern))] + if not files: + raise FileNotFoundError(f"No files found for pattern: {pattern}") + # The export pipeline writes the fixed first-50k-doc validation set to fineweb_val_*. + tokens = torch.cat([load_data_shard(file) for file in files]).contiguous() + usable = ((tokens.numel() - 1) // seq_len) * seq_len + if usable <= 0: + raise ValueError(f"Validation split is too short for TRAIN_SEQ_LEN={seq_len}") + return tokens[: usable + 1] + + +def eval_val( + args: Hyperparameters, + model: nn.Module, + rank: int, + world_size: int, + device: torch.device, + grad_accum_steps: int, + val_tokens: Tensor, + base_bytes_lut: Tensor, + has_leading_space_lut: Tensor, + is_boundary_token_lut: Tensor, + eval_seq_len: int | None = None, +) -> tuple[float, float]: + seq_len = eval_seq_len or args.train_seq_len + local_batch_tokens = args.val_batch_size // (world_size * grad_accum_steps) + if local_batch_tokens < seq_len: + raise ValueError( + "VAL_BATCH_SIZE must provide at least one sequence per rank; " + f"got VAL_BATCH_SIZE={args.val_batch_size}, WORLD_SIZE={world_size}, " + f"GRAD_ACCUM_STEPS={grad_accum_steps}, seq_len={seq_len}" + ) + local_batch_seqs = local_batch_tokens // seq_len + total_seqs = (val_tokens.numel() - 1) // seq_len + seq_start = (total_seqs * rank) // world_size + seq_end = (total_seqs * (rank + 1)) // world_size + val_loss_sum = torch.zeros((), device=device, dtype=torch.float64) + val_token_count = torch.zeros((), device=device, dtype=torch.float64) + val_byte_count = torch.zeros((), device=device, dtype=torch.float64) + + model.eval() + with torch.inference_mode(): + for batch_seq_start in range(seq_start, seq_end, local_batch_seqs): + batch_seq_end = min(batch_seq_start + local_batch_seqs, seq_end) + raw_start = batch_seq_start * seq_len + raw_end = batch_seq_end * seq_len + 1 + local = val_tokens[raw_start:raw_end].to(device=device, dtype=torch.int64, non_blocking=True) + x = local[:-1].reshape(-1, seq_len) + y = local[1:].reshape(-1, seq_len) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + batch_loss = model(x, y).detach() + batch_token_count = float(y.numel()) + val_loss_sum += batch_loss.to(torch.float64) * batch_token_count + val_token_count += batch_token_count + prev_ids = x.reshape(-1) + tgt_ids = y.reshape(-1) + token_bytes = base_bytes_lut[tgt_ids].to(dtype=torch.int16) + token_bytes += (has_leading_space_lut[tgt_ids] & ~is_boundary_token_lut[prev_ids]).to(dtype=torch.int16) + val_byte_count += token_bytes.to(torch.float64).sum() + + if dist.is_available() and dist.is_initialized(): + dist.all_reduce(val_loss_sum, op=dist.ReduceOp.SUM) + dist.all_reduce(val_token_count, op=dist.ReduceOp.SUM) + dist.all_reduce(val_byte_count, op=dist.ReduceOp.SUM) + + val_loss = val_loss_sum / val_token_count + bits_per_token = val_loss.item() / math.log(2.0) + tokens_per_byte = val_token_count.item() / val_byte_count.item() + model.train() + return float(val_loss.item()), float(bits_per_token * tokens_per_byte) + +# ----------------------------- +# POST-TRAINING QUANTIZATION +# ----------------------------- +# +# It's silly to export our model, which is trained in bf16 and fp32, at that same precision. +# Instead, we get approximately the same model (with a small hit) by quantizing the model to int8 & zlib compressing. +# We can then decompress the model and run in higher precision for evaluation, after closing in under the size limit. + +CONTROL_TENSOR_NAME_PATTERNS = tuple( + pattern + for pattern in os.environ.get( + "CONTROL_TENSOR_NAME_PATTERNS", + "attn_scale,attn_scales,mlp_scale,mlp_scales,resid_mix,resid_mixes,q_gain,skip_weight,skip_weights,smear,vr_lambda,attn_gate,canon_a,canon_c,delta_gate", + ).split(",") + if pattern +) +INT8_KEEP_FLOAT_FP32_NAME_PATTERNS = tuple( + pattern + for pattern in os.environ.get( + "INT8_KEEP_FLOAT_FP32_NAME_PATTERNS", + ",".join(CONTROL_TENSOR_NAME_PATTERNS), + ).split(",") + if pattern +) +INT8_KEEP_FLOAT_MAX_NUMEL = 65_536 +INT8_KEEP_FLOAT_STORE_DTYPE = torch.float16 +INT8_PER_ROW_SCALE_DTYPE = torch.float16 +INT8_CLIP_PERCENTILE = 99.99984 +INT8_CLIP_Q = INT8_CLIP_PERCENTILE / 100.0 + +def tensor_nbytes(t: Tensor) -> int: + return int(t.numel()) * int(t.element_size()) + +def keep_float_tensor(name: str, t: Tensor, passthrough_orig_dtypes: dict[str, str]) -> Tensor: + if any(pattern in name for pattern in INT8_KEEP_FLOAT_FP32_NAME_PATTERNS): + return t.float().contiguous() + if t.dtype in {torch.float32, torch.bfloat16}: + passthrough_orig_dtypes[name] = str(t.dtype).removeprefix("torch.") + return t.to(dtype=INT8_KEEP_FLOAT_STORE_DTYPE).contiguous() + return t + +def quantize_float_tensor(t: Tensor, qmax: int = 127) -> tuple[Tensor, Tensor]: + """Quantize to [-qmax, qmax] range. Default int8 (qmax=127), int6 (qmax=31), int5 (qmax=15).""" + t32 = t.float() + qmin = -qmax + if t32.ndim == 2: + clip_abs = ( + torch.quantile(t32.abs(), INT8_CLIP_Q, dim=1) + if t32.numel() + else torch.empty((t32.shape[0],), dtype=torch.float32) + ) + clipped = torch.maximum(torch.minimum(t32, clip_abs[:, None]), -clip_abs[:, None]) + scale = (clip_abs / float(qmax)).clamp_min(1.0 / float(qmax)) + q = torch.clamp(torch.round(clipped / scale[:, None]), qmin, qmax).to(torch.int8).contiguous() + return q, scale.to(dtype=INT8_PER_ROW_SCALE_DTYPE).contiguous() + + clip_abs = float(torch.quantile(t32.abs().flatten(), INT8_CLIP_Q).item()) if t32.numel() else 0.0 + scale = torch.tensor(clip_abs / float(qmax) if clip_abs > 0 else 1.0, dtype=torch.float32) + q = torch.clamp(torch.round(torch.clamp(t32, -clip_abs, clip_abs) / scale), qmin, qmax).to(torch.int8).contiguous() + return q, scale + +def quantize_state_dict_int8(state_dict: dict[str, Tensor]): + # Single supported clean-script export format: + # - per-row int8 for 2D float tensors + # - per-tensor int8 for other float tensors + # - exact passthrough for non-floats + # - passthrough for small float tensors, stored as fp16 to save bytes + quantized: dict[str, Tensor] = {} + scales: dict[str, Tensor] = {} + dtypes: dict[str, str] = {} + passthrough: dict[str, Tensor] = {} + passthrough_orig_dtypes: dict[str, str] = {} + qmeta: dict[str, dict[str, object]] = {} + stats = dict.fromkeys( + ("param_count", "num_tensors", "num_float_tensors", "num_nonfloat_tensors", "baseline_tensor_bytes", "int8_payload_bytes"), + 0, + ) + + for name, tensor in state_dict.items(): + t = tensor.detach().to("cpu").contiguous() + stats["param_count"] += int(t.numel()) + stats["num_tensors"] += 1 + stats["baseline_tensor_bytes"] += tensor_nbytes(t) + + if not t.is_floating_point(): + stats["num_nonfloat_tensors"] += 1 + passthrough[name] = t + stats["int8_payload_bytes"] += tensor_nbytes(t) + continue + + # Small float tensors are cheap enough to keep directly. We still downcast + # fp32/bf16 passthrough tensors to fp16 so metadata does not dominate size. + if t.numel() <= INT8_KEEP_FLOAT_MAX_NUMEL: + kept = keep_float_tensor(name, t, passthrough_orig_dtypes) + passthrough[name] = kept + stats["int8_payload_bytes"] += tensor_nbytes(kept) + continue + + stats["num_float_tensors"] += 1 + # Mixed quantization: int6 for MLP layers 3-7 to save artifact space + int6_mlp_layers = os.environ.get("INT6_MLP_LAYERS", "") + qmax = 127 # default int8 + if int6_mlp_layers: + for li in int6_mlp_layers.split(","): + if li.strip() and f"blocks.{li.strip()}.mlp" in name and t.ndim == 2: + qmax = 31 # int6 + break + q, s = quantize_float_tensor(t, qmax=qmax) + if s.ndim > 0: + qmeta[name] = {"scheme": "per_row", "axis": 0} + quantized[name] = q + scales[name] = s + dtypes[name] = str(t.dtype).removeprefix("torch.") + stats["int8_payload_bytes"] += tensor_nbytes(q) + tensor_nbytes(s) + + obj: dict[str, object] = { + "__quant_format__": "int8_clean_per_row_v1", + "quantized": quantized, + "scales": scales, + "dtypes": dtypes, + "passthrough": passthrough, + } + if qmeta: + obj["qmeta"] = qmeta + if passthrough_orig_dtypes: + obj["passthrough_orig_dtypes"] = passthrough_orig_dtypes + return obj, stats + +def dequantize_state_dict_int8(obj: dict[str, object]) -> dict[str, Tensor]: + out: dict[str, Tensor] = {} + qmeta = obj.get("qmeta", {}) + passthrough_orig_dtypes = obj.get("passthrough_orig_dtypes", {}) + for name, q in obj["quantized"].items(): + dtype = getattr(torch, obj["dtypes"][name]) + s = obj["scales"][name] + if qmeta.get(name, {}).get("scheme") == "per_row" or s.ndim > 0: + s = s.to(dtype=torch.float32) + # Broadcast the saved row scale back across trailing dimensions. + out[name] = (q.float() * s.view(q.shape[0], *([1] * (q.ndim - 1)))).to(dtype=dtype).contiguous() + else: + scale = float(s.item()) + out[name] = (q.float() * scale).to(dtype=dtype).contiguous() + for name, t in obj["passthrough"].items(): + # Restore small tensors, undoing the temporary fp16 storage cast if needed. + out_t = t.detach().to("cpu").contiguous() + orig_dtype = passthrough_orig_dtypes.get(name) + if isinstance(orig_dtype, str): + out_t = out_t.to(dtype=getattr(torch, orig_dtype)).contiguous() + out[name] = out_t + return out + + +# ----------------------------- +# DATA LOADING +# ----------------------------- + +def load_data_shard(file: Path) -> Tensor: + header_bytes = 256 * np.dtype(" None: + self.file_idx = (self.file_idx + 1) % len(self.files) + self.tokens = load_data_shard(self.files[self.file_idx]) + self.pos = 0 + + def take(self, n: int) -> Tensor: + chunks: list[Tensor] = [] + remaining = n + while remaining > 0: + avail = self.tokens.numel() - self.pos + if avail <= 0: + self._advance_file() + continue + k = min(remaining, avail) + chunks.append(self.tokens[self.pos : self.pos + k]) + self.pos += k + remaining -= k + return chunks[0] if len(chunks) == 1 else torch.cat(chunks) + + +class DistributedTokenLoader: + # Each call consumes a contiguous chunk from the shared token stream, then slices out + # one disjoint span per rank. The extra "+1" token lets us build (x, y) by shifting. + def __init__(self, pattern: str, rank: int, world_size: int, device: torch.device): + self.rank = rank + self.world_size = world_size + self.device = device + self.stream = TokenStream(pattern) + + def next_batch(self, global_tokens: int, seq_len: int, grad_accum_steps: int) -> tuple[Tensor, Tensor]: + local_tokens = global_tokens // (self.world_size * grad_accum_steps) + per_rank_span = local_tokens + 1 + chunk = self.stream.take(per_rank_span * self.world_size) + start = self.rank * per_rank_span + local = chunk[start : start + per_rank_span].to(dtype=torch.int64) + x = local[:-1].reshape(-1, seq_len) + y = local[1:].reshape(-1, seq_len) + return x.to(self.device, non_blocking=True), y.to(self.device, non_blocking=True) + +# ----------------------------- +# TRANSFORMER MODULES +# ----------------------------- + +class RMSNorm(nn.Module): + def __init__(self, eps: float | None = None): + super().__init__() + self.eps = eps + + def forward(self, x: Tensor) -> Tensor: + return F.rms_norm(x, (x.size(-1),), eps=self.eps) + + +class CastedLinear(nn.Linear): + _qat_enabled: bool = False + _soft_round: bool = False + _soft_round_alpha: float = 1.0 + _quant_percentile: float = float(os.environ.get("QUANT_PERCENTILE", "1.0")) + + def forward(self, x: Tensor) -> Tensor: + w = self.weight.to(x.dtype) + if CastedLinear._qat_enabled and self.training and w.ndim == 2: + w32 = self.weight.float() + pct = CastedLinear._quant_percentile + row_max = (torch.quantile(w32.abs(), pct, dim=1) if pct < 1.0 + else w32.abs().amax(dim=1)).detach() + scale = (row_max / 31.0).clamp_min(1.0 / 31.0) + r = w32 / scale[:, None] + if CastedLinear._soft_round: + alpha = CastedLinear._soft_round_alpha + r_frac = r - r.detach().floor() - 0.5 + norm = torch.tanh(torch.tensor(alpha * 0.5, device=r.device, dtype=r.dtype)) + r_soft = r.detach().floor() + 0.5 + torch.tanh(alpha * r_frac) / (2.0 * norm) + w_q = (torch.clamp(r_soft, -32, 31) * scale[:, None]).to(x.dtype) + w = w_q # soft-round is differentiable, no STE needed + else: + with torch.no_grad(): + w_q = (torch.clamp(torch.round(r), -32, 31) * scale[:, None]).to(x.dtype) + w = w + (w_q - w).detach() # STE + bias = self.bias.to(x.dtype) if self.bias is not None else None + return F.linear(x, w, bias) + + +def restore_low_dim_params_to_fp32(module: nn.Module) -> None: + # Keep small/control parameters in fp32 even when the model body runs in bf16. + with torch.no_grad(): + for name, param in module.named_parameters(): + if (param.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS)) and param.dtype != torch.float32: + param.data = param.data.float() + + +class Rotary(nn.Module): + # NTK-aware RoPE: auto-scales base frequency when seq_len exceeds train_seq_len. + def __init__(self, dim: int, base: float = 10000.0, train_seq_len: int = 1024, rope_dims: int = 0): + super().__init__() + self.rope_dims = rope_dims if rope_dims > 0 else dim + self.dim = dim + self.base = base + self.train_seq_len = train_seq_len + rd = self.rope_dims + inv_freq = 1.0 / (base ** (torch.arange(0, rd, 2, dtype=torch.float32) / rd)) + self.register_buffer("inv_freq", inv_freq, persistent=False) + self._seq_len_cached = 0 + self._cos_cached: Tensor | None = None + self._sin_cached: Tensor | None = None + + def forward(self, seq_len: int, device: torch.device, dtype: torch.dtype) -> tuple[Tensor, Tensor]: + if ( + self._cos_cached is None + or self._sin_cached is None + or self._seq_len_cached != seq_len + or self._cos_cached.device != device + ): + rd = self.rope_dims + if seq_len > self.train_seq_len: + scale = seq_len / self.train_seq_len + new_base = self.base * (scale ** (rd / (rd - 2))) + inv_freq = 1.0 / (new_base ** (torch.arange(0, rd, 2, dtype=torch.float32, device=device) / rd)) + else: + inv_freq = self.inv_freq.to(device) + t = torch.arange(seq_len, device=device, dtype=inv_freq.dtype) + freqs = torch.outer(t, inv_freq) + self._cos_cached = freqs.cos()[None, :, None, :] + self._sin_cached = freqs.sin()[None, :, None, :] + self._seq_len_cached = seq_len + return self._cos_cached.to(dtype=dtype), self._sin_cached.to(dtype=dtype) + + +def apply_rotary_emb(x: Tensor, cos: Tensor, sin: Tensor) -> Tensor: + rd = cos.size(-1) * 2 + if rd < x.size(-1): + x_rope, x_pass = x[..., :rd], x[..., rd:] + half = rd // 2 + x1, x2 = x_rope[..., :half], x_rope[..., half:] + x_rot = torch.cat((x1 * cos + x2 * sin, x1 * (-sin) + x2 * cos), dim=-1) + return torch.cat((x_rot, x_pass), dim=-1) + half = x.size(-1) // 2 + x1, x2 = x[..., :half], x[..., half:] + return torch.cat((x1 * cos + x2 * sin, x1 * (-sin) + x2 * cos), dim=-1) + + +class CausalSelfAttention(nn.Module): + def __init__( + self, + dim: int, + num_heads: int, + num_kv_heads: int, + rope_base: float, + qk_gain_init: float, + rope_dims: int = 0, + value_residual: bool = False, + gated_attention: bool = False, + ): + super().__init__() + if dim % num_heads != 0: + raise ValueError("model_dim must be divisible by num_heads") + if num_heads % num_kv_heads != 0: + raise ValueError("num_heads must be divisible by num_kv_heads") + self.num_heads = num_heads + self.num_kv_heads = num_kv_heads + self.head_dim = dim // num_heads + if self.head_dim % 2 != 0: + raise ValueError("head_dim must be even for RoPE") + kv_dim = self.num_kv_heads * self.head_dim + self.c_q = CastedLinear(dim, dim, bias=False) + self.c_k = CastedLinear(dim, kv_dim, bias=False) + self.c_v = CastedLinear(dim, kv_dim, bias=False) + self.proj = CastedLinear(dim, dim, bias=False) + self.proj._zero_init = True + self.q_gain = nn.Parameter(torch.full((num_heads,), qk_gain_init, dtype=torch.float32)) + self.rope_dims = rope_dims + self.rotary = Rotary(self.head_dim, base=rope_base, train_seq_len=1024, rope_dims=rope_dims) + self.use_xsa = False + self.value_residual = value_residual + if value_residual: + self.vr_lambda = nn.Parameter(torch.tensor([0.5, 0.5], dtype=torch.float32)) + self.gated_attention = gated_attention + if gated_attention: + self.attn_gate = nn.Linear(dim, num_heads, bias=True) + nn.init.zeros_(self.attn_gate.weight) + nn.init.constant_(self.attn_gate.bias, 4.0) + + def _xsa_efficient(self, y: Tensor, v: Tensor) -> Tensor: + """Subtract self-value projection via GQA-aware reshape (no repeat_interleave).""" + B, T, H, D = y.shape + Hkv = v.size(-2) + group = H // Hkv + y_g = y.reshape(B, T, Hkv, group, D) + vn = F.normalize(v, dim=-1).unsqueeze(-2) + proj = (y_g * vn).sum(dim=-1, keepdim=True) * vn + return (y_g - proj).reshape(B, T, H, D) + + def forward(self, x: Tensor, v0: Tensor | None = None) -> tuple[Tensor, Tensor | None]: + bsz, seqlen, dim = x.shape + q = self.c_q(x).reshape(bsz, seqlen, self.num_heads, self.head_dim) + k = self.c_k(x).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim) + v = self.c_v(x).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim) + raw_v = v if self.value_residual else None + if self.value_residual and v0 is not None: + lam = self.vr_lambda.to(dtype=v.dtype) + v = lam[0] * v0 + lam[1] * v + q = F.rms_norm(q, (q.size(-1),)) + k = F.rms_norm(k, (k.size(-1),)) + cos, sin = self.rotary(seqlen, x.device, q.dtype) + q = apply_rotary_emb(q, cos, sin) + k = apply_rotary_emb(k, cos, sin) + q = q * self.q_gain.to(dtype=q.dtype)[None, None, :, None] + fa_dtype = torch.bfloat16 + if _USE_FA3: + y = flash_attn_3_func(q.to(fa_dtype), k.to(fa_dtype), v.to(fa_dtype), causal=True) + else: + # SDPA fallback: (B, T, H, D) -> (B, H, T, D), expand KV for GQA + q_t = q.to(fa_dtype).transpose(1, 2) + k_t = k.to(fa_dtype).transpose(1, 2) + v_t = v.to(fa_dtype).transpose(1, 2) + if self.num_kv_heads != self.num_heads: + rep = self.num_heads // self.num_kv_heads + k_t = k_t.repeat_interleave(rep, dim=1) + v_t = v_t.repeat_interleave(rep, dim=1) + y = F.scaled_dot_product_attention(q_t, k_t, v_t, is_causal=True) + y = y.transpose(1, 2) # (B, H, T, D) -> (B, T, H, D) + if self.use_xsa: + y = self._xsa_efficient(y, v) + if self.gated_attention: + gate = torch.sigmoid(self.attn_gate(x)) # (B, T, num_heads) + y = y * gate.unsqueeze(-1) # (B, T, H, 1) broadcast to (B, T, H, D) + y = y.reshape(bsz, seqlen, dim) + return self.proj(y), raw_v + + +class SmearGate(nn.Module): + def __init__(self, dim: int): + super().__init__() + self.gate = nn.Parameter(torch.zeros(dim, dtype=torch.float32)) + + def forward(self, x: Tensor) -> Tensor: + g = torch.sigmoid(self.gate.to(dtype=x.dtype))[None, None, :] + x_prev = torch.cat([torch.zeros_like(x[:, :1]), x[:, :-1]], dim=1) + return (1 - g) * x + g * x_prev + + +class BigramHashEmbedding(nn.Module): + def __init__(self, bigram_vocab_size: int, bigram_dim: int, model_dim: int): + super().__init__() + self.bigram_vocab_size = bigram_vocab_size + self.embed = nn.Embedding(bigram_vocab_size, bigram_dim) + nn.init.zeros_(self.embed.weight) + self.proj = CastedLinear(bigram_dim, model_dim, bias=False) if bigram_dim != model_dim else None + if self.proj is not None: + nn.init.zeros_(self.proj.weight) + self.scale = nn.Parameter(torch.tensor(0.05, dtype=torch.float32)) + + def bigram_hash(self, tokens: Tensor) -> Tensor: + t = tokens.to(torch.int32) + mod = self.bigram_vocab_size - 1 + out = torch.empty_like(t) + out[..., 0] = mod + out[..., 1:] = torch.bitwise_xor(36313 * t[..., 1:], 27191 * t[..., :-1]) % mod + return out.long() + + def forward(self, token_ids: Tensor) -> Tensor: + h = self.embed(self.bigram_hash(token_ids)) + if self.proj is not None: + h = self.proj(h) + return h * self.scale.to(dtype=h.dtype) + + +class MLP(nn.Module): + def __init__(self, dim: int, mlp_mult: int): + super().__init__() + hidden = int(mlp_mult * dim) + self.fc = CastedLinear(dim, hidden, bias=False) + self.proj = CastedLinear(hidden, dim, bias=False) + self.proj._zero_init = True + self.use_leaky = bool(int(os.environ.get("LEAKY_RELU", "1"))) + self.leaky_slope = float(os.environ.get("LEAKY_SLOPE", "0.5")) + + def forward(self, x: Tensor) -> Tensor: + x = F.leaky_relu(self.fc(x), self.leaky_slope) if self.use_leaky else torch.relu(self.fc(x)) + return self.proj(x.square()) + + +class CanonAC(nn.Module): + """Canon Autoregressive Convolution with DeltaGate. Manual shift+mul (no Conv1d).""" + def __init__(self, dim: int, kernel: int = 4, delta_gate_init: float = -4.0): + super().__init__() + self.kernel = kernel + self.weight = nn.Parameter(torch.zeros(kernel, dim)) + self.delta_gate_logit = nn.Parameter(torch.tensor(delta_gate_init)) + + def forward(self, x: Tensor) -> Tensor: + B, T, D = x.shape + K = self.kernel + w = self.weight.to(x.dtype) + x_pad = F.pad(x, (0, 0, K - 1, 0)) + y = w[0] * x_pad[:, K - 1:, :] + for k in range(1, K): + y = y + w[k] * x_pad[:, K - 1 - k : T + K - 1 - k, :] + gate = torch.sigmoid(self.delta_gate_logit.to(x.dtype)) + return x + gate * y + + +class Block(nn.Module): + def __init__( + self, + dim: int, + num_heads: int, + num_kv_heads: int, + mlp_mult: int, + rope_base: float, + qk_gain_init: float, + rope_dims: int = 0, + layer_idx: int = 0, + ln_scale: bool = False, + value_residual: bool = False, + gated_attention: bool = False, + canon_kernel: int = 0, + canon_delta_gate_init: float = -4.0, + ): + super().__init__() + self.attn_norm = RMSNorm() + self.mlp_norm = RMSNorm() + self.attn = CausalSelfAttention(dim, num_heads, num_kv_heads, rope_base, qk_gain_init, + rope_dims=rope_dims, value_residual=value_residual, + gated_attention=gated_attention) + self.mlp = MLP(dim, mlp_mult) + self.attn_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + self.mlp_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + self.resid_mix = nn.Parameter(torch.stack((torch.ones(dim), torch.zeros(dim))).float()) + self.ln_scale_factor = 1.0 / math.sqrt(layer_idx + 1) if ln_scale else 1.0 + self.canon_a = CanonAC(dim, canon_kernel, canon_delta_gate_init) if canon_kernel > 0 else None + self.canon_c = CanonAC(dim, canon_kernel, canon_delta_gate_init) if canon_kernel > 0 else None + + def forward(self, x: Tensor, x0: Tensor, v0: Tensor | None = None) -> tuple[Tensor, Tensor | None]: + mix = self.resid_mix.to(dtype=x.dtype) + x = mix[0][None, None, :] * x + mix[1][None, None, :] * x0 + s = self.ln_scale_factor + attn_in = self.attn_norm(x) * s + if self.canon_a is not None: + attn_in = self.canon_a(attn_in) + attn_out, raw_v = self.attn(attn_in, v0=v0) + x = x + self.attn_scale.to(dtype=x.dtype)[None, None, :] * attn_out + mlp_in = self.mlp_norm(x) * s + if self.canon_c is not None: + mlp_in = self.canon_c(mlp_in) + x = x + self.mlp_scale.to(dtype=x.dtype)[None, None, :] * self.mlp(mlp_in) + return x, raw_v + + +class FlowRefiner(nn.Module): + """1-step flow matching refiner in low-dim latent space. + + Projects hidden states to a low-dim latent, applies a learned velocity + field (1-step Euler), projects back. Output is additive adjustment to + hidden states. Initialized near-zero so the refiner starts as identity. + """ + def __init__(self, model_dim: int, latent_dim: int = 64, hidden_dim: int = 256, init_scale: float = 0.01): + super().__init__() + self.down_proj = nn.Linear(model_dim, latent_dim, bias=False) + self.velocity_net = nn.Sequential( + nn.Linear(latent_dim, hidden_dim, bias=True), + nn.GELU(), + nn.Linear(hidden_dim, latent_dim, bias=True), + ) + self.up_proj = nn.Linear(latent_dim, model_dim, bias=False) + self.gate = nn.Parameter(torch.tensor(-5.0)) # sigmoid(-5) ≈ 0.007 + self._init_scale = init_scale + self._init_weights() + self.velocity_net[2]._zero_init = True + + def _init_weights(self): + nn.init.normal_(self.down_proj.weight, std=self._init_scale) + nn.init.normal_(self.velocity_net[0].weight, std=self._init_scale) + nn.init.zeros_(self.velocity_net[0].bias) + nn.init.zeros_(self.velocity_net[2].weight) + nn.init.zeros_(self.velocity_net[2].bias) + nn.init.normal_(self.up_proj.weight, std=self._init_scale) + + def forward(self, x: Tensor) -> Tensor: + z = self.down_proj(x) + v = self.velocity_net(z) + z_refined = z + v + delta = self.up_proj(z_refined) + return torch.sigmoid(self.gate) * delta + + +class NativeFlowMatcher(nn.Module): + """Conditional Flow Matching refiner for hidden states. + + During training: computes auxiliary CFM loss on interpolated states. + During inference: applies a single Euler step correction at t=1. + """ + def __init__(self, model_dim: int, hidden_dim: int = 256, init_scale: float = 0.01): + super().__init__() + self.model_dim = model_dim + # Sinusoidal time embedding → projected to hidden_dim + self.time_proj = nn.Sequential( + nn.Linear(model_dim, hidden_dim, bias=True), + nn.GELU(), + ) + # Velocity network: x → hidden → x (with time conditioning via addition) + self.v_in = nn.Linear(model_dim, hidden_dim, bias=True) + self.v_act = nn.GELU() + self.v_out = nn.Linear(hidden_dim, model_dim, bias=False) + self.gate = nn.Parameter(torch.tensor(-5.0)) # sigmoid(-5) ≈ 0.007 + self._init_scale = init_scale + self._init_weights() + self.v_out._zero_init = True + + def _init_weights(self): + nn.init.normal_(self.v_in.weight, std=self._init_scale) + nn.init.zeros_(self.v_in.bias) + nn.init.normal_(self.time_proj[0].weight, std=self._init_scale) + nn.init.zeros_(self.time_proj[0].bias) + nn.init.zeros_(self.v_out.weight) + + @staticmethod + def _sinusoidal_time_emb(t: Tensor, dim: int) -> Tensor: + """Sinusoidal positional embedding for scalar time t. t: (...,) -> (..., dim).""" + half_dim = dim // 2 + emb = math.log(10000.0) / max(half_dim - 1, 1) + emb = torch.exp(torch.arange(half_dim, device=t.device, dtype=t.dtype) * -emb) + emb = t.unsqueeze(-1) * emb # (..., half_dim) + emb = torch.cat([emb.sin(), emb.cos()], dim=-1) # (..., dim) + if dim % 2 == 1: + emb = F.pad(emb, (0, 1)) + return emb + + def _velocity(self, x: Tensor, t_emb: Tensor) -> Tensor: + """Compute velocity v(x, t). x: (..., model_dim), t_emb: (..., hidden_dim).""" + h = self.v_in(x) # (..., hidden_dim) + h = h + t_emb # time conditioning via addition + h = self.v_act(h) # GELU + return self.v_out(h) # (..., model_dim) + + def forward(self, x: Tensor) -> tuple[Tensor, Tensor]: + """Returns (correction, cfm_loss). + + During training: cfm_loss is the flow matching objective (> 0). + During eval: cfm_loss is zero. + correction: gated velocity at t=1, to be added to x. + """ + B_seq = x.shape[:-1] # works for any leading dimensions + D = self.model_dim + + # Inference path: velocity at t=1 (clean input, no noise) + t_ones = torch.ones(*B_seq, device=x.device, dtype=x.dtype) + t_emb_ones = self.time_proj(self._sinusoidal_time_emb(t_ones, D)) + v_at_1 = self._velocity(x, t_emb_ones) + correction = torch.sigmoid(self.gate) * v_at_1 + + # Training path: auxiliary CFM loss + if self.training: + # Sample random t ~ U(0,1) per position + t = torch.rand(*B_seq, device=x.device, dtype=x.dtype) + z = torch.randn_like(x) # noise + # OT interpolant: x_t = (1-t)*z + t*x + t_expanded = t.unsqueeze(-1) # (..., 1) + x_t = (1.0 - t_expanded) * z + t_expanded * x.detach() + # Target velocity = x - z (OT conditional velocity field) + target_v = x.detach() - z + # Predict velocity at interpolated point + t_emb = self.time_proj(self._sinusoidal_time_emb(t, D)) + pred_v = self._velocity(x_t, t_emb) + cfm_loss = F.mse_loss(pred_v, target_v) + else: + cfm_loss = x.new_zeros(()) + + return correction, cfm_loss + + +class TTTLinearRefiner(nn.Module): + """E2E TTT-Linear refiner (Sun et al., 2024, arXiv:2407.04620). + + Applied to final hidden states [B, L, D] before lm_head. The hidden + state is a per-head linear model W updated via mini-batch gradient + descent on a learned self-supervised reconstruction task. Uses the + dual form for GPU efficiency. Trained end-to-end: the outer loop + optimizes theta_K, theta_V, theta_Q projections and the inner-loop + initialization W_0, making the model learn WHAT to compress. + """ + + def __init__(self, model_dim: int, num_heads: int = 8, + mini_batch_size: int = 16, ttt_base_lr: float = 1.0): + super().__init__() + self.model_dim = model_dim + self.num_heads = num_heads + self.head_dim = model_dim // num_heads + self.mini_batch_size = mini_batch_size + self.ttt_base_lr = ttt_base_lr + self.eta = ttt_base_lr / self.head_dim + + # Learnable projections (outer-loop) + self.theta_K = nn.Linear(model_dim, model_dim, bias=False) + self.theta_V = nn.Linear(model_dim, model_dim, bias=False) + self.theta_Q = nn.Linear(model_dim, model_dim, bias=False) + + # Inner-loop model: W_0, b_0 + self.W1 = nn.Parameter(torch.normal(0, 0.02, + size=(num_heads, self.head_dim, self.head_dim))) + self.b1 = nn.Parameter(torch.zeros(num_heads, 1, self.head_dim)) + + # Inner-loop LayerNorm (outer-loop trainable) + self.ttt_ln_w = nn.Parameter(torch.ones(num_heads, self.head_dim)) + self.ttt_ln_b = nn.Parameter(torch.zeros(num_heads, self.head_dim)) + + # Output projection + gate (starts near zero for residual safety) + self.o_proj = nn.Linear(model_dim, model_dim, bias=False) + self.post_norm = nn.LayerNorm(model_dim, eps=1e-6) + self.gate = nn.Parameter(torch.tensor(-5.0)) + nn.init.zeros_(self.o_proj.weight) + self.o_proj._zero_init = True + + @staticmethod + def _ln_fwd(x: Tensor, gamma: Tensor, beta: Tensor, eps: float = 1e-6) -> Tensor: + mu = x.mean(dim=-1, keepdim=True) + var = x.var(dim=-1, keepdim=True, unbiased=False) + x_hat = (x - mu) / torch.sqrt(var + eps) + return gamma * x_hat + beta + + @staticmethod + def _ln_fused_l2_bwd(x: Tensor, target: Tensor, gamma: Tensor, beta: Tensor, + eps: float = 1e-6) -> Tensor: + D = x.shape[-1] + mu = x.mean(dim=-1, keepdim=True) + var = x.var(dim=-1, keepdim=True, unbiased=False) + std = torch.sqrt(var + eps) + x_hat = (x - mu) / std + y = gamma * x_hat + beta + grad_output = y - target + grad_x_hat = grad_output * gamma + z = (1.0 / D) * ( + D * grad_x_hat + - grad_x_hat.sum(dim=-1, keepdim=True) + - x_hat * (grad_x_hat * x_hat).sum(dim=-1, keepdim=True) + ) / std + return z + + def forward(self, x: Tensor) -> Tensor: + """x: [B, L, D]. Returns [B, L, D] additive refinement (gated).""" + B, L, D = x.shape + NH, HD, b = self.num_heads, self.head_dim, self.mini_batch_size + eta = self.eta + + num_mb = L // b + usable_L = num_mb * b + x_proc = x[:, :usable_L] if usable_L < L else x + + # Project: [B, L, D] -> [B, NH, L, HD] + XK = self.theta_K(x_proc).reshape(B, usable_L, NH, HD).permute(0, 2, 1, 3) + XV = self.theta_V(x_proc).reshape(B, usable_L, NH, HD).permute(0, 2, 1, 3) + XQ = self.theta_Q(x_proc).reshape(B, usable_L, NH, HD).permute(0, 2, 1, 3) + + # Reshape into mini-batches: [B, NH, num_mb, b, HD] + XK = XK.reshape(B, NH, num_mb, b, HD) + XV = XV.reshape(B, NH, num_mb, b, HD) + XQ = XQ.reshape(B, NH, num_mb, b, HD) + + ln_w = self.ttt_ln_w.reshape(1, NH, 1, HD) + ln_b = self.ttt_ln_b.reshape(1, NH, 1, HD) + + # Causal mask for dual form + causal = torch.tril(torch.ones(b, b, device=x.device, dtype=x.dtype)) + + # Initialize inner-loop model + W = self.W1.unsqueeze(0).expand(B, -1, -1, -1).clone() + bias = self.b1.unsqueeze(0).expand(B, -1, -1, -1).clone() + + all_out = [] + for m in range(num_mb): + xk = XK[:, :, m] + xv = XV[:, :, m] + xq = XQ[:, :, m] + + Z1 = xk @ W + bias + target = xv - xk + + grad = self._ln_fused_l2_bwd(Z1, target, ln_w, ln_b) + + Attn = causal.unsqueeze(0).unsqueeze(0) * (xq @ xk.transpose(-2, -1)) + cumgrad = (causal.unsqueeze(0).unsqueeze(0) @ grad) + b_bar = bias - eta * cumgrad + Z_bar = xq @ W - eta * Attn @ grad + b_bar + + W = W - eta * xk.transpose(-2, -1) @ grad + bias = bias - eta * grad.sum(dim=-2, keepdim=True) + + Z_bar = self._ln_fwd(Z_bar, ln_w, ln_b) + out_mb = xq + Z_bar + all_out.append(out_mb) + + z = torch.cat(all_out, dim=2).permute(0, 2, 1, 3).reshape(B, usable_L, D) + z = self.post_norm(z) + z = self.o_proj(z) + result = torch.sigmoid(self.gate) * z + + if usable_L < L: + pad = torch.zeros(B, L - usable_L, D, device=x.device, dtype=x.dtype) + result = torch.cat([result, pad], dim=1) + + return result + + +class GPT(nn.Module): + def __init__( + self, + vocab_size: int, + num_layers: int, + model_dim: int, + num_heads: int, + num_kv_heads: int, + mlp_mult: int, + tie_embeddings: bool, + tied_embed_init_std: float, + logit_softcap: float, + rope_base: float, + qk_gain_init: float, + mtp_num_heads: int = 0, + mtp_loss_weight: float = 0.1, + bigram_vocab_size: int = 0, + bigram_dim: int = 128, + xsa_last_n: int = 0, + rope_dims: int = 0, + ln_scale: bool = False, + value_residual: bool = False, + gated_attention: bool = False, + canon_last_n: int = 0, + canon_kernel: int = 4, + canon_delta_gate_init: float = -4.0, + flow_enabled: bool = False, + flow_latent_dim: int = 64, + flow_hidden_dim: int = 256, + flow_init_scale: float = 0.01, + native_flow_enabled: bool = False, + native_flow_hidden_dim: int = 256, + native_flow_init_scale: float = 0.01, + native_flow_loss_weight: float = 0.1, + e2e_ttt_enabled: bool = False, + e2e_ttt_num_heads: int = 8, + e2e_ttt_mini_batch: int = 16, + e2e_ttt_base_lr: float = 1.0, + ): + super().__init__() + if logit_softcap <= 0.0: + raise ValueError(f"logit_softcap must be positive, got {logit_softcap}") + self.tie_embeddings = tie_embeddings + self.tied_embed_init_std = tied_embed_init_std + self.logit_softcap = logit_softcap + self.mtp_num_heads = mtp_num_heads + self.mtp_loss_weight = mtp_loss_weight + self.tok_emb = nn.Embedding(vocab_size, model_dim) + self.bigram = BigramHashEmbedding(bigram_vocab_size, bigram_dim, model_dim) if bigram_vocab_size > 0 else None + self.smear = SmearGate(model_dim) + self.num_encoder_layers = num_layers // 2 + self.num_decoder_layers = num_layers - self.num_encoder_layers + self.num_skip_weights = min(self.num_encoder_layers, self.num_decoder_layers) + self.skip_weights = nn.Parameter(torch.ones(self.num_skip_weights, model_dim, dtype=torch.float32)) + canon_start = num_layers - canon_last_n if canon_last_n > 0 else num_layers + self.blocks = nn.ModuleList( + [ + Block( + model_dim, + num_heads, + num_kv_heads, + mlp_mult, + rope_base, + qk_gain_init, + rope_dims=rope_dims, + layer_idx=i, + ln_scale=ln_scale, + value_residual=value_residual, + gated_attention=gated_attention, + canon_kernel=canon_kernel if i >= canon_start else 0, + canon_delta_gate_init=canon_delta_gate_init, + ) + for i in range(num_layers) + ] + ) + self.final_norm = RMSNorm() + self.lm_head = None if tie_embeddings else CastedLinear(model_dim, vocab_size, bias=False) + if self.lm_head is not None: + self.lm_head._zero_init = True + self.flow_refiner = FlowRefiner(model_dim, flow_latent_dim, flow_hidden_dim, flow_init_scale) if flow_enabled else None + self.native_flow = NativeFlowMatcher(model_dim, native_flow_hidden_dim, native_flow_init_scale) if native_flow_enabled else None + self.native_flow_loss_weight = native_flow_loss_weight + self.ttt_refiner = TTTLinearRefiner(model_dim, e2e_ttt_num_heads, e2e_ttt_mini_batch, e2e_ttt_base_lr) if e2e_ttt_enabled else None + self.mtp_heads = nn.ModuleList( + [CastedLinear(model_dim, vocab_size, bias=False) for _ in range(mtp_num_heads)] + ) + for head in self.mtp_heads: + head._zero_init = True + if xsa_last_n > 0: + for i in range(max(0, num_layers - xsa_last_n), num_layers): + self.blocks[i].attn.use_xsa = True + self._init_weights() + + def _init_weights(self) -> None: + if self.tie_embeddings: + nn.init.normal_(self.tok_emb.weight, mean=0.0, std=self.tied_embed_init_std) + num_layers = len(self.blocks) + for name, module in self.named_modules(): + if isinstance(module, nn.Linear): + if getattr(module, "_zero_init", False): + nn.init.zeros_(module.weight) + elif module.weight.ndim == 2 and module.weight.shape[0] >= 64 and module.weight.shape[1] >= 64: + nn.init.orthogonal_(module.weight, gain=1.0) + if ".proj." in name or name.endswith(".proj"): + with torch.no_grad(): + module.weight.mul_(1.0 / math.sqrt(2 * num_layers)) + + def forward(self, input_ids: Tensor, target_ids: Tensor) -> Tensor: + x = self.tok_emb(input_ids) + if self.bigram is not None: + x = x + self.bigram(input_ids) + x = F.rms_norm(x, (x.size(-1),)) + x = self.smear(x) + x0 = x + v0 = None + skips: list[Tensor] = [] + + for i in range(self.num_encoder_layers): + x, raw_v = self.blocks[i](x, x0, v0=v0) + if v0 is None and raw_v is not None: + v0 = raw_v + skips.append(x) + for i in range(self.num_decoder_layers): + if skips: + x = x + self.skip_weights[i].to(dtype=x.dtype)[None, None, :] * skips.pop() + x, _ = self.blocks[self.num_encoder_layers + i](x, x0, v0=v0) + + x = self.final_norm(x) + if self.ttt_refiner is not None: + x = x + self.ttt_refiner(x) + if self.flow_refiner is not None: + x = x + self.flow_refiner(x) + nflow_loss = x.new_zeros(()) + if self.native_flow is not None: + correction, nflow_loss = self.native_flow(x) + x = x + correction + x_flat = x.reshape(-1, x.size(-1)) + targets = target_ids.reshape(-1) + if self.tie_embeddings: + logits_proj = F.linear(x_flat, self.tok_emb.weight) + else: + if self.lm_head is None: + raise RuntimeError("lm_head is required when tie_embeddings=False") + logits_proj = self.lm_head(x_flat) + logits = self.logit_softcap * torch.tanh(logits_proj / self.logit_softcap) + main_loss = F.cross_entropy(logits.float(), targets, reduction="mean") + + if self.training and self.mtp_num_heads > 0 and self.mtp_loss_weight > 0.0: + _, seqlen, dim = x.shape + mtp_loss_sum = x.new_zeros(()) + mtp_loss_count = 0 + for k, mtp_head in enumerate(self.mtp_heads): + valid_t = seqlen - (k + 1) + if valid_t <= 0: + continue + mtp_hidden = x[:, :valid_t, :].reshape(-1, dim) + mtp_targets = target_ids[:, k + 1 :].reshape(-1) + mtp_logits_proj = mtp_head(mtp_hidden) + mtp_logits = self.logit_softcap * torch.tanh(mtp_logits_proj / self.logit_softcap) + mtp_loss_sum = mtp_loss_sum + F.cross_entropy(mtp_logits.float(), mtp_targets, reduction="mean") + mtp_loss_count += 1 + if mtp_loss_count > 0: + main_loss = main_loss + self.mtp_loss_weight * (mtp_loss_sum / mtp_loss_count) + + if self.native_flow is not None and self.training: + main_loss = main_loss + self.native_flow_loss_weight * nflow_loss + + return main_loss + + def forward_logits(self, input_ids: Tensor) -> Tensor: + """Return logits (bsz, seq_len, vocab) without computing loss.""" + x = self.tok_emb(input_ids) + if self.bigram is not None: + x = x + self.bigram(input_ids) + x = F.rms_norm(x, (x.size(-1),)) + x = self.smear(x) + x0 = x + v0 = None + skips: list[Tensor] = [] + for i in range(self.num_encoder_layers): + x, raw_v = self.blocks[i](x, x0, v0=v0) + if v0 is None and raw_v is not None: + v0 = raw_v + skips.append(x) + for i in range(self.num_decoder_layers): + if skips: + x = x + self.skip_weights[i].to(dtype=x.dtype)[None, None, :] * skips.pop() + x, _ = self.blocks[self.num_encoder_layers + i](x, x0, v0=v0) + x = self.final_norm(x) + if self.ttt_refiner is not None: + x = x + self.ttt_refiner(x) + if self.flow_refiner is not None: + x = x + self.flow_refiner(x) + if self.native_flow is not None: + correction, _ = self.native_flow(x) + x = x + correction + if self.tie_embeddings: + logits_proj = F.linear(x, self.tok_emb.weight) + else: + logits_proj = self.lm_head(x) + return self.logit_softcap * torch.tanh(logits_proj / self.logit_softcap) + + +# ----------------------------- +# SLIDING WINDOW EVALUATION +# ----------------------------- + +def eval_val_sliding( + args: Hyperparameters, + base_model: nn.Module, + rank: int, + world_size: int, + device: torch.device, + val_tokens: Tensor, + base_bytes_lut: Tensor, + has_leading_space_lut: Tensor, + is_boundary_token_lut: Tensor, + stride: int, + batch_seqs: int = 32, + eval_seq_len: int | None = None, +) -> tuple[float, float]: + """Sliding window evaluation: each token scored with maximum context. + Optionally uses entropy-gated 5-gram cache (NGRAM_CACHE=1).""" + seq_len = eval_seq_len or args.train_seq_len + total_tokens = val_tokens.numel() - 1 + + window_starts = [ws for ws in range(0, total_tokens, stride) + if min(ws + seq_len, total_tokens) - ws >= 1] + total_windows = len(window_starts) + + my_s = (total_windows * rank) // world_size + my_e = (total_windows * (rank + 1)) // world_size + my_windows = window_starts[my_s:my_e] + + loss_sum = torch.zeros((), device=device, dtype=torch.float64) + token_count = torch.zeros((), device=device, dtype=torch.float64) + byte_count = torch.zeros((), device=device, dtype=torch.float64) + + # N-gram eval cache with multi-order backoff + entropy-adaptive alpha (PR #702 inspired) + _ngram_default = "1" if world_size > 1 else "0" + use_ngram = bool(int(os.environ.get("NGRAM_CACHE", _ngram_default))) + ngram_alpha = float(os.environ.get("NGRAM_ALPHA", "0.40")) + ngram_min_count = int(os.environ.get("NGRAM_MIN_COUNT", "2")) + ngram_order = int(os.environ.get("NGRAM_ORDER", "7")) + ngram_min_order = int(os.environ.get("NGRAM_MIN_ORDER", "2")) + ngram_buckets = int(os.environ.get("NGRAM_BUCKETS", "4194304")) + ngram_entropy = bool(int(os.environ.get("NGRAM_ENTROPY", "1"))) + ngram_ent_base = float(os.environ.get("NGRAM_ENT_BASE", "0.05")) + ngram_ent_range = float(os.environ.get("NGRAM_ENT_RANGE", "0.55")) + ngram_ent_scale = float(os.environ.get("NGRAM_ENT_SCALE", "2.0")) + ngram_ent_thresh = float(os.environ.get("NGRAM_ENT_THRESH", "4.0")) + if use_ngram: + val_np = val_tokens.cpu().numpy() + _n_orders = ngram_order - ngram_min_order + 1 + ctx_tables = [np.zeros((ngram_buckets,), dtype=np.uint32) for _ in range(_n_orders)] + full_tables = [np.zeros((ngram_buckets,), dtype=np.uint32) for _ in range(_n_orders)] + ng_mask = np.uint64(ngram_buckets - 1) + ng_primes = np.array( + [np.uint64(36313), np.uint64(27191), np.uint64(51647), np.uint64(81929), + np.uint64(131071), np.uint64(175447), np.uint64(209591)], + dtype=np.uint64, + ) + print(f"ngram_cache:enabled orders={ngram_min_order}-{ngram_order} backoff " + f"entropy={ngram_entropy} alpha={ngram_alpha} " + f"ent_base={ngram_ent_base} ent_range={ngram_ent_range} " + f"min_count={ngram_min_count} buckets={ngram_buckets}", flush=True) + + base_model.eval() + compiled_logits = torch.compile(base_model.forward_logits, dynamic=False, fullgraph=True) + + with torch.inference_mode(): + for bi in range(0, len(my_windows), batch_seqs): + batch_ws = my_windows[bi:bi + batch_seqs] + bsz = len(batch_ws) + + x_batch = torch.zeros(bsz, seq_len, dtype=torch.int64, device=device) + y_batch = torch.zeros(bsz, seq_len, dtype=torch.int64, device=device) + wlens: list[int] = [] + + for i, ws in enumerate(batch_ws): + end = min(ws + seq_len, total_tokens) + wlen = end - ws + wlens.append(wlen) + chunk = val_tokens[ws:end + 1].to(dtype=torch.int64, device=device) + x_batch[i, :wlen] = chunk[:-1] + y_batch[i, :wlen] = chunk[1:] + + with torch.autocast(device_type="cuda", dtype=torch.bfloat16): + logits = compiled_logits(x_batch) + + nll = F.cross_entropy( + logits.reshape(-1, logits.size(-1)).float(), + y_batch.reshape(-1), + reduction="none", + ).reshape(bsz, seq_len) + + for i, ws in enumerate(batch_ws): + wlen = wlens[i] + s = 0 if ws == 0 else max(wlen - stride, 0) + seg_len = wlen - s + if seg_len <= 0: + continue + + scored_nll = nll[i, s:wlen].to(torch.float64) + + if use_ngram: + seg_nll_np = scored_nll.cpu().numpy() + seg_model_p = np.exp(-seg_nll_np) + n_seg = len(seg_nll_np) + global_j = np.arange(ws + s + 1, ws + wlen + 1, dtype=np.int64) + + # Entropy-adaptive alpha: compute from model logits (GPU) + if ngram_entropy: + with torch.no_grad(): + lp = F.log_softmax(logits[i, s:wlen].float(), dim=-1) + seg_ent = -(lp.exp() * lp).sum(dim=-1).cpu().numpy() + alpha_per_tok = ngram_ent_base + ngram_ent_range / ( + 1.0 + np.exp(-ngram_ent_scale * (seg_ent - ngram_ent_thresh))) + + # Precompute hashes for all orders + order_data = [] # (v_idx, ctx_key, full_key) per order + for oi in range(_n_orders): + ctx_w = ngram_min_order + oi - 1 + valid = global_j >= ctx_w + if not valid.any(): + order_data.append(None) + continue + v_idx = np.nonzero(valid)[0] + jv = global_j[v_idx] + ctx_hash = np.zeros(len(jv), dtype=np.uint64) + for k in range(ctx_w): + tok = val_np[jv - (ctx_w - k)].astype(np.uint64) + ctx_hash ^= tok * ng_primes[k % len(ng_primes)] + ctx_key = (ctx_hash & ng_mask).astype(np.int64) + tgt_np = val_np[jv].astype(np.uint64) + full_key = ((ctx_hash ^ (tgt_np * ng_primes[ctx_w % len(ng_primes)])) & ng_mask).astype(np.int64) + order_data.append((v_idx, ctx_key, full_key)) + + # Multi-order backoff: highest order first, fill unmatched with lower orders + best_p_ng = np.full(n_seg, -1.0) + for oi in range(_n_orders - 1, -1, -1): + if order_data[oi] is None: + continue + v_idx, ctx_key, full_key = order_data[oi] + ctx_counts = ctx_tables[oi][ctx_key].astype(np.float64) + full_counts = full_tables[oi][full_key].astype(np.float64) + has_match = ctx_counts >= float(ngram_min_count) + needs_fill = has_match & (best_p_ng[v_idx] < 0) + if needs_fill.any(): + fill_idx = v_idx[needs_fill] + p = np.minimum(full_counts[needs_fill], ctx_counts[needs_fill]) / np.maximum(ctx_counts[needs_fill], 1.0) + best_p_ng[fill_idx] = np.clip(p, 0.0, 1.0) + + # Mix model probability with n-gram + has_match = best_p_ng >= 0 + if has_match.any(): + if ngram_entropy: + alpha = alpha_per_tok[has_match] + else: + alpha = ngram_alpha + seg_model_p[has_match] = (1.0 - alpha) * seg_model_p[has_match] + alpha * best_p_ng[has_match] + seg_nll_np = -np.log(np.clip(seg_model_p, 1e-12, 1.0)) + + # Score-first: update ALL order tables AFTER scoring + for oi in range(_n_orders): + if order_data[oi] is None: + continue + v_idx, ctx_key, full_key = order_data[oi] + np.add.at(ctx_tables[oi], ctx_key, 1) + np.add.at(full_tables[oi], full_key, 1) + + scored_nll = torch.from_numpy(seg_nll_np).to(dtype=torch.float64, device=device) + + loss_sum += scored_nll.sum() + token_count += float(seg_len) + tgt = y_batch[i, s:wlen] + prev = x_batch[i, s:wlen] + tb = base_bytes_lut[tgt].to(torch.float64) + tb += (has_leading_space_lut[tgt] & ~is_boundary_token_lut[prev]).to(torch.float64) + byte_count += tb.sum() + + if dist.is_available() and dist.is_initialized(): + dist.all_reduce(loss_sum, op=dist.ReduceOp.SUM) + dist.all_reduce(token_count, op=dist.ReduceOp.SUM) + dist.all_reduce(byte_count, op=dist.ReduceOp.SUM) + + val_loss = (loss_sum / token_count).item() + bits_per_token = val_loss / math.log(2.0) + tokens_per_byte = token_count.item() / byte_count.item() + base_model.train() + return val_loss, bits_per_token * tokens_per_byte + + +# ----------------------------- +# TEST-TIME TRAINING (TTT) +# ----------------------------- + +def _clear_rotary_caches(model: nn.Module) -> None: + """Clear cached RoPE tensors to avoid 'Inference tensors cannot be saved for backward'.""" + for m in model.modules(): + if isinstance(m, Rotary): + m._cos_cached = None + m._sin_cached = None + m._seq_len_cached = 0 + + +def ttt_adapt(args: Hyperparameters, base_model: nn.Module, device: torch.device, + val_tokens: Tensor, rank: int = 0, world_size: int = 1, + log_fn=None) -> None: + """Score-first TTT: process val data in chunks, score each chunk first + (inference_mode), then train on scored tokens. Compliant with Issue #677.""" + seq_len = args.train_seq_len + total_tokens = val_tokens.numel() - 1 + chunk_tokens = args.ttt_chunk_tokens + batch_seqs = args.ttt_batch_seqs + + # Freeze early blocks + if args.ttt_freeze_blocks > 0: + for i, block in enumerate(base_model.blocks): + if i < args.ttt_freeze_blocks: + for p in block.parameters(): + p.requires_grad_(False) + + ttt_params = [p for p in base_model.parameters() if p.requires_grad] + if args.ttt_optimizer == "adamw": + optimizer = torch.optim.AdamW(ttt_params, lr=args.ttt_lr, weight_decay=0.0) + else: + optimizer = torch.optim.SGD(ttt_params, lr=args.ttt_lr, momentum=args.ttt_momentum) + + t0 = time.perf_counter() + chunk_idx = 0 + + for chunk_start in range(0, total_tokens - seq_len, chunk_tokens): + chunk_end = min(chunk_start + chunk_tokens, total_tokens) + chunk_len = chunk_end - chunk_start + n_seqs = chunk_len // seq_len + if n_seqs == 0: + break + + my_start = (n_seqs * rank) // world_size + my_end = (n_seqs * (rank + 1)) // world_size + if my_end <= my_start: + continue + + # Phase 1: Score chunk under inference_mode (forward only) + base_model.eval() + with torch.inference_mode(): + for si in range(my_start, my_end, batch_seqs): + se = min(si + batch_seqs, my_end) + raw_s = chunk_start + si * seq_len + raw_e = chunk_start + se * seq_len + 1 + local = val_tokens[raw_s:raw_e].to(device=device, dtype=torch.int64) + x = local[:-1].reshape(-1, seq_len) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16): + base_model.forward_logits(x) + + # Phase 2: Train on scored tokens (K epochs) + base_model.train() + for epoch in range(args.ttt_epochs): + for si in range(my_start, my_end, batch_seqs): + se = min(si + batch_seqs, my_end) + raw_s = chunk_start + si * seq_len + raw_e = chunk_start + se * seq_len + 1 + local = val_tokens[raw_s:raw_e].to(device=device, dtype=torch.int64) + x = local[:-1].reshape(-1, seq_len) + y = local[1:].reshape(-1, seq_len) + + optimizer.zero_grad(set_to_none=True) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16): + loss = base_model(x, y) + loss.backward() + + if world_size > 1: + for p in ttt_params: + if p.grad is not None: + dist.all_reduce(p.grad, op=dist.ReduceOp.AVG) + + torch.nn.utils.clip_grad_norm_(ttt_params, 1.0) + optimizer.step() + + chunk_idx += 1 + if log_fn and chunk_idx % 20 == 0: + log_fn(f"ttt:chunk={chunk_idx} elapsed={time.perf_counter()-t0:.1f}s") + + # Restore all params + for p in base_model.parameters(): + p.requires_grad_(True) + + if log_fn: + log_fn(f"ttt:done chunks={chunk_idx} elapsed={time.perf_counter()-t0:.1f}s") + + +def eval_val_sliding_ttt( + args: Hyperparameters, + base_model: nn.Module, + rank: int, + world_size: int, + device: torch.device, + val_tokens: Tensor, + base_bytes_lut: Tensor, + has_leading_space_lut: Tensor, + is_boundary_token_lut: Tensor, + stride: int, + batch_seqs: int = 32, + eval_seq_len: int | None = None, + log_fn=None, +) -> tuple[float, float]: + """Legal single-pass TTT: score each chunk with sliding windows, then train on it. + Tokens are always scored BEFORE any training on their chunk, so the evaluation + is never contaminated by future information.""" + seq_len = eval_seq_len or args.train_seq_len + total_tokens = val_tokens.numel() - 1 + ttt_chunk = args.ttt_chunk_tokens + + # Build window starts (same logic as eval_val_sliding) + window_starts = [ws for ws in range(0, total_tokens, stride) + if min(ws + seq_len, total_tokens) - ws >= stride or ws == 0] + + # Map each window to the chunk that contains its first scored token + num_chunks = (total_tokens + ttt_chunk - 1) // ttt_chunk + chunk_windows: list[list[int]] = [[] for _ in range(num_chunks)] + for ws in window_starts: + end = min(ws + seq_len, total_tokens) + wlen = end - ws + s = 0 if ws == 0 else max(wlen - stride, 0) + scored_start = ws + s + ci = min(scored_start // ttt_chunk, num_chunks - 1) + chunk_windows[ci].append(ws) + + if log_fn: + log_fn(f"ttt_sliding:start chunks={num_chunks} chunk_tokens={ttt_chunk} " + f"total_windows={len(window_starts)} stride={stride} " + f"ttt_lr={args.ttt_lr} ttt_epochs={args.ttt_epochs} " + f"ttt_optimizer={args.ttt_optimizer} freeze_blocks={args.ttt_freeze_blocks}") + + loss_sum = torch.zeros((), device=device, dtype=torch.float64) + token_count = torch.zeros((), device=device, dtype=torch.float64) + byte_count = torch.zeros((), device=device, dtype=torch.float64) + + # Freeze first N blocks + n_blocks = len(base_model.blocks) + frozen_block_ids = set(range(min(args.ttt_freeze_blocks, n_blocks))) + + ttt_params = [] + for name, p in base_model.named_parameters(): + freeze = False + for bi in frozen_block_ids: + if f"blocks.{bi}." in name: + freeze = True + break + if freeze: + p.requires_grad_(False) + else: + p.requires_grad_(True) + ttt_params.append(p) + + if log_fn: + log_fn(f"ttt_sliding:params unfrozen={sum(p.numel() for p in ttt_params)} " + f"frozen={sum(p.numel() for p in base_model.parameters() if not p.requires_grad)}") + + if args.ttt_optimizer == "adamw": + optimizer = torch.optim.AdamW(ttt_params, lr=args.ttt_lr, weight_decay=0.0) + else: + optimizer = torch.optim.SGD(ttt_params, lr=args.ttt_lr, momentum=args.ttt_momentum) + + t0 = time.perf_counter() + + for ci in range(num_chunks): + windows = chunk_windows[ci] + if not windows: + continue + + # --- Phase 1: SCORE this chunk's windows --- + my_s = (len(windows) * rank) // world_size + my_e = (len(windows) * (rank + 1)) // world_size + my_windows = windows[my_s:my_e] + + base_model.eval() + with torch.inference_mode(): + for bi in range(0, len(my_windows), batch_seqs): + batch_ws = my_windows[bi:bi + batch_seqs] + bsz = len(batch_ws) + x_batch = torch.zeros(bsz, seq_len, dtype=torch.int64, device=device) + y_batch = torch.zeros(bsz, seq_len, dtype=torch.int64, device=device) + wlens: list[int] = [] + for i, ws in enumerate(batch_ws): + end = min(ws + seq_len, total_tokens) + wlen = end - ws + wlens.append(wlen) + chunk_tok = val_tokens[ws:end + 1].to(dtype=torch.int64, device=device) + x_batch[i, :wlen] = chunk_tok[:-1] + y_batch[i, :wlen] = chunk_tok[1:] + with torch.autocast(device_type="cuda", dtype=torch.bfloat16): + logits = base_model.forward_logits(x_batch) + nll = F.cross_entropy( + logits.reshape(-1, logits.size(-1)).float(), + y_batch.reshape(-1), reduction="none", + ).reshape(bsz, seq_len) + for i, ws in enumerate(batch_ws): + wlen = wlens[i] + s = 0 if ws == 0 else max(wlen - stride, 0) + scored_nll = nll[i, s:wlen].to(torch.float64) + loss_sum += scored_nll.sum() + token_count += float(wlen - s) + tgt, prev = y_batch[i, s:wlen], x_batch[i, s:wlen] + tb = base_bytes_lut[tgt].to(torch.float64) + tb += (has_leading_space_lut[tgt] & ~is_boundary_token_lut[prev]).to(torch.float64) + byte_count += tb.sum() + + # --- Phase 2: TRAIN on this chunk's tokens (already scored = legal) --- + _clear_rotary_caches(base_model) + is_last_chunk = (ci == num_chunks - 1) + if not is_last_chunk and args.ttt_epochs > 0: + base_model.train() + chunk_start = ci * ttt_chunk + chunk_end = min((ci + 1) * ttt_chunk, total_tokens) + chunk_seqs = (chunk_end - chunk_start) // seq_len + if chunk_seqs > 0: + # Cosine LR schedule with 5% warmup + warmup_chunks = max(num_chunks // 20, 1) + if ci < warmup_chunks: + lr_scale = (ci + 1) / warmup_chunks + else: + progress = (ci - warmup_chunks) / max(num_chunks - 1 - warmup_chunks, 1) + lr_scale = 0.5 * (1.0 + math.cos(math.pi * progress)) + cos_lr = args.ttt_lr * lr_scale + for pg in optimizer.param_groups: + pg['lr'] = cos_lr + + my_seq_s = (chunk_seqs * rank) // world_size + my_seq_e = (chunk_seqs * (rank + 1)) // world_size + my_chunk_seqs = my_seq_e - my_seq_s + + for _ep in range(args.ttt_epochs): + for bs in range(0, my_chunk_seqs, args.ttt_batch_seqs): + be = min(bs + args.ttt_batch_seqs, my_chunk_seqs) + actual_bs = my_seq_s + bs + actual_be = my_seq_s + be + start_tok = chunk_start + actual_bs * seq_len + end_tok = chunk_start + actual_be * seq_len + 1 + if end_tok > val_tokens.numel(): + continue + local = val_tokens[start_tok:end_tok].to(device=device, dtype=torch.int64) + x = local[:-1].reshape(-1, seq_len) + y = local[1:].reshape(-1, seq_len) + optimizer.zero_grad(set_to_none=True) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16): + loss = base_model(x, y) + loss.backward() + if world_size > 1: + for p in ttt_params: + if p.grad is not None: + dist.all_reduce(p.grad, op=dist.ReduceOp.AVG) + torch.nn.utils.clip_grad_norm_(ttt_params, 1.0) + optimizer.step() + + if log_fn and (ci % 10 == 0 or ci == num_chunks - 1): + elapsed = time.perf_counter() - t0 + rl = loss_sum.item() / max(token_count.item(), 1) + rbpb = rl / math.log(2.0) * (token_count.item() / max(byte_count.item(), 1)) if token_count.item() > 0 else 0.0 + log_fn(f" ttt_chunk [{ci+1}/{num_chunks}] bpb={rbpb:.6f} time={elapsed:.1f}s") + + if dist.is_available() and dist.is_initialized(): + dist.all_reduce(loss_sum, op=dist.ReduceOp.SUM) + dist.all_reduce(token_count, op=dist.ReduceOp.SUM) + dist.all_reduce(byte_count, op=dist.ReduceOp.SUM) + + val_loss = (loss_sum / token_count).item() + val_bpb = val_loss / math.log(2.0) * (token_count.item() / byte_count.item()) + + # Restore all params and return to eval mode + for p in base_model.parameters(): + p.requires_grad_(True) + base_model.eval() + + if log_fn: + log_fn(f"ttt_sliding:done val_loss={val_loss:.6f} val_bpb={val_bpb:.6f} " + f"elapsed={time.perf_counter() - t0:.1f}s") + return val_loss, val_bpb + + +# ----------------------------- +# INT6 MIXED QUANTIZATION (transplanted from working diagnostic scripts) +# ----------------------------- + +def _classify_param(name: str) -> str: + if "tok_emb" in name or "lm_head" in name: + return "embed" + if ".mlp." in name: + return "mlp" + if ".attn." in name or (".proj." in name and ".mlp." not in name): + return "attn" + return "other" + +def quantize_int6_per_row(t: Tensor, qmax: int = 31) -> tuple[Tensor, Tensor]: + t32 = t.float() + qmin = -qmax - 1 + pct = CastedLinear._quant_percentile + if t32.ndim == 2: + row_max = (torch.quantile(t32.abs(), pct, dim=1) if pct < 1.0 + else t32.abs().amax(dim=1)) + scale = (row_max / float(qmax)).clamp_min(1.0 / float(qmax)).to(torch.float16) + clipped = t32.clamp(-row_max[:, None], row_max[:, None]) + q = torch.clamp(torch.round(clipped / scale.float()[:, None]), qmin, qmax).to(torch.int8) + return q, scale + amax = t32.abs().max().item() + scale = torch.tensor(amax / float(qmax) if amax > 0 else 1.0, dtype=torch.float16) + q = torch.clamp(torch.round(t32 / scale.float()), qmin, qmax).to(torch.int8) + return q, scale + +def mixed_quantize_int6(state_dict: dict[str, Tensor], int6_cats: set[str], + int5_layers: set[int] | None = None): + num_layers_total = max( + (int(k.split(".")[1]) for k in state_dict if k.startswith("blocks.")), + default=0, + ) + 1 + late_k_layers = set(range(num_layers_total - 2, num_layers_total)) + if int5_layers is None: + int5_layers = set() + + result: dict[str, Tensor] = {} + meta: dict[str, object] = {} + for name, tensor in state_dict.items(): + t = tensor.detach().cpu().contiguous() + cat = _classify_param(name) + if not t.is_floating_point() or t.numel() <= 65536: + result[name] = t.to(torch.float16) if t.is_floating_point() else t + meta[name] = "passthrough" + continue + if any(p in name for p in CONTROL_TENSOR_NAME_PATTERNS): + result[name] = t.float() + meta[name] = "passthrough_ctrl" + continue + # Determine layer index for int5 fallback + layer_idx = -1 + if name.startswith("blocks."): + try: + layer_idx = int(name.split(".")[1]) + except (IndexError, ValueError): + pass + # tok_emb.weight falls through to int8 via "embed" category + if cat in int6_cats and t.ndim >= 1: + qmax = 15 if layer_idx in int5_layers else 31 + q, s = quantize_int6_per_row(t, qmax=qmax) + result[name + ".q"] = q + result[name + ".scale"] = s + meta[name] = {"type": "int5" if qmax == 15 else "int6"} + else: + q, s = quantize_float_tensor(t) + result[name + ".q"] = q + result[name + ".scale"] = s + meta[name] = {"type": "int8"} + return result, meta + +def dequantize_mixed_int6(result: dict[str, Tensor], meta: dict[str, object], + template_sd: dict[str, Tensor]) -> dict[str, Tensor]: + out: dict[str, Tensor] = {} + for name, orig in template_sd.items(): + info = meta.get(name) + if info is None: + continue + orig_dtype = orig.dtype + if info in ("passthrough", "passthrough_ctrl", "passthrough_fp16"): + t = result[name] + if t.dtype == torch.float16 and orig_dtype in (torch.float32, torch.bfloat16): + t = t.to(orig_dtype) + out[name] = t + continue + q, s = result[name + ".q"], result[name + ".scale"] + if s.ndim > 0: + out[name] = (q.float() * s.float().view(q.shape[0], *([1] * (q.ndim - 1)))).to(orig_dtype) + else: + out[name] = (q.float() * float(s.item())).to(orig_dtype) + return out + + +# ----------------------------- +# TRAINING +# ----------------------------- + +def main() -> None: + global zeropower_via_newtonschulz5 + + code = Path(__file__).read_text(encoding="utf-8") + args = Hyperparameters() + zeropower_via_newtonschulz5 = torch.compile(zeropower_via_newtonschulz5) + + # ----------------------------- + # DISTRIBUTED + CUDA SETUP + # ----------------------------- + + distributed = "RANK" in os.environ and "WORLD_SIZE" in os.environ and int(os.environ.get("WORLD_SIZE", "1")) > 1 + rank = int(os.environ.get("RANK", "0")) + world_size = int(os.environ.get("WORLD_SIZE", "1")) + local_rank = int(os.environ.get("LOCAL_RANK", "0")) + if world_size <= 0: + raise ValueError(f"WORLD_SIZE must be positive, got {world_size}") + if 8 % world_size != 0: + raise ValueError(f"WORLD_SIZE={world_size} must divide 8 so grad_accum_steps stays integral") + grad_accum_steps = 8 // world_size + grad_scale = 1.0 / grad_accum_steps + if not torch.cuda.is_available(): + raise RuntimeError("CUDA is required") + device = torch.device("cuda", local_rank) + torch.cuda.set_device(device) + if distributed: + dist.init_process_group(backend="nccl", device_id=device) + dist.barrier() + master_process = rank == 0 + + # Fast math knobs + torch.backends.cuda.matmul.allow_tf32 = True + torch.backends.cudnn.allow_tf32 = True + from torch.backends.cuda import enable_cudnn_sdp, enable_flash_sdp, enable_math_sdp, enable_mem_efficient_sdp + + if _USE_FA3: + enable_cudnn_sdp(False) + enable_flash_sdp(True) + enable_mem_efficient_sdp(False) + enable_math_sdp(False) + else: + enable_cudnn_sdp(False) + enable_flash_sdp(True) + enable_mem_efficient_sdp(True) + enable_math_sdp(True) + + logfile = None + if master_process: + os.makedirs("logs", exist_ok=True) + logfile = f"logs/{args.run_id}.txt" + print(logfile) + + def log0(msg: str, console: bool = True) -> None: + if not master_process: + return + if console: + print(msg) + if logfile is not None: + with open(logfile, "a", encoding="utf-8") as f: + print(msg, file=f) + + log0(code, console=False) + log0("=" * 100, console=False) + log0(f"Running Python {sys.version}", console=False) + log0(f"Running PyTorch {torch.__version__}", console=False) + log0( + subprocess.run(["nvidia-smi"], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True, check=False).stdout, + console=False, + ) + log0("=" * 100, console=False) + + # ----------------------------- + # TOKENIZER + VALIDATION METRIC SETUP + # ----------------------------- + + random.seed(args.seed) + np.random.seed(args.seed) + torch.manual_seed(args.seed) + torch.cuda.manual_seed_all(args.seed) + + if not args.tokenizer_path.endswith(".model"): + raise ValueError(f"Script only setup for SentencePiece .model file: {args.tokenizer_path}") + sp = spm.SentencePieceProcessor(model_file=args.tokenizer_path) + if int(sp.vocab_size()) != args.vocab_size: + raise ValueError( + f"VOCAB_SIZE={args.vocab_size} does not match tokenizer vocab_size={int(sp.vocab_size())}" + ) + dataset_dir = Path(args.data_path).resolve() + actual_train_files = len(list(dataset_dir.glob("fineweb_train_*.bin"))) + effective_eval_seq_len = args.eval_seq_len if args.eval_seq_len > 0 else args.train_seq_len + val_seq_len = max(args.train_seq_len, effective_eval_seq_len) + val_tokens = load_validation_tokens(args.val_files, val_seq_len) + base_bytes_lut, has_leading_space_lut, is_boundary_token_lut = build_sentencepiece_luts( + sp, args.vocab_size, device + ) + log0(f"val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path={args.tokenizer_path}") + log0(f"train_loader:dataset:{dataset_dir.name} train_shards:{actual_train_files}") + log0(f"val_loader:shards pattern={args.val_files} tokens:{val_tokens.numel() - 1}") + + # ----------------------------- + # MODEL + OPTIMIZER SETUP + # ----------------------------- + + CastedLinear._qat_enabled = args.qat_enabled + + base_model = GPT( + vocab_size=args.vocab_size, + num_layers=args.num_layers, + model_dim=args.model_dim, + num_heads=args.num_heads, + num_kv_heads=args.num_kv_heads, + mlp_mult=args.mlp_mult, + tie_embeddings=args.tie_embeddings, + tied_embed_init_std=args.tied_embed_init_std, + logit_softcap=args.logit_softcap, + rope_base=args.rope_base, + qk_gain_init=args.qk_gain_init, + mtp_num_heads=args.mtp_num_heads, + mtp_loss_weight=args.mtp_loss_weight, + bigram_vocab_size=args.bigram_vocab_size, + bigram_dim=args.bigram_dim, + xsa_last_n=args.xsa_last_n, + rope_dims=args.rope_dims, + ln_scale=args.ln_scale, + value_residual=args.value_residual, + gated_attention=args.gated_attention, + canon_last_n=args.canon_last_n, + canon_kernel=args.canon_kernel, + canon_delta_gate_init=args.canon_delta_gate_init, + flow_enabled=args.flow_enabled, + flow_latent_dim=args.flow_latent_dim, + flow_hidden_dim=args.flow_hidden_dim, + flow_init_scale=args.flow_init_scale, + native_flow_enabled=args.native_flow_enabled, + native_flow_hidden_dim=args.native_flow_hidden_dim, + native_flow_init_scale=args.native_flow_init_scale, + native_flow_loss_weight=args.native_flow_loss_weight, + e2e_ttt_enabled=args.e2e_ttt_enabled, + e2e_ttt_num_heads=args.e2e_ttt_num_heads, + e2e_ttt_mini_batch=args.e2e_ttt_mini_batch, + e2e_ttt_base_lr=args.e2e_ttt_base_lr, + ).to(device).bfloat16() + for module in base_model.modules(): + if isinstance(module, CastedLinear): + module.float() + restore_low_dim_params_to_fp32(base_model) + compiled_model = torch.compile(base_model, dynamic=False, fullgraph=True) + model: nn.Module = DDP(compiled_model, device_ids=[local_rank], broadcast_buffers=False, static_graph=True) if distributed else compiled_model + + # Optimizer split: + # - token embedding (Adam) uses EMBED_LR + # - untied lm_head (Adam) uses HEAD_LR + # - matrix params in transformer blocks use MATRIX_LR via Muon + # - vectors/scalars use SCALAR_LR via Adam + block_named_params = list(base_model.blocks.named_parameters()) + matrix_params = [ + p + for name, p in block_named_params + if p.ndim == 2 and not any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) + ] + if base_model.mtp_num_heads > 0: + matrix_params.extend([p for p in base_model.mtp_heads.parameters() if p.ndim == 2]) + scalar_params = [ + p + for name, p in block_named_params + if p.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) + ] + if base_model.skip_weights.numel() > 0: + scalar_params.append(base_model.skip_weights) + scalar_params.append(base_model.smear.gate) + if base_model.bigram is not None: + scalar_params.append(base_model.bigram.scale) + if base_model.flow_refiner is not None: + scalar_params.extend(list(base_model.flow_refiner.parameters())) + if base_model.native_flow is not None: + scalar_params.extend(list(base_model.native_flow.parameters())) + if base_model.ttt_refiner is not None: + scalar_params.extend(list(base_model.ttt_refiner.parameters())) + token_lr = args.tied_embed_lr if args.tie_embeddings else args.embed_lr + tok_params = [{"params": [base_model.tok_emb.weight], "lr": token_lr, "base_lr": token_lr}] + if base_model.bigram is not None: + tok_params.append({"params": [base_model.bigram.embed.weight], "lr": token_lr, "base_lr": token_lr}) + if base_model.bigram.proj is not None: + matrix_params.append(base_model.bigram.proj.weight) + optimizer_tok = torch.optim.AdamW( + tok_params, + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + weight_decay=args.adam_wd, + fused=True, + ) + optimizer_muon = Muon( + matrix_params, + lr=args.matrix_lr, + momentum=args.muon_momentum, + backend_steps=args.muon_backend_steps, + weight_decay=args.muon_wd, + ) + for group in optimizer_muon.param_groups: + group["base_lr"] = args.matrix_lr + optimizer_scalar = torch.optim.AdamW( + [{"params": scalar_params, "lr": args.scalar_lr, "base_lr": args.scalar_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + weight_decay=args.adam_wd, + fused=True, + ) + optimizers: list[torch.optim.Optimizer] = [optimizer_tok, optimizer_muon, optimizer_scalar] + if base_model.lm_head is not None: + optimizer_head = torch.optim.Adam( + [{"params": [base_model.lm_head.weight], "lr": args.head_lr, "base_lr": args.head_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + fused=True, + ) + optimizers.insert(1, optimizer_head) + + n_params = sum(p.numel() for p in base_model.parameters()) + mtp_params = sum(p.numel() for p in base_model.mtp_heads.parameters()) + log0(f"model_params:{n_params}") + log0(f"mtp_num_heads:{args.mtp_num_heads} mtp_loss_weight:{args.mtp_loss_weight} mtp_params:{mtp_params}") + log0(f"world_size:{world_size} grad_accum_steps:{grad_accum_steps}") + log0(f"sdp_backends:fa3={_USE_FA3} cudnn=False flash=True mem_efficient={not _USE_FA3} math={not _USE_FA3}") + log0(f"attention_mode:gqa num_heads:{args.num_heads} num_kv_heads:{args.num_kv_heads}") + log0( + f"tie_embeddings:{args.tie_embeddings} embed_lr:{token_lr} " + f"head_lr:{args.head_lr if base_model.lm_head is not None else 0.0} " + f"matrix_lr:{args.matrix_lr} scalar_lr:{args.scalar_lr}" + ) + log0( + f"train_batch_tokens:{args.train_batch_tokens} train_seq_len:{args.train_seq_len} " + f"iterations:{args.iterations} warmup_steps:{args.warmup_steps} " + f"max_wallclock_seconds:{args.max_wallclock_seconds:.3f}" + ) + log0(f"seed:{args.seed}") + + # ----------------------------- + # DATA LOADER & MODEL WARMUP + # ----------------------------- + + train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) + + def zero_grad_all() -> None: + for opt in optimizers: + opt.zero_grad(set_to_none=True) + + max_wallclock_ms = 1000.0 * args.max_wallclock_seconds if args.max_wallclock_seconds > 0 else None + + def lr_mul(step: int, elapsed_ms: float) -> float: + if args.warmdown_iters <= 0: + return 1.0 + if max_wallclock_ms is None: + warmdown_start = max(args.iterations - args.warmdown_iters, 0) + return max((args.iterations - step) / max(args.warmdown_iters, 1), 0.0) if warmdown_start <= step < args.iterations else 1.0 + step_ms = elapsed_ms / max(step, 1) + warmdown_ms = args.warmdown_iters * step_ms + remaining_ms = max(max_wallclock_ms - elapsed_ms, 0.0) + return remaining_ms / max(warmdown_ms, 1e-9) if remaining_ms <= warmdown_ms else 1.0 + + # Warmup primes the compiled forward/backward/optimizer paths, then we restore the + # initial weights/optimizer state so measured training starts from the true init. + if args.warmup_steps > 0: + initial_model_state = {name: tensor.detach().cpu().clone() for name, tensor in base_model.state_dict().items()} + initial_optimizer_states = [copy.deepcopy(opt.state_dict()) for opt in optimizers] + model.train() + for warmup_step in range(args.warmup_steps): + zero_grad_all() + for micro_step in range(grad_accum_steps): + if distributed: + model.require_backward_grad_sync = micro_step == grad_accum_steps - 1 + x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + warmup_loss = model(x, y) + (warmup_loss * grad_scale).backward() + for opt in optimizers: + opt.step() + zero_grad_all() + if args.warmup_steps <= 20 or (warmup_step + 1) % 10 == 0 or warmup_step + 1 == args.warmup_steps: + log0(f"warmup_step:{warmup_step + 1}/{args.warmup_steps}") + base_model.load_state_dict(initial_model_state, strict=True) + for opt, state in zip(optimizers, initial_optimizer_states, strict=True): + opt.load_state_dict(state) + zero_grad_all() + if distributed: + model.require_backward_grad_sync = True + train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) + + # ----------------------------- + # MAIN TRAINING LOOP + # ----------------------------- + + swa_state: dict[str, Tensor] | None = None + swa_count = 0 + + ema_state: dict[str, Tensor] | None = None + if args.ema_enabled: + ema_state = {name: t.detach().float().clone() for name, t in base_model.state_dict().items()} + + training_time_ms = 0.0 + stop_after_step: int | None = None + torch.cuda.synchronize() + t0 = time.perf_counter() + + eval_only_path = os.environ.get("EVAL_ONLY", "") + if eval_only_path: + log0(f"eval_only: loading {eval_only_path}, skipping training") + base_model.load_state_dict(torch.load(eval_only_path, map_location=device, weights_only=False), strict=False) + ema_state = None # prevent random EMA from overwriting loaded weights + swa_state = None + swa_count = 0 + args.iterations = 0 # skip training, go straight to eval + + step = 0 + while True: + last_step = step == args.iterations or (stop_after_step is not None and step >= stop_after_step) + + should_validate = last_step or (args.val_loss_every > 0 and step % args.val_loss_every == 0) + if should_validate: + torch.cuda.synchronize() + training_time_ms += 1000.0 * (time.perf_counter() - t0) + val_loss, val_bpb = eval_val( + args, + model, + rank, + world_size, + device, + grad_accum_steps, + val_tokens, + base_bytes_lut, + has_leading_space_lut, + is_boundary_token_lut, + ) + log0( + f"step:{step}/{args.iterations} val_loss:{val_loss:.4f} val_bpb:{val_bpb:.4f} " + f"train_time:{training_time_ms:.0f}ms step_avg:{training_time_ms / max(step, 1):.2f}ms" + ) + torch.cuda.synchronize() + t0 = time.perf_counter() + + if ( + args.checkpoint_every > 0 + and step > 0 + and step % args.checkpoint_every == 0 + and not last_step + and master_process + ): + ckpt_sd = {k: v for k, v in base_model.state_dict().items() if "mtp_heads" not in k} + ckpt_path = f"checkpoint_step{step}_{args.run_id}.pt" + torch.save(ckpt_sd, ckpt_path) + log0(f"checkpoint_saved: {ckpt_path} ({os.path.getsize(ckpt_path)} bytes)") + + if last_step: + if stop_after_step is not None and step < args.iterations: + log0( + f"stopping_early: wallclock_cap train_time:{training_time_ms:.0f}ms " + f"step:{step}/{args.iterations}" + ) + break + + elapsed_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) + scale = lr_mul(step, elapsed_ms) + qat_threshold = float(os.environ.get("QAT_THRESHOLD", "0.1")) + if args.late_qat and scale < qat_threshold and not CastedLinear._qat_enabled: + CastedLinear._qat_enabled = True + CastedLinear._soft_round = args.soft_round_qat + log0(f"late_qat:enabled step:{step} scale:{scale:.4f} soft_round:{args.soft_round_qat}") + if CastedLinear._qat_enabled and CastedLinear._soft_round: + qat_progress = max(0.0, 1.0 - (scale / qat_threshold)) + CastedLinear._soft_round_alpha = 1.0 + 15.0 * qat_progress # 1→16 + zero_grad_all() + train_loss = torch.zeros((), device=device) + for micro_step in range(grad_accum_steps): + if distributed: + model.require_backward_grad_sync = micro_step == grad_accum_steps - 1 + x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + loss = model(x, y) + train_loss += loss.detach() + (loss * grad_scale).backward() + train_loss /= grad_accum_steps + + frac = min(step / args.muon_momentum_warmup_steps, 1.0) if args.muon_momentum_warmup_steps > 0 else 1.0 + muon_momentum = (1 - frac) * args.muon_momentum_warmup_start + frac * args.muon_momentum + for group in optimizer_muon.param_groups: + group["momentum"] = muon_momentum + + for opt in optimizers: + for group in opt.param_groups: + group["lr"] = group["base_lr"] * scale + + if args.grad_clip_norm > 0: + torch.nn.utils.clip_grad_norm_(base_model.parameters(), args.grad_clip_norm) + for opt in optimizers: + opt.step() + zero_grad_all() + + step += 1 + approx_training_time_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) + + if ema_state is not None: + d = args.ema_decay + with torch.no_grad(): + for name, t in base_model.state_dict().items(): + ema_state[name].mul_(d).add_(t.detach().float(), alpha=1.0 - d) + + if args.swa_enabled and scale < 0.5 and step % args.swa_every == 0: + # Tight SWA: collect from EMA state if available, else from raw model + src = ema_state if ema_state is not None else {name: t.detach().float() for name, t in base_model.state_dict().items()} + if swa_state is None: + swa_state = {name: t.clone() for name, t in src.items()} + swa_count = 1 + log0(f"swa:start step:{step} tight={ema_state is not None}") + else: + for name in swa_state: + swa_state[name].add_(src[name]) + swa_count += 1 + + should_log_train = ( + args.train_log_every > 0 + and (step <= 10 or step % args.train_log_every == 0 or stop_after_step is not None) + ) + if should_log_train: + log0( + f"step:{step}/{args.iterations} train_loss:{train_loss.item():.4f} " + f"train_time:{approx_training_time_ms:.0f}ms step_avg:{approx_training_time_ms / step:.2f}ms" + ) + + # Needed to sync whether we've reached the wallclock cap. + reached_cap = max_wallclock_ms is not None and approx_training_time_ms >= max_wallclock_ms + if distributed and max_wallclock_ms is not None: + reached_cap_tensor = torch.tensor(int(reached_cap), device=device) + dist.all_reduce(reached_cap_tensor, op=dist.ReduceOp.MAX) + reached_cap = bool(reached_cap_tensor.item()) + if stop_after_step is None and reached_cap: + stop_after_step = step + + log0( + f"peak memory allocated: {torch.cuda.max_memory_allocated() // 1024 // 1024} MiB " + f"reserved: {torch.cuda.max_memory_reserved() // 1024 // 1024} MiB" + ) + + if args.swa_enabled and swa_state is not None and swa_count > 1: + log0(f"swa:applying Tight SWA averaged {swa_count} EMA checkpoints") + avg_state = {name: (t / swa_count).to(dtype=base_model.state_dict()[name].dtype) + for name, t in swa_state.items()} + del swa_state + if ema_state is not None: + del ema_state + base_model.load_state_dict(avg_state, strict=True) + elif ema_state is not None: + log0("ema:applying EMA weights") + avg_state = {name: t.to(dtype=base_model.state_dict()[name].dtype) + for name, t in ema_state.items()} + del ema_state + base_model.load_state_dict(avg_state, strict=True) + + # ----------------------------- + # SERIALIZATION + ROUNDTRIP VALIDATION + # ----------------------------- + + # Use RUN_ID for unique filenames when multiple jobs share CWD + _run_id = os.environ.get("RUN_ID", "") + _save_prefix = f"final_model_{_run_id}" if _run_id else "final_model" + _pt_path = f"{_save_prefix}.pt" + _ptz_path = f"{_save_prefix}.int6.ptz" + log0(f"save_paths: pt={_pt_path} ptz={_ptz_path}") + + full_state_dict = base_model.state_dict() + export_sd = {k: v for k, v in full_state_dict.items() if "mtp_heads" not in k} + excluded_mtp = sum(int(t.numel()) for k, t in full_state_dict.items() if "mtp_heads" in k) + if excluded_mtp > 0: + log0(f"export_excluding_mtp_params:{excluded_mtp}") + + if master_process: + torch.save(export_sd, _pt_path) + model_bytes = os.path.getsize(_pt_path) + code_bytes = len(code.encode("utf-8")) + log0(f"Serialized model: {model_bytes} bytes") + log0(f"Code size: {code_bytes} bytes") + + sd_cpu = {k: v.detach().cpu() for k, v in export_sd.items()} + code_bytes = len(code.encode("utf-8")) + artifact_limit = 16_000_000 - code_bytes + + # --- Auto-downgrade quantization: try int6 first, fall back to int5 middle layers --- + num_layers_total = max( + (int(k.split(".")[1]) for k in sd_cpu if k.startswith("blocks.")), + default=0, + ) + 1 + _zstd_levels = [int(os.environ.get("ZSTD_LEVEL", "16")), 1, 17, 2] + # Phase 1: pure int6 with multiple zstd levels + quant_result, quant_meta = mixed_quantize_int6(sd_cpu, {"mlp", "attn"}) + quant_buf = io.BytesIO() + torch.save({"w": quant_result, "m": quant_meta}, quant_buf) + quant_raw = quant_buf.getvalue() + quant_blob = None + chosen_level = _zstd_levels[0] + for lvl in _zstd_levels: + blob = zstandard.ZstdCompressor(level=lvl).compress(quant_raw) if _COMPRESSOR == "zstd" else zlib.compress(quant_raw, 9) + if master_process: + log0(f"quant_try int6 zstd-{lvl}: {len(blob)} bytes (limit {artifact_limit})") + if len(blob) <= artifact_limit: + quant_blob = blob + chosen_level = lvl + break + # Phase 2: progressive int5 fallback — one layer at a time from middle outward + if quant_blob is None: + mid = num_layers_total // 2 + # Expand outward from center: L5, L4, L6, L3, L7, L2, L8, ... + candidates = [] + for offset in range(num_layers_total): + for sign in [0, 1]: + layer = mid + offset if sign == 0 else mid - offset + if 0 <= layer < num_layers_total and layer not in candidates: + candidates.append(layer) + int5_layers: set[int] = set() + for layer in candidates: + int5_layers.add(layer) + if master_process: + log0(f"quant_fallback: int5 layers={sorted(int5_layers)}") + quant_result, quant_meta = mixed_quantize_int6(sd_cpu, {"mlp", "attn"}, int5_layers=int5_layers) + quant_buf = io.BytesIO() + torch.save({"w": quant_result, "m": quant_meta}, quant_buf) + quant_raw = quant_buf.getvalue() + for lvl in _zstd_levels: + blob = zstandard.ZstdCompressor(level=lvl).compress(quant_raw) if _COMPRESSOR == "zstd" else zlib.compress(quant_raw, 9) + if master_process: + log0(f"quant_try int5[{len(int5_layers)}L] zstd-{lvl}: {len(blob)} bytes (limit {artifact_limit})") + if len(blob) <= artifact_limit: + quant_blob = blob + chosen_level = lvl + break + if quant_blob is not None: + break + if quant_blob is None: + quant_blob = blob # Use last attempt even if over limit + if master_process: + log0(f"WARNING: artifact still over limit after all fallbacks") + if master_process: + with open(_ptz_path, "wb") as f: + f.write(quant_blob) + quant_file_bytes = len(quant_blob) + log0(f"Serialized model quant+{_COMPRESSOR}-{chosen_level}: {quant_file_bytes} bytes") + log0(f"Total submission size: {quant_file_bytes + code_bytes} bytes") + + # Roundtrip: decompress + dequantize into fresh model + eval + if distributed: + dist.barrier() + with open(_ptz_path, "rb") as f: + quant_blob_disk = f.read() + quant_state = torch.load( + io.BytesIO(zstandard.ZstdDecompressor().decompress(quant_blob_disk) if _COMPRESSOR == "zstd" else zlib.decompress(quant_blob_disk)), + map_location="cpu", + ) + deq_state = dequantize_mixed_int6(quant_state["w"], quant_state["m"], sd_cpu) + + eval_model = GPT( + vocab_size=args.vocab_size, num_layers=args.num_layers, model_dim=args.model_dim, + num_heads=args.num_heads, num_kv_heads=args.num_kv_heads, mlp_mult=args.mlp_mult, + tie_embeddings=args.tie_embeddings, tied_embed_init_std=args.tied_embed_init_std, + logit_softcap=args.logit_softcap, rope_base=args.rope_base, qk_gain_init=args.qk_gain_init, + mtp_num_heads=0, mtp_loss_weight=0.0, + bigram_vocab_size=args.bigram_vocab_size, bigram_dim=args.bigram_dim, + xsa_last_n=args.xsa_last_n, + rope_dims=args.rope_dims, + ln_scale=args.ln_scale, + value_residual=args.value_residual, + gated_attention=args.gated_attention, + canon_last_n=args.canon_last_n, + canon_kernel=args.canon_kernel, + canon_delta_gate_init=args.canon_delta_gate_init, + flow_enabled=args.flow_enabled, + flow_latent_dim=args.flow_latent_dim, + flow_hidden_dim=args.flow_hidden_dim, + flow_init_scale=args.flow_init_scale, + native_flow_enabled=args.native_flow_enabled, + native_flow_hidden_dim=args.native_flow_hidden_dim, + native_flow_init_scale=args.native_flow_init_scale, + native_flow_loss_weight=args.native_flow_loss_weight, + e2e_ttt_enabled=args.e2e_ttt_enabled, + e2e_ttt_num_heads=args.e2e_ttt_num_heads, + e2e_ttt_mini_batch=args.e2e_ttt_mini_batch, + e2e_ttt_base_lr=args.e2e_ttt_base_lr, + ).to(device).bfloat16() + for m in eval_model.modules(): + if isinstance(m, CastedLinear): + m.float() + restore_low_dim_params_to_fp32(eval_model) + eval_model.load_state_dict(deq_state, strict=True) + + # TTT: adapt model on validation data before eval + legal_ttt = bool(int(os.environ.get("LEGAL_TTT", "0"))) + if args.ttt_enabled and not legal_ttt: + # --- Invalid two-pass TTT (adapt then eval separately) --- + if distributed: + dist.barrier() + for block in eval_model.blocks: + block.attn.rotary._cos_cached = None + block.attn.rotary._sin_cached = None + block.attn.rotary._seq_len_cached = 0 + log0(f"ttt:start score-first optimizer={args.ttt_optimizer} lr={args.ttt_lr} " + f"epochs={args.ttt_epochs} freeze_blocks={args.ttt_freeze_blocks} " + f"chunk_tokens={args.ttt_chunk_tokens}") + t_ttt = time.perf_counter() + ttt_adapt(args, eval_model, device, val_tokens, + rank=rank, world_size=world_size, log_fn=log0) + log0(f"ttt:elapsed={time.perf_counter() - t_ttt:.1f}s") + if distributed: + dist.barrier() + + compiled_eval = torch.compile(eval_model, dynamic=False, fullgraph=True) + + # Standard non-overlapping eval (sanity check) + torch.cuda.synchronize() + t_qeval = time.perf_counter() + q_val_loss, q_val_bpb = eval_val( + args, compiled_eval, rank, world_size, device, grad_accum_steps, + val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, + eval_seq_len=effective_eval_seq_len, + ) + torch.cuda.synchronize() + log0( + f"final_int6_roundtrip val_loss:{q_val_loss:.4f} val_bpb:{q_val_bpb:.4f} " + f"eval_time:{1000.0 * (time.perf_counter() - t_qeval):.0f}ms" + ) + log0(f"final_int6_roundtrip_exact val_loss:{q_val_loss:.8f} val_bpb:{q_val_bpb:.8f}") + + # Sliding window eval (submission score) + sw_seq_len = effective_eval_seq_len + if args.eval_stride > 0 and args.eval_stride < sw_seq_len: + torch.cuda.synchronize() + t_slide = time.perf_counter() + if legal_ttt and args.ttt_enabled: + # Legal single-pass TTT: score → train interleaved per chunk + log0(f"legal_ttt:start stride={args.eval_stride} " + f"optimizer={args.ttt_optimizer} lr={args.ttt_lr} " + f"epochs={args.ttt_epochs} freeze_blocks={args.ttt_freeze_blocks}") + sw_val_loss, sw_val_bpb = eval_val_sliding_ttt( + args, eval_model, rank, world_size, device, + val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, + stride=args.eval_stride, + eval_seq_len=sw_seq_len, + log_fn=log0, + ) + else: + sw_val_loss, sw_val_bpb = eval_val_sliding( + args, eval_model, rank, world_size, device, + val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, + stride=args.eval_stride, + eval_seq_len=sw_seq_len, + ) + torch.cuda.synchronize() + log0( + f"final_int6_sliding_window val_loss:{sw_val_loss:.4f} val_bpb:{sw_val_bpb:.4f} " + f"stride:{args.eval_stride} eval_time:{1000.0 * (time.perf_counter() - t_slide):.0f}ms" + ) + log0(f"final_int6_sliding_window_exact val_loss:{sw_val_loss:.8f} val_bpb:{sw_val_bpb:.8f}") + + # Second sliding window eval at stride=64 for submission comparison + if args.eval_stride != 64 and 64 < sw_seq_len: + torch.cuda.synchronize() + t_slide64 = time.perf_counter() + sw64_val_loss, sw64_val_bpb = eval_val_sliding( + args, eval_model, rank, world_size, device, + val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, + stride=64, + eval_seq_len=sw_seq_len, + ) + torch.cuda.synchronize() + log0( + f"final_int6_sliding_window_s64 val_loss:{sw64_val_loss:.4f} val_bpb:{sw64_val_bpb:.4f} " + f"stride:64 eval_time:{1000.0 * (time.perf_counter() - t_slide64):.0f}ms" + ) + log0(f"final_int6_sliding_window_s64_exact val_loss:{sw64_val_loss:.8f} val_bpb:{sw64_val_bpb:.8f}") + + if distributed: + dist.destroy_process_group() + + +if __name__ == "__main__": + main() + + +==================================================================================================== +Running Python 3.12.7 | packaged by conda-forge | (main, Oct 4 2024, 16:05:46) [GCC 13.3.0] +Running PyTorch 2.10.0+cu128 +Sun Mar 29 20:03:41 2026 ++-----------------------------------------------------------------------------------------+ +| NVIDIA-SMI 570.172.08 Driver Version: 570.172.08 CUDA Version: 12.8 | +|-----------------------------------------+------------------------+----------------------+ +| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC | +| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. | +| | | MIG M. | +|=========================================+========================+======================| +| 0 NVIDIA A100-PCIE-40GB On | 00000000:65:00.0 Off | 0 | +| N/A 34C P0 46W / 250W | 423MiB / 40960MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ + ++-----------------------------------------------------------------------------------------+ +| Processes: | +| GPU GI CI PID Type Process name GPU Memory | +| ID ID Usage | +|=========================================================================================| +| 0 N/A N/A 3356590 C ...ameter_golf/.venv/bin/python3 414MiB | ++-----------------------------------------------------------------------------------------+ + +==================================================================================================== +val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path=/hpfs/scratch/gpfs/mcclec07/code/parameter_golf/repo/data/tokenizers/fineweb_1024_bpe.model +train_loader:dataset:fineweb10B_sp1024 train_shards:80 +val_loader:shards pattern=/hpfs/scratch/gpfs/mcclec07/code/parameter_golf/repo/data/datasets/fineweb10B_sp1024/fineweb_val_*.bin tokens:62021632 +model_params:27530952 +mtp_num_heads:0 mtp_loss_weight:0.2 mtp_params:0 +world_size:1 grad_accum_steps:8 +sdp_backends:fa3=False cudnn=False flash=True mem_efficient=True math=True +attention_mode:gqa num_heads:8 num_kv_heads:4 +tie_embeddings:True embed_lr:0.035 head_lr:0.0 matrix_lr:0.025 scalar_lr:0.025 +train_batch_tokens:786432 train_seq_len:2048 iterations:7000 warmup_steps:20 max_wallclock_seconds:600.000 +seed:42 +warmup_step:1/20 +warmup_step:2/20 +warmup_step:3/20 +warmup_step:4/20 +warmup_step:5/20 +warmup_step:6/20 +warmup_step:7/20 +warmup_step:8/20 +warmup_step:9/20 +warmup_step:10/20 +warmup_step:11/20 +warmup_step:12/20 +warmup_step:13/20 +warmup_step:14/20 +warmup_step:15/20 +warmup_step:16/20 +warmup_step:17/20 +warmup_step:18/20 +warmup_step:19/20 +warmup_step:20/20 +eval_only: loading /hpfs/scratch/gpfs/mcclec07/code/parameter_golf/runs/nflow_55342820/models/final_model_pr940_nflow_55342820.pt, skipping training +step:0/0 val_loss:1.9197 val_bpb:1.1370 train_time:67ms step_avg:67.20ms +peak memory allocated: 25725 MiB reserved: 25972 MiB +save_paths: pt=final_model_eval_nflow7k_nottt_55375246.pt ptz=final_model_eval_nflow7k_nottt_55375246.int6.ptz +Serialized model: 107295209 bytes +Code size: 115032 bytes +quant_try int6 zstd-16: 16242245 bytes (limit 15884968) +quant_try int6 zstd-1: 16298729 bytes (limit 15884968) +quant_try int6 zstd-17: 16250339 bytes (limit 15884968) +quant_try int6 zstd-2: 16301974 bytes (limit 15884968) +quant_fallback: int5 layers=[5] +quant_try int5[1L] zstd-16: 15927696 bytes (limit 15884968) +quant_try int5[1L] zstd-1: 16001920 bytes (limit 15884968) +quant_try int5[1L] zstd-17: 16297960 bytes (limit 15884968) +quant_try int5[1L] zstd-2: 16053283 bytes (limit 15884968) +quant_fallback: int5 layers=[5, 6] +quant_try int5[2L] zstd-16: 15630744 bytes (limit 15884968) +Serialized model quant+zstd-16: 15630744 bytes +Total submission size: 15745776 bytes +final_int6_roundtrip val_loss:1.9363 val_bpb:1.1468 eval_time:67735ms +final_int6_roundtrip_exact val_loss:1.93630739 val_bpb:1.14679030 +final_int6_sliding_window val_loss:1.8963 val_bpb:1.1231 stride:64 eval_time:1702654ms +final_int6_sliding_window_exact val_loss:1.89632895 val_bpb:1.12311579 diff --git a/records/track_non_record_16mb/2026-03-31_11L_NativeFlowMatcher_LegalTTT/supplementary/seed_runs/slurm_eval_s1337_legal_ttt.sh b/records/track_non_record_16mb/2026-03-31_11L_NativeFlowMatcher_LegalTTT/supplementary/seed_runs/slurm_eval_s1337_legal_ttt.sh new file mode 100644 index 0000000000..28db1df401 --- /dev/null +++ b/records/track_non_record_16mb/2026-03-31_11L_NativeFlowMatcher_LegalTTT/supplementary/seed_runs/slurm_eval_s1337_legal_ttt.sh @@ -0,0 +1,103 @@ +#!/bin/bash +############################################################################# +# Eval: NativeFlowMatcher 7k with LEGAL TTT — SEED 1337 (fixed paths) +############################################################################# +#SBATCH --job-name=eval_nf_s1337_lttt +#SBATCH --partition=gpu +#SBATCH --gres=gpu:a100:1 +#SBATCH --nodes=1 +#SBATCH --cpus-per-task=8 +#SBATCH --mem=64G +#SBATCH --time=05:00:00 +#SBATCH --account=medcam +#SBATCH --output=/hpfs/scratch/gpfs/mcclec07/code/parameter_golf/runs/nflow_s1337_55398556/eval/eval_legal_ttt_%j.out +#SBATCH --error=/hpfs/scratch/gpfs/mcclec07/code/parameter_golf/runs/nflow_s1337_55398556/eval/eval_legal_ttt_%j.err +#SBATCH --chdir=/hpfs/scratch/gpfs/mcclec07/code/parameter_golf + +set -euo pipefail + +source /hpfs/scratch/gpfs/mcclec07/code/parameter_golf/.venv/bin/activate +export CUDA_VISIBLE_DEVICES=0 + +export RUN_ID="eval_nflow_s1337_legal_ttt_${SLURM_JOB_ID}" +LOGFILE="/hpfs/scratch/gpfs/mcclec07/code/parameter_golf/runs/nflow_s1337_55398556/eval/eval_legal_ttt_${SLURM_JOB_ID}.txt" + +# ── Hardcoded checkpoint path ─────────────────────────────────────────── +export EVAL_ONLY="/hpfs/scratch/gpfs/mcclec07/code/parameter_golf/runs/nflow_s1337_55398556/models/final_model_pr940_nflow_s1337_55398556.pt" + +export DATA_PATH="/hpfs/scratch/gpfs/mcclec07/code/parameter_golf/repo/data/datasets/fineweb10B_sp1024" +export TOKENIZER_PATH="/hpfs/scratch/gpfs/mcclec07/code/parameter_golf/repo/data/tokenizers/fineweb_1024_bpe.model" +export VOCAB_SIZE=1024 + +# ── Architecture ───────────────────────────────────────────────────────── +export NUM_LAYERS=11 +export MODEL_DIM=512 +export NUM_HEADS=8 +export NUM_KV_HEADS=4 +export MLP_MULT=3 +export TIE_EMBEDDINGS=1 +export BIGRAM_VOCAB_SIZE=4096 +export BIGRAM_DIM=128 +export ROPE_DIMS=16 +export ROPE_BASE=10000 +export LOGIT_SOFTCAP=30.0 +export LN_SCALE=1 +export XSA_LAST_N=11 +export VALUE_RESIDUAL=1 +export GATED_ATTENTION=1 +export QK_GAIN_INIT=1.5 +export LEAKY_RELU=1 +export LEAKY_SLOPE=0.5 + +# ── NativeFlowMatcher ─────────────────────────────────────────────────── +export NATIVE_FLOW_ENABLED=1 +export NATIVE_FLOW_HIDDEN_DIM=256 +export NATIVE_FLOW_INIT_SCALE=0.01 +export NATIVE_FLOW_LOSS_WEIGHT=0.1 + +# ── Disabled modules ──────────────────────────────────────────────────── +export FLOW_ENABLED=0 +export E2E_TTT_ENABLED=0 +export EMA_ENABLED=0 +export SWA_ENABLED=0 +export QAT_ENABLED=0 +export CANON_LAST_N=0 + +# ── Legal TTT config ──────────────────────────────────────────────────── +export TTT_ENABLED=1 +export LEGAL_TTT=1 +export TTT_OPTIMIZER=sgd +export TTT_LR=0.002 +export TTT_EPOCHS=10 +export TTT_FREEZE_BLOCKS=2 +export TTT_BATCH_SEQS=32 +export TTT_CHUNK_TOKENS=32768 +export TTT_GRAD_CLIP=1.0 +export TTT_MOMENTUM=0.9 + +# ── Eval config ────────────────────────────────────────────────────────── +export EVAL_STRIDE=64 +export SEED=1337 + +# ── Training params (unused but required) ──────────────────────────────── +export ITERATIONS=7000 +export WARMDOWN_ITERS=2800 +export WARMUP_STEPS=20 +export TRAIN_BATCH_TOKENS=786432 +export TRAIN_SEQ_LEN=2048 +export MATRIX_LR=0.025 +export SCALAR_LR=0.025 +export TIED_EMBED_LR=0.035 +export MUON_MOMENTUM=0.99 +export MUON_WD=0.04 +export ADAM_WD=0.04 +export GRAD_CLIP_NORM=0.3 + +echo "=== Eval NativeFlow seed=1337 with Legal TTT (fixed paths) ===" +echo "Checkpoint: ${EVAL_ONLY}" +echo "Job ID: ${SLURM_JOB_ID}" +echo "Start: $(date)" + +torchrun --standalone --nproc_per_node=1 train_gpt_pr940.py 2>&1 | tee "$LOGFILE" + +echo "End: $(date)" diff --git a/records/track_non_record_16mb/2026-03-31_11L_NativeFlowMatcher_LegalTTT/supplementary/seed_runs/slurm_eval_s1337_nottt.sh b/records/track_non_record_16mb/2026-03-31_11L_NativeFlowMatcher_LegalTTT/supplementary/seed_runs/slurm_eval_s1337_nottt.sh new file mode 100644 index 0000000000..9437859552 --- /dev/null +++ b/records/track_non_record_16mb/2026-03-31_11L_NativeFlowMatcher_LegalTTT/supplementary/seed_runs/slurm_eval_s1337_nottt.sh @@ -0,0 +1,92 @@ +#!/bin/bash +############################################################################# +# Eval: NativeFlowMatcher 7k NO TTT — SEED 1337 (fixed paths) +############################################################################# +#SBATCH --job-name=eval_nf_s1337_nottt +#SBATCH --partition=gpu +#SBATCH --gres=gpu:a100:1 +#SBATCH --nodes=1 +#SBATCH --cpus-per-task=8 +#SBATCH --mem=64G +#SBATCH --time=01:00:00 +#SBATCH --account=medcam +#SBATCH --output=/hpfs/scratch/gpfs/mcclec07/code/parameter_golf/runs/nflow_s1337_55398556/eval/eval_nottt_%j.out +#SBATCH --error=/hpfs/scratch/gpfs/mcclec07/code/parameter_golf/runs/nflow_s1337_55398556/eval/eval_nottt_%j.err +#SBATCH --chdir=/hpfs/scratch/gpfs/mcclec07/code/parameter_golf + +set -euo pipefail + +source /hpfs/scratch/gpfs/mcclec07/code/parameter_golf/.venv/bin/activate +export CUDA_VISIBLE_DEVICES=0 + +export RUN_ID="eval_nflow_s1337_nottt_${SLURM_JOB_ID}" +LOGFILE="/hpfs/scratch/gpfs/mcclec07/code/parameter_golf/runs/nflow_s1337_55398556/eval/eval_nottt_${SLURM_JOB_ID}.txt" + +# ── Hardcoded checkpoint path ─────────────────────────────────────────── +export EVAL_ONLY="/hpfs/scratch/gpfs/mcclec07/code/parameter_golf/runs/nflow_s1337_55398556/models/final_model_pr940_nflow_s1337_55398556.pt" + +export DATA_PATH="/hpfs/scratch/gpfs/mcclec07/code/parameter_golf/repo/data/datasets/fineweb10B_sp1024" +export TOKENIZER_PATH="/hpfs/scratch/gpfs/mcclec07/code/parameter_golf/repo/data/tokenizers/fineweb_1024_bpe.model" +export VOCAB_SIZE=1024 + +# ── Architecture ───────────────────────────────────────────────────────── +export NUM_LAYERS=11 +export MODEL_DIM=512 +export NUM_HEADS=8 +export NUM_KV_HEADS=4 +export MLP_MULT=3 +export TIE_EMBEDDINGS=1 +export BIGRAM_VOCAB_SIZE=4096 +export BIGRAM_DIM=128 +export ROPE_DIMS=16 +export ROPE_BASE=10000 +export LOGIT_SOFTCAP=30.0 +export LN_SCALE=1 +export XSA_LAST_N=11 +export VALUE_RESIDUAL=1 +export GATED_ATTENTION=1 +export QK_GAIN_INIT=1.5 +export LEAKY_RELU=1 +export LEAKY_SLOPE=0.5 + +# ── NativeFlowMatcher ─────────────────────────────────────────────────── +export NATIVE_FLOW_ENABLED=1 +export NATIVE_FLOW_HIDDEN_DIM=256 +export NATIVE_FLOW_INIT_SCALE=0.01 +export NATIVE_FLOW_LOSS_WEIGHT=0.1 + +# ── Disabled modules ──────────────────────────────────────────────────── +export FLOW_ENABLED=0 +export E2E_TTT_ENABLED=0 +export EMA_ENABLED=0 +export SWA_ENABLED=0 +export QAT_ENABLED=0 +export CANON_LAST_N=0 +export TTT_ENABLED=0 + +# ── Eval config ────────────────────────────────────────────────────────── +export EVAL_STRIDE=64 +export SEED=1337 + +# ── Training params (unused but required) ──────────────────────────────── +export ITERATIONS=7000 +export WARMDOWN_ITERS=2800 +export WARMUP_STEPS=20 +export TRAIN_BATCH_TOKENS=786432 +export TRAIN_SEQ_LEN=2048 +export MATRIX_LR=0.025 +export SCALAR_LR=0.025 +export TIED_EMBED_LR=0.035 +export MUON_MOMENTUM=0.99 +export MUON_WD=0.04 +export ADAM_WD=0.04 +export GRAD_CLIP_NORM=0.3 + +echo "=== Eval NativeFlow seed=1337 – No TTT (fixed paths) ===" +echo "Checkpoint: ${EVAL_ONLY}" +echo "Job ID: ${SLURM_JOB_ID}" +echo "Start: $(date)" + +torchrun --standalone --nproc_per_node=1 train_gpt_pr940.py 2>&1 | tee "$LOGFILE" + +echo "End: $(date)" diff --git a/records/track_non_record_16mb/2026-03-31_11L_NativeFlowMatcher_LegalTTT/supplementary/seed_runs/slurm_eval_s2025_legal_ttt.sh b/records/track_non_record_16mb/2026-03-31_11L_NativeFlowMatcher_LegalTTT/supplementary/seed_runs/slurm_eval_s2025_legal_ttt.sh new file mode 100644 index 0000000000..d1722d3e7e --- /dev/null +++ b/records/track_non_record_16mb/2026-03-31_11L_NativeFlowMatcher_LegalTTT/supplementary/seed_runs/slurm_eval_s2025_legal_ttt.sh @@ -0,0 +1,103 @@ +#!/bin/bash +############################################################################# +# Eval: NativeFlowMatcher 7k with LEGAL TTT — SEED 2025 (fixed paths) +############################################################################# +#SBATCH --job-name=eval_nf_s2025_lttt +#SBATCH --partition=gpu +#SBATCH --gres=gpu:a100:1 +#SBATCH --nodes=1 +#SBATCH --cpus-per-task=8 +#SBATCH --mem=64G +#SBATCH --time=05:00:00 +#SBATCH --account=medcam +#SBATCH --output=/hpfs/scratch/gpfs/mcclec07/code/parameter_golf/runs/nflow_s2025_55398557/eval/eval_legal_ttt_%j.out +#SBATCH --error=/hpfs/scratch/gpfs/mcclec07/code/parameter_golf/runs/nflow_s2025_55398557/eval/eval_legal_ttt_%j.err +#SBATCH --chdir=/hpfs/scratch/gpfs/mcclec07/code/parameter_golf + +set -euo pipefail + +source /hpfs/scratch/gpfs/mcclec07/code/parameter_golf/.venv/bin/activate +export CUDA_VISIBLE_DEVICES=0 + +export RUN_ID="eval_nflow_s2025_legal_ttt_${SLURM_JOB_ID}" +LOGFILE="/hpfs/scratch/gpfs/mcclec07/code/parameter_golf/runs/nflow_s2025_55398557/eval/eval_legal_ttt_${SLURM_JOB_ID}.txt" + +# ── Hardcoded checkpoint path ─────────────────────────────────────────── +export EVAL_ONLY="/hpfs/scratch/gpfs/mcclec07/code/parameter_golf/runs/nflow_s2025_55398557/models/final_model_pr940_nflow_s2025_55398557.pt" + +export DATA_PATH="/hpfs/scratch/gpfs/mcclec07/code/parameter_golf/repo/data/datasets/fineweb10B_sp1024" +export TOKENIZER_PATH="/hpfs/scratch/gpfs/mcclec07/code/parameter_golf/repo/data/tokenizers/fineweb_1024_bpe.model" +export VOCAB_SIZE=1024 + +# ── Architecture ───────────────────────────────────────────────────────── +export NUM_LAYERS=11 +export MODEL_DIM=512 +export NUM_HEADS=8 +export NUM_KV_HEADS=4 +export MLP_MULT=3 +export TIE_EMBEDDINGS=1 +export BIGRAM_VOCAB_SIZE=4096 +export BIGRAM_DIM=128 +export ROPE_DIMS=16 +export ROPE_BASE=10000 +export LOGIT_SOFTCAP=30.0 +export LN_SCALE=1 +export XSA_LAST_N=11 +export VALUE_RESIDUAL=1 +export GATED_ATTENTION=1 +export QK_GAIN_INIT=1.5 +export LEAKY_RELU=1 +export LEAKY_SLOPE=0.5 + +# ── NativeFlowMatcher ─────────────────────────────────────────────────── +export NATIVE_FLOW_ENABLED=1 +export NATIVE_FLOW_HIDDEN_DIM=256 +export NATIVE_FLOW_INIT_SCALE=0.01 +export NATIVE_FLOW_LOSS_WEIGHT=0.1 + +# ── Disabled modules ──────────────────────────────────────────────────── +export FLOW_ENABLED=0 +export E2E_TTT_ENABLED=0 +export EMA_ENABLED=0 +export SWA_ENABLED=0 +export QAT_ENABLED=0 +export CANON_LAST_N=0 + +# ── Legal TTT config ──────────────────────────────────────────────────── +export TTT_ENABLED=1 +export LEGAL_TTT=1 +export TTT_OPTIMIZER=sgd +export TTT_LR=0.002 +export TTT_EPOCHS=10 +export TTT_FREEZE_BLOCKS=2 +export TTT_BATCH_SEQS=32 +export TTT_CHUNK_TOKENS=32768 +export TTT_GRAD_CLIP=1.0 +export TTT_MOMENTUM=0.9 + +# ── Eval config ────────────────────────────────────────────────────────── +export EVAL_STRIDE=64 +export SEED=2025 + +# ── Training params (unused but required) ──────────────────────────────── +export ITERATIONS=7000 +export WARMDOWN_ITERS=2800 +export WARMUP_STEPS=20 +export TRAIN_BATCH_TOKENS=786432 +export TRAIN_SEQ_LEN=2048 +export MATRIX_LR=0.025 +export SCALAR_LR=0.025 +export TIED_EMBED_LR=0.035 +export MUON_MOMENTUM=0.99 +export MUON_WD=0.04 +export ADAM_WD=0.04 +export GRAD_CLIP_NORM=0.3 + +echo "=== Eval NativeFlow seed=2025 with Legal TTT (fixed paths) ===" +echo "Checkpoint: ${EVAL_ONLY}" +echo "Job ID: ${SLURM_JOB_ID}" +echo "Start: $(date)" + +torchrun --standalone --nproc_per_node=1 train_gpt_pr940.py 2>&1 | tee "$LOGFILE" + +echo "End: $(date)" diff --git a/records/track_non_record_16mb/2026-03-31_11L_NativeFlowMatcher_LegalTTT/supplementary/seed_runs/slurm_eval_s2025_nottt.sh b/records/track_non_record_16mb/2026-03-31_11L_NativeFlowMatcher_LegalTTT/supplementary/seed_runs/slurm_eval_s2025_nottt.sh new file mode 100644 index 0000000000..1f2f9b2feb --- /dev/null +++ b/records/track_non_record_16mb/2026-03-31_11L_NativeFlowMatcher_LegalTTT/supplementary/seed_runs/slurm_eval_s2025_nottt.sh @@ -0,0 +1,92 @@ +#!/bin/bash +############################################################################# +# Eval: NativeFlowMatcher 7k NO TTT — SEED 2025 (fixed paths) +############################################################################# +#SBATCH --job-name=eval_nf_s2025_nottt +#SBATCH --partition=gpu +#SBATCH --gres=gpu:a100:1 +#SBATCH --nodes=1 +#SBATCH --cpus-per-task=8 +#SBATCH --mem=64G +#SBATCH --time=01:00:00 +#SBATCH --account=medcam +#SBATCH --output=/hpfs/scratch/gpfs/mcclec07/code/parameter_golf/runs/nflow_s2025_55398557/eval/eval_nottt_%j.out +#SBATCH --error=/hpfs/scratch/gpfs/mcclec07/code/parameter_golf/runs/nflow_s2025_55398557/eval/eval_nottt_%j.err +#SBATCH --chdir=/hpfs/scratch/gpfs/mcclec07/code/parameter_golf + +set -euo pipefail + +source /hpfs/scratch/gpfs/mcclec07/code/parameter_golf/.venv/bin/activate +export CUDA_VISIBLE_DEVICES=0 + +export RUN_ID="eval_nflow_s2025_nottt_${SLURM_JOB_ID}" +LOGFILE="/hpfs/scratch/gpfs/mcclec07/code/parameter_golf/runs/nflow_s2025_55398557/eval/eval_nottt_${SLURM_JOB_ID}.txt" + +# ── Hardcoded checkpoint path ─────────────────────────────────────────── +export EVAL_ONLY="/hpfs/scratch/gpfs/mcclec07/code/parameter_golf/runs/nflow_s2025_55398557/models/final_model_pr940_nflow_s2025_55398557.pt" + +export DATA_PATH="/hpfs/scratch/gpfs/mcclec07/code/parameter_golf/repo/data/datasets/fineweb10B_sp1024" +export TOKENIZER_PATH="/hpfs/scratch/gpfs/mcclec07/code/parameter_golf/repo/data/tokenizers/fineweb_1024_bpe.model" +export VOCAB_SIZE=1024 + +# ── Architecture ───────────────────────────────────────────────────────── +export NUM_LAYERS=11 +export MODEL_DIM=512 +export NUM_HEADS=8 +export NUM_KV_HEADS=4 +export MLP_MULT=3 +export TIE_EMBEDDINGS=1 +export BIGRAM_VOCAB_SIZE=4096 +export BIGRAM_DIM=128 +export ROPE_DIMS=16 +export ROPE_BASE=10000 +export LOGIT_SOFTCAP=30.0 +export LN_SCALE=1 +export XSA_LAST_N=11 +export VALUE_RESIDUAL=1 +export GATED_ATTENTION=1 +export QK_GAIN_INIT=1.5 +export LEAKY_RELU=1 +export LEAKY_SLOPE=0.5 + +# ── NativeFlowMatcher ─────────────────────────────────────────────────── +export NATIVE_FLOW_ENABLED=1 +export NATIVE_FLOW_HIDDEN_DIM=256 +export NATIVE_FLOW_INIT_SCALE=0.01 +export NATIVE_FLOW_LOSS_WEIGHT=0.1 + +# ── Disabled modules ──────────────────────────────────────────────────── +export FLOW_ENABLED=0 +export E2E_TTT_ENABLED=0 +export EMA_ENABLED=0 +export SWA_ENABLED=0 +export QAT_ENABLED=0 +export CANON_LAST_N=0 +export TTT_ENABLED=0 + +# ── Eval config ────────────────────────────────────────────────────────── +export EVAL_STRIDE=64 +export SEED=2025 + +# ── Training params (unused but required) ──────────────────────────────── +export ITERATIONS=7000 +export WARMDOWN_ITERS=2800 +export WARMUP_STEPS=20 +export TRAIN_BATCH_TOKENS=786432 +export TRAIN_SEQ_LEN=2048 +export MATRIX_LR=0.025 +export SCALAR_LR=0.025 +export TIED_EMBED_LR=0.035 +export MUON_MOMENTUM=0.99 +export MUON_WD=0.04 +export ADAM_WD=0.04 +export GRAD_CLIP_NORM=0.3 + +echo "=== Eval NativeFlow seed=2025 – No TTT (fixed paths) ===" +echo "Checkpoint: ${EVAL_ONLY}" +echo "Job ID: ${SLURM_JOB_ID}" +echo "Start: $(date)" + +torchrun --standalone --nproc_per_node=1 train_gpt_pr940.py 2>&1 | tee "$LOGFILE" + +echo "End: $(date)" diff --git a/records/track_non_record_16mb/2026-03-31_11L_NativeFlowMatcher_LegalTTT/supplementary/seed_runs/slurm_nflow_train_s1337.sh b/records/track_non_record_16mb/2026-03-31_11L_NativeFlowMatcher_LegalTTT/supplementary/seed_runs/slurm_nflow_train_s1337.sh new file mode 100644 index 0000000000..abf7f465fd --- /dev/null +++ b/records/track_non_record_16mb/2026-03-31_11L_NativeFlowMatcher_LegalTTT/supplementary/seed_runs/slurm_nflow_train_s1337.sh @@ -0,0 +1,101 @@ +#!/bin/bash +############################################################################# +# Training: NativeFlowMatcher 7k — SEED 1337 (reproducibility run 2/3) +# Matches configuration of job 55342820 (seed=42) exactly +############################################################################# +#SBATCH --job-name=nflow_s1337 +#SBATCH --partition=gpu +#SBATCH --gres=gpu:a100:1 +#SBATCH --nodes=1 +#SBATCH --cpus-per-task=8 +#SBATCH --mem=64G +#SBATCH --time=14:00:00 +#SBATCH --nice=0 +#SBATCH --output=/hpfs/scratch/gpfs/mcclec07/code/parameter_golf/logs/pr940_nflow_s1337_%j.out +#SBATCH --error=/hpfs/scratch/gpfs/mcclec07/code/parameter_golf/logs/pr940_nflow_s1337_%j.err +#SBATCH --chdir=/hpfs/scratch/gpfs/mcclec07/code/parameter_golf +#SBATCH --account=medcam + +echo "=== Job Info ===" +echo "Job ID: $SLURM_JOB_ID" +echo "Node: $(hostname)" +echo "GPUs: $CUDA_VISIBLE_DEVICES" +echo "Start: $(date)" +echo "================" + +source /hpfs/scratch/gpfs/mcclec07/code/parameter_golf/.venv/bin/activate +export CUDA_VISIBLE_DEVICES=0 + +# --- Run & Data --- +export RUN_ID="pr940_nflow_s1337_${SLURM_JOB_ID}" +export SEED=1337 +export DATA_PATH="/hpfs/scratch/gpfs/mcclec07/code/parameter_golf/repo/data/datasets/fineweb10B_sp1024" +export TOKENIZER_PATH="/hpfs/scratch/gpfs/mcclec07/code/parameter_golf/repo/data/tokenizers/fineweb_1024_bpe.model" +export VOCAB_SIZE=1024 + +# --- Training Schedule --- +export MAX_WALLCLOCK_SECONDS=0 +export ITERATIONS=7000 +export VAL_LOSS_EVERY=500 +export WARMDOWN_ITERS=2800 +export WARMUP_STEPS=20 +export TRAIN_BATCH_TOKENS=786432 +export TRAIN_SEQ_LEN=2048 + +# --- Architecture (identical to seed 42 run) --- +export NUM_LAYERS=11 +export MODEL_DIM=512 +export NUM_HEADS=8 +export NUM_KV_HEADS=4 +export MLP_MULT=3 +export TIE_EMBEDDINGS=1 +export BIGRAM_VOCAB_SIZE=4096 +export BIGRAM_DIM=128 +export ROPE_DIMS=16 +export ROPE_BASE=10000 +export LOGIT_SOFTCAP=30.0 +export LN_SCALE=1 +export XSA_LAST_N=11 +export VALUE_RESIDUAL=1 +export GATED_ATTENTION=1 +export QK_GAIN_INIT=1.5 +export LEAKY_RELU=1 +export LEAKY_SLOPE=0.5 + +# --- Optimizer --- +export MATRIX_LR=0.025 +export SCALAR_LR=0.025 +export TIED_EMBED_LR=0.035 +export MUON_MOMENTUM=0.99 +export MUON_WD=0.04 +export ADAM_WD=0.04 +export GRAD_CLIP_NORM=0.3 +export EVAL_STRIDE=64 + +# --- EMA --- +export EMA_ENABLED=1 +export EMA_DECAY=0.997 + +# --- Disabled features --- +export TTT_ENABLED=0 +export CANON_LAST_N=0 +export SWA_ENABLED=0 +export QAT_ENABLED=0 +export FLOW_ENABLED=0 + +# --- NativeFlowMatcher --- +export NATIVE_FLOW_ENABLED=1 +export NATIVE_FLOW_HIDDEN_DIM=256 +export NATIVE_FLOW_INIT_SCALE=0.01 +export NATIVE_FLOW_LOSS_WEIGHT=0.1 + +# --- Log file --- +LOGFILE="/hpfs/scratch/gpfs/mcclec07/code/parameter_golf/logs/pr940_nflow_s1337_${SLURM_JOB_ID}.txt" + +echo "Running NFM training seed=1337 RUN_ID=$RUN_ID" +echo "Log: $LOGFILE" + +torchrun --standalone --nproc_per_node=1 \ + train_gpt_pr940.py 2>&1 | tee "$LOGFILE" + +echo "=== Done: $(date) ===" diff --git a/records/track_non_record_16mb/2026-03-31_11L_NativeFlowMatcher_LegalTTT/supplementary/seed_runs/slurm_nflow_train_s2025.sh b/records/track_non_record_16mb/2026-03-31_11L_NativeFlowMatcher_LegalTTT/supplementary/seed_runs/slurm_nflow_train_s2025.sh new file mode 100644 index 0000000000..56611986e8 --- /dev/null +++ b/records/track_non_record_16mb/2026-03-31_11L_NativeFlowMatcher_LegalTTT/supplementary/seed_runs/slurm_nflow_train_s2025.sh @@ -0,0 +1,101 @@ +#!/bin/bash +############################################################################# +# Training: NativeFlowMatcher 7k — SEED 2025 (reproducibility run 3/3) +# Matches configuration of job 55342820 (seed=42) exactly +############################################################################# +#SBATCH --job-name=nflow_s2025 +#SBATCH --partition=gpu +#SBATCH --gres=gpu:a100:1 +#SBATCH --nodes=1 +#SBATCH --cpus-per-task=8 +#SBATCH --mem=64G +#SBATCH --time=14:00:00 +#SBATCH --nice=0 +#SBATCH --output=/hpfs/scratch/gpfs/mcclec07/code/parameter_golf/logs/pr940_nflow_s2025_%j.out +#SBATCH --error=/hpfs/scratch/gpfs/mcclec07/code/parameter_golf/logs/pr940_nflow_s2025_%j.err +#SBATCH --chdir=/hpfs/scratch/gpfs/mcclec07/code/parameter_golf +#SBATCH --account=medcam + +echo "=== Job Info ===" +echo "Job ID: $SLURM_JOB_ID" +echo "Node: $(hostname)" +echo "GPUs: $CUDA_VISIBLE_DEVICES" +echo "Start: $(date)" +echo "================" + +source /hpfs/scratch/gpfs/mcclec07/code/parameter_golf/.venv/bin/activate +export CUDA_VISIBLE_DEVICES=0 + +# --- Run & Data --- +export RUN_ID="pr940_nflow_s2025_${SLURM_JOB_ID}" +export SEED=2025 +export DATA_PATH="/hpfs/scratch/gpfs/mcclec07/code/parameter_golf/repo/data/datasets/fineweb10B_sp1024" +export TOKENIZER_PATH="/hpfs/scratch/gpfs/mcclec07/code/parameter_golf/repo/data/tokenizers/fineweb_1024_bpe.model" +export VOCAB_SIZE=1024 + +# --- Training Schedule --- +export MAX_WALLCLOCK_SECONDS=0 +export ITERATIONS=7000 +export VAL_LOSS_EVERY=500 +export WARMDOWN_ITERS=2800 +export WARMUP_STEPS=20 +export TRAIN_BATCH_TOKENS=786432 +export TRAIN_SEQ_LEN=2048 + +# --- Architecture (identical to seed 42 run) --- +export NUM_LAYERS=11 +export MODEL_DIM=512 +export NUM_HEADS=8 +export NUM_KV_HEADS=4 +export MLP_MULT=3 +export TIE_EMBEDDINGS=1 +export BIGRAM_VOCAB_SIZE=4096 +export BIGRAM_DIM=128 +export ROPE_DIMS=16 +export ROPE_BASE=10000 +export LOGIT_SOFTCAP=30.0 +export LN_SCALE=1 +export XSA_LAST_N=11 +export VALUE_RESIDUAL=1 +export GATED_ATTENTION=1 +export QK_GAIN_INIT=1.5 +export LEAKY_RELU=1 +export LEAKY_SLOPE=0.5 + +# --- Optimizer --- +export MATRIX_LR=0.025 +export SCALAR_LR=0.025 +export TIED_EMBED_LR=0.035 +export MUON_MOMENTUM=0.99 +export MUON_WD=0.04 +export ADAM_WD=0.04 +export GRAD_CLIP_NORM=0.3 +export EVAL_STRIDE=64 + +# --- EMA --- +export EMA_ENABLED=1 +export EMA_DECAY=0.997 + +# --- Disabled features --- +export TTT_ENABLED=0 +export CANON_LAST_N=0 +export SWA_ENABLED=0 +export QAT_ENABLED=0 +export FLOW_ENABLED=0 + +# --- NativeFlowMatcher --- +export NATIVE_FLOW_ENABLED=1 +export NATIVE_FLOW_HIDDEN_DIM=256 +export NATIVE_FLOW_INIT_SCALE=0.01 +export NATIVE_FLOW_LOSS_WEIGHT=0.1 + +# --- Log file --- +LOGFILE="/hpfs/scratch/gpfs/mcclec07/code/parameter_golf/logs/pr940_nflow_s2025_${SLURM_JOB_ID}.txt" + +echo "Running NFM training seed=2025 RUN_ID=$RUN_ID" +echo "Log: $LOGFILE" + +torchrun --standalone --nproc_per_node=1 \ + train_gpt_pr940.py 2>&1 | tee "$LOGFILE" + +echo "=== Done: $(date) ===" diff --git a/records/track_non_record_16mb/2026-03-31_11L_NativeFlowMatcher_LegalTTT/supplementary/seed_runs/train_s1337.log b/records/track_non_record_16mb/2026-03-31_11L_NativeFlowMatcher_LegalTTT/supplementary/seed_runs/train_s1337.log new file mode 100644 index 0000000000..6007a40fec --- /dev/null +++ b/records/track_non_record_16mb/2026-03-31_11L_NativeFlowMatcher_LegalTTT/supplementary/seed_runs/train_s1337.log @@ -0,0 +1,2706 @@ +logs/pr940_nflow_s1337_55398556.txt +val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path=/hpfs/scratch/gpfs/mcclec07/code/parameter_golf/repo/data/tokenizers/fineweb_1024_bpe.model +train_loader:dataset:fineweb10B_sp1024 train_shards:80 +val_loader:shards pattern=/hpfs/scratch/gpfs/mcclec07/code/parameter_golf/repo/data/datasets/fineweb10B_sp1024/fineweb_val_*.bin tokens:62021632 +model_params:27530952 +mtp_num_heads:0 mtp_loss_weight:0.2 mtp_params:0 +world_size:1 grad_accum_steps:8 +sdp_backends:fa3=False cudnn=False flash=True mem_efficient=True math=True +attention_mode:gqa num_heads:8 num_kv_heads:4 +tie_embeddings:True embed_lr:0.035 head_lr:0.0 matrix_lr:0.025 scalar_lr:0.025 +train_batch_tokens:786432 train_seq_len:2048 iterations:7000 warmup_steps:20 max_wallclock_seconds:0.000 +seed:1337 +warmup_step:1/20 +warmup_step:2/20 +warmup_step:3/20 +warmup_step:4/20 +warmup_step:5/20 +warmup_step:6/20 +warmup_step:7/20 +warmup_step:8/20 +warmup_step:9/20 +warmup_step:10/20 +warmup_step:11/20 +warmup_step:12/20 +warmup_step:13/20 +warmup_step:14/20 +warmup_step:15/20 +warmup_step:16/20 +warmup_step:17/20 +warmup_step:18/20 +warmup_step:19/20 +warmup_step:20/20 +step:0/7000 val_loss:6.9294 val_bpb:4.1040 train_time:0ms step_avg:0.01ms +step:1/7000 train_loss:7.1301 train_time:1870ms step_avg:1870.23ms +step:2/7000 train_loss:9.1323 train_time:3808ms step_avg:1904.02ms +step:3/7000 train_loss:8.0988 train_time:5745ms step_avg:1915.01ms +step:4/7000 train_loss:7.4310 train_time:7684ms step_avg:1921.02ms +step:5/7000 train_loss:7.2609 train_time:9624ms step_avg:1924.71ms +step:6/7000 train_loss:7.2483 train_time:11565ms step_avg:1927.55ms +step:7/7000 train_loss:7.0422 train_time:13508ms step_avg:1929.66ms +step:8/7000 train_loss:6.8783 train_time:15449ms step_avg:1931.17ms +step:9/7000 train_loss:6.5860 train_time:17393ms step_avg:1932.57ms +step:10/7000 train_loss:6.2036 train_time:19340ms step_avg:1934.01ms +step:200/7000 train_loss:2.7335 train_time:393022ms step_avg:1965.11ms +step:400/7000 train_loss:2.5355 train_time:786228ms step_avg:1965.57ms +step:500/7000 val_loss:2.3261 val_bpb:1.3777 train_time:982674ms step_avg:1965.35ms +step:600/7000 train_loss:2.3711 train_time:1179199ms step_avg:1965.33ms +step:800/7000 train_loss:2.3689 train_time:1572034ms step_avg:1965.04ms +step:1000/7000 train_loss:2.3500 train_time:1965102ms step_avg:1965.10ms +step:1000/7000 val_loss:2.2029 val_bpb:1.3047 train_time:1965105ms step_avg:1965.11ms +step:1200/7000 train_loss:2.3094 train_time:2358246ms step_avg:1965.20ms +step:1400/7000 train_loss:2.3491 train_time:2751529ms step_avg:1965.38ms +step:1500/7000 val_loss:2.1612 val_bpb:1.2800 train_time:2948216ms step_avg:1965.48ms +step:1600/7000 train_loss:2.1923 train_time:3144940ms step_avg:1965.59ms +step:1800/7000 train_loss:2.2428 train_time:3538724ms step_avg:1965.96ms +step:2000/7000 train_loss:2.1343 train_time:3932516ms step_avg:1966.26ms +step:2000/7000 val_loss:2.1101 val_bpb:1.2497 train_time:3932520ms step_avg:1966.26ms +step:2200/7000 train_loss:2.2132 train_time:4326232ms step_avg:1966.47ms +step:2400/7000 train_loss:2.1632 train_time:4720234ms step_avg:1966.76ms +step:2500/7000 val_loss:2.0874 val_bpb:1.2363 train_time:4917159ms step_avg:1966.86ms +step:2600/7000 train_loss:2.1639 train_time:5114189ms step_avg:1967.00ms +step:2800/7000 train_loss:2.2085 train_time:5508072ms step_avg:1967.17ms +step:3000/7000 train_loss:2.1608 train_time:5901941ms step_avg:1967.31ms +step:3000/7000 val_loss:2.0742 val_bpb:1.2285 train_time:5901944ms step_avg:1967.31ms +step:3200/7000 train_loss:2.1644 train_time:6295959ms step_avg:1967.49ms +step:3400/7000 train_loss:2.1390 train_time:6690073ms step_avg:1967.67ms +step:3500/7000 val_loss:2.0663 val_bpb:1.2238 train_time:6887089ms step_avg:1967.74ms +step:3600/7000 train_loss:2.1812 train_time:7084186ms step_avg:1967.83ms +step:3800/7000 train_loss:2.1519 train_time:7478201ms step_avg:1967.95ms +step:4000/7000 train_loss:2.2331 train_time:7872398ms step_avg:1968.10ms +step:4000/7000 val_loss:2.0608 val_bpb:1.2205 train_time:7872401ms step_avg:1968.10ms +step:4200/7000 train_loss:2.1234 train_time:8266454ms step_avg:1968.20ms +step:4400/7000 train_loss:2.1051 train_time:8660328ms step_avg:1968.26ms +step:4500/7000 val_loss:2.0439 val_bpb:1.2105 train_time:8857296ms step_avg:1968.29ms +step:4600/7000 train_loss:2.0881 train_time:9054351ms step_avg:1968.34ms +step:4800/7000 train_loss:2.2217 train_time:9448372ms step_avg:1968.41ms +step:5000/7000 train_loss:2.1346 train_time:9842356ms step_avg:1968.47ms +step:5000/7000 val_loss:2.0234 val_bpb:1.1983 train_time:9842359ms step_avg:1968.47ms +step:5200/7000 train_loss:2.1035 train_time:10236364ms step_avg:1968.53ms +step:5400/7000 train_loss:2.0777 train_time:10630253ms step_avg:1968.57ms +step:5500/7000 val_loss:2.0005 val_bpb:1.1848 train_time:10827248ms step_avg:1968.59ms +step:5600/7000 train_loss:2.0621 train_time:11024162ms step_avg:1968.60ms +step:5800/7000 train_loss:2.0492 train_time:11418063ms step_avg:1968.63ms +step:6000/7000 train_loss:2.0302 train_time:11812016ms step_avg:1968.67ms +step:6000/7000 val_loss:1.9772 val_bpb:1.1710 train_time:11812019ms step_avg:1968.67ms +step:6200/7000 train_loss:2.1248 train_time:12205822ms step_avg:1968.68ms +step:6400/7000 train_loss:2.1042 train_time:12599771ms step_avg:1968.71ms +step:6500/7000 val_loss:1.9471 val_bpb:1.1532 train_time:12796711ms step_avg:1968.72ms +step:6600/7000 train_loss:2.0032 train_time:12993629ms step_avg:1968.73ms +step:6800/7000 train_loss:2.0856 train_time:13387398ms step_avg:1968.74ms +step:7000/7000 train_loss:1.9287 train_time:13780882ms step_avg:1968.70ms +step:7000/7000 val_loss:1.9223 val_bpb:1.1385 train_time:13780886ms step_avg:1968.70ms +peak memory allocated: 25832 MiB reserved: 26006 MiB +ema:applying EMA weights +save_paths: pt=final_model_pr940_nflow_s1337_55398556.pt ptz=final_model_pr940_nflow_s1337_55398556.int6.ptz +Serialized model: 107294152 bytes +Code size: 115032 bytes +quant_try int6 zstd-16: 16265007 bytes (limit 15884968) +quant_try int6 zstd-1: 16289666 bytes (limit 15884968) +quant_try int6 zstd-17: 16250866 bytes (limit 15884968) +quant_try int6 zstd-2: 16298408 bytes (limit 15884968) +quant_fallback: int5 layers=[5] +quant_try int5[1L] zstd-16: 15944126 bytes (limit 15884968) +quant_try int5[1L] zstd-1: 15995452 bytes (limit 15884968) +quant_try int5[1L] zstd-17: 16042936 bytes (limit 15884968) +quant_try int5[1L] zstd-2: 16051685 bytes (limit 15884968) +quant_fallback: int5 layers=[5, 6] +quant_try int5[2L] zstd-16: 15621901 bytes (limit 15884968) +Serialized model quant+zstd-16: 15621901 bytes +Total submission size: 15736933 bytes +final_int6_roundtrip val_loss:1.9372 val_bpb:1.1473 eval_time:67819ms +final_int6_roundtrip_exact val_loss:1.93715323 val_bpb:1.14729126 +final_int6_sliding_window val_loss:1.8973 val_bpb:1.1237 stride:64 eval_time:1709300ms +final_int6_sliding_window_exact val_loss:1.89726464 val_bpb:1.12366996 +ss_weight = float(os.environ.get("NATIVE_FLOW_LOSS_WEIGHT", "0.1")) + + # E2E TTT-Linear refiner (Sun et al., 2024) + e2e_ttt_enabled = bool(int(os.environ.get("E2E_TTT_ENABLED", "0"))) + e2e_ttt_num_heads = int(os.environ.get("E2E_TTT_NUM_HEADS", "8")) + e2e_ttt_mini_batch = int(os.environ.get("E2E_TTT_MINI_BATCH", "16")) + e2e_ttt_base_lr = float(os.environ.get("E2E_TTT_BASE_LR", "1.0")) + +# ----------------------------- +# MUON OPTIMIZER +# ----------------------------- +# +# As borrowed from modded-nanogpt +# Background on Muon: https://kellerjordan.github.io/posts/muon/ + +def zeropower_via_newtonschulz5(G: Tensor, steps: int = 10, eps: float = 1e-7) -> Tensor: + a, b, c = (3.4445, -4.7750, 2.0315) + X = G.bfloat16() + X /= X.norm() + eps + transposed = G.size(0) > G.size(1) + if transposed: + X = X.T + for _ in range(steps): + A = X @ X.T + B = b * A + c * A @ A + X = a * X + B @ X + return X.T if transposed else X + + +class Muon(torch.optim.Optimizer): + def __init__(self, params, lr: float, momentum: float, backend_steps: int, + nesterov: bool = True, weight_decay: float = 0.0): + super().__init__( + params, + dict(lr=lr, momentum=momentum, backend_steps=backend_steps, + nesterov=nesterov, weight_decay=weight_decay), + ) + + @torch.no_grad() + def step(self, closure=None): + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + + distributed = dist.is_available() and dist.is_initialized() + world_size = dist.get_world_size() if distributed else 1 + rank = dist.get_rank() if distributed else 0 + + for group in self.param_groups: + params = group["params"] + if not params: + continue + lr = group["lr"] + momentum = group["momentum"] + backend_steps = group["backend_steps"] + nesterov = group["nesterov"] + + total_params = sum(int(p.numel()) for p in params) + updates_flat = torch.zeros(total_params, device=params[0].device, dtype=torch.bfloat16) + + curr = 0 + for i, p in enumerate(params): + if i % world_size == rank and p.grad is not None: + g = p.grad + state = self.state[p] + if "momentum_buffer" not in state: + state["momentum_buffer"] = torch.zeros_like(g) + buf = state["momentum_buffer"] + buf.mul_(momentum).add_(g) + if nesterov: + g = g.add(buf, alpha=momentum) + g = zeropower_via_newtonschulz5(g, steps=backend_steps) + g *= max(1, g.size(0) / g.size(1)) ** 0.5 + updates_flat[curr : curr + p.numel()] = g.reshape(-1) + curr += p.numel() + + if distributed: + dist.all_reduce(updates_flat, op=dist.ReduceOp.SUM) + + wd = group.get("weight_decay", 0.0) + curr = 0 + for p in params: + if wd > 0.0: + p.data.mul_(1.0 - lr * wd) + g = updates_flat[curr : curr + p.numel()].view_as(p).to(dtype=p.dtype) + p.add_(g, alpha=-lr) + curr += p.numel() + + return loss + + +# ----------------------------- +# TOKENIZER-AGNOSTIC EVALUATION SETUP +# ----------------------------- +# +# It's common for small models have a large fraction of their parameters be embeddings, since the 2 * d_model * d_vocab vectors can be gigantic. +# Instead of locking the tokenizer, we let you bring your own and calculate our validation metrics on the average compression of the validation set. +# We calculate BPB (bits-per-byte) instead of validation loss, so we need methods to count the number of bits per token in the tokenizer. +# Note: Submissions that edit the tokenizer will be examined more carefully, since screwing this up might unjustly improve your score. + +def build_sentencepiece_luts( + sp: spm.SentencePieceProcessor, vocab_size: int, device: torch.device +) -> tuple[Tensor, Tensor, Tensor]: + sp_vocab_size = int(sp.vocab_size()) + table_size = max(sp_vocab_size, vocab_size) + base_bytes_np = np.zeros((table_size,), dtype=np.int16) + has_leading_space_np = np.zeros((table_size,), dtype=np.bool_) + is_boundary_token_np = np.ones((table_size,), dtype=np.bool_) + for token_id in range(sp_vocab_size): + if sp.is_control(token_id) or sp.is_unknown(token_id) or sp.is_unused(token_id): + continue + is_boundary_token_np[token_id] = False + if sp.is_byte(token_id): + base_bytes_np[token_id] = 1 + continue + piece = sp.id_to_piece(token_id) + if piece.startswith("▁"): + has_leading_space_np[token_id] = True + piece = piece[1:] + base_bytes_np[token_id] = len(piece.encode("utf-8")) + return ( + torch.tensor(base_bytes_np, dtype=torch.int16, device=device), + torch.tensor(has_leading_space_np, dtype=torch.bool, device=device), + torch.tensor(is_boundary_token_np, dtype=torch.bool, device=device), + ) + + +def load_validation_tokens(pattern: str, seq_len: int) -> Tensor: + files = [Path(p) for p in sorted(glob.glob(pattern))] + if not files: + raise FileNotFoundError(f"No files found for pattern: {pattern}") + # The export pipeline writes the fixed first-50k-doc validation set to fineweb_val_*. + tokens = torch.cat([load_data_shard(file) for file in files]).contiguous() + usable = ((tokens.numel() - 1) // seq_len) * seq_len + if usable <= 0: + raise ValueError(f"Validation split is too short for TRAIN_SEQ_LEN={seq_len}") + return tokens[: usable + 1] + + +def eval_val( + args: Hyperparameters, + model: nn.Module, + rank: int, + world_size: int, + device: torch.device, + grad_accum_steps: int, + val_tokens: Tensor, + base_bytes_lut: Tensor, + has_leading_space_lut: Tensor, + is_boundary_token_lut: Tensor, + eval_seq_len: int | None = None, +) -> tuple[float, float]: + seq_len = eval_seq_len or args.train_seq_len + local_batch_tokens = args.val_batch_size // (world_size * grad_accum_steps) + if local_batch_tokens < seq_len: + raise ValueError( + "VAL_BATCH_SIZE must provide at least one sequence per rank; " + f"got VAL_BATCH_SIZE={args.val_batch_size}, WORLD_SIZE={world_size}, " + f"GRAD_ACCUM_STEPS={grad_accum_steps}, seq_len={seq_len}" + ) + local_batch_seqs = local_batch_tokens // seq_len + total_seqs = (val_tokens.numel() - 1) // seq_len + seq_start = (total_seqs * rank) // world_size + seq_end = (total_seqs * (rank + 1)) // world_size + val_loss_sum = torch.zeros((), device=device, dtype=torch.float64) + val_token_count = torch.zeros((), device=device, dtype=torch.float64) + val_byte_count = torch.zeros((), device=device, dtype=torch.float64) + + model.eval() + with torch.inference_mode(): + for batch_seq_start in range(seq_start, seq_end, local_batch_seqs): + batch_seq_end = min(batch_seq_start + local_batch_seqs, seq_end) + raw_start = batch_seq_start * seq_len + raw_end = batch_seq_end * seq_len + 1 + local = val_tokens[raw_start:raw_end].to(device=device, dtype=torch.int64, non_blocking=True) + x = local[:-1].reshape(-1, seq_len) + y = local[1:].reshape(-1, seq_len) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + batch_loss = model(x, y).detach() + batch_token_count = float(y.numel()) + val_loss_sum += batch_loss.to(torch.float64) * batch_token_count + val_token_count += batch_token_count + prev_ids = x.reshape(-1) + tgt_ids = y.reshape(-1) + token_bytes = base_bytes_lut[tgt_ids].to(dtype=torch.int16) + token_bytes += (has_leading_space_lut[tgt_ids] & ~is_boundary_token_lut[prev_ids]).to(dtype=torch.int16) + val_byte_count += token_bytes.to(torch.float64).sum() + + if dist.is_available() and dist.is_initialized(): + dist.all_reduce(val_loss_sum, op=dist.ReduceOp.SUM) + dist.all_reduce(val_token_count, op=dist.ReduceOp.SUM) + dist.all_reduce(val_byte_count, op=dist.ReduceOp.SUM) + + val_loss = val_loss_sum / val_token_count + bits_per_token = val_loss.item() / math.log(2.0) + tokens_per_byte = val_token_count.item() / val_byte_count.item() + model.train() + return float(val_loss.item()), float(bits_per_token * tokens_per_byte) + +# ----------------------------- +# POST-TRAINING QUANTIZATION +# ----------------------------- +# +# It's silly to export our model, which is trained in bf16 and fp32, at that same precision. +# Instead, we get approximately the same model (with a small hit) by quantizing the model to int8 & zlib compressing. +# We can then decompress the model and run in higher precision for evaluation, after closing in under the size limit. + +CONTROL_TENSOR_NAME_PATTERNS = tuple( + pattern + for pattern in os.environ.get( + "CONTROL_TENSOR_NAME_PATTERNS", + "attn_scale,attn_scales,mlp_scale,mlp_scales,resid_mix,resid_mixes,q_gain,skip_weight,skip_weights,smear,vr_lambda,attn_gate,canon_a,canon_c,delta_gate", + ).split(",") + if pattern +) +INT8_KEEP_FLOAT_FP32_NAME_PATTERNS = tuple( + pattern + for pattern in os.environ.get( + "INT8_KEEP_FLOAT_FP32_NAME_PATTERNS", + ",".join(CONTROL_TENSOR_NAME_PATTERNS), + ).split(",") + if pattern +) +INT8_KEEP_FLOAT_MAX_NUMEL = 65_536 +INT8_KEEP_FLOAT_STORE_DTYPE = torch.float16 +INT8_PER_ROW_SCALE_DTYPE = torch.float16 +INT8_CLIP_PERCENTILE = 99.99984 +INT8_CLIP_Q = INT8_CLIP_PERCENTILE / 100.0 + +def tensor_nbytes(t: Tensor) -> int: + return int(t.numel()) * int(t.element_size()) + +def keep_float_tensor(name: str, t: Tensor, passthrough_orig_dtypes: dict[str, str]) -> Tensor: + if any(pattern in name for pattern in INT8_KEEP_FLOAT_FP32_NAME_PATTERNS): + return t.float().contiguous() + if t.dtype in {torch.float32, torch.bfloat16}: + passthrough_orig_dtypes[name] = str(t.dtype).removeprefix("torch.") + return t.to(dtype=INT8_KEEP_FLOAT_STORE_DTYPE).contiguous() + return t + +def quantize_float_tensor(t: Tensor, qmax: int = 127) -> tuple[Tensor, Tensor]: + """Quantize to [-qmax, qmax] range. Default int8 (qmax=127), int6 (qmax=31), int5 (qmax=15).""" + t32 = t.float() + qmin = -qmax + if t32.ndim == 2: + clip_abs = ( + torch.quantile(t32.abs(), INT8_CLIP_Q, dim=1) + if t32.numel() + else torch.empty((t32.shape[0],), dtype=torch.float32) + ) + clipped = torch.maximum(torch.minimum(t32, clip_abs[:, None]), -clip_abs[:, None]) + scale = (clip_abs / float(qmax)).clamp_min(1.0 / float(qmax)) + q = torch.clamp(torch.round(clipped / scale[:, None]), qmin, qmax).to(torch.int8).contiguous() + return q, scale.to(dtype=INT8_PER_ROW_SCALE_DTYPE).contiguous() + + clip_abs = float(torch.quantile(t32.abs().flatten(), INT8_CLIP_Q).item()) if t32.numel() else 0.0 + scale = torch.tensor(clip_abs / float(qmax) if clip_abs > 0 else 1.0, dtype=torch.float32) + q = torch.clamp(torch.round(torch.clamp(t32, -clip_abs, clip_abs) / scale), qmin, qmax).to(torch.int8).contiguous() + return q, scale + +def quantize_state_dict_int8(state_dict: dict[str, Tensor]): + # Single supported clean-script export format: + # - per-row int8 for 2D float tensors + # - per-tensor int8 for other float tensors + # - exact passthrough for non-floats + # - passthrough for small float tensors, stored as fp16 to save bytes + quantized: dict[str, Tensor] = {} + scales: dict[str, Tensor] = {} + dtypes: dict[str, str] = {} + passthrough: dict[str, Tensor] = {} + passthrough_orig_dtypes: dict[str, str] = {} + qmeta: dict[str, dict[str, object]] = {} + stats = dict.fromkeys( + ("param_count", "num_tensors", "num_float_tensors", "num_nonfloat_tensors", "baseline_tensor_bytes", "int8_payload_bytes"), + 0, + ) + + for name, tensor in state_dict.items(): + t = tensor.detach().to("cpu").contiguous() + stats["param_count"] += int(t.numel()) + stats["num_tensors"] += 1 + stats["baseline_tensor_bytes"] += tensor_nbytes(t) + + if not t.is_floating_point(): + stats["num_nonfloat_tensors"] += 1 + passthrough[name] = t + stats["int8_payload_bytes"] += tensor_nbytes(t) + continue + + # Small float tensors are cheap enough to keep directly. We still downcast + # fp32/bf16 passthrough tensors to fp16 so metadata does not dominate size. + if t.numel() <= INT8_KEEP_FLOAT_MAX_NUMEL: + kept = keep_float_tensor(name, t, passthrough_orig_dtypes) + passthrough[name] = kept + stats["int8_payload_bytes"] += tensor_nbytes(kept) + continue + + stats["num_float_tensors"] += 1 + # Mixed quantization: int6 for MLP layers 3-7 to save artifact space + int6_mlp_layers = os.environ.get("INT6_MLP_LAYERS", "") + qmax = 127 # default int8 + if int6_mlp_layers: + for li in int6_mlp_layers.split(","): + if li.strip() and f"blocks.{li.strip()}.mlp" in name and t.ndim == 2: + qmax = 31 # int6 + break + q, s = quantize_float_tensor(t, qmax=qmax) + if s.ndim > 0: + qmeta[name] = {"scheme": "per_row", "axis": 0} + quantized[name] = q + scales[name] = s + dtypes[name] = str(t.dtype).removeprefix("torch.") + stats["int8_payload_bytes"] += tensor_nbytes(q) + tensor_nbytes(s) + + obj: dict[str, object] = { + "__quant_format__": "int8_clean_per_row_v1", + "quantized": quantized, + "scales": scales, + "dtypes": dtypes, + "passthrough": passthrough, + } + if qmeta: + obj["qmeta"] = qmeta + if passthrough_orig_dtypes: + obj["passthrough_orig_dtypes"] = passthrough_orig_dtypes + return obj, stats + +def dequantize_state_dict_int8(obj: dict[str, object]) -> dict[str, Tensor]: + out: dict[str, Tensor] = {} + qmeta = obj.get("qmeta", {}) + passthrough_orig_dtypes = obj.get("passthrough_orig_dtypes", {}) + for name, q in obj["quantized"].items(): + dtype = getattr(torch, obj["dtypes"][name]) + s = obj["scales"][name] + if qmeta.get(name, {}).get("scheme") == "per_row" or s.ndim > 0: + s = s.to(dtype=torch.float32) + # Broadcast the saved row scale back across trailing dimensions. + out[name] = (q.float() * s.view(q.shape[0], *([1] * (q.ndim - 1)))).to(dtype=dtype).contiguous() + else: + scale = float(s.item()) + out[name] = (q.float() * scale).to(dtype=dtype).contiguous() + for name, t in obj["passthrough"].items(): + # Restore small tensors, undoing the temporary fp16 storage cast if needed. + out_t = t.detach().to("cpu").contiguous() + orig_dtype = passthrough_orig_dtypes.get(name) + if isinstance(orig_dtype, str): + out_t = out_t.to(dtype=getattr(torch, orig_dtype)).contiguous() + out[name] = out_t + return out + + +# ----------------------------- +# DATA LOADING +# ----------------------------- + +def load_data_shard(file: Path) -> Tensor: + header_bytes = 256 * np.dtype(" None: + self.file_idx = (self.file_idx + 1) % len(self.files) + self.tokens = load_data_shard(self.files[self.file_idx]) + self.pos = 0 + + def take(self, n: int) -> Tensor: + chunks: list[Tensor] = [] + remaining = n + while remaining > 0: + avail = self.tokens.numel() - self.pos + if avail <= 0: + self._advance_file() + continue + k = min(remaining, avail) + chunks.append(self.tokens[self.pos : self.pos + k]) + self.pos += k + remaining -= k + return chunks[0] if len(chunks) == 1 else torch.cat(chunks) + + +class DistributedTokenLoader: + # Each call consumes a contiguous chunk from the shared token stream, then slices out + # one disjoint span per rank. The extra "+1" token lets us build (x, y) by shifting. + def __init__(self, pattern: str, rank: int, world_size: int, device: torch.device): + self.rank = rank + self.world_size = world_size + self.device = device + self.stream = TokenStream(pattern) + + def next_batch(self, global_tokens: int, seq_len: int, grad_accum_steps: int) -> tuple[Tensor, Tensor]: + local_tokens = global_tokens // (self.world_size * grad_accum_steps) + per_rank_span = local_tokens + 1 + chunk = self.stream.take(per_rank_span * self.world_size) + start = self.rank * per_rank_span + local = chunk[start : start + per_rank_span].to(dtype=torch.int64) + x = local[:-1].reshape(-1, seq_len) + y = local[1:].reshape(-1, seq_len) + return x.to(self.device, non_blocking=True), y.to(self.device, non_blocking=True) + +# ----------------------------- +# TRANSFORMER MODULES +# ----------------------------- + +class RMSNorm(nn.Module): + def __init__(self, eps: float | None = None): + super().__init__() + self.eps = eps + + def forward(self, x: Tensor) -> Tensor: + return F.rms_norm(x, (x.size(-1),), eps=self.eps) + + +class CastedLinear(nn.Linear): + _qat_enabled: bool = False + _soft_round: bool = False + _soft_round_alpha: float = 1.0 + _quant_percentile: float = float(os.environ.get("QUANT_PERCENTILE", "1.0")) + + def forward(self, x: Tensor) -> Tensor: + w = self.weight.to(x.dtype) + if CastedLinear._qat_enabled and self.training and w.ndim == 2: + w32 = self.weight.float() + pct = CastedLinear._quant_percentile + row_max = (torch.quantile(w32.abs(), pct, dim=1) if pct < 1.0 + else w32.abs().amax(dim=1)).detach() + scale = (row_max / 31.0).clamp_min(1.0 / 31.0) + r = w32 / scale[:, None] + if CastedLinear._soft_round: + alpha = CastedLinear._soft_round_alpha + r_frac = r - r.detach().floor() - 0.5 + norm = torch.tanh(torch.tensor(alpha * 0.5, device=r.device, dtype=r.dtype)) + r_soft = r.detach().floor() + 0.5 + torch.tanh(alpha * r_frac) / (2.0 * norm) + w_q = (torch.clamp(r_soft, -32, 31) * scale[:, None]).to(x.dtype) + w = w_q # soft-round is differentiable, no STE needed + else: + with torch.no_grad(): + w_q = (torch.clamp(torch.round(r), -32, 31) * scale[:, None]).to(x.dtype) + w = w + (w_q - w).detach() # STE + bias = self.bias.to(x.dtype) if self.bias is not None else None + return F.linear(x, w, bias) + + +def restore_low_dim_params_to_fp32(module: nn.Module) -> None: + # Keep small/control parameters in fp32 even when the model body runs in bf16. + with torch.no_grad(): + for name, param in module.named_parameters(): + if (param.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS)) and param.dtype != torch.float32: + param.data = param.data.float() + + +class Rotary(nn.Module): + # NTK-aware RoPE: auto-scales base frequency when seq_len exceeds train_seq_len. + def __init__(self, dim: int, base: float = 10000.0, train_seq_len: int = 1024, rope_dims: int = 0): + super().__init__() + self.rope_dims = rope_dims if rope_dims > 0 else dim + self.dim = dim + self.base = base + self.train_seq_len = train_seq_len + rd = self.rope_dims + inv_freq = 1.0 / (base ** (torch.arange(0, rd, 2, dtype=torch.float32) / rd)) + self.register_buffer("inv_freq", inv_freq, persistent=False) + self._seq_len_cached = 0 + self._cos_cached: Tensor | None = None + self._sin_cached: Tensor | None = None + + def forward(self, seq_len: int, device: torch.device, dtype: torch.dtype) -> tuple[Tensor, Tensor]: + if ( + self._cos_cached is None + or self._sin_cached is None + or self._seq_len_cached != seq_len + or self._cos_cached.device != device + ): + rd = self.rope_dims + if seq_len > self.train_seq_len: + scale = seq_len / self.train_seq_len + new_base = self.base * (scale ** (rd / (rd - 2))) + inv_freq = 1.0 / (new_base ** (torch.arange(0, rd, 2, dtype=torch.float32, device=device) / rd)) + else: + inv_freq = self.inv_freq.to(device) + t = torch.arange(seq_len, device=device, dtype=inv_freq.dtype) + freqs = torch.outer(t, inv_freq) + self._cos_cached = freqs.cos()[None, :, None, :] + self._sin_cached = freqs.sin()[None, :, None, :] + self._seq_len_cached = seq_len + return self._cos_cached.to(dtype=dtype), self._sin_cached.to(dtype=dtype) + + +def apply_rotary_emb(x: Tensor, cos: Tensor, sin: Tensor) -> Tensor: + rd = cos.size(-1) * 2 + if rd < x.size(-1): + x_rope, x_pass = x[..., :rd], x[..., rd:] + half = rd // 2 + x1, x2 = x_rope[..., :half], x_rope[..., half:] + x_rot = torch.cat((x1 * cos + x2 * sin, x1 * (-sin) + x2 * cos), dim=-1) + return torch.cat((x_rot, x_pass), dim=-1) + half = x.size(-1) // 2 + x1, x2 = x[..., :half], x[..., half:] + return torch.cat((x1 * cos + x2 * sin, x1 * (-sin) + x2 * cos), dim=-1) + + +class CausalSelfAttention(nn.Module): + def __init__( + self, + dim: int, + num_heads: int, + num_kv_heads: int, + rope_base: float, + qk_gain_init: float, + rope_dims: int = 0, + value_residual: bool = False, + gated_attention: bool = False, + ): + super().__init__() + if dim % num_heads != 0: + raise ValueError("model_dim must be divisible by num_heads") + if num_heads % num_kv_heads != 0: + raise ValueError("num_heads must be divisible by num_kv_heads") + self.num_heads = num_heads + self.num_kv_heads = num_kv_heads + self.head_dim = dim // num_heads + if self.head_dim % 2 != 0: + raise ValueError("head_dim must be even for RoPE") + kv_dim = self.num_kv_heads * self.head_dim + self.c_q = CastedLinear(dim, dim, bias=False) + self.c_k = CastedLinear(dim, kv_dim, bias=False) + self.c_v = CastedLinear(dim, kv_dim, bias=False) + self.proj = CastedLinear(dim, dim, bias=False) + self.proj._zero_init = True + self.q_gain = nn.Parameter(torch.full((num_heads,), qk_gain_init, dtype=torch.float32)) + self.rope_dims = rope_dims + self.rotary = Rotary(self.head_dim, base=rope_base, train_seq_len=1024, rope_dims=rope_dims) + self.use_xsa = False + self.value_residual = value_residual + if value_residual: + self.vr_lambda = nn.Parameter(torch.tensor([0.5, 0.5], dtype=torch.float32)) + self.gated_attention = gated_attention + if gated_attention: + self.attn_gate = nn.Linear(dim, num_heads, bias=True) + nn.init.zeros_(self.attn_gate.weight) + nn.init.constant_(self.attn_gate.bias, 4.0) + + def _xsa_efficient(self, y: Tensor, v: Tensor) -> Tensor: + """Subtract self-value projection via GQA-aware reshape (no repeat_interleave).""" + B, T, H, D = y.shape + Hkv = v.size(-2) + group = H // Hkv + y_g = y.reshape(B, T, Hkv, group, D) + vn = F.normalize(v, dim=-1).unsqueeze(-2) + proj = (y_g * vn).sum(dim=-1, keepdim=True) * vn + return (y_g - proj).reshape(B, T, H, D) + + def forward(self, x: Tensor, v0: Tensor | None = None) -> tuple[Tensor, Tensor | None]: + bsz, seqlen, dim = x.shape + q = self.c_q(x).reshape(bsz, seqlen, self.num_heads, self.head_dim) + k = self.c_k(x).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim) + v = self.c_v(x).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim) + raw_v = v if self.value_residual else None + if self.value_residual and v0 is not None: + lam = self.vr_lambda.to(dtype=v.dtype) + v = lam[0] * v0 + lam[1] * v + q = F.rms_norm(q, (q.size(-1),)) + k = F.rms_norm(k, (k.size(-1),)) + cos, sin = self.rotary(seqlen, x.device, q.dtype) + q = apply_rotary_emb(q, cos, sin) + k = apply_rotary_emb(k, cos, sin) + q = q * self.q_gain.to(dtype=q.dtype)[None, None, :, None] + fa_dtype = torch.bfloat16 + if _USE_FA3: + y = flash_attn_3_func(q.to(fa_dtype), k.to(fa_dtype), v.to(fa_dtype), causal=True) + else: + # SDPA fallback: (B, T, H, D) -> (B, H, T, D), expand KV for GQA + q_t = q.to(fa_dtype).transpose(1, 2) + k_t = k.to(fa_dtype).transpose(1, 2) + v_t = v.to(fa_dtype).transpose(1, 2) + if self.num_kv_heads != self.num_heads: + rep = self.num_heads // self.num_kv_heads + k_t = k_t.repeat_interleave(rep, dim=1) + v_t = v_t.repeat_interleave(rep, dim=1) + y = F.scaled_dot_product_attention(q_t, k_t, v_t, is_causal=True) + y = y.transpose(1, 2) # (B, H, T, D) -> (B, T, H, D) + if self.use_xsa: + y = self._xsa_efficient(y, v) + if self.gated_attention: + gate = torch.sigmoid(self.attn_gate(x)) # (B, T, num_heads) + y = y * gate.unsqueeze(-1) # (B, T, H, 1) broadcast to (B, T, H, D) + y = y.reshape(bsz, seqlen, dim) + return self.proj(y), raw_v + + +class SmearGate(nn.Module): + def __init__(self, dim: int): + super().__init__() + self.gate = nn.Parameter(torch.zeros(dim, dtype=torch.float32)) + + def forward(self, x: Tensor) -> Tensor: + g = torch.sigmoid(self.gate.to(dtype=x.dtype))[None, None, :] + x_prev = torch.cat([torch.zeros_like(x[:, :1]), x[:, :-1]], dim=1) + return (1 - g) * x + g * x_prev + + +class BigramHashEmbedding(nn.Module): + def __init__(self, bigram_vocab_size: int, bigram_dim: int, model_dim: int): + super().__init__() + self.bigram_vocab_size = bigram_vocab_size + self.embed = nn.Embedding(bigram_vocab_size, bigram_dim) + nn.init.zeros_(self.embed.weight) + self.proj = CastedLinear(bigram_dim, model_dim, bias=False) if bigram_dim != model_dim else None + if self.proj is not None: + nn.init.zeros_(self.proj.weight) + self.scale = nn.Parameter(torch.tensor(0.05, dtype=torch.float32)) + + def bigram_hash(self, tokens: Tensor) -> Tensor: + t = tokens.to(torch.int32) + mod = self.bigram_vocab_size - 1 + out = torch.empty_like(t) + out[..., 0] = mod + out[..., 1:] = torch.bitwise_xor(36313 * t[..., 1:], 27191 * t[..., :-1]) % mod + return out.long() + + def forward(self, token_ids: Tensor) -> Tensor: + h = self.embed(self.bigram_hash(token_ids)) + if self.proj is not None: + h = self.proj(h) + return h * self.scale.to(dtype=h.dtype) + + +class MLP(nn.Module): + def __init__(self, dim: int, mlp_mult: int): + super().__init__() + hidden = int(mlp_mult * dim) + self.fc = CastedLinear(dim, hidden, bias=False) + self.proj = CastedLinear(hidden, dim, bias=False) + self.proj._zero_init = True + self.use_leaky = bool(int(os.environ.get("LEAKY_RELU", "1"))) + self.leaky_slope = float(os.environ.get("LEAKY_SLOPE", "0.5")) + + def forward(self, x: Tensor) -> Tensor: + x = F.leaky_relu(self.fc(x), self.leaky_slope) if self.use_leaky else torch.relu(self.fc(x)) + return self.proj(x.square()) + + +class CanonAC(nn.Module): + """Canon Autoregressive Convolution with DeltaGate. Manual shift+mul (no Conv1d).""" + def __init__(self, dim: int, kernel: int = 4, delta_gate_init: float = -4.0): + super().__init__() + self.kernel = kernel + self.weight = nn.Parameter(torch.zeros(kernel, dim)) + self.delta_gate_logit = nn.Parameter(torch.tensor(delta_gate_init)) + + def forward(self, x: Tensor) -> Tensor: + B, T, D = x.shape + K = self.kernel + w = self.weight.to(x.dtype) + x_pad = F.pad(x, (0, 0, K - 1, 0)) + y = w[0] * x_pad[:, K - 1:, :] + for k in range(1, K): + y = y + w[k] * x_pad[:, K - 1 - k : T + K - 1 - k, :] + gate = torch.sigmoid(self.delta_gate_logit.to(x.dtype)) + return x + gate * y + + +class Block(nn.Module): + def __init__( + self, + dim: int, + num_heads: int, + num_kv_heads: int, + mlp_mult: int, + rope_base: float, + qk_gain_init: float, + rope_dims: int = 0, + layer_idx: int = 0, + ln_scale: bool = False, + value_residual: bool = False, + gated_attention: bool = False, + canon_kernel: int = 0, + canon_delta_gate_init: float = -4.0, + ): + super().__init__() + self.attn_norm = RMSNorm() + self.mlp_norm = RMSNorm() + self.attn = CausalSelfAttention(dim, num_heads, num_kv_heads, rope_base, qk_gain_init, + rope_dims=rope_dims, value_residual=value_residual, + gated_attention=gated_attention) + self.mlp = MLP(dim, mlp_mult) + self.attn_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + self.mlp_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + self.resid_mix = nn.Parameter(torch.stack((torch.ones(dim), torch.zeros(dim))).float()) + self.ln_scale_factor = 1.0 / math.sqrt(layer_idx + 1) if ln_scale else 1.0 + self.canon_a = CanonAC(dim, canon_kernel, canon_delta_gate_init) if canon_kernel > 0 else None + self.canon_c = CanonAC(dim, canon_kernel, canon_delta_gate_init) if canon_kernel > 0 else None + + def forward(self, x: Tensor, x0: Tensor, v0: Tensor | None = None) -> tuple[Tensor, Tensor | None]: + mix = self.resid_mix.to(dtype=x.dtype) + x = mix[0][None, None, :] * x + mix[1][None, None, :] * x0 + s = self.ln_scale_factor + attn_in = self.attn_norm(x) * s + if self.canon_a is not None: + attn_in = self.canon_a(attn_in) + attn_out, raw_v = self.attn(attn_in, v0=v0) + x = x + self.attn_scale.to(dtype=x.dtype)[None, None, :] * attn_out + mlp_in = self.mlp_norm(x) * s + if self.canon_c is not None: + mlp_in = self.canon_c(mlp_in) + x = x + self.mlp_scale.to(dtype=x.dtype)[None, None, :] * self.mlp(mlp_in) + return x, raw_v + + +class FlowRefiner(nn.Module): + """1-step flow matching refiner in low-dim latent space. + + Projects hidden states to a low-dim latent, applies a learned velocity + field (1-step Euler), projects back. Output is additive adjustment to + hidden states. Initialized near-zero so the refiner starts as identity. + """ + def __init__(self, model_dim: int, latent_dim: int = 64, hidden_dim: int = 256, init_scale: float = 0.01): + super().__init__() + self.down_proj = nn.Linear(model_dim, latent_dim, bias=False) + self.velocity_net = nn.Sequential( + nn.Linear(latent_dim, hidden_dim, bias=True), + nn.GELU(), + nn.Linear(hidden_dim, latent_dim, bias=True), + ) + self.up_proj = nn.Linear(latent_dim, model_dim, bias=False) + self.gate = nn.Parameter(torch.tensor(-5.0)) # sigmoid(-5) ≈ 0.007 + self._init_scale = init_scale + self._init_weights() + self.velocity_net[2]._zero_init = True + + def _init_weights(self): + nn.init.normal_(self.down_proj.weight, std=self._init_scale) + nn.init.normal_(self.velocity_net[0].weight, std=self._init_scale) + nn.init.zeros_(self.velocity_net[0].bias) + nn.init.zeros_(self.velocity_net[2].weight) + nn.init.zeros_(self.velocity_net[2].bias) + nn.init.normal_(self.up_proj.weight, std=self._init_scale) + + def forward(self, x: Tensor) -> Tensor: + z = self.down_proj(x) + v = self.velocity_net(z) + z_refined = z + v + delta = self.up_proj(z_refined) + return torch.sigmoid(self.gate) * delta + + +class NativeFlowMatcher(nn.Module): + """Conditional Flow Matching refiner for hidden states. + + During training: computes auxiliary CFM loss on interpolated states. + During inference: applies a single Euler step correction at t=1. + """ + def __init__(self, model_dim: int, hidden_dim: int = 256, init_scale: float = 0.01): + super().__init__() + self.model_dim = model_dim + # Sinusoidal time embedding → projected to hidden_dim + self.time_proj = nn.Sequential( + nn.Linear(model_dim, hidden_dim, bias=True), + nn.GELU(), + ) + # Velocity network: x → hidden → x (with time conditioning via addition) + self.v_in = nn.Linear(model_dim, hidden_dim, bias=True) + self.v_act = nn.GELU() + self.v_out = nn.Linear(hidden_dim, model_dim, bias=False) + self.gate = nn.Parameter(torch.tensor(-5.0)) # sigmoid(-5) ≈ 0.007 + self._init_scale = init_scale + self._init_weights() + self.v_out._zero_init = True + + def _init_weights(self): + nn.init.normal_(self.v_in.weight, std=self._init_scale) + nn.init.zeros_(self.v_in.bias) + nn.init.normal_(self.time_proj[0].weight, std=self._init_scale) + nn.init.zeros_(self.time_proj[0].bias) + nn.init.zeros_(self.v_out.weight) + + @staticmethod + def _sinusoidal_time_emb(t: Tensor, dim: int) -> Tensor: + """Sinusoidal positional embedding for scalar time t. t: (...,) -> (..., dim).""" + half_dim = dim // 2 + emb = math.log(10000.0) / max(half_dim - 1, 1) + emb = torch.exp(torch.arange(half_dim, device=t.device, dtype=t.dtype) * -emb) + emb = t.unsqueeze(-1) * emb # (..., half_dim) + emb = torch.cat([emb.sin(), emb.cos()], dim=-1) # (..., dim) + if dim % 2 == 1: + emb = F.pad(emb, (0, 1)) + return emb + + def _velocity(self, x: Tensor, t_emb: Tensor) -> Tensor: + """Compute velocity v(x, t). x: (..., model_dim), t_emb: (..., hidden_dim).""" + h = self.v_in(x) # (..., hidden_dim) + h = h + t_emb # time conditioning via addition + h = self.v_act(h) # GELU + return self.v_out(h) # (..., model_dim) + + def forward(self, x: Tensor) -> tuple[Tensor, Tensor]: + """Returns (correction, cfm_loss). + + During training: cfm_loss is the flow matching objective (> 0). + During eval: cfm_loss is zero. + correction: gated velocity at t=1, to be added to x. + """ + B_seq = x.shape[:-1] # works for any leading dimensions + D = self.model_dim + + # Inference path: velocity at t=1 (clean input, no noise) + t_ones = torch.ones(*B_seq, device=x.device, dtype=x.dtype) + t_emb_ones = self.time_proj(self._sinusoidal_time_emb(t_ones, D)) + v_at_1 = self._velocity(x, t_emb_ones) + correction = torch.sigmoid(self.gate) * v_at_1 + + # Training path: auxiliary CFM loss + if self.training: + # Sample random t ~ U(0,1) per position + t = torch.rand(*B_seq, device=x.device, dtype=x.dtype) + z = torch.randn_like(x) # noise + # OT interpolant: x_t = (1-t)*z + t*x + t_expanded = t.unsqueeze(-1) # (..., 1) + x_t = (1.0 - t_expanded) * z + t_expanded * x.detach() + # Target velocity = x - z (OT conditional velocity field) + target_v = x.detach() - z + # Predict velocity at interpolated point + t_emb = self.time_proj(self._sinusoidal_time_emb(t, D)) + pred_v = self._velocity(x_t, t_emb) + cfm_loss = F.mse_loss(pred_v, target_v) + else: + cfm_loss = x.new_zeros(()) + + return correction, cfm_loss + + +class TTTLinearRefiner(nn.Module): + """E2E TTT-Linear refiner (Sun et al., 2024, arXiv:2407.04620). + + Applied to final hidden states [B, L, D] before lm_head. The hidden + state is a per-head linear model W updated via mini-batch gradient + descent on a learned self-supervised reconstruction task. Uses the + dual form for GPU efficiency. Trained end-to-end: the outer loop + optimizes theta_K, theta_V, theta_Q projections and the inner-loop + initialization W_0, making the model learn WHAT to compress. + """ + + def __init__(self, model_dim: int, num_heads: int = 8, + mini_batch_size: int = 16, ttt_base_lr: float = 1.0): + super().__init__() + self.model_dim = model_dim + self.num_heads = num_heads + self.head_dim = model_dim // num_heads + self.mini_batch_size = mini_batch_size + self.ttt_base_lr = ttt_base_lr + self.eta = ttt_base_lr / self.head_dim + + # Learnable projections (outer-loop) + self.theta_K = nn.Linear(model_dim, model_dim, bias=False) + self.theta_V = nn.Linear(model_dim, model_dim, bias=False) + self.theta_Q = nn.Linear(model_dim, model_dim, bias=False) + + # Inner-loop model: W_0, b_0 + self.W1 = nn.Parameter(torch.normal(0, 0.02, + size=(num_heads, self.head_dim, self.head_dim))) + self.b1 = nn.Parameter(torch.zeros(num_heads, 1, self.head_dim)) + + # Inner-loop LayerNorm (outer-loop trainable) + self.ttt_ln_w = nn.Parameter(torch.ones(num_heads, self.head_dim)) + self.ttt_ln_b = nn.Parameter(torch.zeros(num_heads, self.head_dim)) + + # Output projection + gate (starts near zero for residual safety) + self.o_proj = nn.Linear(model_dim, model_dim, bias=False) + self.post_norm = nn.LayerNorm(model_dim, eps=1e-6) + self.gate = nn.Parameter(torch.tensor(-5.0)) + nn.init.zeros_(self.o_proj.weight) + self.o_proj._zero_init = True + + @staticmethod + def _ln_fwd(x: Tensor, gamma: Tensor, beta: Tensor, eps: float = 1e-6) -> Tensor: + mu = x.mean(dim=-1, keepdim=True) + var = x.var(dim=-1, keepdim=True, unbiased=False) + x_hat = (x - mu) / torch.sqrt(var + eps) + return gamma * x_hat + beta + + @staticmethod + def _ln_fused_l2_bwd(x: Tensor, target: Tensor, gamma: Tensor, beta: Tensor, + eps: float = 1e-6) -> Tensor: + D = x.shape[-1] + mu = x.mean(dim=-1, keepdim=True) + var = x.var(dim=-1, keepdim=True, unbiased=False) + std = torch.sqrt(var + eps) + x_hat = (x - mu) / std + y = gamma * x_hat + beta + grad_output = y - target + grad_x_hat = grad_output * gamma + z = (1.0 / D) * ( + D * grad_x_hat + - grad_x_hat.sum(dim=-1, keepdim=True) + - x_hat * (grad_x_hat * x_hat).sum(dim=-1, keepdim=True) + ) / std + return z + + def forward(self, x: Tensor) -> Tensor: + """x: [B, L, D]. Returns [B, L, D] additive refinement (gated).""" + B, L, D = x.shape + NH, HD, b = self.num_heads, self.head_dim, self.mini_batch_size + eta = self.eta + + num_mb = L // b + usable_L = num_mb * b + x_proc = x[:, :usable_L] if usable_L < L else x + + # Project: [B, L, D] -> [B, NH, L, HD] + XK = self.theta_K(x_proc).reshape(B, usable_L, NH, HD).permute(0, 2, 1, 3) + XV = self.theta_V(x_proc).reshape(B, usable_L, NH, HD).permute(0, 2, 1, 3) + XQ = self.theta_Q(x_proc).reshape(B, usable_L, NH, HD).permute(0, 2, 1, 3) + + # Reshape into mini-batches: [B, NH, num_mb, b, HD] + XK = XK.reshape(B, NH, num_mb, b, HD) + XV = XV.reshape(B, NH, num_mb, b, HD) + XQ = XQ.reshape(B, NH, num_mb, b, HD) + + ln_w = self.ttt_ln_w.reshape(1, NH, 1, HD) + ln_b = self.ttt_ln_b.reshape(1, NH, 1, HD) + + # Causal mask for dual form + causal = torch.tril(torch.ones(b, b, device=x.device, dtype=x.dtype)) + + # Initialize inner-loop model + W = self.W1.unsqueeze(0).expand(B, -1, -1, -1).clone() + bias = self.b1.unsqueeze(0).expand(B, -1, -1, -1).clone() + + all_out = [] + for m in range(num_mb): + xk = XK[:, :, m] + xv = XV[:, :, m] + xq = XQ[:, :, m] + + Z1 = xk @ W + bias + target = xv - xk + + grad = self._ln_fused_l2_bwd(Z1, target, ln_w, ln_b) + + Attn = causal.unsqueeze(0).unsqueeze(0) * (xq @ xk.transpose(-2, -1)) + cumgrad = (causal.unsqueeze(0).unsqueeze(0) @ grad) + b_bar = bias - eta * cumgrad + Z_bar = xq @ W - eta * Attn @ grad + b_bar + + W = W - eta * xk.transpose(-2, -1) @ grad + bias = bias - eta * grad.sum(dim=-2, keepdim=True) + + Z_bar = self._ln_fwd(Z_bar, ln_w, ln_b) + out_mb = xq + Z_bar + all_out.append(out_mb) + + z = torch.cat(all_out, dim=2).permute(0, 2, 1, 3).reshape(B, usable_L, D) + z = self.post_norm(z) + z = self.o_proj(z) + result = torch.sigmoid(self.gate) * z + + if usable_L < L: + pad = torch.zeros(B, L - usable_L, D, device=x.device, dtype=x.dtype) + result = torch.cat([result, pad], dim=1) + + return result + + +class GPT(nn.Module): + def __init__( + self, + vocab_size: int, + num_layers: int, + model_dim: int, + num_heads: int, + num_kv_heads: int, + mlp_mult: int, + tie_embeddings: bool, + tied_embed_init_std: float, + logit_softcap: float, + rope_base: float, + qk_gain_init: float, + mtp_num_heads: int = 0, + mtp_loss_weight: float = 0.1, + bigram_vocab_size: int = 0, + bigram_dim: int = 128, + xsa_last_n: int = 0, + rope_dims: int = 0, + ln_scale: bool = False, + value_residual: bool = False, + gated_attention: bool = False, + canon_last_n: int = 0, + canon_kernel: int = 4, + canon_delta_gate_init: float = -4.0, + flow_enabled: bool = False, + flow_latent_dim: int = 64, + flow_hidden_dim: int = 256, + flow_init_scale: float = 0.01, + native_flow_enabled: bool = False, + native_flow_hidden_dim: int = 256, + native_flow_init_scale: float = 0.01, + native_flow_loss_weight: float = 0.1, + e2e_ttt_enabled: bool = False, + e2e_ttt_num_heads: int = 8, + e2e_ttt_mini_batch: int = 16, + e2e_ttt_base_lr: float = 1.0, + ): + super().__init__() + if logit_softcap <= 0.0: + raise ValueError(f"logit_softcap must be positive, got {logit_softcap}") + self.tie_embeddings = tie_embeddings + self.tied_embed_init_std = tied_embed_init_std + self.logit_softcap = logit_softcap + self.mtp_num_heads = mtp_num_heads + self.mtp_loss_weight = mtp_loss_weight + self.tok_emb = nn.Embedding(vocab_size, model_dim) + self.bigram = BigramHashEmbedding(bigram_vocab_size, bigram_dim, model_dim) if bigram_vocab_size > 0 else None + self.smear = SmearGate(model_dim) + self.num_encoder_layers = num_layers // 2 + self.num_decoder_layers = num_layers - self.num_encoder_layers + self.num_skip_weights = min(self.num_encoder_layers, self.num_decoder_layers) + self.skip_weights = nn.Parameter(torch.ones(self.num_skip_weights, model_dim, dtype=torch.float32)) + canon_start = num_layers - canon_last_n if canon_last_n > 0 else num_layers + self.blocks = nn.ModuleList( + [ + Block( + model_dim, + num_heads, + num_kv_heads, + mlp_mult, + rope_base, + qk_gain_init, + rope_dims=rope_dims, + layer_idx=i, + ln_scale=ln_scale, + value_residual=value_residual, + gated_attention=gated_attention, + canon_kernel=canon_kernel if i >= canon_start else 0, + canon_delta_gate_init=canon_delta_gate_init, + ) + for i in range(num_layers) + ] + ) + self.final_norm = RMSNorm() + self.lm_head = None if tie_embeddings else CastedLinear(model_dim, vocab_size, bias=False) + if self.lm_head is not None: + self.lm_head._zero_init = True + self.flow_refiner = FlowRefiner(model_dim, flow_latent_dim, flow_hidden_dim, flow_init_scale) if flow_enabled else None + self.native_flow = NativeFlowMatcher(model_dim, native_flow_hidden_dim, native_flow_init_scale) if native_flow_enabled else None + self.native_flow_loss_weight = native_flow_loss_weight + self.ttt_refiner = TTTLinearRefiner(model_dim, e2e_ttt_num_heads, e2e_ttt_mini_batch, e2e_ttt_base_lr) if e2e_ttt_enabled else None + self.mtp_heads = nn.ModuleList( + [CastedLinear(model_dim, vocab_size, bias=False) for _ in range(mtp_num_heads)] + ) + for head in self.mtp_heads: + head._zero_init = True + if xsa_last_n > 0: + for i in range(max(0, num_layers - xsa_last_n), num_layers): + self.blocks[i].attn.use_xsa = True + self._init_weights() + + def _init_weights(self) -> None: + if self.tie_embeddings: + nn.init.normal_(self.tok_emb.weight, mean=0.0, std=self.tied_embed_init_std) + num_layers = len(self.blocks) + for name, module in self.named_modules(): + if isinstance(module, nn.Linear): + if getattr(module, "_zero_init", False): + nn.init.zeros_(module.weight) + elif module.weight.ndim == 2 and module.weight.shape[0] >= 64 and module.weight.shape[1] >= 64: + nn.init.orthogonal_(module.weight, gain=1.0) + if ".proj." in name or name.endswith(".proj"): + with torch.no_grad(): + module.weight.mul_(1.0 / math.sqrt(2 * num_layers)) + + def forward(self, input_ids: Tensor, target_ids: Tensor) -> Tensor: + x = self.tok_emb(input_ids) + if self.bigram is not None: + x = x + self.bigram(input_ids) + x = F.rms_norm(x, (x.size(-1),)) + x = self.smear(x) + x0 = x + v0 = None + skips: list[Tensor] = [] + + for i in range(self.num_encoder_layers): + x, raw_v = self.blocks[i](x, x0, v0=v0) + if v0 is None and raw_v is not None: + v0 = raw_v + skips.append(x) + for i in range(self.num_decoder_layers): + if skips: + x = x + self.skip_weights[i].to(dtype=x.dtype)[None, None, :] * skips.pop() + x, _ = self.blocks[self.num_encoder_layers + i](x, x0, v0=v0) + + x = self.final_norm(x) + if self.ttt_refiner is not None: + x = x + self.ttt_refiner(x) + if self.flow_refiner is not None: + x = x + self.flow_refiner(x) + nflow_loss = x.new_zeros(()) + if self.native_flow is not None: + correction, nflow_loss = self.native_flow(x) + x = x + correction + x_flat = x.reshape(-1, x.size(-1)) + targets = target_ids.reshape(-1) + if self.tie_embeddings: + logits_proj = F.linear(x_flat, self.tok_emb.weight) + else: + if self.lm_head is None: + raise RuntimeError("lm_head is required when tie_embeddings=False") + logits_proj = self.lm_head(x_flat) + logits = self.logit_softcap * torch.tanh(logits_proj / self.logit_softcap) + main_loss = F.cross_entropy(logits.float(), targets, reduction="mean") + + if self.training and self.mtp_num_heads > 0 and self.mtp_loss_weight > 0.0: + _, seqlen, dim = x.shape + mtp_loss_sum = x.new_zeros(()) + mtp_loss_count = 0 + for k, mtp_head in enumerate(self.mtp_heads): + valid_t = seqlen - (k + 1) + if valid_t <= 0: + continue + mtp_hidden = x[:, :valid_t, :].reshape(-1, dim) + mtp_targets = target_ids[:, k + 1 :].reshape(-1) + mtp_logits_proj = mtp_head(mtp_hidden) + mtp_logits = self.logit_softcap * torch.tanh(mtp_logits_proj / self.logit_softcap) + mtp_loss_sum = mtp_loss_sum + F.cross_entropy(mtp_logits.float(), mtp_targets, reduction="mean") + mtp_loss_count += 1 + if mtp_loss_count > 0: + main_loss = main_loss + self.mtp_loss_weight * (mtp_loss_sum / mtp_loss_count) + + if self.native_flow is not None and self.training: + main_loss = main_loss + self.native_flow_loss_weight * nflow_loss + + return main_loss + + def forward_logits(self, input_ids: Tensor) -> Tensor: + """Return logits (bsz, seq_len, vocab) without computing loss.""" + x = self.tok_emb(input_ids) + if self.bigram is not None: + x = x + self.bigram(input_ids) + x = F.rms_norm(x, (x.size(-1),)) + x = self.smear(x) + x0 = x + v0 = None + skips: list[Tensor] = [] + for i in range(self.num_encoder_layers): + x, raw_v = self.blocks[i](x, x0, v0=v0) + if v0 is None and raw_v is not None: + v0 = raw_v + skips.append(x) + for i in range(self.num_decoder_layers): + if skips: + x = x + self.skip_weights[i].to(dtype=x.dtype)[None, None, :] * skips.pop() + x, _ = self.blocks[self.num_encoder_layers + i](x, x0, v0=v0) + x = self.final_norm(x) + if self.ttt_refiner is not None: + x = x + self.ttt_refiner(x) + if self.flow_refiner is not None: + x = x + self.flow_refiner(x) + if self.native_flow is not None: + correction, _ = self.native_flow(x) + x = x + correction + if self.tie_embeddings: + logits_proj = F.linear(x, self.tok_emb.weight) + else: + logits_proj = self.lm_head(x) + return self.logit_softcap * torch.tanh(logits_proj / self.logit_softcap) + + +# ----------------------------- +# SLIDING WINDOW EVALUATION +# ----------------------------- + +def eval_val_sliding( + args: Hyperparameters, + base_model: nn.Module, + rank: int, + world_size: int, + device: torch.device, + val_tokens: Tensor, + base_bytes_lut: Tensor, + has_leading_space_lut: Tensor, + is_boundary_token_lut: Tensor, + stride: int, + batch_seqs: int = 32, + eval_seq_len: int | None = None, +) -> tuple[float, float]: + """Sliding window evaluation: each token scored with maximum context. + Optionally uses entropy-gated 5-gram cache (NGRAM_CACHE=1).""" + seq_len = eval_seq_len or args.train_seq_len + total_tokens = val_tokens.numel() - 1 + + window_starts = [ws for ws in range(0, total_tokens, stride) + if min(ws + seq_len, total_tokens) - ws >= 1] + total_windows = len(window_starts) + + my_s = (total_windows * rank) // world_size + my_e = (total_windows * (rank + 1)) // world_size + my_windows = window_starts[my_s:my_e] + + loss_sum = torch.zeros((), device=device, dtype=torch.float64) + token_count = torch.zeros((), device=device, dtype=torch.float64) + byte_count = torch.zeros((), device=device, dtype=torch.float64) + + # N-gram eval cache with multi-order backoff + entropy-adaptive alpha (PR #702 inspired) + _ngram_default = "1" if world_size > 1 else "0" + use_ngram = bool(int(os.environ.get("NGRAM_CACHE", _ngram_default))) + ngram_alpha = float(os.environ.get("NGRAM_ALPHA", "0.40")) + ngram_min_count = int(os.environ.get("NGRAM_MIN_COUNT", "2")) + ngram_order = int(os.environ.get("NGRAM_ORDER", "7")) + ngram_min_order = int(os.environ.get("NGRAM_MIN_ORDER", "2")) + ngram_buckets = int(os.environ.get("NGRAM_BUCKETS", "4194304")) + ngram_entropy = bool(int(os.environ.get("NGRAM_ENTROPY", "1"))) + ngram_ent_base = float(os.environ.get("NGRAM_ENT_BASE", "0.05")) + ngram_ent_range = float(os.environ.get("NGRAM_ENT_RANGE", "0.55")) + ngram_ent_scale = float(os.environ.get("NGRAM_ENT_SCALE", "2.0")) + ngram_ent_thresh = float(os.environ.get("NGRAM_ENT_THRESH", "4.0")) + if use_ngram: + val_np = val_tokens.cpu().numpy() + _n_orders = ngram_order - ngram_min_order + 1 + ctx_tables = [np.zeros((ngram_buckets,), dtype=np.uint32) for _ in range(_n_orders)] + full_tables = [np.zeros((ngram_buckets,), dtype=np.uint32) for _ in range(_n_orders)] + ng_mask = np.uint64(ngram_buckets - 1) + ng_primes = np.array( + [np.uint64(36313), np.uint64(27191), np.uint64(51647), np.uint64(81929), + np.uint64(131071), np.uint64(175447), np.uint64(209591)], + dtype=np.uint64, + ) + print(f"ngram_cache:enabled orders={ngram_min_order}-{ngram_order} backoff " + f"entropy={ngram_entropy} alpha={ngram_alpha} " + f"ent_base={ngram_ent_base} ent_range={ngram_ent_range} " + f"min_count={ngram_min_count} buckets={ngram_buckets}", flush=True) + + base_model.eval() + compiled_logits = torch.compile(base_model.forward_logits, dynamic=False, fullgraph=True) + + with torch.inference_mode(): + for bi in range(0, len(my_windows), batch_seqs): + batch_ws = my_windows[bi:bi + batch_seqs] + bsz = len(batch_ws) + + x_batch = torch.zeros(bsz, seq_len, dtype=torch.int64, device=device) + y_batch = torch.zeros(bsz, seq_len, dtype=torch.int64, device=device) + wlens: list[int] = [] + + for i, ws in enumerate(batch_ws): + end = min(ws + seq_len, total_tokens) + wlen = end - ws + wlens.append(wlen) + chunk = val_tokens[ws:end + 1].to(dtype=torch.int64, device=device) + x_batch[i, :wlen] = chunk[:-1] + y_batch[i, :wlen] = chunk[1:] + + with torch.autocast(device_type="cuda", dtype=torch.bfloat16): + logits = compiled_logits(x_batch) + + nll = F.cross_entropy( + logits.reshape(-1, logits.size(-1)).float(), + y_batch.reshape(-1), + reduction="none", + ).reshape(bsz, seq_len) + + for i, ws in enumerate(batch_ws): + wlen = wlens[i] + s = 0 if ws == 0 else max(wlen - stride, 0) + seg_len = wlen - s + if seg_len <= 0: + continue + + scored_nll = nll[i, s:wlen].to(torch.float64) + + if use_ngram: + seg_nll_np = scored_nll.cpu().numpy() + seg_model_p = np.exp(-seg_nll_np) + n_seg = len(seg_nll_np) + global_j = np.arange(ws + s + 1, ws + wlen + 1, dtype=np.int64) + + # Entropy-adaptive alpha: compute from model logits (GPU) + if ngram_entropy: + with torch.no_grad(): + lp = F.log_softmax(logits[i, s:wlen].float(), dim=-1) + seg_ent = -(lp.exp() * lp).sum(dim=-1).cpu().numpy() + alpha_per_tok = ngram_ent_base + ngram_ent_range / ( + 1.0 + np.exp(-ngram_ent_scale * (seg_ent - ngram_ent_thresh))) + + # Precompute hashes for all orders + order_data = [] # (v_idx, ctx_key, full_key) per order + for oi in range(_n_orders): + ctx_w = ngram_min_order + oi - 1 + valid = global_j >= ctx_w + if not valid.any(): + order_data.append(None) + continue + v_idx = np.nonzero(valid)[0] + jv = global_j[v_idx] + ctx_hash = np.zeros(len(jv), dtype=np.uint64) + for k in range(ctx_w): + tok = val_np[jv - (ctx_w - k)].astype(np.uint64) + ctx_hash ^= tok * ng_primes[k % len(ng_primes)] + ctx_key = (ctx_hash & ng_mask).astype(np.int64) + tgt_np = val_np[jv].astype(np.uint64) + full_key = ((ctx_hash ^ (tgt_np * ng_primes[ctx_w % len(ng_primes)])) & ng_mask).astype(np.int64) + order_data.append((v_idx, ctx_key, full_key)) + + # Multi-order backoff: highest order first, fill unmatched with lower orders + best_p_ng = np.full(n_seg, -1.0) + for oi in range(_n_orders - 1, -1, -1): + if order_data[oi] is None: + continue + v_idx, ctx_key, full_key = order_data[oi] + ctx_counts = ctx_tables[oi][ctx_key].astype(np.float64) + full_counts = full_tables[oi][full_key].astype(np.float64) + has_match = ctx_counts >= float(ngram_min_count) + needs_fill = has_match & (best_p_ng[v_idx] < 0) + if needs_fill.any(): + fill_idx = v_idx[needs_fill] + p = np.minimum(full_counts[needs_fill], ctx_counts[needs_fill]) / np.maximum(ctx_counts[needs_fill], 1.0) + best_p_ng[fill_idx] = np.clip(p, 0.0, 1.0) + + # Mix model probability with n-gram + has_match = best_p_ng >= 0 + if has_match.any(): + if ngram_entropy: + alpha = alpha_per_tok[has_match] + else: + alpha = ngram_alpha + seg_model_p[has_match] = (1.0 - alpha) * seg_model_p[has_match] + alpha * best_p_ng[has_match] + seg_nll_np = -np.log(np.clip(seg_model_p, 1e-12, 1.0)) + + # Score-first: update ALL order tables AFTER scoring + for oi in range(_n_orders): + if order_data[oi] is None: + continue + v_idx, ctx_key, full_key = order_data[oi] + np.add.at(ctx_tables[oi], ctx_key, 1) + np.add.at(full_tables[oi], full_key, 1) + + scored_nll = torch.from_numpy(seg_nll_np).to(dtype=torch.float64, device=device) + + loss_sum += scored_nll.sum() + token_count += float(seg_len) + tgt = y_batch[i, s:wlen] + prev = x_batch[i, s:wlen] + tb = base_bytes_lut[tgt].to(torch.float64) + tb += (has_leading_space_lut[tgt] & ~is_boundary_token_lut[prev]).to(torch.float64) + byte_count += tb.sum() + + if dist.is_available() and dist.is_initialized(): + dist.all_reduce(loss_sum, op=dist.ReduceOp.SUM) + dist.all_reduce(token_count, op=dist.ReduceOp.SUM) + dist.all_reduce(byte_count, op=dist.ReduceOp.SUM) + + val_loss = (loss_sum / token_count).item() + bits_per_token = val_loss / math.log(2.0) + tokens_per_byte = token_count.item() / byte_count.item() + base_model.train() + return val_loss, bits_per_token * tokens_per_byte + + +# ----------------------------- +# TEST-TIME TRAINING (TTT) +# ----------------------------- + +def _clear_rotary_caches(model: nn.Module) -> None: + """Clear cached RoPE tensors to avoid 'Inference tensors cannot be saved for backward'.""" + for m in model.modules(): + if isinstance(m, Rotary): + m._cos_cached = None + m._sin_cached = None + m._seq_len_cached = 0 + + +def ttt_adapt(args: Hyperparameters, base_model: nn.Module, device: torch.device, + val_tokens: Tensor, rank: int = 0, world_size: int = 1, + log_fn=None) -> None: + """Score-first TTT: process val data in chunks, score each chunk first + (inference_mode), then train on scored tokens. Compliant with Issue #677.""" + seq_len = args.train_seq_len + total_tokens = val_tokens.numel() - 1 + chunk_tokens = args.ttt_chunk_tokens + batch_seqs = args.ttt_batch_seqs + + # Freeze early blocks + if args.ttt_freeze_blocks > 0: + for i, block in enumerate(base_model.blocks): + if i < args.ttt_freeze_blocks: + for p in block.parameters(): + p.requires_grad_(False) + + ttt_params = [p for p in base_model.parameters() if p.requires_grad] + if args.ttt_optimizer == "adamw": + optimizer = torch.optim.AdamW(ttt_params, lr=args.ttt_lr, weight_decay=0.0) + else: + optimizer = torch.optim.SGD(ttt_params, lr=args.ttt_lr, momentum=args.ttt_momentum) + + t0 = time.perf_counter() + chunk_idx = 0 + + for chunk_start in range(0, total_tokens - seq_len, chunk_tokens): + chunk_end = min(chunk_start + chunk_tokens, total_tokens) + chunk_len = chunk_end - chunk_start + n_seqs = chunk_len // seq_len + if n_seqs == 0: + break + + my_start = (n_seqs * rank) // world_size + my_end = (n_seqs * (rank + 1)) // world_size + if my_end <= my_start: + continue + + # Phase 1: Score chunk under inference_mode (forward only) + base_model.eval() + with torch.inference_mode(): + for si in range(my_start, my_end, batch_seqs): + se = min(si + batch_seqs, my_end) + raw_s = chunk_start + si * seq_len + raw_e = chunk_start + se * seq_len + 1 + local = val_tokens[raw_s:raw_e].to(device=device, dtype=torch.int64) + x = local[:-1].reshape(-1, seq_len) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16): + base_model.forward_logits(x) + + # Phase 2: Train on scored tokens (K epochs) + base_model.train() + for epoch in range(args.ttt_epochs): + for si in range(my_start, my_end, batch_seqs): + se = min(si + batch_seqs, my_end) + raw_s = chunk_start + si * seq_len + raw_e = chunk_start + se * seq_len + 1 + local = val_tokens[raw_s:raw_e].to(device=device, dtype=torch.int64) + x = local[:-1].reshape(-1, seq_len) + y = local[1:].reshape(-1, seq_len) + + optimizer.zero_grad(set_to_none=True) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16): + loss = base_model(x, y) + loss.backward() + + if world_size > 1: + for p in ttt_params: + if p.grad is not None: + dist.all_reduce(p.grad, op=dist.ReduceOp.AVG) + + torch.nn.utils.clip_grad_norm_(ttt_params, 1.0) + optimizer.step() + + chunk_idx += 1 + if log_fn and chunk_idx % 20 == 0: + log_fn(f"ttt:chunk={chunk_idx} elapsed={time.perf_counter()-t0:.1f}s") + + # Restore all params + for p in base_model.parameters(): + p.requires_grad_(True) + + if log_fn: + log_fn(f"ttt:done chunks={chunk_idx} elapsed={time.perf_counter()-t0:.1f}s") + + +def eval_val_sliding_ttt( + args: Hyperparameters, + base_model: nn.Module, + rank: int, + world_size: int, + device: torch.device, + val_tokens: Tensor, + base_bytes_lut: Tensor, + has_leading_space_lut: Tensor, + is_boundary_token_lut: Tensor, + stride: int, + batch_seqs: int = 32, + eval_seq_len: int | None = None, + log_fn=None, +) -> tuple[float, float]: + """Legal single-pass TTT: score each chunk with sliding windows, then train on it. + Tokens are always scored BEFORE any training on their chunk, so the evaluation + is never contaminated by future information.""" + seq_len = eval_seq_len or args.train_seq_len + total_tokens = val_tokens.numel() - 1 + ttt_chunk = args.ttt_chunk_tokens + + # Build window starts (same logic as eval_val_sliding) + window_starts = [ws for ws in range(0, total_tokens, stride) + if min(ws + seq_len, total_tokens) - ws >= stride or ws == 0] + + # Map each window to the chunk that contains its first scored token + num_chunks = (total_tokens + ttt_chunk - 1) // ttt_chunk + chunk_windows: list[list[int]] = [[] for _ in range(num_chunks)] + for ws in window_starts: + end = min(ws + seq_len, total_tokens) + wlen = end - ws + s = 0 if ws == 0 else max(wlen - stride, 0) + scored_start = ws + s + ci = min(scored_start // ttt_chunk, num_chunks - 1) + chunk_windows[ci].append(ws) + + if log_fn: + log_fn(f"ttt_sliding:start chunks={num_chunks} chunk_tokens={ttt_chunk} " + f"total_windows={len(window_starts)} stride={stride} " + f"ttt_lr={args.ttt_lr} ttt_epochs={args.ttt_epochs} " + f"ttt_optimizer={args.ttt_optimizer} freeze_blocks={args.ttt_freeze_blocks}") + + loss_sum = torch.zeros((), device=device, dtype=torch.float64) + token_count = torch.zeros((), device=device, dtype=torch.float64) + byte_count = torch.zeros((), device=device, dtype=torch.float64) + + # Freeze first N blocks + n_blocks = len(base_model.blocks) + frozen_block_ids = set(range(min(args.ttt_freeze_blocks, n_blocks))) + + ttt_params = [] + for name, p in base_model.named_parameters(): + freeze = False + for bi in frozen_block_ids: + if f"blocks.{bi}." in name: + freeze = True + break + if freeze: + p.requires_grad_(False) + else: + p.requires_grad_(True) + ttt_params.append(p) + + if log_fn: + log_fn(f"ttt_sliding:params unfrozen={sum(p.numel() for p in ttt_params)} " + f"frozen={sum(p.numel() for p in base_model.parameters() if not p.requires_grad)}") + + if args.ttt_optimizer == "adamw": + optimizer = torch.optim.AdamW(ttt_params, lr=args.ttt_lr, weight_decay=0.0) + else: + optimizer = torch.optim.SGD(ttt_params, lr=args.ttt_lr, momentum=args.ttt_momentum) + + t0 = time.perf_counter() + + for ci in range(num_chunks): + windows = chunk_windows[ci] + if not windows: + continue + + # --- Phase 1: SCORE this chunk's windows --- + my_s = (len(windows) * rank) // world_size + my_e = (len(windows) * (rank + 1)) // world_size + my_windows = windows[my_s:my_e] + + base_model.eval() + with torch.inference_mode(): + for bi in range(0, len(my_windows), batch_seqs): + batch_ws = my_windows[bi:bi + batch_seqs] + bsz = len(batch_ws) + x_batch = torch.zeros(bsz, seq_len, dtype=torch.int64, device=device) + y_batch = torch.zeros(bsz, seq_len, dtype=torch.int64, device=device) + wlens: list[int] = [] + for i, ws in enumerate(batch_ws): + end = min(ws + seq_len, total_tokens) + wlen = end - ws + wlens.append(wlen) + chunk_tok = val_tokens[ws:end + 1].to(dtype=torch.int64, device=device) + x_batch[i, :wlen] = chunk_tok[:-1] + y_batch[i, :wlen] = chunk_tok[1:] + with torch.autocast(device_type="cuda", dtype=torch.bfloat16): + logits = base_model.forward_logits(x_batch) + nll = F.cross_entropy( + logits.reshape(-1, logits.size(-1)).float(), + y_batch.reshape(-1), reduction="none", + ).reshape(bsz, seq_len) + for i, ws in enumerate(batch_ws): + wlen = wlens[i] + s = 0 if ws == 0 else max(wlen - stride, 0) + scored_nll = nll[i, s:wlen].to(torch.float64) + loss_sum += scored_nll.sum() + token_count += float(wlen - s) + tgt, prev = y_batch[i, s:wlen], x_batch[i, s:wlen] + tb = base_bytes_lut[tgt].to(torch.float64) + tb += (has_leading_space_lut[tgt] & ~is_boundary_token_lut[prev]).to(torch.float64) + byte_count += tb.sum() + + # --- Phase 2: TRAIN on this chunk's tokens (already scored = legal) --- + _clear_rotary_caches(base_model) + is_last_chunk = (ci == num_chunks - 1) + if not is_last_chunk and args.ttt_epochs > 0: + base_model.train() + chunk_start = ci * ttt_chunk + chunk_end = min((ci + 1) * ttt_chunk, total_tokens) + chunk_seqs = (chunk_end - chunk_start) // seq_len + if chunk_seqs > 0: + # Cosine LR schedule with 5% warmup + warmup_chunks = max(num_chunks // 20, 1) + if ci < warmup_chunks: + lr_scale = (ci + 1) / warmup_chunks + else: + progress = (ci - warmup_chunks) / max(num_chunks - 1 - warmup_chunks, 1) + lr_scale = 0.5 * (1.0 + math.cos(math.pi * progress)) + cos_lr = args.ttt_lr * lr_scale + for pg in optimizer.param_groups: + pg['lr'] = cos_lr + + my_seq_s = (chunk_seqs * rank) // world_size + my_seq_e = (chunk_seqs * (rank + 1)) // world_size + my_chunk_seqs = my_seq_e - my_seq_s + + for _ep in range(args.ttt_epochs): + for bs in range(0, my_chunk_seqs, args.ttt_batch_seqs): + be = min(bs + args.ttt_batch_seqs, my_chunk_seqs) + actual_bs = my_seq_s + bs + actual_be = my_seq_s + be + start_tok = chunk_start + actual_bs * seq_len + end_tok = chunk_start + actual_be * seq_len + 1 + if end_tok > val_tokens.numel(): + continue + local = val_tokens[start_tok:end_tok].to(device=device, dtype=torch.int64) + x = local[:-1].reshape(-1, seq_len) + y = local[1:].reshape(-1, seq_len) + optimizer.zero_grad(set_to_none=True) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16): + loss = base_model(x, y) + loss.backward() + if world_size > 1: + for p in ttt_params: + if p.grad is not None: + dist.all_reduce(p.grad, op=dist.ReduceOp.AVG) + torch.nn.utils.clip_grad_norm_(ttt_params, 1.0) + optimizer.step() + + if log_fn and (ci % 10 == 0 or ci == num_chunks - 1): + elapsed = time.perf_counter() - t0 + rl = loss_sum.item() / max(token_count.item(), 1) + rbpb = rl / math.log(2.0) * (token_count.item() / max(byte_count.item(), 1)) if token_count.item() > 0 else 0.0 + log_fn(f" ttt_chunk [{ci+1}/{num_chunks}] bpb={rbpb:.6f} time={elapsed:.1f}s") + + if dist.is_available() and dist.is_initialized(): + dist.all_reduce(loss_sum, op=dist.ReduceOp.SUM) + dist.all_reduce(token_count, op=dist.ReduceOp.SUM) + dist.all_reduce(byte_count, op=dist.ReduceOp.SUM) + + val_loss = (loss_sum / token_count).item() + val_bpb = val_loss / math.log(2.0) * (token_count.item() / byte_count.item()) + + # Restore all params and return to eval mode + for p in base_model.parameters(): + p.requires_grad_(True) + base_model.eval() + + if log_fn: + log_fn(f"ttt_sliding:done val_loss={val_loss:.6f} val_bpb={val_bpb:.6f} " + f"elapsed={time.perf_counter() - t0:.1f}s") + return val_loss, val_bpb + + +# ----------------------------- +# INT6 MIXED QUANTIZATION (transplanted from working diagnostic scripts) +# ----------------------------- + +def _classify_param(name: str) -> str: + if "tok_emb" in name or "lm_head" in name: + return "embed" + if ".mlp." in name: + return "mlp" + if ".attn." in name or (".proj." in name and ".mlp." not in name): + return "attn" + return "other" + +def quantize_int6_per_row(t: Tensor, qmax: int = 31) -> tuple[Tensor, Tensor]: + t32 = t.float() + qmin = -qmax - 1 + pct = CastedLinear._quant_percentile + if t32.ndim == 2: + row_max = (torch.quantile(t32.abs(), pct, dim=1) if pct < 1.0 + else t32.abs().amax(dim=1)) + scale = (row_max / float(qmax)).clamp_min(1.0 / float(qmax)).to(torch.float16) + clipped = t32.clamp(-row_max[:, None], row_max[:, None]) + q = torch.clamp(torch.round(clipped / scale.float()[:, None]), qmin, qmax).to(torch.int8) + return q, scale + amax = t32.abs().max().item() + scale = torch.tensor(amax / float(qmax) if amax > 0 else 1.0, dtype=torch.float16) + q = torch.clamp(torch.round(t32 / scale.float()), qmin, qmax).to(torch.int8) + return q, scale + +def mixed_quantize_int6(state_dict: dict[str, Tensor], int6_cats: set[str], + int5_layers: set[int] | None = None): + num_layers_total = max( + (int(k.split(".")[1]) for k in state_dict if k.startswith("blocks.")), + default=0, + ) + 1 + late_k_layers = set(range(num_layers_total - 2, num_layers_total)) + if int5_layers is None: + int5_layers = set() + + result: dict[str, Tensor] = {} + meta: dict[str, object] = {} + for name, tensor in state_dict.items(): + t = tensor.detach().cpu().contiguous() + cat = _classify_param(name) + if not t.is_floating_point() or t.numel() <= 65536: + result[name] = t.to(torch.float16) if t.is_floating_point() else t + meta[name] = "passthrough" + continue + if any(p in name for p in CONTROL_TENSOR_NAME_PATTERNS): + result[name] = t.float() + meta[name] = "passthrough_ctrl" + continue + # Determine layer index for int5 fallback + layer_idx = -1 + if name.startswith("blocks."): + try: + layer_idx = int(name.split(".")[1]) + except (IndexError, ValueError): + pass + # tok_emb.weight falls through to int8 via "embed" category + if cat in int6_cats and t.ndim >= 1: + qmax = 15 if layer_idx in int5_layers else 31 + q, s = quantize_int6_per_row(t, qmax=qmax) + result[name + ".q"] = q + result[name + ".scale"] = s + meta[name] = {"type": "int5" if qmax == 15 else "int6"} + else: + q, s = quantize_float_tensor(t) + result[name + ".q"] = q + result[name + ".scale"] = s + meta[name] = {"type": "int8"} + return result, meta + +def dequantize_mixed_int6(result: dict[str, Tensor], meta: dict[str, object], + template_sd: dict[str, Tensor]) -> dict[str, Tensor]: + out: dict[str, Tensor] = {} + for name, orig in template_sd.items(): + info = meta.get(name) + if info is None: + continue + orig_dtype = orig.dtype + if info in ("passthrough", "passthrough_ctrl", "passthrough_fp16"): + t = result[name] + if t.dtype == torch.float16 and orig_dtype in (torch.float32, torch.bfloat16): + t = t.to(orig_dtype) + out[name] = t + continue + q, s = result[name + ".q"], result[name + ".scale"] + if s.ndim > 0: + out[name] = (q.float() * s.float().view(q.shape[0], *([1] * (q.ndim - 1)))).to(orig_dtype) + else: + out[name] = (q.float() * float(s.item())).to(orig_dtype) + return out + + +# ----------------------------- +# TRAINING +# ----------------------------- + +def main() -> None: + global zeropower_via_newtonschulz5 + + code = Path(__file__).read_text(encoding="utf-8") + args = Hyperparameters() + zeropower_via_newtonschulz5 = torch.compile(zeropower_via_newtonschulz5) + + # ----------------------------- + # DISTRIBUTED + CUDA SETUP + # ----------------------------- + + distributed = "RANK" in os.environ and "WORLD_SIZE" in os.environ and int(os.environ.get("WORLD_SIZE", "1")) > 1 + rank = int(os.environ.get("RANK", "0")) + world_size = int(os.environ.get("WORLD_SIZE", "1")) + local_rank = int(os.environ.get("LOCAL_RANK", "0")) + if world_size <= 0: + raise ValueError(f"WORLD_SIZE must be positive, got {world_size}") + if 8 % world_size != 0: + raise ValueError(f"WORLD_SIZE={world_size} must divide 8 so grad_accum_steps stays integral") + grad_accum_steps = 8 // world_size + grad_scale = 1.0 / grad_accum_steps + if not torch.cuda.is_available(): + raise RuntimeError("CUDA is required") + device = torch.device("cuda", local_rank) + torch.cuda.set_device(device) + if distributed: + dist.init_process_group(backend="nccl", device_id=device) + dist.barrier() + master_process = rank == 0 + + # Fast math knobs + torch.backends.cuda.matmul.allow_tf32 = True + torch.backends.cudnn.allow_tf32 = True + from torch.backends.cuda import enable_cudnn_sdp, enable_flash_sdp, enable_math_sdp, enable_mem_efficient_sdp + + if _USE_FA3: + enable_cudnn_sdp(False) + enable_flash_sdp(True) + enable_mem_efficient_sdp(False) + enable_math_sdp(False) + else: + enable_cudnn_sdp(False) + enable_flash_sdp(True) + enable_mem_efficient_sdp(True) + enable_math_sdp(True) + + logfile = None + if master_process: + os.makedirs("logs", exist_ok=True) + logfile = f"logs/{args.run_id}.txt" + print(logfile) + + def log0(msg: str, console: bool = True) -> None: + if not master_process: + return + if console: + print(msg) + if logfile is not None: + with open(logfile, "a", encoding="utf-8") as f: + print(msg, file=f) + + log0(code, console=False) + log0("=" * 100, console=False) + log0(f"Running Python {sys.version}", console=False) + log0(f"Running PyTorch {torch.__version__}", console=False) + log0( + subprocess.run(["nvidia-smi"], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True, check=False).stdout, + console=False, + ) + log0("=" * 100, console=False) + + # ----------------------------- + # TOKENIZER + VALIDATION METRIC SETUP + # ----------------------------- + + random.seed(args.seed) + np.random.seed(args.seed) + torch.manual_seed(args.seed) + torch.cuda.manual_seed_all(args.seed) + + if not args.tokenizer_path.endswith(".model"): + raise ValueError(f"Script only setup for SentencePiece .model file: {args.tokenizer_path}") + sp = spm.SentencePieceProcessor(model_file=args.tokenizer_path) + if int(sp.vocab_size()) != args.vocab_size: + raise ValueError( + f"VOCAB_SIZE={args.vocab_size} does not match tokenizer vocab_size={int(sp.vocab_size())}" + ) + dataset_dir = Path(args.data_path).resolve() + actual_train_files = len(list(dataset_dir.glob("fineweb_train_*.bin"))) + effective_eval_seq_len = args.eval_seq_len if args.eval_seq_len > 0 else args.train_seq_len + val_seq_len = max(args.train_seq_len, effective_eval_seq_len) + val_tokens = load_validation_tokens(args.val_files, val_seq_len) + base_bytes_lut, has_leading_space_lut, is_boundary_token_lut = build_sentencepiece_luts( + sp, args.vocab_size, device + ) + log0(f"val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path={args.tokenizer_path}") + log0(f"train_loader:dataset:{dataset_dir.name} train_shards:{actual_train_files}") + log0(f"val_loader:shards pattern={args.val_files} tokens:{val_tokens.numel() - 1}") + + # ----------------------------- + # MODEL + OPTIMIZER SETUP + # ----------------------------- + + CastedLinear._qat_enabled = args.qat_enabled + + base_model = GPT( + vocab_size=args.vocab_size, + num_layers=args.num_layers, + model_dim=args.model_dim, + num_heads=args.num_heads, + num_kv_heads=args.num_kv_heads, + mlp_mult=args.mlp_mult, + tie_embeddings=args.tie_embeddings, + tied_embed_init_std=args.tied_embed_init_std, + logit_softcap=args.logit_softcap, + rope_base=args.rope_base, + qk_gain_init=args.qk_gain_init, + mtp_num_heads=args.mtp_num_heads, + mtp_loss_weight=args.mtp_loss_weight, + bigram_vocab_size=args.bigram_vocab_size, + bigram_dim=args.bigram_dim, + xsa_last_n=args.xsa_last_n, + rope_dims=args.rope_dims, + ln_scale=args.ln_scale, + value_residual=args.value_residual, + gated_attention=args.gated_attention, + canon_last_n=args.canon_last_n, + canon_kernel=args.canon_kernel, + canon_delta_gate_init=args.canon_delta_gate_init, + flow_enabled=args.flow_enabled, + flow_latent_dim=args.flow_latent_dim, + flow_hidden_dim=args.flow_hidden_dim, + flow_init_scale=args.flow_init_scale, + native_flow_enabled=args.native_flow_enabled, + native_flow_hidden_dim=args.native_flow_hidden_dim, + native_flow_init_scale=args.native_flow_init_scale, + native_flow_loss_weight=args.native_flow_loss_weight, + e2e_ttt_enabled=args.e2e_ttt_enabled, + e2e_ttt_num_heads=args.e2e_ttt_num_heads, + e2e_ttt_mini_batch=args.e2e_ttt_mini_batch, + e2e_ttt_base_lr=args.e2e_ttt_base_lr, + ).to(device).bfloat16() + for module in base_model.modules(): + if isinstance(module, CastedLinear): + module.float() + restore_low_dim_params_to_fp32(base_model) + compiled_model = torch.compile(base_model, dynamic=False, fullgraph=True) + model: nn.Module = DDP(compiled_model, device_ids=[local_rank], broadcast_buffers=False, static_graph=True) if distributed else compiled_model + + # Optimizer split: + # - token embedding (Adam) uses EMBED_LR + # - untied lm_head (Adam) uses HEAD_LR + # - matrix params in transformer blocks use MATRIX_LR via Muon + # - vectors/scalars use SCALAR_LR via Adam + block_named_params = list(base_model.blocks.named_parameters()) + matrix_params = [ + p + for name, p in block_named_params + if p.ndim == 2 and not any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) + ] + if base_model.mtp_num_heads > 0: + matrix_params.extend([p for p in base_model.mtp_heads.parameters() if p.ndim == 2]) + scalar_params = [ + p + for name, p in block_named_params + if p.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) + ] + if base_model.skip_weights.numel() > 0: + scalar_params.append(base_model.skip_weights) + scalar_params.append(base_model.smear.gate) + if base_model.bigram is not None: + scalar_params.append(base_model.bigram.scale) + if base_model.flow_refiner is not None: + scalar_params.extend(list(base_model.flow_refiner.parameters())) + if base_model.native_flow is not None: + scalar_params.extend(list(base_model.native_flow.parameters())) + if base_model.ttt_refiner is not None: + scalar_params.extend(list(base_model.ttt_refiner.parameters())) + token_lr = args.tied_embed_lr if args.tie_embeddings else args.embed_lr + tok_params = [{"params": [base_model.tok_emb.weight], "lr": token_lr, "base_lr": token_lr}] + if base_model.bigram is not None: + tok_params.append({"params": [base_model.bigram.embed.weight], "lr": token_lr, "base_lr": token_lr}) + if base_model.bigram.proj is not None: + matrix_params.append(base_model.bigram.proj.weight) + optimizer_tok = torch.optim.AdamW( + tok_params, + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + weight_decay=args.adam_wd, + fused=True, + ) + optimizer_muon = Muon( + matrix_params, + lr=args.matrix_lr, + momentum=args.muon_momentum, + backend_steps=args.muon_backend_steps, + weight_decay=args.muon_wd, + ) + for group in optimizer_muon.param_groups: + group["base_lr"] = args.matrix_lr + optimizer_scalar = torch.optim.AdamW( + [{"params": scalar_params, "lr": args.scalar_lr, "base_lr": args.scalar_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + weight_decay=args.adam_wd, + fused=True, + ) + optimizers: list[torch.optim.Optimizer] = [optimizer_tok, optimizer_muon, optimizer_scalar] + if base_model.lm_head is not None: + optimizer_head = torch.optim.Adam( + [{"params": [base_model.lm_head.weight], "lr": args.head_lr, "base_lr": args.head_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + fused=True, + ) + optimizers.insert(1, optimizer_head) + + n_params = sum(p.numel() for p in base_model.parameters()) + mtp_params = sum(p.numel() for p in base_model.mtp_heads.parameters()) + log0(f"model_params:{n_params}") + log0(f"mtp_num_heads:{args.mtp_num_heads} mtp_loss_weight:{args.mtp_loss_weight} mtp_params:{mtp_params}") + log0(f"world_size:{world_size} grad_accum_steps:{grad_accum_steps}") + log0(f"sdp_backends:fa3={_USE_FA3} cudnn=False flash=True mem_efficient={not _USE_FA3} math={not _USE_FA3}") + log0(f"attention_mode:gqa num_heads:{args.num_heads} num_kv_heads:{args.num_kv_heads}") + log0( + f"tie_embeddings:{args.tie_embeddings} embed_lr:{token_lr} " + f"head_lr:{args.head_lr if base_model.lm_head is not None else 0.0} " + f"matrix_lr:{args.matrix_lr} scalar_lr:{args.scalar_lr}" + ) + log0( + f"train_batch_tokens:{args.train_batch_tokens} train_seq_len:{args.train_seq_len} " + f"iterations:{args.iterations} warmup_steps:{args.warmup_steps} " + f"max_wallclock_seconds:{args.max_wallclock_seconds:.3f}" + ) + log0(f"seed:{args.seed}") + + # ----------------------------- + # DATA LOADER & MODEL WARMUP + # ----------------------------- + + train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) + + def zero_grad_all() -> None: + for opt in optimizers: + opt.zero_grad(set_to_none=True) + + max_wallclock_ms = 1000.0 * args.max_wallclock_seconds if args.max_wallclock_seconds > 0 else None + + def lr_mul(step: int, elapsed_ms: float) -> float: + if args.warmdown_iters <= 0: + return 1.0 + if max_wallclock_ms is None: + warmdown_start = max(args.iterations - args.warmdown_iters, 0) + return max((args.iterations - step) / max(args.warmdown_iters, 1), 0.0) if warmdown_start <= step < args.iterations else 1.0 + step_ms = elapsed_ms / max(step, 1) + warmdown_ms = args.warmdown_iters * step_ms + remaining_ms = max(max_wallclock_ms - elapsed_ms, 0.0) + return remaining_ms / max(warmdown_ms, 1e-9) if remaining_ms <= warmdown_ms else 1.0 + + # Warmup primes the compiled forward/backward/optimizer paths, then we restore the + # initial weights/optimizer state so measured training starts from the true init. + if args.warmup_steps > 0: + initial_model_state = {name: tensor.detach().cpu().clone() for name, tensor in base_model.state_dict().items()} + initial_optimizer_states = [copy.deepcopy(opt.state_dict()) for opt in optimizers] + model.train() + for warmup_step in range(args.warmup_steps): + zero_grad_all() + for micro_step in range(grad_accum_steps): + if distributed: + model.require_backward_grad_sync = micro_step == grad_accum_steps - 1 + x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + warmup_loss = model(x, y) + (warmup_loss * grad_scale).backward() + for opt in optimizers: + opt.step() + zero_grad_all() + if args.warmup_steps <= 20 or (warmup_step + 1) % 10 == 0 or warmup_step + 1 == args.warmup_steps: + log0(f"warmup_step:{warmup_step + 1}/{args.warmup_steps}") + base_model.load_state_dict(initial_model_state, strict=True) + for opt, state in zip(optimizers, initial_optimizer_states, strict=True): + opt.load_state_dict(state) + zero_grad_all() + if distributed: + model.require_backward_grad_sync = True + train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) + + # ----------------------------- + # MAIN TRAINING LOOP + # ----------------------------- + + swa_state: dict[str, Tensor] | None = None + swa_count = 0 + + ema_state: dict[str, Tensor] | None = None + if args.ema_enabled: + ema_state = {name: t.detach().float().clone() for name, t in base_model.state_dict().items()} + + training_time_ms = 0.0 + stop_after_step: int | None = None + torch.cuda.synchronize() + t0 = time.perf_counter() + + eval_only_path = os.environ.get("EVAL_ONLY", "") + if eval_only_path: + log0(f"eval_only: loading {eval_only_path}, skipping training") + base_model.load_state_dict(torch.load(eval_only_path, map_location=device, weights_only=False), strict=False) + ema_state = None # prevent random EMA from overwriting loaded weights + swa_state = None + swa_count = 0 + args.iterations = 0 # skip training, go straight to eval + + step = 0 + while True: + last_step = step == args.iterations or (stop_after_step is not None and step >= stop_after_step) + + should_validate = last_step or (args.val_loss_every > 0 and step % args.val_loss_every == 0) + if should_validate: + torch.cuda.synchronize() + training_time_ms += 1000.0 * (time.perf_counter() - t0) + val_loss, val_bpb = eval_val( + args, + model, + rank, + world_size, + device, + grad_accum_steps, + val_tokens, + base_bytes_lut, + has_leading_space_lut, + is_boundary_token_lut, + ) + log0( + f"step:{step}/{args.iterations} val_loss:{val_loss:.4f} val_bpb:{val_bpb:.4f} " + f"train_time:{training_time_ms:.0f}ms step_avg:{training_time_ms / max(step, 1):.2f}ms" + ) + torch.cuda.synchronize() + t0 = time.perf_counter() + + if ( + args.checkpoint_every > 0 + and step > 0 + and step % args.checkpoint_every == 0 + and not last_step + and master_process + ): + ckpt_sd = {k: v for k, v in base_model.state_dict().items() if "mtp_heads" not in k} + ckpt_path = f"checkpoint_step{step}_{args.run_id}.pt" + torch.save(ckpt_sd, ckpt_path) + log0(f"checkpoint_saved: {ckpt_path} ({os.path.getsize(ckpt_path)} bytes)") + + if last_step: + if stop_after_step is not None and step < args.iterations: + log0( + f"stopping_early: wallclock_cap train_time:{training_time_ms:.0f}ms " + f"step:{step}/{args.iterations}" + ) + break + + elapsed_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) + scale = lr_mul(step, elapsed_ms) + qat_threshold = float(os.environ.get("QAT_THRESHOLD", "0.1")) + if args.late_qat and scale < qat_threshold and not CastedLinear._qat_enabled: + CastedLinear._qat_enabled = True + CastedLinear._soft_round = args.soft_round_qat + log0(f"late_qat:enabled step:{step} scale:{scale:.4f} soft_round:{args.soft_round_qat}") + if CastedLinear._qat_enabled and CastedLinear._soft_round: + qat_progress = max(0.0, 1.0 - (scale / qat_threshold)) + CastedLinear._soft_round_alpha = 1.0 + 15.0 * qat_progress # 1→16 + zero_grad_all() + train_loss = torch.zeros((), device=device) + for micro_step in range(grad_accum_steps): + if distributed: + model.require_backward_grad_sync = micro_step == grad_accum_steps - 1 + x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + loss = model(x, y) + train_loss += loss.detach() + (loss * grad_scale).backward() + train_loss /= grad_accum_steps + + frac = min(step / args.muon_momentum_warmup_steps, 1.0) if args.muon_momentum_warmup_steps > 0 else 1.0 + muon_momentum = (1 - frac) * args.muon_momentum_warmup_start + frac * args.muon_momentum + for group in optimizer_muon.param_groups: + group["momentum"] = muon_momentum + + for opt in optimizers: + for group in opt.param_groups: + group["lr"] = group["base_lr"] * scale + + if args.grad_clip_norm > 0: + torch.nn.utils.clip_grad_norm_(base_model.parameters(), args.grad_clip_norm) + for opt in optimizers: + opt.step() + zero_grad_all() + + step += 1 + approx_training_time_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) + + if ema_state is not None: + d = args.ema_decay + with torch.no_grad(): + for name, t in base_model.state_dict().items(): + ema_state[name].mul_(d).add_(t.detach().float(), alpha=1.0 - d) + + if args.swa_enabled and scale < 0.5 and step % args.swa_every == 0: + # Tight SWA: collect from EMA state if available, else from raw model + src = ema_state if ema_state is not None else {name: t.detach().float() for name, t in base_model.state_dict().items()} + if swa_state is None: + swa_state = {name: t.clone() for name, t in src.items()} + swa_count = 1 + log0(f"swa:start step:{step} tight={ema_state is not None}") + else: + for name in swa_state: + swa_state[name].add_(src[name]) + swa_count += 1 + + should_log_train = ( + args.train_log_every > 0 + and (step <= 10 or step % args.train_log_every == 0 or stop_after_step is not None) + ) + if should_log_train: + log0( + f"step:{step}/{args.iterations} train_loss:{train_loss.item():.4f} " + f"train_time:{approx_training_time_ms:.0f}ms step_avg:{approx_training_time_ms / step:.2f}ms" + ) + + # Needed to sync whether we've reached the wallclock cap. + reached_cap = max_wallclock_ms is not None and approx_training_time_ms >= max_wallclock_ms + if distributed and max_wallclock_ms is not None: + reached_cap_tensor = torch.tensor(int(reached_cap), device=device) + dist.all_reduce(reached_cap_tensor, op=dist.ReduceOp.MAX) + reached_cap = bool(reached_cap_tensor.item()) + if stop_after_step is None and reached_cap: + stop_after_step = step + + log0( + f"peak memory allocated: {torch.cuda.max_memory_allocated() // 1024 // 1024} MiB " + f"reserved: {torch.cuda.max_memory_reserved() // 1024 // 1024} MiB" + ) + + if args.swa_enabled and swa_state is not None and swa_count > 1: + log0(f"swa:applying Tight SWA averaged {swa_count} EMA checkpoints") + avg_state = {name: (t / swa_count).to(dtype=base_model.state_dict()[name].dtype) + for name, t in swa_state.items()} + del swa_state + if ema_state is not None: + del ema_state + base_model.load_state_dict(avg_state, strict=True) + elif ema_state is not None: + log0("ema:applying EMA weights") + avg_state = {name: t.to(dtype=base_model.state_dict()[name].dtype) + for name, t in ema_state.items()} + del ema_state + base_model.load_state_dict(avg_state, strict=True) + + # ----------------------------- + # SERIALIZATION + ROUNDTRIP VALIDATION + # ----------------------------- + + # Use RUN_ID for unique filenames when multiple jobs share CWD + _run_id = os.environ.get("RUN_ID", "") + _save_prefix = f"final_model_{_run_id}" if _run_id else "final_model" + _pt_path = f"{_save_prefix}.pt" + _ptz_path = f"{_save_prefix}.int6.ptz" + log0(f"save_paths: pt={_pt_path} ptz={_ptz_path}") + + full_state_dict = base_model.state_dict() + export_sd = {k: v for k, v in full_state_dict.items() if "mtp_heads" not in k} + excluded_mtp = sum(int(t.numel()) for k, t in full_state_dict.items() if "mtp_heads" in k) + if excluded_mtp > 0: + log0(f"export_excluding_mtp_params:{excluded_mtp}") + + if master_process: + torch.save(export_sd, _pt_path) + model_bytes = os.path.getsize(_pt_path) + code_bytes = len(code.encode("utf-8")) + log0(f"Serialized model: {model_bytes} bytes") + log0(f"Code size: {code_bytes} bytes") + + sd_cpu = {k: v.detach().cpu() for k, v in export_sd.items()} + code_bytes = len(code.encode("utf-8")) + artifact_limit = 16_000_000 - code_bytes + + # --- Auto-downgrade quantization: try int6 first, fall back to int5 middle layers --- + num_layers_total = max( + (int(k.split(".")[1]) for k in sd_cpu if k.startswith("blocks.")), + default=0, + ) + 1 + _zstd_levels = [int(os.environ.get("ZSTD_LEVEL", "16")), 1, 17, 2] + # Phase 1: pure int6 with multiple zstd levels + quant_result, quant_meta = mixed_quantize_int6(sd_cpu, {"mlp", "attn"}) + quant_buf = io.BytesIO() + torch.save({"w": quant_result, "m": quant_meta}, quant_buf) + quant_raw = quant_buf.getvalue() + quant_blob = None + chosen_level = _zstd_levels[0] + for lvl in _zstd_levels: + blob = zstandard.ZstdCompressor(level=lvl).compress(quant_raw) if _COMPRESSOR == "zstd" else zlib.compress(quant_raw, 9) + if master_process: + log0(f"quant_try int6 zstd-{lvl}: {len(blob)} bytes (limit {artifact_limit})") + if len(blob) <= artifact_limit: + quant_blob = blob + chosen_level = lvl + break + # Phase 2: progressive int5 fallback — one layer at a time from middle outward + if quant_blob is None: + mid = num_layers_total // 2 + # Expand outward from center: L5, L4, L6, L3, L7, L2, L8, ... + candidates = [] + for offset in range(num_layers_total): + for sign in [0, 1]: + layer = mid + offset if sign == 0 else mid - offset + if 0 <= layer < num_layers_total and layer not in candidates: + candidates.append(layer) + int5_layers: set[int] = set() + for layer in candidates: + int5_layers.add(layer) + if master_process: + log0(f"quant_fallback: int5 layers={sorted(int5_layers)}") + quant_result, quant_meta = mixed_quantize_int6(sd_cpu, {"mlp", "attn"}, int5_layers=int5_layers) + quant_buf = io.BytesIO() + torch.save({"w": quant_result, "m": quant_meta}, quant_buf) + quant_raw = quant_buf.getvalue() + for lvl in _zstd_levels: + blob = zstandard.ZstdCompressor(level=lvl).compress(quant_raw) if _COMPRESSOR == "zstd" else zlib.compress(quant_raw, 9) + if master_process: + log0(f"quant_try int5[{len(int5_layers)}L] zstd-{lvl}: {len(blob)} bytes (limit {artifact_limit})") + if len(blob) <= artifact_limit: + quant_blob = blob + chosen_level = lvl + break + if quant_blob is not None: + break + if quant_blob is None: + quant_blob = blob # Use last attempt even if over limit + if master_process: + log0(f"WARNING: artifact still over limit after all fallbacks") + if master_process: + with open(_ptz_path, "wb") as f: + f.write(quant_blob) + quant_file_bytes = len(quant_blob) + log0(f"Serialized model quant+{_COMPRESSOR}-{chosen_level}: {quant_file_bytes} bytes") + log0(f"Total submission size: {quant_file_bytes + code_bytes} bytes") + + # Roundtrip: decompress + dequantize into fresh model + eval + if distributed: + dist.barrier() + with open(_ptz_path, "rb") as f: + quant_blob_disk = f.read() + quant_state = torch.load( + io.BytesIO(zstandard.ZstdDecompressor().decompress(quant_blob_disk) if _COMPRESSOR == "zstd" else zlib.decompress(quant_blob_disk)), + map_location="cpu", + ) + deq_state = dequantize_mixed_int6(quant_state["w"], quant_state["m"], sd_cpu) + + eval_model = GPT( + vocab_size=args.vocab_size, num_layers=args.num_layers, model_dim=args.model_dim, + num_heads=args.num_heads, num_kv_heads=args.num_kv_heads, mlp_mult=args.mlp_mult, + tie_embeddings=args.tie_embeddings, tied_embed_init_std=args.tied_embed_init_std, + logit_softcap=args.logit_softcap, rope_base=args.rope_base, qk_gain_init=args.qk_gain_init, + mtp_num_heads=0, mtp_loss_weight=0.0, + bigram_vocab_size=args.bigram_vocab_size, bigram_dim=args.bigram_dim, + xsa_last_n=args.xsa_last_n, + rope_dims=args.rope_dims, + ln_scale=args.ln_scale, + value_residual=args.value_residual, + gated_attention=args.gated_attention, + canon_last_n=args.canon_last_n, + canon_kernel=args.canon_kernel, + canon_delta_gate_init=args.canon_delta_gate_init, + flow_enabled=args.flow_enabled, + flow_latent_dim=args.flow_latent_dim, + flow_hidden_dim=args.flow_hidden_dim, + flow_init_scale=args.flow_init_scale, + native_flow_enabled=args.native_flow_enabled, + native_flow_hidden_dim=args.native_flow_hidden_dim, + native_flow_init_scale=args.native_flow_init_scale, + native_flow_loss_weight=args.native_flow_loss_weight, + e2e_ttt_enabled=args.e2e_ttt_enabled, + e2e_ttt_num_heads=args.e2e_ttt_num_heads, + e2e_ttt_mini_batch=args.e2e_ttt_mini_batch, + e2e_ttt_base_lr=args.e2e_ttt_base_lr, + ).to(device).bfloat16() + for m in eval_model.modules(): + if isinstance(m, CastedLinear): + m.float() + restore_low_dim_params_to_fp32(eval_model) + eval_model.load_state_dict(deq_state, strict=True) + + # TTT: adapt model on validation data before eval + legal_ttt = bool(int(os.environ.get("LEGAL_TTT", "0"))) + if args.ttt_enabled and not legal_ttt: + # --- Invalid two-pass TTT (adapt then eval separately) --- + if distributed: + dist.barrier() + for block in eval_model.blocks: + block.attn.rotary._cos_cached = None + block.attn.rotary._sin_cached = None + block.attn.rotary._seq_len_cached = 0 + log0(f"ttt:start score-first optimizer={args.ttt_optimizer} lr={args.ttt_lr} " + f"epochs={args.ttt_epochs} freeze_blocks={args.ttt_freeze_blocks} " + f"chunk_tokens={args.ttt_chunk_tokens}") + t_ttt = time.perf_counter() + ttt_adapt(args, eval_model, device, val_tokens, + rank=rank, world_size=world_size, log_fn=log0) + log0(f"ttt:elapsed={time.perf_counter() - t_ttt:.1f}s") + if distributed: + dist.barrier() + + compiled_eval = torch.compile(eval_model, dynamic=False, fullgraph=True) + + # Standard non-overlapping eval (sanity check) + torch.cuda.synchronize() + t_qeval = time.perf_counter() + q_val_loss, q_val_bpb = eval_val( + args, compiled_eval, rank, world_size, device, grad_accum_steps, + val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, + eval_seq_len=effective_eval_seq_len, + ) + torch.cuda.synchronize() + log0( + f"final_int6_roundtrip val_loss:{q_val_loss:.4f} val_bpb:{q_val_bpb:.4f} " + f"eval_time:{1000.0 * (time.perf_counter() - t_qeval):.0f}ms" + ) + log0(f"final_int6_roundtrip_exact val_loss:{q_val_loss:.8f} val_bpb:{q_val_bpb:.8f}") + + # Sliding window eval (submission score) + sw_seq_len = effective_eval_seq_len + if args.eval_stride > 0 and args.eval_stride < sw_seq_len: + torch.cuda.synchronize() + t_slide = time.perf_counter() + if legal_ttt and args.ttt_enabled: + # Legal single-pass TTT: score → train interleaved per chunk + log0(f"legal_ttt:start stride={args.eval_stride} " + f"optimizer={args.ttt_optimizer} lr={args.ttt_lr} " + f"epochs={args.ttt_epochs} freeze_blocks={args.ttt_freeze_blocks}") + sw_val_loss, sw_val_bpb = eval_val_sliding_ttt( + args, eval_model, rank, world_size, device, + val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, + stride=args.eval_stride, + eval_seq_len=sw_seq_len, + log_fn=log0, + ) + else: + sw_val_loss, sw_val_bpb = eval_val_sliding( + args, eval_model, rank, world_size, device, + val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, + stride=args.eval_stride, + eval_seq_len=sw_seq_len, + ) + torch.cuda.synchronize() + log0( + f"final_int6_sliding_window val_loss:{sw_val_loss:.4f} val_bpb:{sw_val_bpb:.4f} " + f"stride:{args.eval_stride} eval_time:{1000.0 * (time.perf_counter() - t_slide):.0f}ms" + ) + log0(f"final_int6_sliding_window_exact val_loss:{sw_val_loss:.8f} val_bpb:{sw_val_bpb:.8f}") + + # Second sliding window eval at stride=64 for submission comparison + if args.eval_stride != 64 and 64 < sw_seq_len: + torch.cuda.synchronize() + t_slide64 = time.perf_counter() + sw64_val_loss, sw64_val_bpb = eval_val_sliding( + args, eval_model, rank, world_size, device, + val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, + stride=64, + eval_seq_len=sw_seq_len, + ) + torch.cuda.synchronize() + log0( + f"final_int6_sliding_window_s64 val_loss:{sw64_val_loss:.4f} val_bpb:{sw64_val_bpb:.4f} " + f"stride:64 eval_time:{1000.0 * (time.perf_counter() - t_slide64):.0f}ms" + ) + log0(f"final_int6_sliding_window_s64_exact val_loss:{sw64_val_loss:.8f} val_bpb:{sw64_val_bpb:.8f}") + + if distributed: + dist.destroy_process_group() + + +if __name__ == "__main__": + main() + + +==================================================================================================== +Running Python 3.12.7 | packaged by conda-forge | (main, Oct 4 2024, 16:05:46) [GCC 13.3.0] +Running PyTorch 2.10.0+cu128 +Tue Mar 31 17:19:38 2026 ++-----------------------------------------------------------------------------------------+ +| NVIDIA-SMI 570.172.08 Driver Version: 570.172.08 CUDA Version: 12.8 | +|-----------------------------------------+------------------------+----------------------+ +| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC | +| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. | +| | | MIG M. | +|=========================================+========================+======================| +| 0 NVIDIA A100-PCIE-40GB On | 00000000:17:00.0 Off | 0 | +| N/A 44C P0 49W / 250W | 423MiB / 40960MiB | 1% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ + ++-----------------------------------------------------------------------------------------+ +| Processes: | +| GPU GI CI PID Type Process name GPU Memory | +| ID ID Usage | +|=========================================================================================| +| 0 N/A N/A 355166 C ...ameter_golf/.venv/bin/python3 414MiB | ++-----------------------------------------------------------------------------------------+ + +==================================================================================================== +val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path=/hpfs/scratch/gpfs/mcclec07/code/parameter_golf/repo/data/tokenizers/fineweb_1024_bpe.model +train_loader:dataset:fineweb10B_sp1024 train_shards:80 +val_loader:shards pattern=/hpfs/scratch/gpfs/mcclec07/code/parameter_golf/repo/data/datasets/fineweb10B_sp1024/fineweb_val_*.bin tokens:62021632 +model_params:27530952 +mtp_num_heads:0 mtp_loss_weight:0.2 mtp_params:0 +world_size:1 grad_accum_steps:8 +sdp_backends:fa3=False cudnn=False flash=True mem_efficient=True math=True +attention_mode:gqa num_heads:8 num_kv_heads:4 +tie_embeddings:True embed_lr:0.035 head_lr:0.0 matrix_lr:0.025 scalar_lr:0.025 +train_batch_tokens:786432 train_seq_len:2048 iterations:7000 warmup_steps:20 max_wallclock_seconds:0.000 +seed:1337 +warmup_step:1/20 +warmup_step:2/20 +warmup_step:3/20 +warmup_step:4/20 +warmup_step:5/20 +warmup_step:6/20 +warmup_step:7/20 +warmup_step:8/20 +warmup_step:9/20 +warmup_step:10/20 +warmup_step:11/20 +warmup_step:12/20 +warmup_step:13/20 +warmup_step:14/20 +warmup_step:15/20 +warmup_step:16/20 +warmup_step:17/20 +warmup_step:18/20 +warmup_step:19/20 +warmup_step:20/20 +step:0/7000 val_loss:6.9294 val_bpb:4.1040 train_time:0ms step_avg:0.01ms +step:1/7000 train_loss:7.1301 train_time:1870ms step_avg:1870.23ms +step:2/7000 train_loss:9.1323 train_time:3808ms step_avg:1904.02ms +step:3/7000 train_loss:8.0988 train_time:5745ms step_avg:1915.01ms +step:4/7000 train_loss:7.4310 train_time:7684ms step_avg:1921.02ms +step:5/7000 train_loss:7.2609 train_time:9624ms step_avg:1924.71ms +step:6/7000 train_loss:7.2483 train_time:11565ms step_avg:1927.55ms +step:7/7000 train_loss:7.0422 train_time:13508ms step_avg:1929.66ms +step:8/7000 train_loss:6.8783 train_time:15449ms step_avg:1931.17ms +step:9/7000 train_loss:6.5860 train_time:17393ms step_avg:1932.57ms +step:10/7000 train_loss:6.2036 train_time:19340ms step_avg:1934.01ms +step:200/7000 train_loss:2.7335 train_time:393022ms step_avg:1965.11ms +step:400/7000 train_loss:2.5355 train_time:786228ms step_avg:1965.57ms +step:500/7000 val_loss:2.3261 val_bpb:1.3777 train_time:982674ms step_avg:1965.35ms +step:600/7000 train_loss:2.3711 train_time:1179199ms step_avg:1965.33ms +step:800/7000 train_loss:2.3689 train_time:1572034ms step_avg:1965.04ms +step:1000/7000 train_loss:2.3500 train_time:1965102ms step_avg:1965.10ms +step:1000/7000 val_loss:2.2029 val_bpb:1.3047 train_time:1965105ms step_avg:1965.11ms +step:1200/7000 train_loss:2.3094 train_time:2358246ms step_avg:1965.20ms +step:1400/7000 train_loss:2.3491 train_time:2751529ms step_avg:1965.38ms +step:1500/7000 val_loss:2.1612 val_bpb:1.2800 train_time:2948216ms step_avg:1965.48ms +step:1600/7000 train_loss:2.1923 train_time:3144940ms step_avg:1965.59ms +step:1800/7000 train_loss:2.2428 train_time:3538724ms step_avg:1965.96ms +step:2000/7000 train_loss:2.1343 train_time:3932516ms step_avg:1966.26ms +step:2000/7000 val_loss:2.1101 val_bpb:1.2497 train_time:3932520ms step_avg:1966.26ms +step:2200/7000 train_loss:2.2132 train_time:4326232ms step_avg:1966.47ms +step:2400/7000 train_loss:2.1632 train_time:4720234ms step_avg:1966.76ms +step:2500/7000 val_loss:2.0874 val_bpb:1.2363 train_time:4917159ms step_avg:1966.86ms +step:2600/7000 train_loss:2.1639 train_time:5114189ms step_avg:1967.00ms +step:2800/7000 train_loss:2.2085 train_time:5508072ms step_avg:1967.17ms +step:3000/7000 train_loss:2.1608 train_time:5901941ms step_avg:1967.31ms +step:3000/7000 val_loss:2.0742 val_bpb:1.2285 train_time:5901944ms step_avg:1967.31ms +step:3200/7000 train_loss:2.1644 train_time:6295959ms step_avg:1967.49ms +step:3400/7000 train_loss:2.1390 train_time:6690073ms step_avg:1967.67ms +step:3500/7000 val_loss:2.0663 val_bpb:1.2238 train_time:6887089ms step_avg:1967.74ms +step:3600/7000 train_loss:2.1812 train_time:7084186ms step_avg:1967.83ms +step:3800/7000 train_loss:2.1519 train_time:7478201ms step_avg:1967.95ms +step:4000/7000 train_loss:2.2331 train_time:7872398ms step_avg:1968.10ms +step:4000/7000 val_loss:2.0608 val_bpb:1.2205 train_time:7872401ms step_avg:1968.10ms +step:4200/7000 train_loss:2.1234 train_time:8266454ms step_avg:1968.20ms +step:4400/7000 train_loss:2.1051 train_time:8660328ms step_avg:1968.26ms +step:4500/7000 val_loss:2.0439 val_bpb:1.2105 train_time:8857296ms step_avg:1968.29ms +step:4600/7000 train_loss:2.0881 train_time:9054351ms step_avg:1968.34ms +step:4800/7000 train_loss:2.2217 train_time:9448372ms step_avg:1968.41ms +step:5000/7000 train_loss:2.1346 train_time:9842356ms step_avg:1968.47ms +step:5000/7000 val_loss:2.0234 val_bpb:1.1983 train_time:9842359ms step_avg:1968.47ms +step:5200/7000 train_loss:2.1035 train_time:10236364ms step_avg:1968.53ms +step:5400/7000 train_loss:2.0777 train_time:10630253ms step_avg:1968.57ms +step:5500/7000 val_loss:2.0005 val_bpb:1.1848 train_time:10827248ms step_avg:1968.59ms +step:5600/7000 train_loss:2.0621 train_time:11024162ms step_avg:1968.60ms +step:5800/7000 train_loss:2.0492 train_time:11418063ms step_avg:1968.63ms +step:6000/7000 train_loss:2.0302 train_time:11812016ms step_avg:1968.67ms +step:6000/7000 val_loss:1.9772 val_bpb:1.1710 train_time:11812019ms step_avg:1968.67ms +step:6200/7000 train_loss:2.1248 train_time:12205822ms step_avg:1968.68ms +step:6400/7000 train_loss:2.1042 train_time:12599771ms step_avg:1968.71ms +step:6500/7000 val_loss:1.9471 val_bpb:1.1532 train_time:12796711ms step_avg:1968.72ms +step:6600/7000 train_loss:2.0032 train_time:12993629ms step_avg:1968.73ms +step:6800/7000 train_loss:2.0856 train_time:13387398ms step_avg:1968.74ms +step:7000/7000 train_loss:1.9287 train_time:13780882ms step_avg:1968.70ms +step:7000/7000 val_loss:1.9223 val_bpb:1.1385 train_time:13780886ms step_avg:1968.70ms +peak memory allocated: 25832 MiB reserved: 26006 MiB +ema:applying EMA weights +save_paths: pt=final_model_pr940_nflow_s1337_55398556.pt ptz=final_model_pr940_nflow_s1337_55398556.int6.ptz +Serialized model: 107294152 bytes +Code size: 115032 bytes +quant_try int6 zstd-16: 16265007 bytes (limit 15884968) +quant_try int6 zstd-1: 16289666 bytes (limit 15884968) +quant_try int6 zstd-17: 16250866 bytes (limit 15884968) +quant_try int6 zstd-2: 16298408 bytes (limit 15884968) +quant_fallback: int5 layers=[5] +quant_try int5[1L] zstd-16: 15944126 bytes (limit 15884968) +quant_try int5[1L] zstd-1: 15995452 bytes (limit 15884968) +quant_try int5[1L] zstd-17: 16042936 bytes (limit 15884968) +quant_try int5[1L] zstd-2: 16051685 bytes (limit 15884968) +quant_fallback: int5 layers=[5, 6] +quant_try int5[2L] zstd-16: 15621901 bytes (limit 15884968) +Serialized model quant+zstd-16: 15621901 bytes +Total submission size: 15736933 bytes +final_int6_roundtrip val_loss:1.9372 val_bpb:1.1473 eval_time:67819ms +final_int6_roundtrip_exact val_loss:1.93715323 val_bpb:1.14729126 +final_int6_sliding_window val_loss:1.8973 val_bpb:1.1237 stride:64 eval_time:1709300ms +final_int6_sliding_window_exact val_loss:1.89726464 val_bpb:1.12366996 diff --git a/records/track_non_record_16mb/2026-03-31_11L_NativeFlowMatcher_LegalTTT/supplementary/seed_runs/train_s2025.log b/records/track_non_record_16mb/2026-03-31_11L_NativeFlowMatcher_LegalTTT/supplementary/seed_runs/train_s2025.log new file mode 100644 index 0000000000..7c28d28a4e --- /dev/null +++ b/records/track_non_record_16mb/2026-03-31_11L_NativeFlowMatcher_LegalTTT/supplementary/seed_runs/train_s2025.log @@ -0,0 +1,2706 @@ +logs/pr940_nflow_s2025_55398557.txt +val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path=/hpfs/scratch/gpfs/mcclec07/code/parameter_golf/repo/data/tokenizers/fineweb_1024_bpe.model +train_loader:dataset:fineweb10B_sp1024 train_shards:80 +val_loader:shards pattern=/hpfs/scratch/gpfs/mcclec07/code/parameter_golf/repo/data/datasets/fineweb10B_sp1024/fineweb_val_*.bin tokens:62021632 +model_params:27530952 +mtp_num_heads:0 mtp_loss_weight:0.2 mtp_params:0 +world_size:1 grad_accum_steps:8 +sdp_backends:fa3=False cudnn=False flash=True mem_efficient=True math=True +attention_mode:gqa num_heads:8 num_kv_heads:4 +tie_embeddings:True embed_lr:0.035 head_lr:0.0 matrix_lr:0.025 scalar_lr:0.025 +train_batch_tokens:786432 train_seq_len:2048 iterations:7000 warmup_steps:20 max_wallclock_seconds:0.000 +seed:2025 +warmup_step:1/20 +warmup_step:2/20 +warmup_step:3/20 +warmup_step:4/20 +warmup_step:5/20 +warmup_step:6/20 +warmup_step:7/20 +warmup_step:8/20 +warmup_step:9/20 +warmup_step:10/20 +warmup_step:11/20 +warmup_step:12/20 +warmup_step:13/20 +warmup_step:14/20 +warmup_step:15/20 +warmup_step:16/20 +warmup_step:17/20 +warmup_step:18/20 +warmup_step:19/20 +warmup_step:20/20 +step:0/7000 val_loss:6.9336 val_bpb:4.1065 train_time:0ms step_avg:0.01ms +step:1/7000 train_loss:7.1343 train_time:1860ms step_avg:1860.48ms +step:2/7000 train_loss:9.1272 train_time:3803ms step_avg:1901.28ms +step:3/7000 train_loss:8.0549 train_time:5744ms step_avg:1914.78ms +step:4/7000 train_loss:7.4630 train_time:7688ms step_avg:1921.99ms +step:5/7000 train_loss:7.2148 train_time:9633ms step_avg:1926.66ms +step:6/7000 train_loss:7.1310 train_time:11577ms step_avg:1929.52ms +step:7/7000 train_loss:6.9741 train_time:13525ms step_avg:1932.07ms +step:8/7000 train_loss:6.7660 train_time:15472ms step_avg:1933.95ms +step:9/7000 train_loss:6.4398 train_time:17421ms step_avg:1935.71ms +step:10/7000 train_loss:6.1453 train_time:19372ms step_avg:1937.16ms +step:200/7000 train_loss:2.7566 train_time:392618ms step_avg:1963.09ms +step:400/7000 train_loss:2.5625 train_time:784888ms step_avg:1962.22ms +step:500/7000 val_loss:2.3400 val_bpb:1.3859 train_time:980822ms step_avg:1961.64ms +step:600/7000 train_loss:2.3837 train_time:1176807ms step_avg:1961.35ms +step:800/7000 train_loss:2.3724 train_time:1568577ms step_avg:1960.72ms +step:1000/7000 train_loss:2.3572 train_time:1960683ms step_avg:1960.68ms +step:1000/7000 val_loss:2.2068 val_bpb:1.3070 train_time:1960686ms step_avg:1960.69ms +step:1200/7000 train_loss:2.3121 train_time:2352803ms step_avg:1960.67ms +step:1400/7000 train_loss:2.3508 train_time:2745049ms step_avg:1960.75ms +step:1500/7000 val_loss:2.1638 val_bpb:1.2815 train_time:2941195ms step_avg:1960.80ms +step:1600/7000 train_loss:2.1936 train_time:3137383ms step_avg:1960.86ms +step:1800/7000 train_loss:2.2421 train_time:3530037ms step_avg:1961.13ms +step:2000/7000 train_loss:2.1327 train_time:3922761ms step_avg:1961.38ms +step:2000/7000 val_loss:2.1089 val_bpb:1.2490 train_time:3922764ms step_avg:1961.38ms +step:2200/7000 train_loss:2.2111 train_time:4315287ms step_avg:1961.49ms +step:2400/7000 train_loss:2.1628 train_time:4708019ms step_avg:1961.67ms +step:2500/7000 val_loss:2.0850 val_bpb:1.2349 train_time:4904412ms step_avg:1961.76ms +step:2600/7000 train_loss:2.1636 train_time:5100872ms step_avg:1961.87ms +step:2800/7000 train_loss:2.2077 train_time:5493564ms step_avg:1961.99ms +step:3000/7000 train_loss:2.1581 train_time:5886202ms step_avg:1962.07ms +step:3000/7000 val_loss:2.0716 val_bpb:1.2269 train_time:5886206ms step_avg:1962.07ms +step:3200/7000 train_loss:2.1613 train_time:6278939ms step_avg:1962.17ms +step:3400/7000 train_loss:2.1352 train_time:6671814ms step_avg:1962.30ms +step:3500/7000 val_loss:2.0634 val_bpb:1.2221 train_time:6868256ms step_avg:1962.36ms +step:3600/7000 train_loss:2.1780 train_time:7064671ms step_avg:1962.41ms +step:3800/7000 train_loss:2.1471 train_time:7457455ms step_avg:1962.49ms +step:4000/7000 train_loss:2.2282 train_time:7850357ms step_avg:1962.59ms +step:4000/7000 val_loss:2.0582 val_bpb:1.2190 train_time:7850360ms step_avg:1962.59ms +step:4200/7000 train_loss:2.1200 train_time:8243179ms step_avg:1962.66ms +step:4400/7000 train_loss:2.1026 train_time:8635741ms step_avg:1962.67ms +step:4500/7000 val_loss:2.0417 val_bpb:1.2092 train_time:8832080ms step_avg:1962.68ms +step:4600/7000 train_loss:2.0850 train_time:9028513ms step_avg:1962.72ms +step:4800/7000 train_loss:2.2164 train_time:9421306ms step_avg:1962.77ms +step:5000/7000 train_loss:2.1292 train_time:9814044ms step_avg:1962.81ms +step:5000/7000 val_loss:2.0190 val_bpb:1.1958 train_time:9814048ms step_avg:1962.81ms +step:5200/7000 train_loss:2.0990 train_time:10206705ms step_avg:1962.83ms +step:5400/7000 train_loss:2.0739 train_time:10599477ms step_avg:1962.87ms +step:5500/7000 val_loss:1.9965 val_bpb:1.1825 train_time:10795770ms step_avg:1962.87ms +step:5600/7000 train_loss:2.0598 train_time:10992139ms step_avg:1962.88ms +step:5800/7000 train_loss:2.0450 train_time:11384873ms step_avg:1962.91ms +step:6000/7000 train_loss:2.0260 train_time:11777706ms step_avg:1962.95ms +step:6000/7000 val_loss:1.9731 val_bpb:1.1686 train_time:11777710ms step_avg:1962.95ms +step:6200/7000 train_loss:2.1195 train_time:12170526ms step_avg:1962.99ms +step:6400/7000 train_loss:2.0991 train_time:12563320ms step_avg:1963.02ms +step:6500/7000 val_loss:1.9430 val_bpb:1.1508 train_time:12759725ms step_avg:1963.03ms +step:6600/7000 train_loss:1.9991 train_time:12956061ms step_avg:1963.04ms +step:6800/7000 train_loss:2.0804 train_time:13348702ms step_avg:1963.04ms +step:7000/7000 train_loss:1.9237 train_time:13741101ms step_avg:1963.01ms +step:7000/7000 val_loss:1.9180 val_bpb:1.1359 train_time:13741104ms step_avg:1963.01ms +peak memory allocated: 25832 MiB reserved: 26006 MiB +ema:applying EMA weights +save_paths: pt=final_model_pr940_nflow_s2025_55398557.pt ptz=final_model_pr940_nflow_s2025_55398557.int6.ptz +Serialized model: 107294152 bytes +Code size: 115032 bytes +quant_try int6 zstd-16: 16235273 bytes (limit 15884968) +quant_try int6 zstd-1: 16306740 bytes (limit 15884968) +quant_try int6 zstd-17: 16260153 bytes (limit 15884968) +quant_try int6 zstd-2: 16312469 bytes (limit 15884968) +quant_fallback: int5 layers=[5] +quant_try int5[1L] zstd-16: 15932289 bytes (limit 15884968) +quant_try int5[1L] zstd-1: 16009919 bytes (limit 15884968) +quant_try int5[1L] zstd-17: 15954917 bytes (limit 15884968) +quant_try int5[1L] zstd-2: 16059927 bytes (limit 15884968) +quant_fallback: int5 layers=[5, 6] +quant_try int5[2L] zstd-16: 15630918 bytes (limit 15884968) +Serialized model quant+zstd-16: 15630918 bytes +Total submission size: 15745950 bytes +final_int6_roundtrip val_loss:1.9323 val_bpb:1.1444 eval_time:67830ms +final_int6_roundtrip_exact val_loss:1.93234887 val_bpb:1.14444585 +final_int6_sliding_window val_loss:1.8924 val_bpb:1.1208 stride:64 eval_time:1699977ms +final_int6_sliding_window_exact val_loss:1.89237638 val_bpb:1.12077485 +TIVE_FLOW_LOSS_WEIGHT", "0.1")) + + # E2E TTT-Linear refiner (Sun et al., 2024) + e2e_ttt_enabled = bool(int(os.environ.get("E2E_TTT_ENABLED", "0"))) + e2e_ttt_num_heads = int(os.environ.get("E2E_TTT_NUM_HEADS", "8")) + e2e_ttt_mini_batch = int(os.environ.get("E2E_TTT_MINI_BATCH", "16")) + e2e_ttt_base_lr = float(os.environ.get("E2E_TTT_BASE_LR", "1.0")) + +# ----------------------------- +# MUON OPTIMIZER +# ----------------------------- +# +# As borrowed from modded-nanogpt +# Background on Muon: https://kellerjordan.github.io/posts/muon/ + +def zeropower_via_newtonschulz5(G: Tensor, steps: int = 10, eps: float = 1e-7) -> Tensor: + a, b, c = (3.4445, -4.7750, 2.0315) + X = G.bfloat16() + X /= X.norm() + eps + transposed = G.size(0) > G.size(1) + if transposed: + X = X.T + for _ in range(steps): + A = X @ X.T + B = b * A + c * A @ A + X = a * X + B @ X + return X.T if transposed else X + + +class Muon(torch.optim.Optimizer): + def __init__(self, params, lr: float, momentum: float, backend_steps: int, + nesterov: bool = True, weight_decay: float = 0.0): + super().__init__( + params, + dict(lr=lr, momentum=momentum, backend_steps=backend_steps, + nesterov=nesterov, weight_decay=weight_decay), + ) + + @torch.no_grad() + def step(self, closure=None): + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + + distributed = dist.is_available() and dist.is_initialized() + world_size = dist.get_world_size() if distributed else 1 + rank = dist.get_rank() if distributed else 0 + + for group in self.param_groups: + params = group["params"] + if not params: + continue + lr = group["lr"] + momentum = group["momentum"] + backend_steps = group["backend_steps"] + nesterov = group["nesterov"] + + total_params = sum(int(p.numel()) for p in params) + updates_flat = torch.zeros(total_params, device=params[0].device, dtype=torch.bfloat16) + + curr = 0 + for i, p in enumerate(params): + if i % world_size == rank and p.grad is not None: + g = p.grad + state = self.state[p] + if "momentum_buffer" not in state: + state["momentum_buffer"] = torch.zeros_like(g) + buf = state["momentum_buffer"] + buf.mul_(momentum).add_(g) + if nesterov: + g = g.add(buf, alpha=momentum) + g = zeropower_via_newtonschulz5(g, steps=backend_steps) + g *= max(1, g.size(0) / g.size(1)) ** 0.5 + updates_flat[curr : curr + p.numel()] = g.reshape(-1) + curr += p.numel() + + if distributed: + dist.all_reduce(updates_flat, op=dist.ReduceOp.SUM) + + wd = group.get("weight_decay", 0.0) + curr = 0 + for p in params: + if wd > 0.0: + p.data.mul_(1.0 - lr * wd) + g = updates_flat[curr : curr + p.numel()].view_as(p).to(dtype=p.dtype) + p.add_(g, alpha=-lr) + curr += p.numel() + + return loss + + +# ----------------------------- +# TOKENIZER-AGNOSTIC EVALUATION SETUP +# ----------------------------- +# +# It's common for small models have a large fraction of their parameters be embeddings, since the 2 * d_model * d_vocab vectors can be gigantic. +# Instead of locking the tokenizer, we let you bring your own and calculate our validation metrics on the average compression of the validation set. +# We calculate BPB (bits-per-byte) instead of validation loss, so we need methods to count the number of bits per token in the tokenizer. +# Note: Submissions that edit the tokenizer will be examined more carefully, since screwing this up might unjustly improve your score. + +def build_sentencepiece_luts( + sp: spm.SentencePieceProcessor, vocab_size: int, device: torch.device +) -> tuple[Tensor, Tensor, Tensor]: + sp_vocab_size = int(sp.vocab_size()) + table_size = max(sp_vocab_size, vocab_size) + base_bytes_np = np.zeros((table_size,), dtype=np.int16) + has_leading_space_np = np.zeros((table_size,), dtype=np.bool_) + is_boundary_token_np = np.ones((table_size,), dtype=np.bool_) + for token_id in range(sp_vocab_size): + if sp.is_control(token_id) or sp.is_unknown(token_id) or sp.is_unused(token_id): + continue + is_boundary_token_np[token_id] = False + if sp.is_byte(token_id): + base_bytes_np[token_id] = 1 + continue + piece = sp.id_to_piece(token_id) + if piece.startswith("▁"): + has_leading_space_np[token_id] = True + piece = piece[1:] + base_bytes_np[token_id] = len(piece.encode("utf-8")) + return ( + torch.tensor(base_bytes_np, dtype=torch.int16, device=device), + torch.tensor(has_leading_space_np, dtype=torch.bool, device=device), + torch.tensor(is_boundary_token_np, dtype=torch.bool, device=device), + ) + + +def load_validation_tokens(pattern: str, seq_len: int) -> Tensor: + files = [Path(p) for p in sorted(glob.glob(pattern))] + if not files: + raise FileNotFoundError(f"No files found for pattern: {pattern}") + # The export pipeline writes the fixed first-50k-doc validation set to fineweb_val_*. + tokens = torch.cat([load_data_shard(file) for file in files]).contiguous() + usable = ((tokens.numel() - 1) // seq_len) * seq_len + if usable <= 0: + raise ValueError(f"Validation split is too short for TRAIN_SEQ_LEN={seq_len}") + return tokens[: usable + 1] + + +def eval_val( + args: Hyperparameters, + model: nn.Module, + rank: int, + world_size: int, + device: torch.device, + grad_accum_steps: int, + val_tokens: Tensor, + base_bytes_lut: Tensor, + has_leading_space_lut: Tensor, + is_boundary_token_lut: Tensor, + eval_seq_len: int | None = None, +) -> tuple[float, float]: + seq_len = eval_seq_len or args.train_seq_len + local_batch_tokens = args.val_batch_size // (world_size * grad_accum_steps) + if local_batch_tokens < seq_len: + raise ValueError( + "VAL_BATCH_SIZE must provide at least one sequence per rank; " + f"got VAL_BATCH_SIZE={args.val_batch_size}, WORLD_SIZE={world_size}, " + f"GRAD_ACCUM_STEPS={grad_accum_steps}, seq_len={seq_len}" + ) + local_batch_seqs = local_batch_tokens // seq_len + total_seqs = (val_tokens.numel() - 1) // seq_len + seq_start = (total_seqs * rank) // world_size + seq_end = (total_seqs * (rank + 1)) // world_size + val_loss_sum = torch.zeros((), device=device, dtype=torch.float64) + val_token_count = torch.zeros((), device=device, dtype=torch.float64) + val_byte_count = torch.zeros((), device=device, dtype=torch.float64) + + model.eval() + with torch.inference_mode(): + for batch_seq_start in range(seq_start, seq_end, local_batch_seqs): + batch_seq_end = min(batch_seq_start + local_batch_seqs, seq_end) + raw_start = batch_seq_start * seq_len + raw_end = batch_seq_end * seq_len + 1 + local = val_tokens[raw_start:raw_end].to(device=device, dtype=torch.int64, non_blocking=True) + x = local[:-1].reshape(-1, seq_len) + y = local[1:].reshape(-1, seq_len) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + batch_loss = model(x, y).detach() + batch_token_count = float(y.numel()) + val_loss_sum += batch_loss.to(torch.float64) * batch_token_count + val_token_count += batch_token_count + prev_ids = x.reshape(-1) + tgt_ids = y.reshape(-1) + token_bytes = base_bytes_lut[tgt_ids].to(dtype=torch.int16) + token_bytes += (has_leading_space_lut[tgt_ids] & ~is_boundary_token_lut[prev_ids]).to(dtype=torch.int16) + val_byte_count += token_bytes.to(torch.float64).sum() + + if dist.is_available() and dist.is_initialized(): + dist.all_reduce(val_loss_sum, op=dist.ReduceOp.SUM) + dist.all_reduce(val_token_count, op=dist.ReduceOp.SUM) + dist.all_reduce(val_byte_count, op=dist.ReduceOp.SUM) + + val_loss = val_loss_sum / val_token_count + bits_per_token = val_loss.item() / math.log(2.0) + tokens_per_byte = val_token_count.item() / val_byte_count.item() + model.train() + return float(val_loss.item()), float(bits_per_token * tokens_per_byte) + +# ----------------------------- +# POST-TRAINING QUANTIZATION +# ----------------------------- +# +# It's silly to export our model, which is trained in bf16 and fp32, at that same precision. +# Instead, we get approximately the same model (with a small hit) by quantizing the model to int8 & zlib compressing. +# We can then decompress the model and run in higher precision for evaluation, after closing in under the size limit. + +CONTROL_TENSOR_NAME_PATTERNS = tuple( + pattern + for pattern in os.environ.get( + "CONTROL_TENSOR_NAME_PATTERNS", + "attn_scale,attn_scales,mlp_scale,mlp_scales,resid_mix,resid_mixes,q_gain,skip_weight,skip_weights,smear,vr_lambda,attn_gate,canon_a,canon_c,delta_gate", + ).split(",") + if pattern +) +INT8_KEEP_FLOAT_FP32_NAME_PATTERNS = tuple( + pattern + for pattern in os.environ.get( + "INT8_KEEP_FLOAT_FP32_NAME_PATTERNS", + ",".join(CONTROL_TENSOR_NAME_PATTERNS), + ).split(",") + if pattern +) +INT8_KEEP_FLOAT_MAX_NUMEL = 65_536 +INT8_KEEP_FLOAT_STORE_DTYPE = torch.float16 +INT8_PER_ROW_SCALE_DTYPE = torch.float16 +INT8_CLIP_PERCENTILE = 99.99984 +INT8_CLIP_Q = INT8_CLIP_PERCENTILE / 100.0 + +def tensor_nbytes(t: Tensor) -> int: + return int(t.numel()) * int(t.element_size()) + +def keep_float_tensor(name: str, t: Tensor, passthrough_orig_dtypes: dict[str, str]) -> Tensor: + if any(pattern in name for pattern in INT8_KEEP_FLOAT_FP32_NAME_PATTERNS): + return t.float().contiguous() + if t.dtype in {torch.float32, torch.bfloat16}: + passthrough_orig_dtypes[name] = str(t.dtype).removeprefix("torch.") + return t.to(dtype=INT8_KEEP_FLOAT_STORE_DTYPE).contiguous() + return t + +def quantize_float_tensor(t: Tensor, qmax: int = 127) -> tuple[Tensor, Tensor]: + """Quantize to [-qmax, qmax] range. Default int8 (qmax=127), int6 (qmax=31), int5 (qmax=15).""" + t32 = t.float() + qmin = -qmax + if t32.ndim == 2: + clip_abs = ( + torch.quantile(t32.abs(), INT8_CLIP_Q, dim=1) + if t32.numel() + else torch.empty((t32.shape[0],), dtype=torch.float32) + ) + clipped = torch.maximum(torch.minimum(t32, clip_abs[:, None]), -clip_abs[:, None]) + scale = (clip_abs / float(qmax)).clamp_min(1.0 / float(qmax)) + q = torch.clamp(torch.round(clipped / scale[:, None]), qmin, qmax).to(torch.int8).contiguous() + return q, scale.to(dtype=INT8_PER_ROW_SCALE_DTYPE).contiguous() + + clip_abs = float(torch.quantile(t32.abs().flatten(), INT8_CLIP_Q).item()) if t32.numel() else 0.0 + scale = torch.tensor(clip_abs / float(qmax) if clip_abs > 0 else 1.0, dtype=torch.float32) + q = torch.clamp(torch.round(torch.clamp(t32, -clip_abs, clip_abs) / scale), qmin, qmax).to(torch.int8).contiguous() + return q, scale + +def quantize_state_dict_int8(state_dict: dict[str, Tensor]): + # Single supported clean-script export format: + # - per-row int8 for 2D float tensors + # - per-tensor int8 for other float tensors + # - exact passthrough for non-floats + # - passthrough for small float tensors, stored as fp16 to save bytes + quantized: dict[str, Tensor] = {} + scales: dict[str, Tensor] = {} + dtypes: dict[str, str] = {} + passthrough: dict[str, Tensor] = {} + passthrough_orig_dtypes: dict[str, str] = {} + qmeta: dict[str, dict[str, object]] = {} + stats = dict.fromkeys( + ("param_count", "num_tensors", "num_float_tensors", "num_nonfloat_tensors", "baseline_tensor_bytes", "int8_payload_bytes"), + 0, + ) + + for name, tensor in state_dict.items(): + t = tensor.detach().to("cpu").contiguous() + stats["param_count"] += int(t.numel()) + stats["num_tensors"] += 1 + stats["baseline_tensor_bytes"] += tensor_nbytes(t) + + if not t.is_floating_point(): + stats["num_nonfloat_tensors"] += 1 + passthrough[name] = t + stats["int8_payload_bytes"] += tensor_nbytes(t) + continue + + # Small float tensors are cheap enough to keep directly. We still downcast + # fp32/bf16 passthrough tensors to fp16 so metadata does not dominate size. + if t.numel() <= INT8_KEEP_FLOAT_MAX_NUMEL: + kept = keep_float_tensor(name, t, passthrough_orig_dtypes) + passthrough[name] = kept + stats["int8_payload_bytes"] += tensor_nbytes(kept) + continue + + stats["num_float_tensors"] += 1 + # Mixed quantization: int6 for MLP layers 3-7 to save artifact space + int6_mlp_layers = os.environ.get("INT6_MLP_LAYERS", "") + qmax = 127 # default int8 + if int6_mlp_layers: + for li in int6_mlp_layers.split(","): + if li.strip() and f"blocks.{li.strip()}.mlp" in name and t.ndim == 2: + qmax = 31 # int6 + break + q, s = quantize_float_tensor(t, qmax=qmax) + if s.ndim > 0: + qmeta[name] = {"scheme": "per_row", "axis": 0} + quantized[name] = q + scales[name] = s + dtypes[name] = str(t.dtype).removeprefix("torch.") + stats["int8_payload_bytes"] += tensor_nbytes(q) + tensor_nbytes(s) + + obj: dict[str, object] = { + "__quant_format__": "int8_clean_per_row_v1", + "quantized": quantized, + "scales": scales, + "dtypes": dtypes, + "passthrough": passthrough, + } + if qmeta: + obj["qmeta"] = qmeta + if passthrough_orig_dtypes: + obj["passthrough_orig_dtypes"] = passthrough_orig_dtypes + return obj, stats + +def dequantize_state_dict_int8(obj: dict[str, object]) -> dict[str, Tensor]: + out: dict[str, Tensor] = {} + qmeta = obj.get("qmeta", {}) + passthrough_orig_dtypes = obj.get("passthrough_orig_dtypes", {}) + for name, q in obj["quantized"].items(): + dtype = getattr(torch, obj["dtypes"][name]) + s = obj["scales"][name] + if qmeta.get(name, {}).get("scheme") == "per_row" or s.ndim > 0: + s = s.to(dtype=torch.float32) + # Broadcast the saved row scale back across trailing dimensions. + out[name] = (q.float() * s.view(q.shape[0], *([1] * (q.ndim - 1)))).to(dtype=dtype).contiguous() + else: + scale = float(s.item()) + out[name] = (q.float() * scale).to(dtype=dtype).contiguous() + for name, t in obj["passthrough"].items(): + # Restore small tensors, undoing the temporary fp16 storage cast if needed. + out_t = t.detach().to("cpu").contiguous() + orig_dtype = passthrough_orig_dtypes.get(name) + if isinstance(orig_dtype, str): + out_t = out_t.to(dtype=getattr(torch, orig_dtype)).contiguous() + out[name] = out_t + return out + + +# ----------------------------- +# DATA LOADING +# ----------------------------- + +def load_data_shard(file: Path) -> Tensor: + header_bytes = 256 * np.dtype(" None: + self.file_idx = (self.file_idx + 1) % len(self.files) + self.tokens = load_data_shard(self.files[self.file_idx]) + self.pos = 0 + + def take(self, n: int) -> Tensor: + chunks: list[Tensor] = [] + remaining = n + while remaining > 0: + avail = self.tokens.numel() - self.pos + if avail <= 0: + self._advance_file() + continue + k = min(remaining, avail) + chunks.append(self.tokens[self.pos : self.pos + k]) + self.pos += k + remaining -= k + return chunks[0] if len(chunks) == 1 else torch.cat(chunks) + + +class DistributedTokenLoader: + # Each call consumes a contiguous chunk from the shared token stream, then slices out + # one disjoint span per rank. The extra "+1" token lets us build (x, y) by shifting. + def __init__(self, pattern: str, rank: int, world_size: int, device: torch.device): + self.rank = rank + self.world_size = world_size + self.device = device + self.stream = TokenStream(pattern) + + def next_batch(self, global_tokens: int, seq_len: int, grad_accum_steps: int) -> tuple[Tensor, Tensor]: + local_tokens = global_tokens // (self.world_size * grad_accum_steps) + per_rank_span = local_tokens + 1 + chunk = self.stream.take(per_rank_span * self.world_size) + start = self.rank * per_rank_span + local = chunk[start : start + per_rank_span].to(dtype=torch.int64) + x = local[:-1].reshape(-1, seq_len) + y = local[1:].reshape(-1, seq_len) + return x.to(self.device, non_blocking=True), y.to(self.device, non_blocking=True) + +# ----------------------------- +# TRANSFORMER MODULES +# ----------------------------- + +class RMSNorm(nn.Module): + def __init__(self, eps: float | None = None): + super().__init__() + self.eps = eps + + def forward(self, x: Tensor) -> Tensor: + return F.rms_norm(x, (x.size(-1),), eps=self.eps) + + +class CastedLinear(nn.Linear): + _qat_enabled: bool = False + _soft_round: bool = False + _soft_round_alpha: float = 1.0 + _quant_percentile: float = float(os.environ.get("QUANT_PERCENTILE", "1.0")) + + def forward(self, x: Tensor) -> Tensor: + w = self.weight.to(x.dtype) + if CastedLinear._qat_enabled and self.training and w.ndim == 2: + w32 = self.weight.float() + pct = CastedLinear._quant_percentile + row_max = (torch.quantile(w32.abs(), pct, dim=1) if pct < 1.0 + else w32.abs().amax(dim=1)).detach() + scale = (row_max / 31.0).clamp_min(1.0 / 31.0) + r = w32 / scale[:, None] + if CastedLinear._soft_round: + alpha = CastedLinear._soft_round_alpha + r_frac = r - r.detach().floor() - 0.5 + norm = torch.tanh(torch.tensor(alpha * 0.5, device=r.device, dtype=r.dtype)) + r_soft = r.detach().floor() + 0.5 + torch.tanh(alpha * r_frac) / (2.0 * norm) + w_q = (torch.clamp(r_soft, -32, 31) * scale[:, None]).to(x.dtype) + w = w_q # soft-round is differentiable, no STE needed + else: + with torch.no_grad(): + w_q = (torch.clamp(torch.round(r), -32, 31) * scale[:, None]).to(x.dtype) + w = w + (w_q - w).detach() # STE + bias = self.bias.to(x.dtype) if self.bias is not None else None + return F.linear(x, w, bias) + + +def restore_low_dim_params_to_fp32(module: nn.Module) -> None: + # Keep small/control parameters in fp32 even when the model body runs in bf16. + with torch.no_grad(): + for name, param in module.named_parameters(): + if (param.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS)) and param.dtype != torch.float32: + param.data = param.data.float() + + +class Rotary(nn.Module): + # NTK-aware RoPE: auto-scales base frequency when seq_len exceeds train_seq_len. + def __init__(self, dim: int, base: float = 10000.0, train_seq_len: int = 1024, rope_dims: int = 0): + super().__init__() + self.rope_dims = rope_dims if rope_dims > 0 else dim + self.dim = dim + self.base = base + self.train_seq_len = train_seq_len + rd = self.rope_dims + inv_freq = 1.0 / (base ** (torch.arange(0, rd, 2, dtype=torch.float32) / rd)) + self.register_buffer("inv_freq", inv_freq, persistent=False) + self._seq_len_cached = 0 + self._cos_cached: Tensor | None = None + self._sin_cached: Tensor | None = None + + def forward(self, seq_len: int, device: torch.device, dtype: torch.dtype) -> tuple[Tensor, Tensor]: + if ( + self._cos_cached is None + or self._sin_cached is None + or self._seq_len_cached != seq_len + or self._cos_cached.device != device + ): + rd = self.rope_dims + if seq_len > self.train_seq_len: + scale = seq_len / self.train_seq_len + new_base = self.base * (scale ** (rd / (rd - 2))) + inv_freq = 1.0 / (new_base ** (torch.arange(0, rd, 2, dtype=torch.float32, device=device) / rd)) + else: + inv_freq = self.inv_freq.to(device) + t = torch.arange(seq_len, device=device, dtype=inv_freq.dtype) + freqs = torch.outer(t, inv_freq) + self._cos_cached = freqs.cos()[None, :, None, :] + self._sin_cached = freqs.sin()[None, :, None, :] + self._seq_len_cached = seq_len + return self._cos_cached.to(dtype=dtype), self._sin_cached.to(dtype=dtype) + + +def apply_rotary_emb(x: Tensor, cos: Tensor, sin: Tensor) -> Tensor: + rd = cos.size(-1) * 2 + if rd < x.size(-1): + x_rope, x_pass = x[..., :rd], x[..., rd:] + half = rd // 2 + x1, x2 = x_rope[..., :half], x_rope[..., half:] + x_rot = torch.cat((x1 * cos + x2 * sin, x1 * (-sin) + x2 * cos), dim=-1) + return torch.cat((x_rot, x_pass), dim=-1) + half = x.size(-1) // 2 + x1, x2 = x[..., :half], x[..., half:] + return torch.cat((x1 * cos + x2 * sin, x1 * (-sin) + x2 * cos), dim=-1) + + +class CausalSelfAttention(nn.Module): + def __init__( + self, + dim: int, + num_heads: int, + num_kv_heads: int, + rope_base: float, + qk_gain_init: float, + rope_dims: int = 0, + value_residual: bool = False, + gated_attention: bool = False, + ): + super().__init__() + if dim % num_heads != 0: + raise ValueError("model_dim must be divisible by num_heads") + if num_heads % num_kv_heads != 0: + raise ValueError("num_heads must be divisible by num_kv_heads") + self.num_heads = num_heads + self.num_kv_heads = num_kv_heads + self.head_dim = dim // num_heads + if self.head_dim % 2 != 0: + raise ValueError("head_dim must be even for RoPE") + kv_dim = self.num_kv_heads * self.head_dim + self.c_q = CastedLinear(dim, dim, bias=False) + self.c_k = CastedLinear(dim, kv_dim, bias=False) + self.c_v = CastedLinear(dim, kv_dim, bias=False) + self.proj = CastedLinear(dim, dim, bias=False) + self.proj._zero_init = True + self.q_gain = nn.Parameter(torch.full((num_heads,), qk_gain_init, dtype=torch.float32)) + self.rope_dims = rope_dims + self.rotary = Rotary(self.head_dim, base=rope_base, train_seq_len=1024, rope_dims=rope_dims) + self.use_xsa = False + self.value_residual = value_residual + if value_residual: + self.vr_lambda = nn.Parameter(torch.tensor([0.5, 0.5], dtype=torch.float32)) + self.gated_attention = gated_attention + if gated_attention: + self.attn_gate = nn.Linear(dim, num_heads, bias=True) + nn.init.zeros_(self.attn_gate.weight) + nn.init.constant_(self.attn_gate.bias, 4.0) + + def _xsa_efficient(self, y: Tensor, v: Tensor) -> Tensor: + """Subtract self-value projection via GQA-aware reshape (no repeat_interleave).""" + B, T, H, D = y.shape + Hkv = v.size(-2) + group = H // Hkv + y_g = y.reshape(B, T, Hkv, group, D) + vn = F.normalize(v, dim=-1).unsqueeze(-2) + proj = (y_g * vn).sum(dim=-1, keepdim=True) * vn + return (y_g - proj).reshape(B, T, H, D) + + def forward(self, x: Tensor, v0: Tensor | None = None) -> tuple[Tensor, Tensor | None]: + bsz, seqlen, dim = x.shape + q = self.c_q(x).reshape(bsz, seqlen, self.num_heads, self.head_dim) + k = self.c_k(x).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim) + v = self.c_v(x).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim) + raw_v = v if self.value_residual else None + if self.value_residual and v0 is not None: + lam = self.vr_lambda.to(dtype=v.dtype) + v = lam[0] * v0 + lam[1] * v + q = F.rms_norm(q, (q.size(-1),)) + k = F.rms_norm(k, (k.size(-1),)) + cos, sin = self.rotary(seqlen, x.device, q.dtype) + q = apply_rotary_emb(q, cos, sin) + k = apply_rotary_emb(k, cos, sin) + q = q * self.q_gain.to(dtype=q.dtype)[None, None, :, None] + fa_dtype = torch.bfloat16 + if _USE_FA3: + y = flash_attn_3_func(q.to(fa_dtype), k.to(fa_dtype), v.to(fa_dtype), causal=True) + else: + # SDPA fallback: (B, T, H, D) -> (B, H, T, D), expand KV for GQA + q_t = q.to(fa_dtype).transpose(1, 2) + k_t = k.to(fa_dtype).transpose(1, 2) + v_t = v.to(fa_dtype).transpose(1, 2) + if self.num_kv_heads != self.num_heads: + rep = self.num_heads // self.num_kv_heads + k_t = k_t.repeat_interleave(rep, dim=1) + v_t = v_t.repeat_interleave(rep, dim=1) + y = F.scaled_dot_product_attention(q_t, k_t, v_t, is_causal=True) + y = y.transpose(1, 2) # (B, H, T, D) -> (B, T, H, D) + if self.use_xsa: + y = self._xsa_efficient(y, v) + if self.gated_attention: + gate = torch.sigmoid(self.attn_gate(x)) # (B, T, num_heads) + y = y * gate.unsqueeze(-1) # (B, T, H, 1) broadcast to (B, T, H, D) + y = y.reshape(bsz, seqlen, dim) + return self.proj(y), raw_v + + +class SmearGate(nn.Module): + def __init__(self, dim: int): + super().__init__() + self.gate = nn.Parameter(torch.zeros(dim, dtype=torch.float32)) + + def forward(self, x: Tensor) -> Tensor: + g = torch.sigmoid(self.gate.to(dtype=x.dtype))[None, None, :] + x_prev = torch.cat([torch.zeros_like(x[:, :1]), x[:, :-1]], dim=1) + return (1 - g) * x + g * x_prev + + +class BigramHashEmbedding(nn.Module): + def __init__(self, bigram_vocab_size: int, bigram_dim: int, model_dim: int): + super().__init__() + self.bigram_vocab_size = bigram_vocab_size + self.embed = nn.Embedding(bigram_vocab_size, bigram_dim) + nn.init.zeros_(self.embed.weight) + self.proj = CastedLinear(bigram_dim, model_dim, bias=False) if bigram_dim != model_dim else None + if self.proj is not None: + nn.init.zeros_(self.proj.weight) + self.scale = nn.Parameter(torch.tensor(0.05, dtype=torch.float32)) + + def bigram_hash(self, tokens: Tensor) -> Tensor: + t = tokens.to(torch.int32) + mod = self.bigram_vocab_size - 1 + out = torch.empty_like(t) + out[..., 0] = mod + out[..., 1:] = torch.bitwise_xor(36313 * t[..., 1:], 27191 * t[..., :-1]) % mod + return out.long() + + def forward(self, token_ids: Tensor) -> Tensor: + h = self.embed(self.bigram_hash(token_ids)) + if self.proj is not None: + h = self.proj(h) + return h * self.scale.to(dtype=h.dtype) + + +class MLP(nn.Module): + def __init__(self, dim: int, mlp_mult: int): + super().__init__() + hidden = int(mlp_mult * dim) + self.fc = CastedLinear(dim, hidden, bias=False) + self.proj = CastedLinear(hidden, dim, bias=False) + self.proj._zero_init = True + self.use_leaky = bool(int(os.environ.get("LEAKY_RELU", "1"))) + self.leaky_slope = float(os.environ.get("LEAKY_SLOPE", "0.5")) + + def forward(self, x: Tensor) -> Tensor: + x = F.leaky_relu(self.fc(x), self.leaky_slope) if self.use_leaky else torch.relu(self.fc(x)) + return self.proj(x.square()) + + +class CanonAC(nn.Module): + """Canon Autoregressive Convolution with DeltaGate. Manual shift+mul (no Conv1d).""" + def __init__(self, dim: int, kernel: int = 4, delta_gate_init: float = -4.0): + super().__init__() + self.kernel = kernel + self.weight = nn.Parameter(torch.zeros(kernel, dim)) + self.delta_gate_logit = nn.Parameter(torch.tensor(delta_gate_init)) + + def forward(self, x: Tensor) -> Tensor: + B, T, D = x.shape + K = self.kernel + w = self.weight.to(x.dtype) + x_pad = F.pad(x, (0, 0, K - 1, 0)) + y = w[0] * x_pad[:, K - 1:, :] + for k in range(1, K): + y = y + w[k] * x_pad[:, K - 1 - k : T + K - 1 - k, :] + gate = torch.sigmoid(self.delta_gate_logit.to(x.dtype)) + return x + gate * y + + +class Block(nn.Module): + def __init__( + self, + dim: int, + num_heads: int, + num_kv_heads: int, + mlp_mult: int, + rope_base: float, + qk_gain_init: float, + rope_dims: int = 0, + layer_idx: int = 0, + ln_scale: bool = False, + value_residual: bool = False, + gated_attention: bool = False, + canon_kernel: int = 0, + canon_delta_gate_init: float = -4.0, + ): + super().__init__() + self.attn_norm = RMSNorm() + self.mlp_norm = RMSNorm() + self.attn = CausalSelfAttention(dim, num_heads, num_kv_heads, rope_base, qk_gain_init, + rope_dims=rope_dims, value_residual=value_residual, + gated_attention=gated_attention) + self.mlp = MLP(dim, mlp_mult) + self.attn_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + self.mlp_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + self.resid_mix = nn.Parameter(torch.stack((torch.ones(dim), torch.zeros(dim))).float()) + self.ln_scale_factor = 1.0 / math.sqrt(layer_idx + 1) if ln_scale else 1.0 + self.canon_a = CanonAC(dim, canon_kernel, canon_delta_gate_init) if canon_kernel > 0 else None + self.canon_c = CanonAC(dim, canon_kernel, canon_delta_gate_init) if canon_kernel > 0 else None + + def forward(self, x: Tensor, x0: Tensor, v0: Tensor | None = None) -> tuple[Tensor, Tensor | None]: + mix = self.resid_mix.to(dtype=x.dtype) + x = mix[0][None, None, :] * x + mix[1][None, None, :] * x0 + s = self.ln_scale_factor + attn_in = self.attn_norm(x) * s + if self.canon_a is not None: + attn_in = self.canon_a(attn_in) + attn_out, raw_v = self.attn(attn_in, v0=v0) + x = x + self.attn_scale.to(dtype=x.dtype)[None, None, :] * attn_out + mlp_in = self.mlp_norm(x) * s + if self.canon_c is not None: + mlp_in = self.canon_c(mlp_in) + x = x + self.mlp_scale.to(dtype=x.dtype)[None, None, :] * self.mlp(mlp_in) + return x, raw_v + + +class FlowRefiner(nn.Module): + """1-step flow matching refiner in low-dim latent space. + + Projects hidden states to a low-dim latent, applies a learned velocity + field (1-step Euler), projects back. Output is additive adjustment to + hidden states. Initialized near-zero so the refiner starts as identity. + """ + def __init__(self, model_dim: int, latent_dim: int = 64, hidden_dim: int = 256, init_scale: float = 0.01): + super().__init__() + self.down_proj = nn.Linear(model_dim, latent_dim, bias=False) + self.velocity_net = nn.Sequential( + nn.Linear(latent_dim, hidden_dim, bias=True), + nn.GELU(), + nn.Linear(hidden_dim, latent_dim, bias=True), + ) + self.up_proj = nn.Linear(latent_dim, model_dim, bias=False) + self.gate = nn.Parameter(torch.tensor(-5.0)) # sigmoid(-5) ≈ 0.007 + self._init_scale = init_scale + self._init_weights() + self.velocity_net[2]._zero_init = True + + def _init_weights(self): + nn.init.normal_(self.down_proj.weight, std=self._init_scale) + nn.init.normal_(self.velocity_net[0].weight, std=self._init_scale) + nn.init.zeros_(self.velocity_net[0].bias) + nn.init.zeros_(self.velocity_net[2].weight) + nn.init.zeros_(self.velocity_net[2].bias) + nn.init.normal_(self.up_proj.weight, std=self._init_scale) + + def forward(self, x: Tensor) -> Tensor: + z = self.down_proj(x) + v = self.velocity_net(z) + z_refined = z + v + delta = self.up_proj(z_refined) + return torch.sigmoid(self.gate) * delta + + +class NativeFlowMatcher(nn.Module): + """Conditional Flow Matching refiner for hidden states. + + During training: computes auxiliary CFM loss on interpolated states. + During inference: applies a single Euler step correction at t=1. + """ + def __init__(self, model_dim: int, hidden_dim: int = 256, init_scale: float = 0.01): + super().__init__() + self.model_dim = model_dim + # Sinusoidal time embedding → projected to hidden_dim + self.time_proj = nn.Sequential( + nn.Linear(model_dim, hidden_dim, bias=True), + nn.GELU(), + ) + # Velocity network: x → hidden → x (with time conditioning via addition) + self.v_in = nn.Linear(model_dim, hidden_dim, bias=True) + self.v_act = nn.GELU() + self.v_out = nn.Linear(hidden_dim, model_dim, bias=False) + self.gate = nn.Parameter(torch.tensor(-5.0)) # sigmoid(-5) ≈ 0.007 + self._init_scale = init_scale + self._init_weights() + self.v_out._zero_init = True + + def _init_weights(self): + nn.init.normal_(self.v_in.weight, std=self._init_scale) + nn.init.zeros_(self.v_in.bias) + nn.init.normal_(self.time_proj[0].weight, std=self._init_scale) + nn.init.zeros_(self.time_proj[0].bias) + nn.init.zeros_(self.v_out.weight) + + @staticmethod + def _sinusoidal_time_emb(t: Tensor, dim: int) -> Tensor: + """Sinusoidal positional embedding for scalar time t. t: (...,) -> (..., dim).""" + half_dim = dim // 2 + emb = math.log(10000.0) / max(half_dim - 1, 1) + emb = torch.exp(torch.arange(half_dim, device=t.device, dtype=t.dtype) * -emb) + emb = t.unsqueeze(-1) * emb # (..., half_dim) + emb = torch.cat([emb.sin(), emb.cos()], dim=-1) # (..., dim) + if dim % 2 == 1: + emb = F.pad(emb, (0, 1)) + return emb + + def _velocity(self, x: Tensor, t_emb: Tensor) -> Tensor: + """Compute velocity v(x, t). x: (..., model_dim), t_emb: (..., hidden_dim).""" + h = self.v_in(x) # (..., hidden_dim) + h = h + t_emb # time conditioning via addition + h = self.v_act(h) # GELU + return self.v_out(h) # (..., model_dim) + + def forward(self, x: Tensor) -> tuple[Tensor, Tensor]: + """Returns (correction, cfm_loss). + + During training: cfm_loss is the flow matching objective (> 0). + During eval: cfm_loss is zero. + correction: gated velocity at t=1, to be added to x. + """ + B_seq = x.shape[:-1] # works for any leading dimensions + D = self.model_dim + + # Inference path: velocity at t=1 (clean input, no noise) + t_ones = torch.ones(*B_seq, device=x.device, dtype=x.dtype) + t_emb_ones = self.time_proj(self._sinusoidal_time_emb(t_ones, D)) + v_at_1 = self._velocity(x, t_emb_ones) + correction = torch.sigmoid(self.gate) * v_at_1 + + # Training path: auxiliary CFM loss + if self.training: + # Sample random t ~ U(0,1) per position + t = torch.rand(*B_seq, device=x.device, dtype=x.dtype) + z = torch.randn_like(x) # noise + # OT interpolant: x_t = (1-t)*z + t*x + t_expanded = t.unsqueeze(-1) # (..., 1) + x_t = (1.0 - t_expanded) * z + t_expanded * x.detach() + # Target velocity = x - z (OT conditional velocity field) + target_v = x.detach() - z + # Predict velocity at interpolated point + t_emb = self.time_proj(self._sinusoidal_time_emb(t, D)) + pred_v = self._velocity(x_t, t_emb) + cfm_loss = F.mse_loss(pred_v, target_v) + else: + cfm_loss = x.new_zeros(()) + + return correction, cfm_loss + + +class TTTLinearRefiner(nn.Module): + """E2E TTT-Linear refiner (Sun et al., 2024, arXiv:2407.04620). + + Applied to final hidden states [B, L, D] before lm_head. The hidden + state is a per-head linear model W updated via mini-batch gradient + descent on a learned self-supervised reconstruction task. Uses the + dual form for GPU efficiency. Trained end-to-end: the outer loop + optimizes theta_K, theta_V, theta_Q projections and the inner-loop + initialization W_0, making the model learn WHAT to compress. + """ + + def __init__(self, model_dim: int, num_heads: int = 8, + mini_batch_size: int = 16, ttt_base_lr: float = 1.0): + super().__init__() + self.model_dim = model_dim + self.num_heads = num_heads + self.head_dim = model_dim // num_heads + self.mini_batch_size = mini_batch_size + self.ttt_base_lr = ttt_base_lr + self.eta = ttt_base_lr / self.head_dim + + # Learnable projections (outer-loop) + self.theta_K = nn.Linear(model_dim, model_dim, bias=False) + self.theta_V = nn.Linear(model_dim, model_dim, bias=False) + self.theta_Q = nn.Linear(model_dim, model_dim, bias=False) + + # Inner-loop model: W_0, b_0 + self.W1 = nn.Parameter(torch.normal(0, 0.02, + size=(num_heads, self.head_dim, self.head_dim))) + self.b1 = nn.Parameter(torch.zeros(num_heads, 1, self.head_dim)) + + # Inner-loop LayerNorm (outer-loop trainable) + self.ttt_ln_w = nn.Parameter(torch.ones(num_heads, self.head_dim)) + self.ttt_ln_b = nn.Parameter(torch.zeros(num_heads, self.head_dim)) + + # Output projection + gate (starts near zero for residual safety) + self.o_proj = nn.Linear(model_dim, model_dim, bias=False) + self.post_norm = nn.LayerNorm(model_dim, eps=1e-6) + self.gate = nn.Parameter(torch.tensor(-5.0)) + nn.init.zeros_(self.o_proj.weight) + self.o_proj._zero_init = True + + @staticmethod + def _ln_fwd(x: Tensor, gamma: Tensor, beta: Tensor, eps: float = 1e-6) -> Tensor: + mu = x.mean(dim=-1, keepdim=True) + var = x.var(dim=-1, keepdim=True, unbiased=False) + x_hat = (x - mu) / torch.sqrt(var + eps) + return gamma * x_hat + beta + + @staticmethod + def _ln_fused_l2_bwd(x: Tensor, target: Tensor, gamma: Tensor, beta: Tensor, + eps: float = 1e-6) -> Tensor: + D = x.shape[-1] + mu = x.mean(dim=-1, keepdim=True) + var = x.var(dim=-1, keepdim=True, unbiased=False) + std = torch.sqrt(var + eps) + x_hat = (x - mu) / std + y = gamma * x_hat + beta + grad_output = y - target + grad_x_hat = grad_output * gamma + z = (1.0 / D) * ( + D * grad_x_hat + - grad_x_hat.sum(dim=-1, keepdim=True) + - x_hat * (grad_x_hat * x_hat).sum(dim=-1, keepdim=True) + ) / std + return z + + def forward(self, x: Tensor) -> Tensor: + """x: [B, L, D]. Returns [B, L, D] additive refinement (gated).""" + B, L, D = x.shape + NH, HD, b = self.num_heads, self.head_dim, self.mini_batch_size + eta = self.eta + + num_mb = L // b + usable_L = num_mb * b + x_proc = x[:, :usable_L] if usable_L < L else x + + # Project: [B, L, D] -> [B, NH, L, HD] + XK = self.theta_K(x_proc).reshape(B, usable_L, NH, HD).permute(0, 2, 1, 3) + XV = self.theta_V(x_proc).reshape(B, usable_L, NH, HD).permute(0, 2, 1, 3) + XQ = self.theta_Q(x_proc).reshape(B, usable_L, NH, HD).permute(0, 2, 1, 3) + + # Reshape into mini-batches: [B, NH, num_mb, b, HD] + XK = XK.reshape(B, NH, num_mb, b, HD) + XV = XV.reshape(B, NH, num_mb, b, HD) + XQ = XQ.reshape(B, NH, num_mb, b, HD) + + ln_w = self.ttt_ln_w.reshape(1, NH, 1, HD) + ln_b = self.ttt_ln_b.reshape(1, NH, 1, HD) + + # Causal mask for dual form + causal = torch.tril(torch.ones(b, b, device=x.device, dtype=x.dtype)) + + # Initialize inner-loop model + W = self.W1.unsqueeze(0).expand(B, -1, -1, -1).clone() + bias = self.b1.unsqueeze(0).expand(B, -1, -1, -1).clone() + + all_out = [] + for m in range(num_mb): + xk = XK[:, :, m] + xv = XV[:, :, m] + xq = XQ[:, :, m] + + Z1 = xk @ W + bias + target = xv - xk + + grad = self._ln_fused_l2_bwd(Z1, target, ln_w, ln_b) + + Attn = causal.unsqueeze(0).unsqueeze(0) * (xq @ xk.transpose(-2, -1)) + cumgrad = (causal.unsqueeze(0).unsqueeze(0) @ grad) + b_bar = bias - eta * cumgrad + Z_bar = xq @ W - eta * Attn @ grad + b_bar + + W = W - eta * xk.transpose(-2, -1) @ grad + bias = bias - eta * grad.sum(dim=-2, keepdim=True) + + Z_bar = self._ln_fwd(Z_bar, ln_w, ln_b) + out_mb = xq + Z_bar + all_out.append(out_mb) + + z = torch.cat(all_out, dim=2).permute(0, 2, 1, 3).reshape(B, usable_L, D) + z = self.post_norm(z) + z = self.o_proj(z) + result = torch.sigmoid(self.gate) * z + + if usable_L < L: + pad = torch.zeros(B, L - usable_L, D, device=x.device, dtype=x.dtype) + result = torch.cat([result, pad], dim=1) + + return result + + +class GPT(nn.Module): + def __init__( + self, + vocab_size: int, + num_layers: int, + model_dim: int, + num_heads: int, + num_kv_heads: int, + mlp_mult: int, + tie_embeddings: bool, + tied_embed_init_std: float, + logit_softcap: float, + rope_base: float, + qk_gain_init: float, + mtp_num_heads: int = 0, + mtp_loss_weight: float = 0.1, + bigram_vocab_size: int = 0, + bigram_dim: int = 128, + xsa_last_n: int = 0, + rope_dims: int = 0, + ln_scale: bool = False, + value_residual: bool = False, + gated_attention: bool = False, + canon_last_n: int = 0, + canon_kernel: int = 4, + canon_delta_gate_init: float = -4.0, + flow_enabled: bool = False, + flow_latent_dim: int = 64, + flow_hidden_dim: int = 256, + flow_init_scale: float = 0.01, + native_flow_enabled: bool = False, + native_flow_hidden_dim: int = 256, + native_flow_init_scale: float = 0.01, + native_flow_loss_weight: float = 0.1, + e2e_ttt_enabled: bool = False, + e2e_ttt_num_heads: int = 8, + e2e_ttt_mini_batch: int = 16, + e2e_ttt_base_lr: float = 1.0, + ): + super().__init__() + if logit_softcap <= 0.0: + raise ValueError(f"logit_softcap must be positive, got {logit_softcap}") + self.tie_embeddings = tie_embeddings + self.tied_embed_init_std = tied_embed_init_std + self.logit_softcap = logit_softcap + self.mtp_num_heads = mtp_num_heads + self.mtp_loss_weight = mtp_loss_weight + self.tok_emb = nn.Embedding(vocab_size, model_dim) + self.bigram = BigramHashEmbedding(bigram_vocab_size, bigram_dim, model_dim) if bigram_vocab_size > 0 else None + self.smear = SmearGate(model_dim) + self.num_encoder_layers = num_layers // 2 + self.num_decoder_layers = num_layers - self.num_encoder_layers + self.num_skip_weights = min(self.num_encoder_layers, self.num_decoder_layers) + self.skip_weights = nn.Parameter(torch.ones(self.num_skip_weights, model_dim, dtype=torch.float32)) + canon_start = num_layers - canon_last_n if canon_last_n > 0 else num_layers + self.blocks = nn.ModuleList( + [ + Block( + model_dim, + num_heads, + num_kv_heads, + mlp_mult, + rope_base, + qk_gain_init, + rope_dims=rope_dims, + layer_idx=i, + ln_scale=ln_scale, + value_residual=value_residual, + gated_attention=gated_attention, + canon_kernel=canon_kernel if i >= canon_start else 0, + canon_delta_gate_init=canon_delta_gate_init, + ) + for i in range(num_layers) + ] + ) + self.final_norm = RMSNorm() + self.lm_head = None if tie_embeddings else CastedLinear(model_dim, vocab_size, bias=False) + if self.lm_head is not None: + self.lm_head._zero_init = True + self.flow_refiner = FlowRefiner(model_dim, flow_latent_dim, flow_hidden_dim, flow_init_scale) if flow_enabled else None + self.native_flow = NativeFlowMatcher(model_dim, native_flow_hidden_dim, native_flow_init_scale) if native_flow_enabled else None + self.native_flow_loss_weight = native_flow_loss_weight + self.ttt_refiner = TTTLinearRefiner(model_dim, e2e_ttt_num_heads, e2e_ttt_mini_batch, e2e_ttt_base_lr) if e2e_ttt_enabled else None + self.mtp_heads = nn.ModuleList( + [CastedLinear(model_dim, vocab_size, bias=False) for _ in range(mtp_num_heads)] + ) + for head in self.mtp_heads: + head._zero_init = True + if xsa_last_n > 0: + for i in range(max(0, num_layers - xsa_last_n), num_layers): + self.blocks[i].attn.use_xsa = True + self._init_weights() + + def _init_weights(self) -> None: + if self.tie_embeddings: + nn.init.normal_(self.tok_emb.weight, mean=0.0, std=self.tied_embed_init_std) + num_layers = len(self.blocks) + for name, module in self.named_modules(): + if isinstance(module, nn.Linear): + if getattr(module, "_zero_init", False): + nn.init.zeros_(module.weight) + elif module.weight.ndim == 2 and module.weight.shape[0] >= 64 and module.weight.shape[1] >= 64: + nn.init.orthogonal_(module.weight, gain=1.0) + if ".proj." in name or name.endswith(".proj"): + with torch.no_grad(): + module.weight.mul_(1.0 / math.sqrt(2 * num_layers)) + + def forward(self, input_ids: Tensor, target_ids: Tensor) -> Tensor: + x = self.tok_emb(input_ids) + if self.bigram is not None: + x = x + self.bigram(input_ids) + x = F.rms_norm(x, (x.size(-1),)) + x = self.smear(x) + x0 = x + v0 = None + skips: list[Tensor] = [] + + for i in range(self.num_encoder_layers): + x, raw_v = self.blocks[i](x, x0, v0=v0) + if v0 is None and raw_v is not None: + v0 = raw_v + skips.append(x) + for i in range(self.num_decoder_layers): + if skips: + x = x + self.skip_weights[i].to(dtype=x.dtype)[None, None, :] * skips.pop() + x, _ = self.blocks[self.num_encoder_layers + i](x, x0, v0=v0) + + x = self.final_norm(x) + if self.ttt_refiner is not None: + x = x + self.ttt_refiner(x) + if self.flow_refiner is not None: + x = x + self.flow_refiner(x) + nflow_loss = x.new_zeros(()) + if self.native_flow is not None: + correction, nflow_loss = self.native_flow(x) + x = x + correction + x_flat = x.reshape(-1, x.size(-1)) + targets = target_ids.reshape(-1) + if self.tie_embeddings: + logits_proj = F.linear(x_flat, self.tok_emb.weight) + else: + if self.lm_head is None: + raise RuntimeError("lm_head is required when tie_embeddings=False") + logits_proj = self.lm_head(x_flat) + logits = self.logit_softcap * torch.tanh(logits_proj / self.logit_softcap) + main_loss = F.cross_entropy(logits.float(), targets, reduction="mean") + + if self.training and self.mtp_num_heads > 0 and self.mtp_loss_weight > 0.0: + _, seqlen, dim = x.shape + mtp_loss_sum = x.new_zeros(()) + mtp_loss_count = 0 + for k, mtp_head in enumerate(self.mtp_heads): + valid_t = seqlen - (k + 1) + if valid_t <= 0: + continue + mtp_hidden = x[:, :valid_t, :].reshape(-1, dim) + mtp_targets = target_ids[:, k + 1 :].reshape(-1) + mtp_logits_proj = mtp_head(mtp_hidden) + mtp_logits = self.logit_softcap * torch.tanh(mtp_logits_proj / self.logit_softcap) + mtp_loss_sum = mtp_loss_sum + F.cross_entropy(mtp_logits.float(), mtp_targets, reduction="mean") + mtp_loss_count += 1 + if mtp_loss_count > 0: + main_loss = main_loss + self.mtp_loss_weight * (mtp_loss_sum / mtp_loss_count) + + if self.native_flow is not None and self.training: + main_loss = main_loss + self.native_flow_loss_weight * nflow_loss + + return main_loss + + def forward_logits(self, input_ids: Tensor) -> Tensor: + """Return logits (bsz, seq_len, vocab) without computing loss.""" + x = self.tok_emb(input_ids) + if self.bigram is not None: + x = x + self.bigram(input_ids) + x = F.rms_norm(x, (x.size(-1),)) + x = self.smear(x) + x0 = x + v0 = None + skips: list[Tensor] = [] + for i in range(self.num_encoder_layers): + x, raw_v = self.blocks[i](x, x0, v0=v0) + if v0 is None and raw_v is not None: + v0 = raw_v + skips.append(x) + for i in range(self.num_decoder_layers): + if skips: + x = x + self.skip_weights[i].to(dtype=x.dtype)[None, None, :] * skips.pop() + x, _ = self.blocks[self.num_encoder_layers + i](x, x0, v0=v0) + x = self.final_norm(x) + if self.ttt_refiner is not None: + x = x + self.ttt_refiner(x) + if self.flow_refiner is not None: + x = x + self.flow_refiner(x) + if self.native_flow is not None: + correction, _ = self.native_flow(x) + x = x + correction + if self.tie_embeddings: + logits_proj = F.linear(x, self.tok_emb.weight) + else: + logits_proj = self.lm_head(x) + return self.logit_softcap * torch.tanh(logits_proj / self.logit_softcap) + + +# ----------------------------- +# SLIDING WINDOW EVALUATION +# ----------------------------- + +def eval_val_sliding( + args: Hyperparameters, + base_model: nn.Module, + rank: int, + world_size: int, + device: torch.device, + val_tokens: Tensor, + base_bytes_lut: Tensor, + has_leading_space_lut: Tensor, + is_boundary_token_lut: Tensor, + stride: int, + batch_seqs: int = 32, + eval_seq_len: int | None = None, +) -> tuple[float, float]: + """Sliding window evaluation: each token scored with maximum context. + Optionally uses entropy-gated 5-gram cache (NGRAM_CACHE=1).""" + seq_len = eval_seq_len or args.train_seq_len + total_tokens = val_tokens.numel() - 1 + + window_starts = [ws for ws in range(0, total_tokens, stride) + if min(ws + seq_len, total_tokens) - ws >= 1] + total_windows = len(window_starts) + + my_s = (total_windows * rank) // world_size + my_e = (total_windows * (rank + 1)) // world_size + my_windows = window_starts[my_s:my_e] + + loss_sum = torch.zeros((), device=device, dtype=torch.float64) + token_count = torch.zeros((), device=device, dtype=torch.float64) + byte_count = torch.zeros((), device=device, dtype=torch.float64) + + # N-gram eval cache with multi-order backoff + entropy-adaptive alpha (PR #702 inspired) + _ngram_default = "1" if world_size > 1 else "0" + use_ngram = bool(int(os.environ.get("NGRAM_CACHE", _ngram_default))) + ngram_alpha = float(os.environ.get("NGRAM_ALPHA", "0.40")) + ngram_min_count = int(os.environ.get("NGRAM_MIN_COUNT", "2")) + ngram_order = int(os.environ.get("NGRAM_ORDER", "7")) + ngram_min_order = int(os.environ.get("NGRAM_MIN_ORDER", "2")) + ngram_buckets = int(os.environ.get("NGRAM_BUCKETS", "4194304")) + ngram_entropy = bool(int(os.environ.get("NGRAM_ENTROPY", "1"))) + ngram_ent_base = float(os.environ.get("NGRAM_ENT_BASE", "0.05")) + ngram_ent_range = float(os.environ.get("NGRAM_ENT_RANGE", "0.55")) + ngram_ent_scale = float(os.environ.get("NGRAM_ENT_SCALE", "2.0")) + ngram_ent_thresh = float(os.environ.get("NGRAM_ENT_THRESH", "4.0")) + if use_ngram: + val_np = val_tokens.cpu().numpy() + _n_orders = ngram_order - ngram_min_order + 1 + ctx_tables = [np.zeros((ngram_buckets,), dtype=np.uint32) for _ in range(_n_orders)] + full_tables = [np.zeros((ngram_buckets,), dtype=np.uint32) for _ in range(_n_orders)] + ng_mask = np.uint64(ngram_buckets - 1) + ng_primes = np.array( + [np.uint64(36313), np.uint64(27191), np.uint64(51647), np.uint64(81929), + np.uint64(131071), np.uint64(175447), np.uint64(209591)], + dtype=np.uint64, + ) + print(f"ngram_cache:enabled orders={ngram_min_order}-{ngram_order} backoff " + f"entropy={ngram_entropy} alpha={ngram_alpha} " + f"ent_base={ngram_ent_base} ent_range={ngram_ent_range} " + f"min_count={ngram_min_count} buckets={ngram_buckets}", flush=True) + + base_model.eval() + compiled_logits = torch.compile(base_model.forward_logits, dynamic=False, fullgraph=True) + + with torch.inference_mode(): + for bi in range(0, len(my_windows), batch_seqs): + batch_ws = my_windows[bi:bi + batch_seqs] + bsz = len(batch_ws) + + x_batch = torch.zeros(bsz, seq_len, dtype=torch.int64, device=device) + y_batch = torch.zeros(bsz, seq_len, dtype=torch.int64, device=device) + wlens: list[int] = [] + + for i, ws in enumerate(batch_ws): + end = min(ws + seq_len, total_tokens) + wlen = end - ws + wlens.append(wlen) + chunk = val_tokens[ws:end + 1].to(dtype=torch.int64, device=device) + x_batch[i, :wlen] = chunk[:-1] + y_batch[i, :wlen] = chunk[1:] + + with torch.autocast(device_type="cuda", dtype=torch.bfloat16): + logits = compiled_logits(x_batch) + + nll = F.cross_entropy( + logits.reshape(-1, logits.size(-1)).float(), + y_batch.reshape(-1), + reduction="none", + ).reshape(bsz, seq_len) + + for i, ws in enumerate(batch_ws): + wlen = wlens[i] + s = 0 if ws == 0 else max(wlen - stride, 0) + seg_len = wlen - s + if seg_len <= 0: + continue + + scored_nll = nll[i, s:wlen].to(torch.float64) + + if use_ngram: + seg_nll_np = scored_nll.cpu().numpy() + seg_model_p = np.exp(-seg_nll_np) + n_seg = len(seg_nll_np) + global_j = np.arange(ws + s + 1, ws + wlen + 1, dtype=np.int64) + + # Entropy-adaptive alpha: compute from model logits (GPU) + if ngram_entropy: + with torch.no_grad(): + lp = F.log_softmax(logits[i, s:wlen].float(), dim=-1) + seg_ent = -(lp.exp() * lp).sum(dim=-1).cpu().numpy() + alpha_per_tok = ngram_ent_base + ngram_ent_range / ( + 1.0 + np.exp(-ngram_ent_scale * (seg_ent - ngram_ent_thresh))) + + # Precompute hashes for all orders + order_data = [] # (v_idx, ctx_key, full_key) per order + for oi in range(_n_orders): + ctx_w = ngram_min_order + oi - 1 + valid = global_j >= ctx_w + if not valid.any(): + order_data.append(None) + continue + v_idx = np.nonzero(valid)[0] + jv = global_j[v_idx] + ctx_hash = np.zeros(len(jv), dtype=np.uint64) + for k in range(ctx_w): + tok = val_np[jv - (ctx_w - k)].astype(np.uint64) + ctx_hash ^= tok * ng_primes[k % len(ng_primes)] + ctx_key = (ctx_hash & ng_mask).astype(np.int64) + tgt_np = val_np[jv].astype(np.uint64) + full_key = ((ctx_hash ^ (tgt_np * ng_primes[ctx_w % len(ng_primes)])) & ng_mask).astype(np.int64) + order_data.append((v_idx, ctx_key, full_key)) + + # Multi-order backoff: highest order first, fill unmatched with lower orders + best_p_ng = np.full(n_seg, -1.0) + for oi in range(_n_orders - 1, -1, -1): + if order_data[oi] is None: + continue + v_idx, ctx_key, full_key = order_data[oi] + ctx_counts = ctx_tables[oi][ctx_key].astype(np.float64) + full_counts = full_tables[oi][full_key].astype(np.float64) + has_match = ctx_counts >= float(ngram_min_count) + needs_fill = has_match & (best_p_ng[v_idx] < 0) + if needs_fill.any(): + fill_idx = v_idx[needs_fill] + p = np.minimum(full_counts[needs_fill], ctx_counts[needs_fill]) / np.maximum(ctx_counts[needs_fill], 1.0) + best_p_ng[fill_idx] = np.clip(p, 0.0, 1.0) + + # Mix model probability with n-gram + has_match = best_p_ng >= 0 + if has_match.any(): + if ngram_entropy: + alpha = alpha_per_tok[has_match] + else: + alpha = ngram_alpha + seg_model_p[has_match] = (1.0 - alpha) * seg_model_p[has_match] + alpha * best_p_ng[has_match] + seg_nll_np = -np.log(np.clip(seg_model_p, 1e-12, 1.0)) + + # Score-first: update ALL order tables AFTER scoring + for oi in range(_n_orders): + if order_data[oi] is None: + continue + v_idx, ctx_key, full_key = order_data[oi] + np.add.at(ctx_tables[oi], ctx_key, 1) + np.add.at(full_tables[oi], full_key, 1) + + scored_nll = torch.from_numpy(seg_nll_np).to(dtype=torch.float64, device=device) + + loss_sum += scored_nll.sum() + token_count += float(seg_len) + tgt = y_batch[i, s:wlen] + prev = x_batch[i, s:wlen] + tb = base_bytes_lut[tgt].to(torch.float64) + tb += (has_leading_space_lut[tgt] & ~is_boundary_token_lut[prev]).to(torch.float64) + byte_count += tb.sum() + + if dist.is_available() and dist.is_initialized(): + dist.all_reduce(loss_sum, op=dist.ReduceOp.SUM) + dist.all_reduce(token_count, op=dist.ReduceOp.SUM) + dist.all_reduce(byte_count, op=dist.ReduceOp.SUM) + + val_loss = (loss_sum / token_count).item() + bits_per_token = val_loss / math.log(2.0) + tokens_per_byte = token_count.item() / byte_count.item() + base_model.train() + return val_loss, bits_per_token * tokens_per_byte + + +# ----------------------------- +# TEST-TIME TRAINING (TTT) +# ----------------------------- + +def _clear_rotary_caches(model: nn.Module) -> None: + """Clear cached RoPE tensors to avoid 'Inference tensors cannot be saved for backward'.""" + for m in model.modules(): + if isinstance(m, Rotary): + m._cos_cached = None + m._sin_cached = None + m._seq_len_cached = 0 + + +def ttt_adapt(args: Hyperparameters, base_model: nn.Module, device: torch.device, + val_tokens: Tensor, rank: int = 0, world_size: int = 1, + log_fn=None) -> None: + """Score-first TTT: process val data in chunks, score each chunk first + (inference_mode), then train on scored tokens. Compliant with Issue #677.""" + seq_len = args.train_seq_len + total_tokens = val_tokens.numel() - 1 + chunk_tokens = args.ttt_chunk_tokens + batch_seqs = args.ttt_batch_seqs + + # Freeze early blocks + if args.ttt_freeze_blocks > 0: + for i, block in enumerate(base_model.blocks): + if i < args.ttt_freeze_blocks: + for p in block.parameters(): + p.requires_grad_(False) + + ttt_params = [p for p in base_model.parameters() if p.requires_grad] + if args.ttt_optimizer == "adamw": + optimizer = torch.optim.AdamW(ttt_params, lr=args.ttt_lr, weight_decay=0.0) + else: + optimizer = torch.optim.SGD(ttt_params, lr=args.ttt_lr, momentum=args.ttt_momentum) + + t0 = time.perf_counter() + chunk_idx = 0 + + for chunk_start in range(0, total_tokens - seq_len, chunk_tokens): + chunk_end = min(chunk_start + chunk_tokens, total_tokens) + chunk_len = chunk_end - chunk_start + n_seqs = chunk_len // seq_len + if n_seqs == 0: + break + + my_start = (n_seqs * rank) // world_size + my_end = (n_seqs * (rank + 1)) // world_size + if my_end <= my_start: + continue + + # Phase 1: Score chunk under inference_mode (forward only) + base_model.eval() + with torch.inference_mode(): + for si in range(my_start, my_end, batch_seqs): + se = min(si + batch_seqs, my_end) + raw_s = chunk_start + si * seq_len + raw_e = chunk_start + se * seq_len + 1 + local = val_tokens[raw_s:raw_e].to(device=device, dtype=torch.int64) + x = local[:-1].reshape(-1, seq_len) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16): + base_model.forward_logits(x) + + # Phase 2: Train on scored tokens (K epochs) + base_model.train() + for epoch in range(args.ttt_epochs): + for si in range(my_start, my_end, batch_seqs): + se = min(si + batch_seqs, my_end) + raw_s = chunk_start + si * seq_len + raw_e = chunk_start + se * seq_len + 1 + local = val_tokens[raw_s:raw_e].to(device=device, dtype=torch.int64) + x = local[:-1].reshape(-1, seq_len) + y = local[1:].reshape(-1, seq_len) + + optimizer.zero_grad(set_to_none=True) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16): + loss = base_model(x, y) + loss.backward() + + if world_size > 1: + for p in ttt_params: + if p.grad is not None: + dist.all_reduce(p.grad, op=dist.ReduceOp.AVG) + + torch.nn.utils.clip_grad_norm_(ttt_params, 1.0) + optimizer.step() + + chunk_idx += 1 + if log_fn and chunk_idx % 20 == 0: + log_fn(f"ttt:chunk={chunk_idx} elapsed={time.perf_counter()-t0:.1f}s") + + # Restore all params + for p in base_model.parameters(): + p.requires_grad_(True) + + if log_fn: + log_fn(f"ttt:done chunks={chunk_idx} elapsed={time.perf_counter()-t0:.1f}s") + + +def eval_val_sliding_ttt( + args: Hyperparameters, + base_model: nn.Module, + rank: int, + world_size: int, + device: torch.device, + val_tokens: Tensor, + base_bytes_lut: Tensor, + has_leading_space_lut: Tensor, + is_boundary_token_lut: Tensor, + stride: int, + batch_seqs: int = 32, + eval_seq_len: int | None = None, + log_fn=None, +) -> tuple[float, float]: + """Legal single-pass TTT: score each chunk with sliding windows, then train on it. + Tokens are always scored BEFORE any training on their chunk, so the evaluation + is never contaminated by future information.""" + seq_len = eval_seq_len or args.train_seq_len + total_tokens = val_tokens.numel() - 1 + ttt_chunk = args.ttt_chunk_tokens + + # Build window starts (same logic as eval_val_sliding) + window_starts = [ws for ws in range(0, total_tokens, stride) + if min(ws + seq_len, total_tokens) - ws >= stride or ws == 0] + + # Map each window to the chunk that contains its first scored token + num_chunks = (total_tokens + ttt_chunk - 1) // ttt_chunk + chunk_windows: list[list[int]] = [[] for _ in range(num_chunks)] + for ws in window_starts: + end = min(ws + seq_len, total_tokens) + wlen = end - ws + s = 0 if ws == 0 else max(wlen - stride, 0) + scored_start = ws + s + ci = min(scored_start // ttt_chunk, num_chunks - 1) + chunk_windows[ci].append(ws) + + if log_fn: + log_fn(f"ttt_sliding:start chunks={num_chunks} chunk_tokens={ttt_chunk} " + f"total_windows={len(window_starts)} stride={stride} " + f"ttt_lr={args.ttt_lr} ttt_epochs={args.ttt_epochs} " + f"ttt_optimizer={args.ttt_optimizer} freeze_blocks={args.ttt_freeze_blocks}") + + loss_sum = torch.zeros((), device=device, dtype=torch.float64) + token_count = torch.zeros((), device=device, dtype=torch.float64) + byte_count = torch.zeros((), device=device, dtype=torch.float64) + + # Freeze first N blocks + n_blocks = len(base_model.blocks) + frozen_block_ids = set(range(min(args.ttt_freeze_blocks, n_blocks))) + + ttt_params = [] + for name, p in base_model.named_parameters(): + freeze = False + for bi in frozen_block_ids: + if f"blocks.{bi}." in name: + freeze = True + break + if freeze: + p.requires_grad_(False) + else: + p.requires_grad_(True) + ttt_params.append(p) + + if log_fn: + log_fn(f"ttt_sliding:params unfrozen={sum(p.numel() for p in ttt_params)} " + f"frozen={sum(p.numel() for p in base_model.parameters() if not p.requires_grad)}") + + if args.ttt_optimizer == "adamw": + optimizer = torch.optim.AdamW(ttt_params, lr=args.ttt_lr, weight_decay=0.0) + else: + optimizer = torch.optim.SGD(ttt_params, lr=args.ttt_lr, momentum=args.ttt_momentum) + + t0 = time.perf_counter() + + for ci in range(num_chunks): + windows = chunk_windows[ci] + if not windows: + continue + + # --- Phase 1: SCORE this chunk's windows --- + my_s = (len(windows) * rank) // world_size + my_e = (len(windows) * (rank + 1)) // world_size + my_windows = windows[my_s:my_e] + + base_model.eval() + with torch.inference_mode(): + for bi in range(0, len(my_windows), batch_seqs): + batch_ws = my_windows[bi:bi + batch_seqs] + bsz = len(batch_ws) + x_batch = torch.zeros(bsz, seq_len, dtype=torch.int64, device=device) + y_batch = torch.zeros(bsz, seq_len, dtype=torch.int64, device=device) + wlens: list[int] = [] + for i, ws in enumerate(batch_ws): + end = min(ws + seq_len, total_tokens) + wlen = end - ws + wlens.append(wlen) + chunk_tok = val_tokens[ws:end + 1].to(dtype=torch.int64, device=device) + x_batch[i, :wlen] = chunk_tok[:-1] + y_batch[i, :wlen] = chunk_tok[1:] + with torch.autocast(device_type="cuda", dtype=torch.bfloat16): + logits = base_model.forward_logits(x_batch) + nll = F.cross_entropy( + logits.reshape(-1, logits.size(-1)).float(), + y_batch.reshape(-1), reduction="none", + ).reshape(bsz, seq_len) + for i, ws in enumerate(batch_ws): + wlen = wlens[i] + s = 0 if ws == 0 else max(wlen - stride, 0) + scored_nll = nll[i, s:wlen].to(torch.float64) + loss_sum += scored_nll.sum() + token_count += float(wlen - s) + tgt, prev = y_batch[i, s:wlen], x_batch[i, s:wlen] + tb = base_bytes_lut[tgt].to(torch.float64) + tb += (has_leading_space_lut[tgt] & ~is_boundary_token_lut[prev]).to(torch.float64) + byte_count += tb.sum() + + # --- Phase 2: TRAIN on this chunk's tokens (already scored = legal) --- + _clear_rotary_caches(base_model) + is_last_chunk = (ci == num_chunks - 1) + if not is_last_chunk and args.ttt_epochs > 0: + base_model.train() + chunk_start = ci * ttt_chunk + chunk_end = min((ci + 1) * ttt_chunk, total_tokens) + chunk_seqs = (chunk_end - chunk_start) // seq_len + if chunk_seqs > 0: + # Cosine LR schedule with 5% warmup + warmup_chunks = max(num_chunks // 20, 1) + if ci < warmup_chunks: + lr_scale = (ci + 1) / warmup_chunks + else: + progress = (ci - warmup_chunks) / max(num_chunks - 1 - warmup_chunks, 1) + lr_scale = 0.5 * (1.0 + math.cos(math.pi * progress)) + cos_lr = args.ttt_lr * lr_scale + for pg in optimizer.param_groups: + pg['lr'] = cos_lr + + my_seq_s = (chunk_seqs * rank) // world_size + my_seq_e = (chunk_seqs * (rank + 1)) // world_size + my_chunk_seqs = my_seq_e - my_seq_s + + for _ep in range(args.ttt_epochs): + for bs in range(0, my_chunk_seqs, args.ttt_batch_seqs): + be = min(bs + args.ttt_batch_seqs, my_chunk_seqs) + actual_bs = my_seq_s + bs + actual_be = my_seq_s + be + start_tok = chunk_start + actual_bs * seq_len + end_tok = chunk_start + actual_be * seq_len + 1 + if end_tok > val_tokens.numel(): + continue + local = val_tokens[start_tok:end_tok].to(device=device, dtype=torch.int64) + x = local[:-1].reshape(-1, seq_len) + y = local[1:].reshape(-1, seq_len) + optimizer.zero_grad(set_to_none=True) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16): + loss = base_model(x, y) + loss.backward() + if world_size > 1: + for p in ttt_params: + if p.grad is not None: + dist.all_reduce(p.grad, op=dist.ReduceOp.AVG) + torch.nn.utils.clip_grad_norm_(ttt_params, 1.0) + optimizer.step() + + if log_fn and (ci % 10 == 0 or ci == num_chunks - 1): + elapsed = time.perf_counter() - t0 + rl = loss_sum.item() / max(token_count.item(), 1) + rbpb = rl / math.log(2.0) * (token_count.item() / max(byte_count.item(), 1)) if token_count.item() > 0 else 0.0 + log_fn(f" ttt_chunk [{ci+1}/{num_chunks}] bpb={rbpb:.6f} time={elapsed:.1f}s") + + if dist.is_available() and dist.is_initialized(): + dist.all_reduce(loss_sum, op=dist.ReduceOp.SUM) + dist.all_reduce(token_count, op=dist.ReduceOp.SUM) + dist.all_reduce(byte_count, op=dist.ReduceOp.SUM) + + val_loss = (loss_sum / token_count).item() + val_bpb = val_loss / math.log(2.0) * (token_count.item() / byte_count.item()) + + # Restore all params and return to eval mode + for p in base_model.parameters(): + p.requires_grad_(True) + base_model.eval() + + if log_fn: + log_fn(f"ttt_sliding:done val_loss={val_loss:.6f} val_bpb={val_bpb:.6f} " + f"elapsed={time.perf_counter() - t0:.1f}s") + return val_loss, val_bpb + + +# ----------------------------- +# INT6 MIXED QUANTIZATION (transplanted from working diagnostic scripts) +# ----------------------------- + +def _classify_param(name: str) -> str: + if "tok_emb" in name or "lm_head" in name: + return "embed" + if ".mlp." in name: + return "mlp" + if ".attn." in name or (".proj." in name and ".mlp." not in name): + return "attn" + return "other" + +def quantize_int6_per_row(t: Tensor, qmax: int = 31) -> tuple[Tensor, Tensor]: + t32 = t.float() + qmin = -qmax - 1 + pct = CastedLinear._quant_percentile + if t32.ndim == 2: + row_max = (torch.quantile(t32.abs(), pct, dim=1) if pct < 1.0 + else t32.abs().amax(dim=1)) + scale = (row_max / float(qmax)).clamp_min(1.0 / float(qmax)).to(torch.float16) + clipped = t32.clamp(-row_max[:, None], row_max[:, None]) + q = torch.clamp(torch.round(clipped / scale.float()[:, None]), qmin, qmax).to(torch.int8) + return q, scale + amax = t32.abs().max().item() + scale = torch.tensor(amax / float(qmax) if amax > 0 else 1.0, dtype=torch.float16) + q = torch.clamp(torch.round(t32 / scale.float()), qmin, qmax).to(torch.int8) + return q, scale + +def mixed_quantize_int6(state_dict: dict[str, Tensor], int6_cats: set[str], + int5_layers: set[int] | None = None): + num_layers_total = max( + (int(k.split(".")[1]) for k in state_dict if k.startswith("blocks.")), + default=0, + ) + 1 + late_k_layers = set(range(num_layers_total - 2, num_layers_total)) + if int5_layers is None: + int5_layers = set() + + result: dict[str, Tensor] = {} + meta: dict[str, object] = {} + for name, tensor in state_dict.items(): + t = tensor.detach().cpu().contiguous() + cat = _classify_param(name) + if not t.is_floating_point() or t.numel() <= 65536: + result[name] = t.to(torch.float16) if t.is_floating_point() else t + meta[name] = "passthrough" + continue + if any(p in name for p in CONTROL_TENSOR_NAME_PATTERNS): + result[name] = t.float() + meta[name] = "passthrough_ctrl" + continue + # Determine layer index for int5 fallback + layer_idx = -1 + if name.startswith("blocks."): + try: + layer_idx = int(name.split(".")[1]) + except (IndexError, ValueError): + pass + # tok_emb.weight falls through to int8 via "embed" category + if cat in int6_cats and t.ndim >= 1: + qmax = 15 if layer_idx in int5_layers else 31 + q, s = quantize_int6_per_row(t, qmax=qmax) + result[name + ".q"] = q + result[name + ".scale"] = s + meta[name] = {"type": "int5" if qmax == 15 else "int6"} + else: + q, s = quantize_float_tensor(t) + result[name + ".q"] = q + result[name + ".scale"] = s + meta[name] = {"type": "int8"} + return result, meta + +def dequantize_mixed_int6(result: dict[str, Tensor], meta: dict[str, object], + template_sd: dict[str, Tensor]) -> dict[str, Tensor]: + out: dict[str, Tensor] = {} + for name, orig in template_sd.items(): + info = meta.get(name) + if info is None: + continue + orig_dtype = orig.dtype + if info in ("passthrough", "passthrough_ctrl", "passthrough_fp16"): + t = result[name] + if t.dtype == torch.float16 and orig_dtype in (torch.float32, torch.bfloat16): + t = t.to(orig_dtype) + out[name] = t + continue + q, s = result[name + ".q"], result[name + ".scale"] + if s.ndim > 0: + out[name] = (q.float() * s.float().view(q.shape[0], *([1] * (q.ndim - 1)))).to(orig_dtype) + else: + out[name] = (q.float() * float(s.item())).to(orig_dtype) + return out + + +# ----------------------------- +# TRAINING +# ----------------------------- + +def main() -> None: + global zeropower_via_newtonschulz5 + + code = Path(__file__).read_text(encoding="utf-8") + args = Hyperparameters() + zeropower_via_newtonschulz5 = torch.compile(zeropower_via_newtonschulz5) + + # ----------------------------- + # DISTRIBUTED + CUDA SETUP + # ----------------------------- + + distributed = "RANK" in os.environ and "WORLD_SIZE" in os.environ and int(os.environ.get("WORLD_SIZE", "1")) > 1 + rank = int(os.environ.get("RANK", "0")) + world_size = int(os.environ.get("WORLD_SIZE", "1")) + local_rank = int(os.environ.get("LOCAL_RANK", "0")) + if world_size <= 0: + raise ValueError(f"WORLD_SIZE must be positive, got {world_size}") + if 8 % world_size != 0: + raise ValueError(f"WORLD_SIZE={world_size} must divide 8 so grad_accum_steps stays integral") + grad_accum_steps = 8 // world_size + grad_scale = 1.0 / grad_accum_steps + if not torch.cuda.is_available(): + raise RuntimeError("CUDA is required") + device = torch.device("cuda", local_rank) + torch.cuda.set_device(device) + if distributed: + dist.init_process_group(backend="nccl", device_id=device) + dist.barrier() + master_process = rank == 0 + + # Fast math knobs + torch.backends.cuda.matmul.allow_tf32 = True + torch.backends.cudnn.allow_tf32 = True + from torch.backends.cuda import enable_cudnn_sdp, enable_flash_sdp, enable_math_sdp, enable_mem_efficient_sdp + + if _USE_FA3: + enable_cudnn_sdp(False) + enable_flash_sdp(True) + enable_mem_efficient_sdp(False) + enable_math_sdp(False) + else: + enable_cudnn_sdp(False) + enable_flash_sdp(True) + enable_mem_efficient_sdp(True) + enable_math_sdp(True) + + logfile = None + if master_process: + os.makedirs("logs", exist_ok=True) + logfile = f"logs/{args.run_id}.txt" + print(logfile) + + def log0(msg: str, console: bool = True) -> None: + if not master_process: + return + if console: + print(msg) + if logfile is not None: + with open(logfile, "a", encoding="utf-8") as f: + print(msg, file=f) + + log0(code, console=False) + log0("=" * 100, console=False) + log0(f"Running Python {sys.version}", console=False) + log0(f"Running PyTorch {torch.__version__}", console=False) + log0( + subprocess.run(["nvidia-smi"], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True, check=False).stdout, + console=False, + ) + log0("=" * 100, console=False) + + # ----------------------------- + # TOKENIZER + VALIDATION METRIC SETUP + # ----------------------------- + + random.seed(args.seed) + np.random.seed(args.seed) + torch.manual_seed(args.seed) + torch.cuda.manual_seed_all(args.seed) + + if not args.tokenizer_path.endswith(".model"): + raise ValueError(f"Script only setup for SentencePiece .model file: {args.tokenizer_path}") + sp = spm.SentencePieceProcessor(model_file=args.tokenizer_path) + if int(sp.vocab_size()) != args.vocab_size: + raise ValueError( + f"VOCAB_SIZE={args.vocab_size} does not match tokenizer vocab_size={int(sp.vocab_size())}" + ) + dataset_dir = Path(args.data_path).resolve() + actual_train_files = len(list(dataset_dir.glob("fineweb_train_*.bin"))) + effective_eval_seq_len = args.eval_seq_len if args.eval_seq_len > 0 else args.train_seq_len + val_seq_len = max(args.train_seq_len, effective_eval_seq_len) + val_tokens = load_validation_tokens(args.val_files, val_seq_len) + base_bytes_lut, has_leading_space_lut, is_boundary_token_lut = build_sentencepiece_luts( + sp, args.vocab_size, device + ) + log0(f"val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path={args.tokenizer_path}") + log0(f"train_loader:dataset:{dataset_dir.name} train_shards:{actual_train_files}") + log0(f"val_loader:shards pattern={args.val_files} tokens:{val_tokens.numel() - 1}") + + # ----------------------------- + # MODEL + OPTIMIZER SETUP + # ----------------------------- + + CastedLinear._qat_enabled = args.qat_enabled + + base_model = GPT( + vocab_size=args.vocab_size, + num_layers=args.num_layers, + model_dim=args.model_dim, + num_heads=args.num_heads, + num_kv_heads=args.num_kv_heads, + mlp_mult=args.mlp_mult, + tie_embeddings=args.tie_embeddings, + tied_embed_init_std=args.tied_embed_init_std, + logit_softcap=args.logit_softcap, + rope_base=args.rope_base, + qk_gain_init=args.qk_gain_init, + mtp_num_heads=args.mtp_num_heads, + mtp_loss_weight=args.mtp_loss_weight, + bigram_vocab_size=args.bigram_vocab_size, + bigram_dim=args.bigram_dim, + xsa_last_n=args.xsa_last_n, + rope_dims=args.rope_dims, + ln_scale=args.ln_scale, + value_residual=args.value_residual, + gated_attention=args.gated_attention, + canon_last_n=args.canon_last_n, + canon_kernel=args.canon_kernel, + canon_delta_gate_init=args.canon_delta_gate_init, + flow_enabled=args.flow_enabled, + flow_latent_dim=args.flow_latent_dim, + flow_hidden_dim=args.flow_hidden_dim, + flow_init_scale=args.flow_init_scale, + native_flow_enabled=args.native_flow_enabled, + native_flow_hidden_dim=args.native_flow_hidden_dim, + native_flow_init_scale=args.native_flow_init_scale, + native_flow_loss_weight=args.native_flow_loss_weight, + e2e_ttt_enabled=args.e2e_ttt_enabled, + e2e_ttt_num_heads=args.e2e_ttt_num_heads, + e2e_ttt_mini_batch=args.e2e_ttt_mini_batch, + e2e_ttt_base_lr=args.e2e_ttt_base_lr, + ).to(device).bfloat16() + for module in base_model.modules(): + if isinstance(module, CastedLinear): + module.float() + restore_low_dim_params_to_fp32(base_model) + compiled_model = torch.compile(base_model, dynamic=False, fullgraph=True) + model: nn.Module = DDP(compiled_model, device_ids=[local_rank], broadcast_buffers=False, static_graph=True) if distributed else compiled_model + + # Optimizer split: + # - token embedding (Adam) uses EMBED_LR + # - untied lm_head (Adam) uses HEAD_LR + # - matrix params in transformer blocks use MATRIX_LR via Muon + # - vectors/scalars use SCALAR_LR via Adam + block_named_params = list(base_model.blocks.named_parameters()) + matrix_params = [ + p + for name, p in block_named_params + if p.ndim == 2 and not any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) + ] + if base_model.mtp_num_heads > 0: + matrix_params.extend([p for p in base_model.mtp_heads.parameters() if p.ndim == 2]) + scalar_params = [ + p + for name, p in block_named_params + if p.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) + ] + if base_model.skip_weights.numel() > 0: + scalar_params.append(base_model.skip_weights) + scalar_params.append(base_model.smear.gate) + if base_model.bigram is not None: + scalar_params.append(base_model.bigram.scale) + if base_model.flow_refiner is not None: + scalar_params.extend(list(base_model.flow_refiner.parameters())) + if base_model.native_flow is not None: + scalar_params.extend(list(base_model.native_flow.parameters())) + if base_model.ttt_refiner is not None: + scalar_params.extend(list(base_model.ttt_refiner.parameters())) + token_lr = args.tied_embed_lr if args.tie_embeddings else args.embed_lr + tok_params = [{"params": [base_model.tok_emb.weight], "lr": token_lr, "base_lr": token_lr}] + if base_model.bigram is not None: + tok_params.append({"params": [base_model.bigram.embed.weight], "lr": token_lr, "base_lr": token_lr}) + if base_model.bigram.proj is not None: + matrix_params.append(base_model.bigram.proj.weight) + optimizer_tok = torch.optim.AdamW( + tok_params, + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + weight_decay=args.adam_wd, + fused=True, + ) + optimizer_muon = Muon( + matrix_params, + lr=args.matrix_lr, + momentum=args.muon_momentum, + backend_steps=args.muon_backend_steps, + weight_decay=args.muon_wd, + ) + for group in optimizer_muon.param_groups: + group["base_lr"] = args.matrix_lr + optimizer_scalar = torch.optim.AdamW( + [{"params": scalar_params, "lr": args.scalar_lr, "base_lr": args.scalar_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + weight_decay=args.adam_wd, + fused=True, + ) + optimizers: list[torch.optim.Optimizer] = [optimizer_tok, optimizer_muon, optimizer_scalar] + if base_model.lm_head is not None: + optimizer_head = torch.optim.Adam( + [{"params": [base_model.lm_head.weight], "lr": args.head_lr, "base_lr": args.head_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + fused=True, + ) + optimizers.insert(1, optimizer_head) + + n_params = sum(p.numel() for p in base_model.parameters()) + mtp_params = sum(p.numel() for p in base_model.mtp_heads.parameters()) + log0(f"model_params:{n_params}") + log0(f"mtp_num_heads:{args.mtp_num_heads} mtp_loss_weight:{args.mtp_loss_weight} mtp_params:{mtp_params}") + log0(f"world_size:{world_size} grad_accum_steps:{grad_accum_steps}") + log0(f"sdp_backends:fa3={_USE_FA3} cudnn=False flash=True mem_efficient={not _USE_FA3} math={not _USE_FA3}") + log0(f"attention_mode:gqa num_heads:{args.num_heads} num_kv_heads:{args.num_kv_heads}") + log0( + f"tie_embeddings:{args.tie_embeddings} embed_lr:{token_lr} " + f"head_lr:{args.head_lr if base_model.lm_head is not None else 0.0} " + f"matrix_lr:{args.matrix_lr} scalar_lr:{args.scalar_lr}" + ) + log0( + f"train_batch_tokens:{args.train_batch_tokens} train_seq_len:{args.train_seq_len} " + f"iterations:{args.iterations} warmup_steps:{args.warmup_steps} " + f"max_wallclock_seconds:{args.max_wallclock_seconds:.3f}" + ) + log0(f"seed:{args.seed}") + + # ----------------------------- + # DATA LOADER & MODEL WARMUP + # ----------------------------- + + train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) + + def zero_grad_all() -> None: + for opt in optimizers: + opt.zero_grad(set_to_none=True) + + max_wallclock_ms = 1000.0 * args.max_wallclock_seconds if args.max_wallclock_seconds > 0 else None + + def lr_mul(step: int, elapsed_ms: float) -> float: + if args.warmdown_iters <= 0: + return 1.0 + if max_wallclock_ms is None: + warmdown_start = max(args.iterations - args.warmdown_iters, 0) + return max((args.iterations - step) / max(args.warmdown_iters, 1), 0.0) if warmdown_start <= step < args.iterations else 1.0 + step_ms = elapsed_ms / max(step, 1) + warmdown_ms = args.warmdown_iters * step_ms + remaining_ms = max(max_wallclock_ms - elapsed_ms, 0.0) + return remaining_ms / max(warmdown_ms, 1e-9) if remaining_ms <= warmdown_ms else 1.0 + + # Warmup primes the compiled forward/backward/optimizer paths, then we restore the + # initial weights/optimizer state so measured training starts from the true init. + if args.warmup_steps > 0: + initial_model_state = {name: tensor.detach().cpu().clone() for name, tensor in base_model.state_dict().items()} + initial_optimizer_states = [copy.deepcopy(opt.state_dict()) for opt in optimizers] + model.train() + for warmup_step in range(args.warmup_steps): + zero_grad_all() + for micro_step in range(grad_accum_steps): + if distributed: + model.require_backward_grad_sync = micro_step == grad_accum_steps - 1 + x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + warmup_loss = model(x, y) + (warmup_loss * grad_scale).backward() + for opt in optimizers: + opt.step() + zero_grad_all() + if args.warmup_steps <= 20 or (warmup_step + 1) % 10 == 0 or warmup_step + 1 == args.warmup_steps: + log0(f"warmup_step:{warmup_step + 1}/{args.warmup_steps}") + base_model.load_state_dict(initial_model_state, strict=True) + for opt, state in zip(optimizers, initial_optimizer_states, strict=True): + opt.load_state_dict(state) + zero_grad_all() + if distributed: + model.require_backward_grad_sync = True + train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) + + # ----------------------------- + # MAIN TRAINING LOOP + # ----------------------------- + + swa_state: dict[str, Tensor] | None = None + swa_count = 0 + + ema_state: dict[str, Tensor] | None = None + if args.ema_enabled: + ema_state = {name: t.detach().float().clone() for name, t in base_model.state_dict().items()} + + training_time_ms = 0.0 + stop_after_step: int | None = None + torch.cuda.synchronize() + t0 = time.perf_counter() + + eval_only_path = os.environ.get("EVAL_ONLY", "") + if eval_only_path: + log0(f"eval_only: loading {eval_only_path}, skipping training") + base_model.load_state_dict(torch.load(eval_only_path, map_location=device, weights_only=False), strict=False) + ema_state = None # prevent random EMA from overwriting loaded weights + swa_state = None + swa_count = 0 + args.iterations = 0 # skip training, go straight to eval + + step = 0 + while True: + last_step = step == args.iterations or (stop_after_step is not None and step >= stop_after_step) + + should_validate = last_step or (args.val_loss_every > 0 and step % args.val_loss_every == 0) + if should_validate: + torch.cuda.synchronize() + training_time_ms += 1000.0 * (time.perf_counter() - t0) + val_loss, val_bpb = eval_val( + args, + model, + rank, + world_size, + device, + grad_accum_steps, + val_tokens, + base_bytes_lut, + has_leading_space_lut, + is_boundary_token_lut, + ) + log0( + f"step:{step}/{args.iterations} val_loss:{val_loss:.4f} val_bpb:{val_bpb:.4f} " + f"train_time:{training_time_ms:.0f}ms step_avg:{training_time_ms / max(step, 1):.2f}ms" + ) + torch.cuda.synchronize() + t0 = time.perf_counter() + + if ( + args.checkpoint_every > 0 + and step > 0 + and step % args.checkpoint_every == 0 + and not last_step + and master_process + ): + ckpt_sd = {k: v for k, v in base_model.state_dict().items() if "mtp_heads" not in k} + ckpt_path = f"checkpoint_step{step}_{args.run_id}.pt" + torch.save(ckpt_sd, ckpt_path) + log0(f"checkpoint_saved: {ckpt_path} ({os.path.getsize(ckpt_path)} bytes)") + + if last_step: + if stop_after_step is not None and step < args.iterations: + log0( + f"stopping_early: wallclock_cap train_time:{training_time_ms:.0f}ms " + f"step:{step}/{args.iterations}" + ) + break + + elapsed_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) + scale = lr_mul(step, elapsed_ms) + qat_threshold = float(os.environ.get("QAT_THRESHOLD", "0.1")) + if args.late_qat and scale < qat_threshold and not CastedLinear._qat_enabled: + CastedLinear._qat_enabled = True + CastedLinear._soft_round = args.soft_round_qat + log0(f"late_qat:enabled step:{step} scale:{scale:.4f} soft_round:{args.soft_round_qat}") + if CastedLinear._qat_enabled and CastedLinear._soft_round: + qat_progress = max(0.0, 1.0 - (scale / qat_threshold)) + CastedLinear._soft_round_alpha = 1.0 + 15.0 * qat_progress # 1→16 + zero_grad_all() + train_loss = torch.zeros((), device=device) + for micro_step in range(grad_accum_steps): + if distributed: + model.require_backward_grad_sync = micro_step == grad_accum_steps - 1 + x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + loss = model(x, y) + train_loss += loss.detach() + (loss * grad_scale).backward() + train_loss /= grad_accum_steps + + frac = min(step / args.muon_momentum_warmup_steps, 1.0) if args.muon_momentum_warmup_steps > 0 else 1.0 + muon_momentum = (1 - frac) * args.muon_momentum_warmup_start + frac * args.muon_momentum + for group in optimizer_muon.param_groups: + group["momentum"] = muon_momentum + + for opt in optimizers: + for group in opt.param_groups: + group["lr"] = group["base_lr"] * scale + + if args.grad_clip_norm > 0: + torch.nn.utils.clip_grad_norm_(base_model.parameters(), args.grad_clip_norm) + for opt in optimizers: + opt.step() + zero_grad_all() + + step += 1 + approx_training_time_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) + + if ema_state is not None: + d = args.ema_decay + with torch.no_grad(): + for name, t in base_model.state_dict().items(): + ema_state[name].mul_(d).add_(t.detach().float(), alpha=1.0 - d) + + if args.swa_enabled and scale < 0.5 and step % args.swa_every == 0: + # Tight SWA: collect from EMA state if available, else from raw model + src = ema_state if ema_state is not None else {name: t.detach().float() for name, t in base_model.state_dict().items()} + if swa_state is None: + swa_state = {name: t.clone() for name, t in src.items()} + swa_count = 1 + log0(f"swa:start step:{step} tight={ema_state is not None}") + else: + for name in swa_state: + swa_state[name].add_(src[name]) + swa_count += 1 + + should_log_train = ( + args.train_log_every > 0 + and (step <= 10 or step % args.train_log_every == 0 or stop_after_step is not None) + ) + if should_log_train: + log0( + f"step:{step}/{args.iterations} train_loss:{train_loss.item():.4f} " + f"train_time:{approx_training_time_ms:.0f}ms step_avg:{approx_training_time_ms / step:.2f}ms" + ) + + # Needed to sync whether we've reached the wallclock cap. + reached_cap = max_wallclock_ms is not None and approx_training_time_ms >= max_wallclock_ms + if distributed and max_wallclock_ms is not None: + reached_cap_tensor = torch.tensor(int(reached_cap), device=device) + dist.all_reduce(reached_cap_tensor, op=dist.ReduceOp.MAX) + reached_cap = bool(reached_cap_tensor.item()) + if stop_after_step is None and reached_cap: + stop_after_step = step + + log0( + f"peak memory allocated: {torch.cuda.max_memory_allocated() // 1024 // 1024} MiB " + f"reserved: {torch.cuda.max_memory_reserved() // 1024 // 1024} MiB" + ) + + if args.swa_enabled and swa_state is not None and swa_count > 1: + log0(f"swa:applying Tight SWA averaged {swa_count} EMA checkpoints") + avg_state = {name: (t / swa_count).to(dtype=base_model.state_dict()[name].dtype) + for name, t in swa_state.items()} + del swa_state + if ema_state is not None: + del ema_state + base_model.load_state_dict(avg_state, strict=True) + elif ema_state is not None: + log0("ema:applying EMA weights") + avg_state = {name: t.to(dtype=base_model.state_dict()[name].dtype) + for name, t in ema_state.items()} + del ema_state + base_model.load_state_dict(avg_state, strict=True) + + # ----------------------------- + # SERIALIZATION + ROUNDTRIP VALIDATION + # ----------------------------- + + # Use RUN_ID for unique filenames when multiple jobs share CWD + _run_id = os.environ.get("RUN_ID", "") + _save_prefix = f"final_model_{_run_id}" if _run_id else "final_model" + _pt_path = f"{_save_prefix}.pt" + _ptz_path = f"{_save_prefix}.int6.ptz" + log0(f"save_paths: pt={_pt_path} ptz={_ptz_path}") + + full_state_dict = base_model.state_dict() + export_sd = {k: v for k, v in full_state_dict.items() if "mtp_heads" not in k} + excluded_mtp = sum(int(t.numel()) for k, t in full_state_dict.items() if "mtp_heads" in k) + if excluded_mtp > 0: + log0(f"export_excluding_mtp_params:{excluded_mtp}") + + if master_process: + torch.save(export_sd, _pt_path) + model_bytes = os.path.getsize(_pt_path) + code_bytes = len(code.encode("utf-8")) + log0(f"Serialized model: {model_bytes} bytes") + log0(f"Code size: {code_bytes} bytes") + + sd_cpu = {k: v.detach().cpu() for k, v in export_sd.items()} + code_bytes = len(code.encode("utf-8")) + artifact_limit = 16_000_000 - code_bytes + + # --- Auto-downgrade quantization: try int6 first, fall back to int5 middle layers --- + num_layers_total = max( + (int(k.split(".")[1]) for k in sd_cpu if k.startswith("blocks.")), + default=0, + ) + 1 + _zstd_levels = [int(os.environ.get("ZSTD_LEVEL", "16")), 1, 17, 2] + # Phase 1: pure int6 with multiple zstd levels + quant_result, quant_meta = mixed_quantize_int6(sd_cpu, {"mlp", "attn"}) + quant_buf = io.BytesIO() + torch.save({"w": quant_result, "m": quant_meta}, quant_buf) + quant_raw = quant_buf.getvalue() + quant_blob = None + chosen_level = _zstd_levels[0] + for lvl in _zstd_levels: + blob = zstandard.ZstdCompressor(level=lvl).compress(quant_raw) if _COMPRESSOR == "zstd" else zlib.compress(quant_raw, 9) + if master_process: + log0(f"quant_try int6 zstd-{lvl}: {len(blob)} bytes (limit {artifact_limit})") + if len(blob) <= artifact_limit: + quant_blob = blob + chosen_level = lvl + break + # Phase 2: progressive int5 fallback — one layer at a time from middle outward + if quant_blob is None: + mid = num_layers_total // 2 + # Expand outward from center: L5, L4, L6, L3, L7, L2, L8, ... + candidates = [] + for offset in range(num_layers_total): + for sign in [0, 1]: + layer = mid + offset if sign == 0 else mid - offset + if 0 <= layer < num_layers_total and layer not in candidates: + candidates.append(layer) + int5_layers: set[int] = set() + for layer in candidates: + int5_layers.add(layer) + if master_process: + log0(f"quant_fallback: int5 layers={sorted(int5_layers)}") + quant_result, quant_meta = mixed_quantize_int6(sd_cpu, {"mlp", "attn"}, int5_layers=int5_layers) + quant_buf = io.BytesIO() + torch.save({"w": quant_result, "m": quant_meta}, quant_buf) + quant_raw = quant_buf.getvalue() + for lvl in _zstd_levels: + blob = zstandard.ZstdCompressor(level=lvl).compress(quant_raw) if _COMPRESSOR == "zstd" else zlib.compress(quant_raw, 9) + if master_process: + log0(f"quant_try int5[{len(int5_layers)}L] zstd-{lvl}: {len(blob)} bytes (limit {artifact_limit})") + if len(blob) <= artifact_limit: + quant_blob = blob + chosen_level = lvl + break + if quant_blob is not None: + break + if quant_blob is None: + quant_blob = blob # Use last attempt even if over limit + if master_process: + log0(f"WARNING: artifact still over limit after all fallbacks") + if master_process: + with open(_ptz_path, "wb") as f: + f.write(quant_blob) + quant_file_bytes = len(quant_blob) + log0(f"Serialized model quant+{_COMPRESSOR}-{chosen_level}: {quant_file_bytes} bytes") + log0(f"Total submission size: {quant_file_bytes + code_bytes} bytes") + + # Roundtrip: decompress + dequantize into fresh model + eval + if distributed: + dist.barrier() + with open(_ptz_path, "rb") as f: + quant_blob_disk = f.read() + quant_state = torch.load( + io.BytesIO(zstandard.ZstdDecompressor().decompress(quant_blob_disk) if _COMPRESSOR == "zstd" else zlib.decompress(quant_blob_disk)), + map_location="cpu", + ) + deq_state = dequantize_mixed_int6(quant_state["w"], quant_state["m"], sd_cpu) + + eval_model = GPT( + vocab_size=args.vocab_size, num_layers=args.num_layers, model_dim=args.model_dim, + num_heads=args.num_heads, num_kv_heads=args.num_kv_heads, mlp_mult=args.mlp_mult, + tie_embeddings=args.tie_embeddings, tied_embed_init_std=args.tied_embed_init_std, + logit_softcap=args.logit_softcap, rope_base=args.rope_base, qk_gain_init=args.qk_gain_init, + mtp_num_heads=0, mtp_loss_weight=0.0, + bigram_vocab_size=args.bigram_vocab_size, bigram_dim=args.bigram_dim, + xsa_last_n=args.xsa_last_n, + rope_dims=args.rope_dims, + ln_scale=args.ln_scale, + value_residual=args.value_residual, + gated_attention=args.gated_attention, + canon_last_n=args.canon_last_n, + canon_kernel=args.canon_kernel, + canon_delta_gate_init=args.canon_delta_gate_init, + flow_enabled=args.flow_enabled, + flow_latent_dim=args.flow_latent_dim, + flow_hidden_dim=args.flow_hidden_dim, + flow_init_scale=args.flow_init_scale, + native_flow_enabled=args.native_flow_enabled, + native_flow_hidden_dim=args.native_flow_hidden_dim, + native_flow_init_scale=args.native_flow_init_scale, + native_flow_loss_weight=args.native_flow_loss_weight, + e2e_ttt_enabled=args.e2e_ttt_enabled, + e2e_ttt_num_heads=args.e2e_ttt_num_heads, + e2e_ttt_mini_batch=args.e2e_ttt_mini_batch, + e2e_ttt_base_lr=args.e2e_ttt_base_lr, + ).to(device).bfloat16() + for m in eval_model.modules(): + if isinstance(m, CastedLinear): + m.float() + restore_low_dim_params_to_fp32(eval_model) + eval_model.load_state_dict(deq_state, strict=True) + + # TTT: adapt model on validation data before eval + legal_ttt = bool(int(os.environ.get("LEGAL_TTT", "0"))) + if args.ttt_enabled and not legal_ttt: + # --- Invalid two-pass TTT (adapt then eval separately) --- + if distributed: + dist.barrier() + for block in eval_model.blocks: + block.attn.rotary._cos_cached = None + block.attn.rotary._sin_cached = None + block.attn.rotary._seq_len_cached = 0 + log0(f"ttt:start score-first optimizer={args.ttt_optimizer} lr={args.ttt_lr} " + f"epochs={args.ttt_epochs} freeze_blocks={args.ttt_freeze_blocks} " + f"chunk_tokens={args.ttt_chunk_tokens}") + t_ttt = time.perf_counter() + ttt_adapt(args, eval_model, device, val_tokens, + rank=rank, world_size=world_size, log_fn=log0) + log0(f"ttt:elapsed={time.perf_counter() - t_ttt:.1f}s") + if distributed: + dist.barrier() + + compiled_eval = torch.compile(eval_model, dynamic=False, fullgraph=True) + + # Standard non-overlapping eval (sanity check) + torch.cuda.synchronize() + t_qeval = time.perf_counter() + q_val_loss, q_val_bpb = eval_val( + args, compiled_eval, rank, world_size, device, grad_accum_steps, + val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, + eval_seq_len=effective_eval_seq_len, + ) + torch.cuda.synchronize() + log0( + f"final_int6_roundtrip val_loss:{q_val_loss:.4f} val_bpb:{q_val_bpb:.4f} " + f"eval_time:{1000.0 * (time.perf_counter() - t_qeval):.0f}ms" + ) + log0(f"final_int6_roundtrip_exact val_loss:{q_val_loss:.8f} val_bpb:{q_val_bpb:.8f}") + + # Sliding window eval (submission score) + sw_seq_len = effective_eval_seq_len + if args.eval_stride > 0 and args.eval_stride < sw_seq_len: + torch.cuda.synchronize() + t_slide = time.perf_counter() + if legal_ttt and args.ttt_enabled: + # Legal single-pass TTT: score → train interleaved per chunk + log0(f"legal_ttt:start stride={args.eval_stride} " + f"optimizer={args.ttt_optimizer} lr={args.ttt_lr} " + f"epochs={args.ttt_epochs} freeze_blocks={args.ttt_freeze_blocks}") + sw_val_loss, sw_val_bpb = eval_val_sliding_ttt( + args, eval_model, rank, world_size, device, + val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, + stride=args.eval_stride, + eval_seq_len=sw_seq_len, + log_fn=log0, + ) + else: + sw_val_loss, sw_val_bpb = eval_val_sliding( + args, eval_model, rank, world_size, device, + val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, + stride=args.eval_stride, + eval_seq_len=sw_seq_len, + ) + torch.cuda.synchronize() + log0( + f"final_int6_sliding_window val_loss:{sw_val_loss:.4f} val_bpb:{sw_val_bpb:.4f} " + f"stride:{args.eval_stride} eval_time:{1000.0 * (time.perf_counter() - t_slide):.0f}ms" + ) + log0(f"final_int6_sliding_window_exact val_loss:{sw_val_loss:.8f} val_bpb:{sw_val_bpb:.8f}") + + # Second sliding window eval at stride=64 for submission comparison + if args.eval_stride != 64 and 64 < sw_seq_len: + torch.cuda.synchronize() + t_slide64 = time.perf_counter() + sw64_val_loss, sw64_val_bpb = eval_val_sliding( + args, eval_model, rank, world_size, device, + val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, + stride=64, + eval_seq_len=sw_seq_len, + ) + torch.cuda.synchronize() + log0( + f"final_int6_sliding_window_s64 val_loss:{sw64_val_loss:.4f} val_bpb:{sw64_val_bpb:.4f} " + f"stride:64 eval_time:{1000.0 * (time.perf_counter() - t_slide64):.0f}ms" + ) + log0(f"final_int6_sliding_window_s64_exact val_loss:{sw64_val_loss:.8f} val_bpb:{sw64_val_bpb:.8f}") + + if distributed: + dist.destroy_process_group() + + +if __name__ == "__main__": + main() + + +==================================================================================================== +Running Python 3.12.7 | packaged by conda-forge | (main, Oct 4 2024, 16:05:46) [GCC 13.3.0] +Running PyTorch 2.10.0+cu128 +Tue Mar 31 17:19:38 2026 ++-----------------------------------------------------------------------------------------+ +| NVIDIA-SMI 570.172.08 Driver Version: 570.172.08 CUDA Version: 12.8 | +|-----------------------------------------+------------------------+----------------------+ +| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC | +| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. | +| | | MIG M. | +|=========================================+========================+======================| +| 0 NVIDIA A100-PCIE-40GB On | 00000000:65:00.0 Off | 0 | +| N/A 43C P0 49W / 250W | 423MiB / 40960MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ + ++-----------------------------------------------------------------------------------------+ +| Processes: | +| GPU GI CI PID Type Process name GPU Memory | +| ID ID Usage | +|=========================================================================================| +| 0 N/A N/A 355186 C ...ameter_golf/.venv/bin/python3 414MiB | ++-----------------------------------------------------------------------------------------+ + +==================================================================================================== +val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path=/hpfs/scratch/gpfs/mcclec07/code/parameter_golf/repo/data/tokenizers/fineweb_1024_bpe.model +train_loader:dataset:fineweb10B_sp1024 train_shards:80 +val_loader:shards pattern=/hpfs/scratch/gpfs/mcclec07/code/parameter_golf/repo/data/datasets/fineweb10B_sp1024/fineweb_val_*.bin tokens:62021632 +model_params:27530952 +mtp_num_heads:0 mtp_loss_weight:0.2 mtp_params:0 +world_size:1 grad_accum_steps:8 +sdp_backends:fa3=False cudnn=False flash=True mem_efficient=True math=True +attention_mode:gqa num_heads:8 num_kv_heads:4 +tie_embeddings:True embed_lr:0.035 head_lr:0.0 matrix_lr:0.025 scalar_lr:0.025 +train_batch_tokens:786432 train_seq_len:2048 iterations:7000 warmup_steps:20 max_wallclock_seconds:0.000 +seed:2025 +warmup_step:1/20 +warmup_step:2/20 +warmup_step:3/20 +warmup_step:4/20 +warmup_step:5/20 +warmup_step:6/20 +warmup_step:7/20 +warmup_step:8/20 +warmup_step:9/20 +warmup_step:10/20 +warmup_step:11/20 +warmup_step:12/20 +warmup_step:13/20 +warmup_step:14/20 +warmup_step:15/20 +warmup_step:16/20 +warmup_step:17/20 +warmup_step:18/20 +warmup_step:19/20 +warmup_step:20/20 +step:0/7000 val_loss:6.9336 val_bpb:4.1065 train_time:0ms step_avg:0.01ms +step:1/7000 train_loss:7.1343 train_time:1860ms step_avg:1860.48ms +step:2/7000 train_loss:9.1272 train_time:3803ms step_avg:1901.28ms +step:3/7000 train_loss:8.0549 train_time:5744ms step_avg:1914.78ms +step:4/7000 train_loss:7.4630 train_time:7688ms step_avg:1921.99ms +step:5/7000 train_loss:7.2148 train_time:9633ms step_avg:1926.66ms +step:6/7000 train_loss:7.1310 train_time:11577ms step_avg:1929.52ms +step:7/7000 train_loss:6.9741 train_time:13525ms step_avg:1932.07ms +step:8/7000 train_loss:6.7660 train_time:15472ms step_avg:1933.95ms +step:9/7000 train_loss:6.4398 train_time:17421ms step_avg:1935.71ms +step:10/7000 train_loss:6.1453 train_time:19372ms step_avg:1937.16ms +step:200/7000 train_loss:2.7566 train_time:392618ms step_avg:1963.09ms +step:400/7000 train_loss:2.5625 train_time:784888ms step_avg:1962.22ms +step:500/7000 val_loss:2.3400 val_bpb:1.3859 train_time:980822ms step_avg:1961.64ms +step:600/7000 train_loss:2.3837 train_time:1176807ms step_avg:1961.35ms +step:800/7000 train_loss:2.3724 train_time:1568577ms step_avg:1960.72ms +step:1000/7000 train_loss:2.3572 train_time:1960683ms step_avg:1960.68ms +step:1000/7000 val_loss:2.2068 val_bpb:1.3070 train_time:1960686ms step_avg:1960.69ms +step:1200/7000 train_loss:2.3121 train_time:2352803ms step_avg:1960.67ms +step:1400/7000 train_loss:2.3508 train_time:2745049ms step_avg:1960.75ms +step:1500/7000 val_loss:2.1638 val_bpb:1.2815 train_time:2941195ms step_avg:1960.80ms +step:1600/7000 train_loss:2.1936 train_time:3137383ms step_avg:1960.86ms +step:1800/7000 train_loss:2.2421 train_time:3530037ms step_avg:1961.13ms +step:2000/7000 train_loss:2.1327 train_time:3922761ms step_avg:1961.38ms +step:2000/7000 val_loss:2.1089 val_bpb:1.2490 train_time:3922764ms step_avg:1961.38ms +step:2200/7000 train_loss:2.2111 train_time:4315287ms step_avg:1961.49ms +step:2400/7000 train_loss:2.1628 train_time:4708019ms step_avg:1961.67ms +step:2500/7000 val_loss:2.0850 val_bpb:1.2349 train_time:4904412ms step_avg:1961.76ms +step:2600/7000 train_loss:2.1636 train_time:5100872ms step_avg:1961.87ms +step:2800/7000 train_loss:2.2077 train_time:5493564ms step_avg:1961.99ms +step:3000/7000 train_loss:2.1581 train_time:5886202ms step_avg:1962.07ms +step:3000/7000 val_loss:2.0716 val_bpb:1.2269 train_time:5886206ms step_avg:1962.07ms +step:3200/7000 train_loss:2.1613 train_time:6278939ms step_avg:1962.17ms +step:3400/7000 train_loss:2.1352 train_time:6671814ms step_avg:1962.30ms +step:3500/7000 val_loss:2.0634 val_bpb:1.2221 train_time:6868256ms step_avg:1962.36ms +step:3600/7000 train_loss:2.1780 train_time:7064671ms step_avg:1962.41ms +step:3800/7000 train_loss:2.1471 train_time:7457455ms step_avg:1962.49ms +step:4000/7000 train_loss:2.2282 train_time:7850357ms step_avg:1962.59ms +step:4000/7000 val_loss:2.0582 val_bpb:1.2190 train_time:7850360ms step_avg:1962.59ms +step:4200/7000 train_loss:2.1200 train_time:8243179ms step_avg:1962.66ms +step:4400/7000 train_loss:2.1026 train_time:8635741ms step_avg:1962.67ms +step:4500/7000 val_loss:2.0417 val_bpb:1.2092 train_time:8832080ms step_avg:1962.68ms +step:4600/7000 train_loss:2.0850 train_time:9028513ms step_avg:1962.72ms +step:4800/7000 train_loss:2.2164 train_time:9421306ms step_avg:1962.77ms +step:5000/7000 train_loss:2.1292 train_time:9814044ms step_avg:1962.81ms +step:5000/7000 val_loss:2.0190 val_bpb:1.1958 train_time:9814048ms step_avg:1962.81ms +step:5200/7000 train_loss:2.0990 train_time:10206705ms step_avg:1962.83ms +step:5400/7000 train_loss:2.0739 train_time:10599477ms step_avg:1962.87ms +step:5500/7000 val_loss:1.9965 val_bpb:1.1825 train_time:10795770ms step_avg:1962.87ms +step:5600/7000 train_loss:2.0598 train_time:10992139ms step_avg:1962.88ms +step:5800/7000 train_loss:2.0450 train_time:11384873ms step_avg:1962.91ms +step:6000/7000 train_loss:2.0260 train_time:11777706ms step_avg:1962.95ms +step:6000/7000 val_loss:1.9731 val_bpb:1.1686 train_time:11777710ms step_avg:1962.95ms +step:6200/7000 train_loss:2.1195 train_time:12170526ms step_avg:1962.99ms +step:6400/7000 train_loss:2.0991 train_time:12563320ms step_avg:1963.02ms +step:6500/7000 val_loss:1.9430 val_bpb:1.1508 train_time:12759725ms step_avg:1963.03ms +step:6600/7000 train_loss:1.9991 train_time:12956061ms step_avg:1963.04ms +step:6800/7000 train_loss:2.0804 train_time:13348702ms step_avg:1963.04ms +step:7000/7000 train_loss:1.9237 train_time:13741101ms step_avg:1963.01ms +step:7000/7000 val_loss:1.9180 val_bpb:1.1359 train_time:13741104ms step_avg:1963.01ms +peak memory allocated: 25832 MiB reserved: 26006 MiB +ema:applying EMA weights +save_paths: pt=final_model_pr940_nflow_s2025_55398557.pt ptz=final_model_pr940_nflow_s2025_55398557.int6.ptz +Serialized model: 107294152 bytes +Code size: 115032 bytes +quant_try int6 zstd-16: 16235273 bytes (limit 15884968) +quant_try int6 zstd-1: 16306740 bytes (limit 15884968) +quant_try int6 zstd-17: 16260153 bytes (limit 15884968) +quant_try int6 zstd-2: 16312469 bytes (limit 15884968) +quant_fallback: int5 layers=[5] +quant_try int5[1L] zstd-16: 15932289 bytes (limit 15884968) +quant_try int5[1L] zstd-1: 16009919 bytes (limit 15884968) +quant_try int5[1L] zstd-17: 15954917 bytes (limit 15884968) +quant_try int5[1L] zstd-2: 16059927 bytes (limit 15884968) +quant_fallback: int5 layers=[5, 6] +quant_try int5[2L] zstd-16: 15630918 bytes (limit 15884968) +Serialized model quant+zstd-16: 15630918 bytes +Total submission size: 15745950 bytes +final_int6_roundtrip val_loss:1.9323 val_bpb:1.1444 eval_time:67830ms +final_int6_roundtrip_exact val_loss:1.93234887 val_bpb:1.14444585 +final_int6_sliding_window val_loss:1.8924 val_bpb:1.1208 stride:64 eval_time:1699977ms +final_int6_sliding_window_exact val_loss:1.89237638 val_bpb:1.12077485 diff --git a/records/track_non_record_16mb/2026-03-31_11L_NativeFlowMatcher_LegalTTT/supplementary/slurm_eval_nflow7k_legal_ttt.sh b/records/track_non_record_16mb/2026-03-31_11L_NativeFlowMatcher_LegalTTT/supplementary/slurm_eval_nflow7k_legal_ttt.sh new file mode 100644 index 0000000000..523a17bc81 --- /dev/null +++ b/records/track_non_record_16mb/2026-03-31_11L_NativeFlowMatcher_LegalTTT/supplementary/slurm_eval_nflow7k_legal_ttt.sh @@ -0,0 +1,114 @@ +#!/bin/bash +############################################################################# +# Eval: NativeFlowMatcher 7k checkpoint with LEGAL single-pass TTT +# Model: final_model_pr940_nflow_55342820.pt (27,530,952 params) +# Baseline (no TTT): sliding BPB = 1.12312 +# Expected runtime: ~2h (legal TTT interleaved with sliding window) +############################################################################# +#SBATCH --job-name=eval_nflow7k_lttt +#SBATCH --partition=gpu +#SBATCH --gres=gpu:a100:1 +#SBATCH --nodes=1 +#SBATCH --cpus-per-task=8 +#SBATCH --mem=64G +#SBATCH --time=03:00:00 +#SBATCH --account=medcam +#SBATCH --output=/hpfs/scratch/gpfs/mcclec07/code/parameter_golf/logs/eval_nflow7k_legal_ttt_%j.out +#SBATCH --error=/hpfs/scratch/gpfs/mcclec07/code/parameter_golf/logs/eval_nflow7k_legal_ttt_%j.err +#SBATCH --chdir=/hpfs/scratch/gpfs/mcclec07/code/parameter_golf + +set -euo pipefail + +# ── Environment ────────────────────────────────────────────────────────── +source /hpfs/scratch/gpfs/mcclec07/code/parameter_golf/.venv/bin/activate +export CUDA_VISIBLE_DEVICES=0 + +mkdir -p /hpfs/scratch/gpfs/mcclec07/code/parameter_golf/logs + +# ── Run identifiers ───────────────────────────────────────────────────── +export RUN_ID="eval_nflow7k_legal_ttt_${SLURM_JOB_ID}" +LOGFILE="/hpfs/scratch/gpfs/mcclec07/code/parameter_golf/logs/eval_nflow7k_legal_ttt_${SLURM_JOB_ID}.txt" + +# ── Eval-only: load checkpoint ────────────────────────────────────────── +export EVAL_ONLY="/hpfs/scratch/gpfs/mcclec07/code/parameter_golf/runs/nflow_55342820/models/final_model_pr940_nflow_55342820.pt" + +# ── Data paths ─────────────────────────────────────────────────────────── +export DATA_PATH="/hpfs/scratch/gpfs/mcclec07/code/parameter_golf/repo/data/datasets/fineweb10B_sp1024" +export TOKENIZER_PATH="/hpfs/scratch/gpfs/mcclec07/code/parameter_golf/repo/data/tokenizers/fineweb_1024_bpe.model" +export VOCAB_SIZE=1024 + +# ── Architecture (must match slurm_pr940_nflow_7k.sh exactly) ─────────── +export NUM_LAYERS=11 +export MODEL_DIM=512 +export NUM_HEADS=8 +export NUM_KV_HEADS=4 +export MLP_MULT=3 +export TIE_EMBEDDINGS=1 +export BIGRAM_VOCAB_SIZE=4096 +export BIGRAM_DIM=128 +export ROPE_DIMS=16 +export ROPE_BASE=10000 +export LOGIT_SOFTCAP=30.0 +export LN_SCALE=1 +export XSA_LAST_N=11 +export VALUE_RESIDUAL=1 +export GATED_ATTENTION=1 +export QK_GAIN_INIT=1.5 +export LEAKY_RELU=1 +export LEAKY_SLOPE=0.5 + +# ── NativeFlowMatcher (must match nflow training config) ──────────────── +export NATIVE_FLOW_ENABLED=1 +export NATIVE_FLOW_HIDDEN_DIM=256 +export NATIVE_FLOW_INIT_SCALE=0.01 +export NATIVE_FLOW_LOSS_WEIGHT=0.1 + +# ── Disable other modules ─────────────────────────────────────────────── +export FLOW_ENABLED=0 +export E2E_TTT_ENABLED=0 +export EMA_ENABLED=0 +export SWA_ENABLED=0 +export QAT_ENABLED=0 +export CANON_LAST_N=0 + +# ── Legal TTT config ──────────────────────────────────────────────────── +export TTT_ENABLED=1 +export LEGAL_TTT=1 +export TTT_OPTIMIZER=sgd +export TTT_LR=0.002 +export TTT_EPOCHS=10 +export TTT_FREEZE_BLOCKS=2 +export TTT_BATCH_SEQS=32 +export TTT_CHUNK_TOKENS=32768 +export TTT_GRAD_CLIP=1.0 +export TTT_MOMENTUM=0.9 + +# ── Eval config ────────────────────────────────────────────────────────── +export EVAL_STRIDE=64 +export SEED=42 + +# ── Training params (unused but required by argparse) ──────────────────── +export ITERATIONS=7000 +export WARMDOWN_ITERS=2800 +export WARMUP_STEPS=20 +export TRAIN_BATCH_TOKENS=786432 +export TRAIN_SEQ_LEN=2048 + +# ── Optimizer params (unused but required) ─────────────────────────────── +export MATRIX_LR=0.025 +export SCALAR_LR=0.025 +export TIED_EMBED_LR=0.035 +export MUON_MOMENTUM=0.99 +export MUON_WD=0.04 +export ADAM_WD=0.04 +export GRAD_CLIP_NORM=0.3 + +# ── Run ────────────────────────────────────────────────────────────────── +echo "=== Eval NativeFlow 7k with Legal TTT ===" +echo "Checkpoint: ${EVAL_ONLY}" +echo "Job ID: ${SLURM_JOB_ID}" +echo "Start: $(date)" + +torchrun --standalone --nproc_per_node=1 train_gpt_pr940.py 2>&1 | tee "$LOGFILE" + +echo "End: $(date)" diff --git a/records/track_non_record_16mb/2026-03-31_11L_NativeFlowMatcher_LegalTTT/supplementary/slurm_eval_nflow7k_nottt.sh b/records/track_non_record_16mb/2026-03-31_11L_NativeFlowMatcher_LegalTTT/supplementary/slurm_eval_nflow7k_nottt.sh new file mode 100644 index 0000000000..e415c393ad --- /dev/null +++ b/records/track_non_record_16mb/2026-03-31_11L_NativeFlowMatcher_LegalTTT/supplementary/slurm_eval_nflow7k_nottt.sh @@ -0,0 +1,97 @@ +#!/bin/bash +############################################################################# +# Eval: NativeFlowMatcher 7k checkpoint – NO TTT (sliding window only) +# Model: final_model_pr940_nflow_55342820.pt (27,530,952 params) +# Purpose: Clean sliding-window-only eval for comparison with legal TTT +# Note: Training-time eval already showed 1.12312 but this run ensures +# consistent eval conditions with the legal TTT counterpart. +# Expected runtime: ~30min (sliding window only) +############################################################################# +#SBATCH --job-name=eval_nflow7k_nottt +#SBATCH --partition=gpu +#SBATCH --gres=gpu:a100:1 +#SBATCH --nodes=1 +#SBATCH --cpus-per-task=8 +#SBATCH --mem=64G +#SBATCH --time=01:00:00 +#SBATCH --account=medcam +#SBATCH --output=/hpfs/scratch/gpfs/mcclec07/code/parameter_golf/logs/eval_nflow7k_nottt_%j.out +#SBATCH --error=/hpfs/scratch/gpfs/mcclec07/code/parameter_golf/logs/eval_nflow7k_nottt_%j.err +#SBATCH --chdir=/hpfs/scratch/gpfs/mcclec07/code/parameter_golf + +set -euo pipefail + +source /hpfs/scratch/gpfs/mcclec07/code/parameter_golf/.venv/bin/activate +export CUDA_VISIBLE_DEVICES=0 + +mkdir -p /hpfs/scratch/gpfs/mcclec07/code/parameter_golf/logs + +export RUN_ID="eval_nflow7k_nottt_${SLURM_JOB_ID}" +LOGFILE="/hpfs/scratch/gpfs/mcclec07/code/parameter_golf/logs/eval_nflow7k_nottt_${SLURM_JOB_ID}.txt" + +export EVAL_ONLY="/hpfs/scratch/gpfs/mcclec07/code/parameter_golf/runs/nflow_55342820/models/final_model_pr940_nflow_55342820.pt" + +export DATA_PATH="/hpfs/scratch/gpfs/mcclec07/code/parameter_golf/repo/data/datasets/fineweb10B_sp1024" +export TOKENIZER_PATH="/hpfs/scratch/gpfs/mcclec07/code/parameter_golf/repo/data/tokenizers/fineweb_1024_bpe.model" +export VOCAB_SIZE=1024 + +# ── Architecture (match nflow training) ────────────────────────────────── +export NUM_LAYERS=11 +export MODEL_DIM=512 +export NUM_HEADS=8 +export NUM_KV_HEADS=4 +export MLP_MULT=3 +export TIE_EMBEDDINGS=1 +export BIGRAM_VOCAB_SIZE=4096 +export BIGRAM_DIM=128 +export ROPE_DIMS=16 +export ROPE_BASE=10000 +export LOGIT_SOFTCAP=30.0 +export LN_SCALE=1 +export XSA_LAST_N=11 +export VALUE_RESIDUAL=1 +export GATED_ATTENTION=1 +export QK_GAIN_INIT=1.5 +export LEAKY_RELU=1 +export LEAKY_SLOPE=0.5 + +# NativeFlowMatcher ON (must match checkpoint) +export NATIVE_FLOW_ENABLED=1 +export NATIVE_FLOW_HIDDEN_DIM=256 +export NATIVE_FLOW_INIT_SCALE=0.01 +export NATIVE_FLOW_LOSS_WEIGHT=0.1 + +# All other modules OFF +export FLOW_ENABLED=0 +export E2E_TTT_ENABLED=0 +export EMA_ENABLED=0 +export SWA_ENABLED=0 +export QAT_ENABLED=0 +export CANON_LAST_N=0 +export TTT_ENABLED=0 + +export EVAL_STRIDE=64 +export SEED=42 + +export ITERATIONS=7000 +export WARMDOWN_ITERS=2800 +export WARMUP_STEPS=20 +export TRAIN_BATCH_TOKENS=786432 +export TRAIN_SEQ_LEN=2048 + +export MATRIX_LR=0.025 +export SCALAR_LR=0.025 +export TIED_EMBED_LR=0.035 +export MUON_MOMENTUM=0.99 +export MUON_WD=0.04 +export ADAM_WD=0.04 +export GRAD_CLIP_NORM=0.3 + +echo "=== Eval NativeFlow 7k – No TTT ===" +echo "Checkpoint: ${EVAL_ONLY}" +echo "Job ID: ${SLURM_JOB_ID}" +echo "Start: $(date)" + +torchrun --standalone --nproc_per_node=1 train_gpt_pr940.py 2>&1 | tee "$LOGFILE" + +echo "End: $(date)" diff --git a/records/track_non_record_16mb/2026-03-31_11L_NativeFlowMatcher_LegalTTT/supplementary/slurm_pr940_nflow_7k.sh b/records/track_non_record_16mb/2026-03-31_11L_NativeFlowMatcher_LegalTTT/supplementary/slurm_pr940_nflow_7k.sh new file mode 100755 index 0000000000..e454908f4e --- /dev/null +++ b/records/track_non_record_16mb/2026-03-31_11L_NativeFlowMatcher_LegalTTT/supplementary/slurm_pr940_nflow_7k.sh @@ -0,0 +1,102 @@ +#!/bin/bash +#SBATCH --job-name=pr940_nflow +#SBATCH --partition=gpu +#SBATCH --gres=gpu:a100:1 +#SBATCH --nodes=1 +#SBATCH --cpus-per-task=8 +#SBATCH --mem=64G +#SBATCH --time=14:00:00 +#SBATCH --nice=0 +#SBATCH --output=/hpfs/scratch/gpfs/mcclec07/code/parameter_golf/logs/pr940_nflow_%j.out +#SBATCH --error=/hpfs/scratch/gpfs/mcclec07/code/parameter_golf/logs/pr940_nflow_%j.err +#SBATCH --chdir=/hpfs/scratch/gpfs/mcclec07/code/parameter_golf +#SBATCH --account=medcam + +echo "=== Job Info ===" +echo "Job ID: $SLURM_JOB_ID" +echo "Node: $(hostname)" +echo "GPUs: $CUDA_VISIBLE_DEVICES" +echo "Start: $(date)" +echo "================" + +# Activate environment +source /hpfs/scratch/gpfs/mcclec07/code/parameter_golf/.venv/bin/activate + +# Force single GPU +export CUDA_VISIBLE_DEVICES=0 + +# --- Run & Data --- +export RUN_ID="pr940_nflow_${SLURM_JOB_ID}" +export SEED=42 +export DATA_PATH="/hpfs/scratch/gpfs/mcclec07/code/parameter_golf/repo/data/datasets/fineweb10B_sp1024" +export TOKENIZER_PATH="/hpfs/scratch/gpfs/mcclec07/code/parameter_golf/repo/data/tokenizers/fineweb_1024_bpe.model" +export VOCAB_SIZE=1024 + +# --- Training Schedule --- +export MAX_WALLCLOCK_SECONDS=0 +export ITERATIONS=7000 +export VAL_LOSS_EVERY=500 +export WARMDOWN_ITERS=2800 +export WARMUP_STEPS=20 +export TRAIN_BATCH_TOKENS=786432 +export TRAIN_SEQ_LEN=2048 + +# --- Architecture --- +export NUM_LAYERS=11 +export MODEL_DIM=512 +export NUM_HEADS=8 +export NUM_KV_HEADS=4 +export MLP_MULT=3 +export TIE_EMBEDDINGS=1 +export BIGRAM_VOCAB_SIZE=4096 +export BIGRAM_DIM=128 +export ROPE_DIMS=16 +export ROPE_BASE=10000 +export LOGIT_SOFTCAP=30.0 +export LN_SCALE=1 +export XSA_LAST_N=11 +export VALUE_RESIDUAL=1 +export GATED_ATTENTION=1 +export QK_GAIN_INIT=1.5 +export LEAKY_RELU=1 +export LEAKY_SLOPE=0.5 + +# --- Optimizer --- +export MATRIX_LR=0.025 +export SCALAR_LR=0.025 +export TIED_EMBED_LR=0.035 +export MUON_MOMENTUM=0.99 +export MUON_WD=0.04 +export ADAM_WD=0.04 +export GRAD_CLIP_NORM=0.3 +export EVAL_STRIDE=64 + +# --- EMA --- +export EMA_ENABLED=1 +export EMA_DECAY=0.997 + +# --- Disabled features --- +export TTT_ENABLED=0 +export CANON_LAST_N=0 +export SWA_ENABLED=0 +export QAT_ENABLED=0 + +# --- FlowRefiner (disabled) --- +export FLOW_ENABLED=0 + +# --- NativeFlowMatcher --- +export NATIVE_FLOW_ENABLED=1 +export NATIVE_FLOW_HIDDEN_DIM=256 +export NATIVE_FLOW_INIT_SCALE=0.01 +export NATIVE_FLOW_LOSS_WEIGHT=0.1 + +# --- Log file --- +LOGFILE="/hpfs/scratch/gpfs/mcclec07/code/parameter_golf/logs/pr940_nflow_${SLURM_JOB_ID}.txt" + +echo "Running training with RUN_ID=$RUN_ID" +echo "Log: $LOGFILE" + +torchrun --standalone --nproc_per_node=1 \ + train_gpt_pr940.py 2>&1 | tee "$LOGFILE" + +echo "=== Done: $(date) ===" diff --git a/records/track_non_record_16mb/2026-03-31_11L_NativeFlowMatcher_LegalTTT/train.log b/records/track_non_record_16mb/2026-03-31_11L_NativeFlowMatcher_LegalTTT/train.log new file mode 100644 index 0000000000..9dfe851f26 --- /dev/null +++ b/records/track_non_record_16mb/2026-03-31_11L_NativeFlowMatcher_LegalTTT/train.log @@ -0,0 +1,2479 @@ +logs/pr940_nflow_55342820.txt +val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path=/hpfs/scratch/gpfs/mcclec07/code/parameter_golf/repo/data/tokenizers/fineweb_1024_bpe.model +train_loader:dataset:fineweb10B_sp1024 train_shards:80 +val_loader:shards pattern=/hpfs/scratch/gpfs/mcclec07/code/parameter_golf/repo/data/datasets/fineweb10B_sp1024/fineweb_val_*.bin tokens:62021632 +model_params:27530952 +mtp_num_heads:0 mtp_loss_weight:0.2 mtp_params:0 +world_size:1 grad_accum_steps:8 +sdp_backends:fa3=False cudnn=False flash=True mem_efficient=True math=True +attention_mode:gqa num_heads:8 num_kv_heads:4 +tie_embeddings:True embed_lr:0.035 head_lr:0.0 matrix_lr:0.025 scalar_lr:0.025 +train_batch_tokens:786432 train_seq_len:2048 iterations:7000 warmup_steps:20 max_wallclock_seconds:0.000 +seed:42 +warmup_step:1/20 +warmup_step:2/20 +warmup_step:3/20 +warmup_step:4/20 +warmup_step:5/20 +warmup_step:6/20 +warmup_step:7/20 +warmup_step:8/20 +warmup_step:9/20 +warmup_step:10/20 +warmup_step:11/20 +warmup_step:12/20 +warmup_step:13/20 +warmup_step:14/20 +warmup_step:15/20 +warmup_step:16/20 +warmup_step:17/20 +warmup_step:18/20 +warmup_step:19/20 +warmup_step:20/20 +step:0/7000 val_loss:6.9319 val_bpb:4.1055 train_time:0ms step_avg:0.01ms +step:1/7000 train_loss:7.1329 train_time:1910ms step_avg:1910.05ms +step:2/7000 train_loss:9.1205 train_time:3872ms step_avg:1935.75ms +step:3/7000 train_loss:8.0686 train_time:5836ms step_avg:1945.21ms +step:4/7000 train_loss:7.3950 train_time:7800ms step_avg:1949.95ms +step:5/7000 train_loss:7.1158 train_time:9768ms step_avg:1953.57ms +step:6/7000 train_loss:7.1063 train_time:11738ms step_avg:1956.33ms +step:7/7000 train_loss:6.9423 train_time:13711ms step_avg:1958.76ms +step:8/7000 train_loss:6.7316 train_time:15684ms step_avg:1960.48ms +step:9/7000 train_loss:6.4037 train_time:17659ms step_avg:1962.13ms +step:10/7000 train_loss:6.1082 train_time:19636ms step_avg:1963.62ms +step:200/7000 train_loss:2.7242 train_time:396315ms step_avg:1981.58ms +step:400/7000 train_loss:2.5459 train_time:791709ms step_avg:1979.27ms +step:500/7000 val_loss:2.3323 val_bpb:1.3813 train_time:989535ms step_avg:1979.07ms +step:600/7000 train_loss:2.3883 train_time:1187390ms step_avg:1978.98ms +step:800/7000 train_loss:2.3764 train_time:1583181ms step_avg:1978.98ms +step:1000/7000 train_loss:2.3550 train_time:1979099ms step_avg:1979.10ms +step:1000/7000 val_loss:2.2047 val_bpb:1.3058 train_time:1979102ms step_avg:1979.10ms +step:1200/7000 train_loss:2.3096 train_time:2375020ms step_avg:1979.18ms +step:1400/7000 train_loss:2.3519 train_time:2771203ms step_avg:1979.43ms +step:1500/7000 val_loss:2.1639 val_bpb:1.2816 train_time:2969412ms step_avg:1979.61ms +step:1600/7000 train_loss:2.1972 train_time:3167603ms step_avg:1979.75ms +step:1800/7000 train_loss:2.2434 train_time:3564368ms step_avg:1980.20ms +step:2000/7000 train_loss:2.1350 train_time:3961132ms step_avg:1980.57ms +step:2000/7000 val_loss:2.1104 val_bpb:1.2499 train_time:3961136ms step_avg:1980.57ms +step:2200/7000 train_loss:2.2132 train_time:4357740ms step_avg:1980.79ms +step:2400/7000 train_loss:2.1653 train_time:4754499ms step_avg:1981.04ms +step:2500/7000 val_loss:2.0861 val_bpb:1.2355 train_time:4952992ms step_avg:1981.20ms +step:2600/7000 train_loss:2.1657 train_time:5151477ms step_avg:1981.34ms +step:2800/7000 train_loss:2.2096 train_time:5548340ms step_avg:1981.55ms +step:3000/7000 train_loss:2.1612 train_time:5945254ms step_avg:1981.75ms +step:3000/7000 val_loss:2.0739 val_bpb:1.2283 train_time:5945269ms step_avg:1981.76ms +step:3200/7000 train_loss:2.1662 train_time:6342108ms step_avg:1981.91ms +step:3400/7000 train_loss:2.1383 train_time:6738727ms step_avg:1981.98ms +step:3500/7000 val_loss:2.0657 val_bpb:1.2234 train_time:6937035ms step_avg:1982.01ms +step:3600/7000 train_loss:2.1801 train_time:7135459ms step_avg:1982.07ms +step:3800/7000 train_loss:2.1531 train_time:7532083ms step_avg:1982.13ms +step:4000/7000 train_loss:2.2310 train_time:7928770ms step_avg:1982.19ms +step:4000/7000 val_loss:2.0597 val_bpb:1.2199 train_time:7928773ms step_avg:1982.19ms +step:4200/7000 train_loss:2.1213 train_time:8325600ms step_avg:1982.29ms +step:4400/7000 train_loss:2.1049 train_time:8722496ms step_avg:1982.39ms +step:4500/7000 val_loss:2.0440 val_bpb:1.2106 train_time:8920832ms step_avg:1982.41ms +step:4600/7000 train_loss:2.0877 train_time:9119160ms step_avg:1982.43ms +step:4800/7000 train_loss:2.2189 train_time:9515876ms step_avg:1982.47ms +step:5000/7000 train_loss:2.1324 train_time:9912614ms step_avg:1982.52ms +step:5000/7000 val_loss:2.0220 val_bpb:1.1975 train_time:9912617ms step_avg:1982.52ms +step:5200/7000 train_loss:2.1037 train_time:10309364ms step_avg:1982.57ms +step:5400/7000 train_loss:2.0787 train_time:10705965ms step_avg:1982.59ms +step:5500/7000 val_loss:1.9997 val_bpb:1.1843 train_time:10904304ms step_avg:1982.60ms +step:5600/7000 train_loss:2.0623 train_time:11102510ms step_avg:1982.59ms +step:5800/7000 train_loss:2.0481 train_time:11499073ms step_avg:1982.60ms +step:6000/7000 train_loss:2.0294 train_time:11895613ms step_avg:1982.60ms +step:6000/7000 val_loss:1.9767 val_bpb:1.1707 train_time:11895616ms step_avg:1982.60ms +step:6200/7000 train_loss:2.1229 train_time:12292332ms step_avg:1982.63ms +step:6400/7000 train_loss:2.1040 train_time:12689010ms step_avg:1982.66ms +step:6500/7000 val_loss:1.9464 val_bpb:1.1527 train_time:12887389ms step_avg:1982.68ms +step:6600/7000 train_loss:2.0020 train_time:13085769ms step_avg:1982.69ms +step:6800/7000 train_loss:2.0838 train_time:13482720ms step_avg:1982.75ms +step:7000/7000 train_loss:1.9271 train_time:13879411ms step_avg:1982.77ms +step:7000/7000 val_loss:1.9215 val_bpb:1.1380 train_time:13879414ms step_avg:1982.77ms +peak memory allocated: 25832 MiB reserved: 26006 MiB +ema:applying EMA weights +save_paths: pt=final_model_pr940_nflow_55342820.pt ptz=final_model_pr940_nflow_55342820.int6.ptz +Serialized model: 107292674 bytes +Code size: 104738 bytes +quant_try int6 zstd-16: 16242245 bytes (limit 15895262) +quant_try int6 zstd-1: 16298729 bytes (limit 15895262) +quant_try int6 zstd-17: 16250339 bytes (limit 15895262) +quant_try int6 zstd-2: 16301974 bytes (limit 15895262) +quant_fallback: int5 layers=[5] +quant_try int5[1L] zstd-16: 15927696 bytes (limit 15895262) +quant_try int5[1L] zstd-1: 16001920 bytes (limit 15895262) +quant_try int5[1L] zstd-17: 16297960 bytes (limit 15895262) +quant_try int5[1L] zstd-2: 16053283 bytes (limit 15895262) +quant_fallback: int5 layers=[5, 6] +quant_try int5[2L] zstd-16: 15630744 bytes (limit 15895262) +Serialized model quant+zstd-16: 15630744 bytes +Total submission size: 15735482 bytes +final_int6_roundtrip val_loss:1.9363 val_bpb:1.1468 eval_time:67775ms +final_int6_roundtrip_exact val_loss:1.93630746 val_bpb:1.14679034 +final_int6_sliding_window val_loss:1.8963 val_bpb:1.1231 stride:64 eval_time:1713117ms +final_int6_sliding_window_exact val_loss:1.89632895 val_bpb:1.12311579 +-Linear refiner (Sun et al., 2024) + e2e_ttt_enabled = bool(int(os.environ.get("E2E_TTT_ENABLED", "0"))) + e2e_ttt_num_heads = int(os.environ.get("E2E_TTT_NUM_HEADS", "8")) + e2e_ttt_mini_batch = int(os.environ.get("E2E_TTT_MINI_BATCH", "16")) + e2e_ttt_base_lr = float(os.environ.get("E2E_TTT_BASE_LR", "1.0")) + +# ----------------------------- +# MUON OPTIMIZER +# ----------------------------- +# +# As borrowed from modded-nanogpt +# Background on Muon: https://kellerjordan.github.io/posts/muon/ + +def zeropower_via_newtonschulz5(G: Tensor, steps: int = 10, eps: float = 1e-7) -> Tensor: + a, b, c = (3.4445, -4.7750, 2.0315) + X = G.bfloat16() + X /= X.norm() + eps + transposed = G.size(0) > G.size(1) + if transposed: + X = X.T + for _ in range(steps): + A = X @ X.T + B = b * A + c * A @ A + X = a * X + B @ X + return X.T if transposed else X + + +class Muon(torch.optim.Optimizer): + def __init__(self, params, lr: float, momentum: float, backend_steps: int, + nesterov: bool = True, weight_decay: float = 0.0): + super().__init__( + params, + dict(lr=lr, momentum=momentum, backend_steps=backend_steps, + nesterov=nesterov, weight_decay=weight_decay), + ) + + @torch.no_grad() + def step(self, closure=None): + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + + distributed = dist.is_available() and dist.is_initialized() + world_size = dist.get_world_size() if distributed else 1 + rank = dist.get_rank() if distributed else 0 + + for group in self.param_groups: + params = group["params"] + if not params: + continue + lr = group["lr"] + momentum = group["momentum"] + backend_steps = group["backend_steps"] + nesterov = group["nesterov"] + + total_params = sum(int(p.numel()) for p in params) + updates_flat = torch.zeros(total_params, device=params[0].device, dtype=torch.bfloat16) + + curr = 0 + for i, p in enumerate(params): + if i % world_size == rank and p.grad is not None: + g = p.grad + state = self.state[p] + if "momentum_buffer" not in state: + state["momentum_buffer"] = torch.zeros_like(g) + buf = state["momentum_buffer"] + buf.mul_(momentum).add_(g) + if nesterov: + g = g.add(buf, alpha=momentum) + g = zeropower_via_newtonschulz5(g, steps=backend_steps) + g *= max(1, g.size(0) / g.size(1)) ** 0.5 + updates_flat[curr : curr + p.numel()] = g.reshape(-1) + curr += p.numel() + + if distributed: + dist.all_reduce(updates_flat, op=dist.ReduceOp.SUM) + + wd = group.get("weight_decay", 0.0) + curr = 0 + for p in params: + if wd > 0.0: + p.data.mul_(1.0 - lr * wd) + g = updates_flat[curr : curr + p.numel()].view_as(p).to(dtype=p.dtype) + p.add_(g, alpha=-lr) + curr += p.numel() + + return loss + + +# ----------------------------- +# TOKENIZER-AGNOSTIC EVALUATION SETUP +# ----------------------------- +# +# It's common for small models have a large fraction of their parameters be embeddings, since the 2 * d_model * d_vocab vectors can be gigantic. +# Instead of locking the tokenizer, we let you bring your own and calculate our validation metrics on the average compression of the validation set. +# We calculate BPB (bits-per-byte) instead of validation loss, so we need methods to count the number of bits per token in the tokenizer. +# Note: Submissions that edit the tokenizer will be examined more carefully, since screwing this up might unjustly improve your score. + +def build_sentencepiece_luts( + sp: spm.SentencePieceProcessor, vocab_size: int, device: torch.device +) -> tuple[Tensor, Tensor, Tensor]: + sp_vocab_size = int(sp.vocab_size()) + table_size = max(sp_vocab_size, vocab_size) + base_bytes_np = np.zeros((table_size,), dtype=np.int16) + has_leading_space_np = np.zeros((table_size,), dtype=np.bool_) + is_boundary_token_np = np.ones((table_size,), dtype=np.bool_) + for token_id in range(sp_vocab_size): + if sp.is_control(token_id) or sp.is_unknown(token_id) or sp.is_unused(token_id): + continue + is_boundary_token_np[token_id] = False + if sp.is_byte(token_id): + base_bytes_np[token_id] = 1 + continue + piece = sp.id_to_piece(token_id) + if piece.startswith("▁"): + has_leading_space_np[token_id] = True + piece = piece[1:] + base_bytes_np[token_id] = len(piece.encode("utf-8")) + return ( + torch.tensor(base_bytes_np, dtype=torch.int16, device=device), + torch.tensor(has_leading_space_np, dtype=torch.bool, device=device), + torch.tensor(is_boundary_token_np, dtype=torch.bool, device=device), + ) + + +def load_validation_tokens(pattern: str, seq_len: int) -> Tensor: + files = [Path(p) for p in sorted(glob.glob(pattern))] + if not files: + raise FileNotFoundError(f"No files found for pattern: {pattern}") + # The export pipeline writes the fixed first-50k-doc validation set to fineweb_val_*. + tokens = torch.cat([load_data_shard(file) for file in files]).contiguous() + usable = ((tokens.numel() - 1) // seq_len) * seq_len + if usable <= 0: + raise ValueError(f"Validation split is too short for TRAIN_SEQ_LEN={seq_len}") + return tokens[: usable + 1] + + +def eval_val( + args: Hyperparameters, + model: nn.Module, + rank: int, + world_size: int, + device: torch.device, + grad_accum_steps: int, + val_tokens: Tensor, + base_bytes_lut: Tensor, + has_leading_space_lut: Tensor, + is_boundary_token_lut: Tensor, + eval_seq_len: int | None = None, +) -> tuple[float, float]: + seq_len = eval_seq_len or args.train_seq_len + local_batch_tokens = args.val_batch_size // (world_size * grad_accum_steps) + if local_batch_tokens < seq_len: + raise ValueError( + "VAL_BATCH_SIZE must provide at least one sequence per rank; " + f"got VAL_BATCH_SIZE={args.val_batch_size}, WORLD_SIZE={world_size}, " + f"GRAD_ACCUM_STEPS={grad_accum_steps}, seq_len={seq_len}" + ) + local_batch_seqs = local_batch_tokens // seq_len + total_seqs = (val_tokens.numel() - 1) // seq_len + seq_start = (total_seqs * rank) // world_size + seq_end = (total_seqs * (rank + 1)) // world_size + val_loss_sum = torch.zeros((), device=device, dtype=torch.float64) + val_token_count = torch.zeros((), device=device, dtype=torch.float64) + val_byte_count = torch.zeros((), device=device, dtype=torch.float64) + + model.eval() + with torch.inference_mode(): + for batch_seq_start in range(seq_start, seq_end, local_batch_seqs): + batch_seq_end = min(batch_seq_start + local_batch_seqs, seq_end) + raw_start = batch_seq_start * seq_len + raw_end = batch_seq_end * seq_len + 1 + local = val_tokens[raw_start:raw_end].to(device=device, dtype=torch.int64, non_blocking=True) + x = local[:-1].reshape(-1, seq_len) + y = local[1:].reshape(-1, seq_len) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + batch_loss = model(x, y).detach() + batch_token_count = float(y.numel()) + val_loss_sum += batch_loss.to(torch.float64) * batch_token_count + val_token_count += batch_token_count + prev_ids = x.reshape(-1) + tgt_ids = y.reshape(-1) + token_bytes = base_bytes_lut[tgt_ids].to(dtype=torch.int16) + token_bytes += (has_leading_space_lut[tgt_ids] & ~is_boundary_token_lut[prev_ids]).to(dtype=torch.int16) + val_byte_count += token_bytes.to(torch.float64).sum() + + if dist.is_available() and dist.is_initialized(): + dist.all_reduce(val_loss_sum, op=dist.ReduceOp.SUM) + dist.all_reduce(val_token_count, op=dist.ReduceOp.SUM) + dist.all_reduce(val_byte_count, op=dist.ReduceOp.SUM) + + val_loss = val_loss_sum / val_token_count + bits_per_token = val_loss.item() / math.log(2.0) + tokens_per_byte = val_token_count.item() / val_byte_count.item() + model.train() + return float(val_loss.item()), float(bits_per_token * tokens_per_byte) + +# ----------------------------- +# POST-TRAINING QUANTIZATION +# ----------------------------- +# +# It's silly to export our model, which is trained in bf16 and fp32, at that same precision. +# Instead, we get approximately the same model (with a small hit) by quantizing the model to int8 & zlib compressing. +# We can then decompress the model and run in higher precision for evaluation, after closing in under the size limit. + +CONTROL_TENSOR_NAME_PATTERNS = tuple( + pattern + for pattern in os.environ.get( + "CONTROL_TENSOR_NAME_PATTERNS", + "attn_scale,attn_scales,mlp_scale,mlp_scales,resid_mix,resid_mixes,q_gain,skip_weight,skip_weights,smear,vr_lambda,attn_gate,canon_a,canon_c,delta_gate", + ).split(",") + if pattern +) +INT8_KEEP_FLOAT_FP32_NAME_PATTERNS = tuple( + pattern + for pattern in os.environ.get( + "INT8_KEEP_FLOAT_FP32_NAME_PATTERNS", + ",".join(CONTROL_TENSOR_NAME_PATTERNS), + ).split(",") + if pattern +) +INT8_KEEP_FLOAT_MAX_NUMEL = 65_536 +INT8_KEEP_FLOAT_STORE_DTYPE = torch.float16 +INT8_PER_ROW_SCALE_DTYPE = torch.float16 +INT8_CLIP_PERCENTILE = 99.99984 +INT8_CLIP_Q = INT8_CLIP_PERCENTILE / 100.0 + +def tensor_nbytes(t: Tensor) -> int: + return int(t.numel()) * int(t.element_size()) + +def keep_float_tensor(name: str, t: Tensor, passthrough_orig_dtypes: dict[str, str]) -> Tensor: + if any(pattern in name for pattern in INT8_KEEP_FLOAT_FP32_NAME_PATTERNS): + return t.float().contiguous() + if t.dtype in {torch.float32, torch.bfloat16}: + passthrough_orig_dtypes[name] = str(t.dtype).removeprefix("torch.") + return t.to(dtype=INT8_KEEP_FLOAT_STORE_DTYPE).contiguous() + return t + +def quantize_float_tensor(t: Tensor, qmax: int = 127) -> tuple[Tensor, Tensor]: + """Quantize to [-qmax, qmax] range. Default int8 (qmax=127), int6 (qmax=31), int5 (qmax=15).""" + t32 = t.float() + qmin = -qmax + if t32.ndim == 2: + clip_abs = ( + torch.quantile(t32.abs(), INT8_CLIP_Q, dim=1) + if t32.numel() + else torch.empty((t32.shape[0],), dtype=torch.float32) + ) + clipped = torch.maximum(torch.minimum(t32, clip_abs[:, None]), -clip_abs[:, None]) + scale = (clip_abs / float(qmax)).clamp_min(1.0 / float(qmax)) + q = torch.clamp(torch.round(clipped / scale[:, None]), qmin, qmax).to(torch.int8).contiguous() + return q, scale.to(dtype=INT8_PER_ROW_SCALE_DTYPE).contiguous() + + clip_abs = float(torch.quantile(t32.abs().flatten(), INT8_CLIP_Q).item()) if t32.numel() else 0.0 + scale = torch.tensor(clip_abs / float(qmax) if clip_abs > 0 else 1.0, dtype=torch.float32) + q = torch.clamp(torch.round(torch.clamp(t32, -clip_abs, clip_abs) / scale), qmin, qmax).to(torch.int8).contiguous() + return q, scale + +def quantize_state_dict_int8(state_dict: dict[str, Tensor]): + # Single supported clean-script export format: + # - per-row int8 for 2D float tensors + # - per-tensor int8 for other float tensors + # - exact passthrough for non-floats + # - passthrough for small float tensors, stored as fp16 to save bytes + quantized: dict[str, Tensor] = {} + scales: dict[str, Tensor] = {} + dtypes: dict[str, str] = {} + passthrough: dict[str, Tensor] = {} + passthrough_orig_dtypes: dict[str, str] = {} + qmeta: dict[str, dict[str, object]] = {} + stats = dict.fromkeys( + ("param_count", "num_tensors", "num_float_tensors", "num_nonfloat_tensors", "baseline_tensor_bytes", "int8_payload_bytes"), + 0, + ) + + for name, tensor in state_dict.items(): + t = tensor.detach().to("cpu").contiguous() + stats["param_count"] += int(t.numel()) + stats["num_tensors"] += 1 + stats["baseline_tensor_bytes"] += tensor_nbytes(t) + + if not t.is_floating_point(): + stats["num_nonfloat_tensors"] += 1 + passthrough[name] = t + stats["int8_payload_bytes"] += tensor_nbytes(t) + continue + + # Small float tensors are cheap enough to keep directly. We still downcast + # fp32/bf16 passthrough tensors to fp16 so metadata does not dominate size. + if t.numel() <= INT8_KEEP_FLOAT_MAX_NUMEL: + kept = keep_float_tensor(name, t, passthrough_orig_dtypes) + passthrough[name] = kept + stats["int8_payload_bytes"] += tensor_nbytes(kept) + continue + + stats["num_float_tensors"] += 1 + # Mixed quantization: int6 for MLP layers 3-7 to save artifact space + int6_mlp_layers = os.environ.get("INT6_MLP_LAYERS", "") + qmax = 127 # default int8 + if int6_mlp_layers: + for li in int6_mlp_layers.split(","): + if li.strip() and f"blocks.{li.strip()}.mlp" in name and t.ndim == 2: + qmax = 31 # int6 + break + q, s = quantize_float_tensor(t, qmax=qmax) + if s.ndim > 0: + qmeta[name] = {"scheme": "per_row", "axis": 0} + quantized[name] = q + scales[name] = s + dtypes[name] = str(t.dtype).removeprefix("torch.") + stats["int8_payload_bytes"] += tensor_nbytes(q) + tensor_nbytes(s) + + obj: dict[str, object] = { + "__quant_format__": "int8_clean_per_row_v1", + "quantized": quantized, + "scales": scales, + "dtypes": dtypes, + "passthrough": passthrough, + } + if qmeta: + obj["qmeta"] = qmeta + if passthrough_orig_dtypes: + obj["passthrough_orig_dtypes"] = passthrough_orig_dtypes + return obj, stats + +def dequantize_state_dict_int8(obj: dict[str, object]) -> dict[str, Tensor]: + out: dict[str, Tensor] = {} + qmeta = obj.get("qmeta", {}) + passthrough_orig_dtypes = obj.get("passthrough_orig_dtypes", {}) + for name, q in obj["quantized"].items(): + dtype = getattr(torch, obj["dtypes"][name]) + s = obj["scales"][name] + if qmeta.get(name, {}).get("scheme") == "per_row" or s.ndim > 0: + s = s.to(dtype=torch.float32) + # Broadcast the saved row scale back across trailing dimensions. + out[name] = (q.float() * s.view(q.shape[0], *([1] * (q.ndim - 1)))).to(dtype=dtype).contiguous() + else: + scale = float(s.item()) + out[name] = (q.float() * scale).to(dtype=dtype).contiguous() + for name, t in obj["passthrough"].items(): + # Restore small tensors, undoing the temporary fp16 storage cast if needed. + out_t = t.detach().to("cpu").contiguous() + orig_dtype = passthrough_orig_dtypes.get(name) + if isinstance(orig_dtype, str): + out_t = out_t.to(dtype=getattr(torch, orig_dtype)).contiguous() + out[name] = out_t + return out + + +# ----------------------------- +# DATA LOADING +# ----------------------------- + +def load_data_shard(file: Path) -> Tensor: + header_bytes = 256 * np.dtype(" None: + self.file_idx = (self.file_idx + 1) % len(self.files) + self.tokens = load_data_shard(self.files[self.file_idx]) + self.pos = 0 + + def take(self, n: int) -> Tensor: + chunks: list[Tensor] = [] + remaining = n + while remaining > 0: + avail = self.tokens.numel() - self.pos + if avail <= 0: + self._advance_file() + continue + k = min(remaining, avail) + chunks.append(self.tokens[self.pos : self.pos + k]) + self.pos += k + remaining -= k + return chunks[0] if len(chunks) == 1 else torch.cat(chunks) + + +class DistributedTokenLoader: + # Each call consumes a contiguous chunk from the shared token stream, then slices out + # one disjoint span per rank. The extra "+1" token lets us build (x, y) by shifting. + def __init__(self, pattern: str, rank: int, world_size: int, device: torch.device): + self.rank = rank + self.world_size = world_size + self.device = device + self.stream = TokenStream(pattern) + + def next_batch(self, global_tokens: int, seq_len: int, grad_accum_steps: int) -> tuple[Tensor, Tensor]: + local_tokens = global_tokens // (self.world_size * grad_accum_steps) + per_rank_span = local_tokens + 1 + chunk = self.stream.take(per_rank_span * self.world_size) + start = self.rank * per_rank_span + local = chunk[start : start + per_rank_span].to(dtype=torch.int64) + x = local[:-1].reshape(-1, seq_len) + y = local[1:].reshape(-1, seq_len) + return x.to(self.device, non_blocking=True), y.to(self.device, non_blocking=True) + +# ----------------------------- +# TRANSFORMER MODULES +# ----------------------------- + +class RMSNorm(nn.Module): + def __init__(self, eps: float | None = None): + super().__init__() + self.eps = eps + + def forward(self, x: Tensor) -> Tensor: + return F.rms_norm(x, (x.size(-1),), eps=self.eps) + + +class CastedLinear(nn.Linear): + _qat_enabled: bool = False + _soft_round: bool = False + _soft_round_alpha: float = 1.0 + _quant_percentile: float = float(os.environ.get("QUANT_PERCENTILE", "1.0")) + + def forward(self, x: Tensor) -> Tensor: + w = self.weight.to(x.dtype) + if CastedLinear._qat_enabled and self.training and w.ndim == 2: + w32 = self.weight.float() + pct = CastedLinear._quant_percentile + row_max = (torch.quantile(w32.abs(), pct, dim=1) if pct < 1.0 + else w32.abs().amax(dim=1)).detach() + scale = (row_max / 31.0).clamp_min(1.0 / 31.0) + r = w32 / scale[:, None] + if CastedLinear._soft_round: + alpha = CastedLinear._soft_round_alpha + r_frac = r - r.detach().floor() - 0.5 + norm = torch.tanh(torch.tensor(alpha * 0.5, device=r.device, dtype=r.dtype)) + r_soft = r.detach().floor() + 0.5 + torch.tanh(alpha * r_frac) / (2.0 * norm) + w_q = (torch.clamp(r_soft, -32, 31) * scale[:, None]).to(x.dtype) + w = w_q # soft-round is differentiable, no STE needed + else: + with torch.no_grad(): + w_q = (torch.clamp(torch.round(r), -32, 31) * scale[:, None]).to(x.dtype) + w = w + (w_q - w).detach() # STE + bias = self.bias.to(x.dtype) if self.bias is not None else None + return F.linear(x, w, bias) + + +def restore_low_dim_params_to_fp32(module: nn.Module) -> None: + # Keep small/control parameters in fp32 even when the model body runs in bf16. + with torch.no_grad(): + for name, param in module.named_parameters(): + if (param.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS)) and param.dtype != torch.float32: + param.data = param.data.float() + + +class Rotary(nn.Module): + # NTK-aware RoPE: auto-scales base frequency when seq_len exceeds train_seq_len. + def __init__(self, dim: int, base: float = 10000.0, train_seq_len: int = 1024, rope_dims: int = 0): + super().__init__() + self.rope_dims = rope_dims if rope_dims > 0 else dim + self.dim = dim + self.base = base + self.train_seq_len = train_seq_len + rd = self.rope_dims + inv_freq = 1.0 / (base ** (torch.arange(0, rd, 2, dtype=torch.float32) / rd)) + self.register_buffer("inv_freq", inv_freq, persistent=False) + self._seq_len_cached = 0 + self._cos_cached: Tensor | None = None + self._sin_cached: Tensor | None = None + + def forward(self, seq_len: int, device: torch.device, dtype: torch.dtype) -> tuple[Tensor, Tensor]: + if ( + self._cos_cached is None + or self._sin_cached is None + or self._seq_len_cached != seq_len + or self._cos_cached.device != device + ): + rd = self.rope_dims + if seq_len > self.train_seq_len: + scale = seq_len / self.train_seq_len + new_base = self.base * (scale ** (rd / (rd - 2))) + inv_freq = 1.0 / (new_base ** (torch.arange(0, rd, 2, dtype=torch.float32, device=device) / rd)) + else: + inv_freq = self.inv_freq.to(device) + t = torch.arange(seq_len, device=device, dtype=inv_freq.dtype) + freqs = torch.outer(t, inv_freq) + self._cos_cached = freqs.cos()[None, :, None, :] + self._sin_cached = freqs.sin()[None, :, None, :] + self._seq_len_cached = seq_len + return self._cos_cached.to(dtype=dtype), self._sin_cached.to(dtype=dtype) + + +def apply_rotary_emb(x: Tensor, cos: Tensor, sin: Tensor) -> Tensor: + rd = cos.size(-1) * 2 + if rd < x.size(-1): + x_rope, x_pass = x[..., :rd], x[..., rd:] + half = rd // 2 + x1, x2 = x_rope[..., :half], x_rope[..., half:] + x_rot = torch.cat((x1 * cos + x2 * sin, x1 * (-sin) + x2 * cos), dim=-1) + return torch.cat((x_rot, x_pass), dim=-1) + half = x.size(-1) // 2 + x1, x2 = x[..., :half], x[..., half:] + return torch.cat((x1 * cos + x2 * sin, x1 * (-sin) + x2 * cos), dim=-1) + + +class CausalSelfAttention(nn.Module): + def __init__( + self, + dim: int, + num_heads: int, + num_kv_heads: int, + rope_base: float, + qk_gain_init: float, + rope_dims: int = 0, + value_residual: bool = False, + gated_attention: bool = False, + ): + super().__init__() + if dim % num_heads != 0: + raise ValueError("model_dim must be divisible by num_heads") + if num_heads % num_kv_heads != 0: + raise ValueError("num_heads must be divisible by num_kv_heads") + self.num_heads = num_heads + self.num_kv_heads = num_kv_heads + self.head_dim = dim // num_heads + if self.head_dim % 2 != 0: + raise ValueError("head_dim must be even for RoPE") + kv_dim = self.num_kv_heads * self.head_dim + self.c_q = CastedLinear(dim, dim, bias=False) + self.c_k = CastedLinear(dim, kv_dim, bias=False) + self.c_v = CastedLinear(dim, kv_dim, bias=False) + self.proj = CastedLinear(dim, dim, bias=False) + self.proj._zero_init = True + self.q_gain = nn.Parameter(torch.full((num_heads,), qk_gain_init, dtype=torch.float32)) + self.rope_dims = rope_dims + self.rotary = Rotary(self.head_dim, base=rope_base, train_seq_len=1024, rope_dims=rope_dims) + self.use_xsa = False + self.value_residual = value_residual + if value_residual: + self.vr_lambda = nn.Parameter(torch.tensor([0.5, 0.5], dtype=torch.float32)) + self.gated_attention = gated_attention + if gated_attention: + self.attn_gate = nn.Linear(dim, num_heads, bias=True) + nn.init.zeros_(self.attn_gate.weight) + nn.init.constant_(self.attn_gate.bias, 4.0) + + def _xsa_efficient(self, y: Tensor, v: Tensor) -> Tensor: + """Subtract self-value projection via GQA-aware reshape (no repeat_interleave).""" + B, T, H, D = y.shape + Hkv = v.size(-2) + group = H // Hkv + y_g = y.reshape(B, T, Hkv, group, D) + vn = F.normalize(v, dim=-1).unsqueeze(-2) + proj = (y_g * vn).sum(dim=-1, keepdim=True) * vn + return (y_g - proj).reshape(B, T, H, D) + + def forward(self, x: Tensor, v0: Tensor | None = None) -> tuple[Tensor, Tensor | None]: + bsz, seqlen, dim = x.shape + q = self.c_q(x).reshape(bsz, seqlen, self.num_heads, self.head_dim) + k = self.c_k(x).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim) + v = self.c_v(x).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim) + raw_v = v if self.value_residual else None + if self.value_residual and v0 is not None: + lam = self.vr_lambda.to(dtype=v.dtype) + v = lam[0] * v0 + lam[1] * v + q = F.rms_norm(q, (q.size(-1),)) + k = F.rms_norm(k, (k.size(-1),)) + cos, sin = self.rotary(seqlen, x.device, q.dtype) + q = apply_rotary_emb(q, cos, sin) + k = apply_rotary_emb(k, cos, sin) + q = q * self.q_gain.to(dtype=q.dtype)[None, None, :, None] + fa_dtype = torch.bfloat16 + if _USE_FA3: + y = flash_attn_3_func(q.to(fa_dtype), k.to(fa_dtype), v.to(fa_dtype), causal=True) + else: + # SDPA fallback: (B, T, H, D) -> (B, H, T, D), expand KV for GQA + q_t = q.to(fa_dtype).transpose(1, 2) + k_t = k.to(fa_dtype).transpose(1, 2) + v_t = v.to(fa_dtype).transpose(1, 2) + if self.num_kv_heads != self.num_heads: + rep = self.num_heads // self.num_kv_heads + k_t = k_t.repeat_interleave(rep, dim=1) + v_t = v_t.repeat_interleave(rep, dim=1) + y = F.scaled_dot_product_attention(q_t, k_t, v_t, is_causal=True) + y = y.transpose(1, 2) # (B, H, T, D) -> (B, T, H, D) + if self.use_xsa: + y = self._xsa_efficient(y, v) + if self.gated_attention: + gate = torch.sigmoid(self.attn_gate(x)) # (B, T, num_heads) + y = y * gate.unsqueeze(-1) # (B, T, H, 1) broadcast to (B, T, H, D) + y = y.reshape(bsz, seqlen, dim) + return self.proj(y), raw_v + + +class SmearGate(nn.Module): + def __init__(self, dim: int): + super().__init__() + self.gate = nn.Parameter(torch.zeros(dim, dtype=torch.float32)) + + def forward(self, x: Tensor) -> Tensor: + g = torch.sigmoid(self.gate.to(dtype=x.dtype))[None, None, :] + x_prev = torch.cat([torch.zeros_like(x[:, :1]), x[:, :-1]], dim=1) + return (1 - g) * x + g * x_prev + + +class BigramHashEmbedding(nn.Module): + def __init__(self, bigram_vocab_size: int, bigram_dim: int, model_dim: int): + super().__init__() + self.bigram_vocab_size = bigram_vocab_size + self.embed = nn.Embedding(bigram_vocab_size, bigram_dim) + nn.init.zeros_(self.embed.weight) + self.proj = CastedLinear(bigram_dim, model_dim, bias=False) if bigram_dim != model_dim else None + if self.proj is not None: + nn.init.zeros_(self.proj.weight) + self.scale = nn.Parameter(torch.tensor(0.05, dtype=torch.float32)) + + def bigram_hash(self, tokens: Tensor) -> Tensor: + t = tokens.to(torch.int32) + mod = self.bigram_vocab_size - 1 + out = torch.empty_like(t) + out[..., 0] = mod + out[..., 1:] = torch.bitwise_xor(36313 * t[..., 1:], 27191 * t[..., :-1]) % mod + return out.long() + + def forward(self, token_ids: Tensor) -> Tensor: + h = self.embed(self.bigram_hash(token_ids)) + if self.proj is not None: + h = self.proj(h) + return h * self.scale.to(dtype=h.dtype) + + +class MLP(nn.Module): + def __init__(self, dim: int, mlp_mult: int): + super().__init__() + hidden = int(mlp_mult * dim) + self.fc = CastedLinear(dim, hidden, bias=False) + self.proj = CastedLinear(hidden, dim, bias=False) + self.proj._zero_init = True + self.use_leaky = bool(int(os.environ.get("LEAKY_RELU", "1"))) + self.leaky_slope = float(os.environ.get("LEAKY_SLOPE", "0.5")) + + def forward(self, x: Tensor) -> Tensor: + x = F.leaky_relu(self.fc(x), self.leaky_slope) if self.use_leaky else torch.relu(self.fc(x)) + return self.proj(x.square()) + + +class CanonAC(nn.Module): + """Canon Autoregressive Convolution with DeltaGate. Manual shift+mul (no Conv1d).""" + def __init__(self, dim: int, kernel: int = 4, delta_gate_init: float = -4.0): + super().__init__() + self.kernel = kernel + self.weight = nn.Parameter(torch.zeros(kernel, dim)) + self.delta_gate_logit = nn.Parameter(torch.tensor(delta_gate_init)) + + def forward(self, x: Tensor) -> Tensor: + B, T, D = x.shape + K = self.kernel + w = self.weight.to(x.dtype) + x_pad = F.pad(x, (0, 0, K - 1, 0)) + y = w[0] * x_pad[:, K - 1:, :] + for k in range(1, K): + y = y + w[k] * x_pad[:, K - 1 - k : T + K - 1 - k, :] + gate = torch.sigmoid(self.delta_gate_logit.to(x.dtype)) + return x + gate * y + + +class Block(nn.Module): + def __init__( + self, + dim: int, + num_heads: int, + num_kv_heads: int, + mlp_mult: int, + rope_base: float, + qk_gain_init: float, + rope_dims: int = 0, + layer_idx: int = 0, + ln_scale: bool = False, + value_residual: bool = False, + gated_attention: bool = False, + canon_kernel: int = 0, + canon_delta_gate_init: float = -4.0, + ): + super().__init__() + self.attn_norm = RMSNorm() + self.mlp_norm = RMSNorm() + self.attn = CausalSelfAttention(dim, num_heads, num_kv_heads, rope_base, qk_gain_init, + rope_dims=rope_dims, value_residual=value_residual, + gated_attention=gated_attention) + self.mlp = MLP(dim, mlp_mult) + self.attn_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + self.mlp_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + self.resid_mix = nn.Parameter(torch.stack((torch.ones(dim), torch.zeros(dim))).float()) + self.ln_scale_factor = 1.0 / math.sqrt(layer_idx + 1) if ln_scale else 1.0 + self.canon_a = CanonAC(dim, canon_kernel, canon_delta_gate_init) if canon_kernel > 0 else None + self.canon_c = CanonAC(dim, canon_kernel, canon_delta_gate_init) if canon_kernel > 0 else None + + def forward(self, x: Tensor, x0: Tensor, v0: Tensor | None = None) -> tuple[Tensor, Tensor | None]: + mix = self.resid_mix.to(dtype=x.dtype) + x = mix[0][None, None, :] * x + mix[1][None, None, :] * x0 + s = self.ln_scale_factor + attn_in = self.attn_norm(x) * s + if self.canon_a is not None: + attn_in = self.canon_a(attn_in) + attn_out, raw_v = self.attn(attn_in, v0=v0) + x = x + self.attn_scale.to(dtype=x.dtype)[None, None, :] * attn_out + mlp_in = self.mlp_norm(x) * s + if self.canon_c is not None: + mlp_in = self.canon_c(mlp_in) + x = x + self.mlp_scale.to(dtype=x.dtype)[None, None, :] * self.mlp(mlp_in) + return x, raw_v + + +class FlowRefiner(nn.Module): + """1-step flow matching refiner in low-dim latent space. + + Projects hidden states to a low-dim latent, applies a learned velocity + field (1-step Euler), projects back. Output is additive adjustment to + hidden states. Initialized near-zero so the refiner starts as identity. + """ + def __init__(self, model_dim: int, latent_dim: int = 64, hidden_dim: int = 256, init_scale: float = 0.01): + super().__init__() + self.down_proj = nn.Linear(model_dim, latent_dim, bias=False) + self.velocity_net = nn.Sequential( + nn.Linear(latent_dim, hidden_dim, bias=True), + nn.GELU(), + nn.Linear(hidden_dim, latent_dim, bias=True), + ) + self.up_proj = nn.Linear(latent_dim, model_dim, bias=False) + self.gate = nn.Parameter(torch.tensor(-5.0)) # sigmoid(-5) ≈ 0.007 + self._init_scale = init_scale + self._init_weights() + self.velocity_net[2]._zero_init = True + + def _init_weights(self): + nn.init.normal_(self.down_proj.weight, std=self._init_scale) + nn.init.normal_(self.velocity_net[0].weight, std=self._init_scale) + nn.init.zeros_(self.velocity_net[0].bias) + nn.init.zeros_(self.velocity_net[2].weight) + nn.init.zeros_(self.velocity_net[2].bias) + nn.init.normal_(self.up_proj.weight, std=self._init_scale) + + def forward(self, x: Tensor) -> Tensor: + z = self.down_proj(x) + v = self.velocity_net(z) + z_refined = z + v + delta = self.up_proj(z_refined) + return torch.sigmoid(self.gate) * delta + + +class NativeFlowMatcher(nn.Module): + """Conditional Flow Matching refiner for hidden states. + + During training: computes auxiliary CFM loss on interpolated states. + During inference: applies a single Euler step correction at t=1. + """ + def __init__(self, model_dim: int, hidden_dim: int = 256, init_scale: float = 0.01): + super().__init__() + self.model_dim = model_dim + # Sinusoidal time embedding → projected to hidden_dim + self.time_proj = nn.Sequential( + nn.Linear(model_dim, hidden_dim, bias=True), + nn.GELU(), + ) + # Velocity network: x → hidden → x (with time conditioning via addition) + self.v_in = nn.Linear(model_dim, hidden_dim, bias=True) + self.v_act = nn.GELU() + self.v_out = nn.Linear(hidden_dim, model_dim, bias=False) + self.gate = nn.Parameter(torch.tensor(-5.0)) # sigmoid(-5) ≈ 0.007 + self._init_scale = init_scale + self._init_weights() + self.v_out._zero_init = True + + def _init_weights(self): + nn.init.normal_(self.v_in.weight, std=self._init_scale) + nn.init.zeros_(self.v_in.bias) + nn.init.normal_(self.time_proj[0].weight, std=self._init_scale) + nn.init.zeros_(self.time_proj[0].bias) + nn.init.zeros_(self.v_out.weight) + + @staticmethod + def _sinusoidal_time_emb(t: Tensor, dim: int) -> Tensor: + """Sinusoidal positional embedding for scalar time t. t: (...,) -> (..., dim).""" + half_dim = dim // 2 + emb = math.log(10000.0) / max(half_dim - 1, 1) + emb = torch.exp(torch.arange(half_dim, device=t.device, dtype=t.dtype) * -emb) + emb = t.unsqueeze(-1) * emb # (..., half_dim) + emb = torch.cat([emb.sin(), emb.cos()], dim=-1) # (..., dim) + if dim % 2 == 1: + emb = F.pad(emb, (0, 1)) + return emb + + def _velocity(self, x: Tensor, t_emb: Tensor) -> Tensor: + """Compute velocity v(x, t). x: (..., model_dim), t_emb: (..., hidden_dim).""" + h = self.v_in(x) # (..., hidden_dim) + h = h + t_emb # time conditioning via addition + h = self.v_act(h) # GELU + return self.v_out(h) # (..., model_dim) + + def forward(self, x: Tensor) -> tuple[Tensor, Tensor]: + """Returns (correction, cfm_loss). + + During training: cfm_loss is the flow matching objective (> 0). + During eval: cfm_loss is zero. + correction: gated velocity at t=1, to be added to x. + """ + B_seq = x.shape[:-1] # works for any leading dimensions + D = self.model_dim + + # Inference path: velocity at t=1 (clean input, no noise) + t_ones = torch.ones(*B_seq, device=x.device, dtype=x.dtype) + t_emb_ones = self.time_proj(self._sinusoidal_time_emb(t_ones, D)) + v_at_1 = self._velocity(x, t_emb_ones) + correction = torch.sigmoid(self.gate) * v_at_1 + + # Training path: auxiliary CFM loss + if self.training: + # Sample random t ~ U(0,1) per position + t = torch.rand(*B_seq, device=x.device, dtype=x.dtype) + z = torch.randn_like(x) # noise + # OT interpolant: x_t = (1-t)*z + t*x + t_expanded = t.unsqueeze(-1) # (..., 1) + x_t = (1.0 - t_expanded) * z + t_expanded * x.detach() + # Target velocity = x - z (OT conditional velocity field) + target_v = x.detach() - z + # Predict velocity at interpolated point + t_emb = self.time_proj(self._sinusoidal_time_emb(t, D)) + pred_v = self._velocity(x_t, t_emb) + cfm_loss = F.mse_loss(pred_v, target_v) + else: + cfm_loss = x.new_zeros(()) + + return correction, cfm_loss + + +class TTTLinearRefiner(nn.Module): + """E2E TTT-Linear refiner (Sun et al., 2024, arXiv:2407.04620). + + Applied to final hidden states [B, L, D] before lm_head. The hidden + state is a per-head linear model W updated via mini-batch gradient + descent on a learned self-supervised reconstruction task. Uses the + dual form for GPU efficiency. Trained end-to-end: the outer loop + optimizes theta_K, theta_V, theta_Q projections and the inner-loop + initialization W_0, making the model learn WHAT to compress. + """ + + def __init__(self, model_dim: int, num_heads: int = 8, + mini_batch_size: int = 16, ttt_base_lr: float = 1.0): + super().__init__() + self.model_dim = model_dim + self.num_heads = num_heads + self.head_dim = model_dim // num_heads + self.mini_batch_size = mini_batch_size + self.ttt_base_lr = ttt_base_lr + self.eta = ttt_base_lr / self.head_dim + + # Learnable projections (outer-loop) + self.theta_K = nn.Linear(model_dim, model_dim, bias=False) + self.theta_V = nn.Linear(model_dim, model_dim, bias=False) + self.theta_Q = nn.Linear(model_dim, model_dim, bias=False) + + # Inner-loop model: W_0, b_0 + self.W1 = nn.Parameter(torch.normal(0, 0.02, + size=(num_heads, self.head_dim, self.head_dim))) + self.b1 = nn.Parameter(torch.zeros(num_heads, 1, self.head_dim)) + + # Inner-loop LayerNorm (outer-loop trainable) + self.ttt_ln_w = nn.Parameter(torch.ones(num_heads, self.head_dim)) + self.ttt_ln_b = nn.Parameter(torch.zeros(num_heads, self.head_dim)) + + # Output projection + gate (starts near zero for residual safety) + self.o_proj = nn.Linear(model_dim, model_dim, bias=False) + self.post_norm = nn.LayerNorm(model_dim, eps=1e-6) + self.gate = nn.Parameter(torch.tensor(-5.0)) + nn.init.zeros_(self.o_proj.weight) + self.o_proj._zero_init = True + + @staticmethod + def _ln_fwd(x: Tensor, gamma: Tensor, beta: Tensor, eps: float = 1e-6) -> Tensor: + mu = x.mean(dim=-1, keepdim=True) + var = x.var(dim=-1, keepdim=True, unbiased=False) + x_hat = (x - mu) / torch.sqrt(var + eps) + return gamma * x_hat + beta + + @staticmethod + def _ln_fused_l2_bwd(x: Tensor, target: Tensor, gamma: Tensor, beta: Tensor, + eps: float = 1e-6) -> Tensor: + D = x.shape[-1] + mu = x.mean(dim=-1, keepdim=True) + var = x.var(dim=-1, keepdim=True, unbiased=False) + std = torch.sqrt(var + eps) + x_hat = (x - mu) / std + y = gamma * x_hat + beta + grad_output = y - target + grad_x_hat = grad_output * gamma + z = (1.0 / D) * ( + D * grad_x_hat + - grad_x_hat.sum(dim=-1, keepdim=True) + - x_hat * (grad_x_hat * x_hat).sum(dim=-1, keepdim=True) + ) / std + return z + + def forward(self, x: Tensor) -> Tensor: + """x: [B, L, D]. Returns [B, L, D] additive refinement (gated).""" + B, L, D = x.shape + NH, HD, b = self.num_heads, self.head_dim, self.mini_batch_size + eta = self.eta + + num_mb = L // b + usable_L = num_mb * b + x_proc = x[:, :usable_L] if usable_L < L else x + + # Project: [B, L, D] -> [B, NH, L, HD] + XK = self.theta_K(x_proc).reshape(B, usable_L, NH, HD).permute(0, 2, 1, 3) + XV = self.theta_V(x_proc).reshape(B, usable_L, NH, HD).permute(0, 2, 1, 3) + XQ = self.theta_Q(x_proc).reshape(B, usable_L, NH, HD).permute(0, 2, 1, 3) + + # Reshape into mini-batches: [B, NH, num_mb, b, HD] + XK = XK.reshape(B, NH, num_mb, b, HD) + XV = XV.reshape(B, NH, num_mb, b, HD) + XQ = XQ.reshape(B, NH, num_mb, b, HD) + + ln_w = self.ttt_ln_w.reshape(1, NH, 1, HD) + ln_b = self.ttt_ln_b.reshape(1, NH, 1, HD) + + # Causal mask for dual form + causal = torch.tril(torch.ones(b, b, device=x.device, dtype=x.dtype)) + + # Initialize inner-loop model + W = self.W1.unsqueeze(0).expand(B, -1, -1, -1).clone() + bias = self.b1.unsqueeze(0).expand(B, -1, -1, -1).clone() + + all_out = [] + for m in range(num_mb): + xk = XK[:, :, m] + xv = XV[:, :, m] + xq = XQ[:, :, m] + + Z1 = xk @ W + bias + target = xv - xk + + grad = self._ln_fused_l2_bwd(Z1, target, ln_w, ln_b) + + Attn = causal.unsqueeze(0).unsqueeze(0) * (xq @ xk.transpose(-2, -1)) + cumgrad = (causal.unsqueeze(0).unsqueeze(0) @ grad) + b_bar = bias - eta * cumgrad + Z_bar = xq @ W - eta * Attn @ grad + b_bar + + W = W - eta * xk.transpose(-2, -1) @ grad + bias = bias - eta * grad.sum(dim=-2, keepdim=True) + + Z_bar = self._ln_fwd(Z_bar, ln_w, ln_b) + out_mb = xq + Z_bar + all_out.append(out_mb) + + z = torch.cat(all_out, dim=2).permute(0, 2, 1, 3).reshape(B, usable_L, D) + z = self.post_norm(z) + z = self.o_proj(z) + result = torch.sigmoid(self.gate) * z + + if usable_L < L: + pad = torch.zeros(B, L - usable_L, D, device=x.device, dtype=x.dtype) + result = torch.cat([result, pad], dim=1) + + return result + + +class GPT(nn.Module): + def __init__( + self, + vocab_size: int, + num_layers: int, + model_dim: int, + num_heads: int, + num_kv_heads: int, + mlp_mult: int, + tie_embeddings: bool, + tied_embed_init_std: float, + logit_softcap: float, + rope_base: float, + qk_gain_init: float, + mtp_num_heads: int = 0, + mtp_loss_weight: float = 0.1, + bigram_vocab_size: int = 0, + bigram_dim: int = 128, + xsa_last_n: int = 0, + rope_dims: int = 0, + ln_scale: bool = False, + value_residual: bool = False, + gated_attention: bool = False, + canon_last_n: int = 0, + canon_kernel: int = 4, + canon_delta_gate_init: float = -4.0, + flow_enabled: bool = False, + flow_latent_dim: int = 64, + flow_hidden_dim: int = 256, + flow_init_scale: float = 0.01, + native_flow_enabled: bool = False, + native_flow_hidden_dim: int = 256, + native_flow_init_scale: float = 0.01, + native_flow_loss_weight: float = 0.1, + e2e_ttt_enabled: bool = False, + e2e_ttt_num_heads: int = 8, + e2e_ttt_mini_batch: int = 16, + e2e_ttt_base_lr: float = 1.0, + ): + super().__init__() + if logit_softcap <= 0.0: + raise ValueError(f"logit_softcap must be positive, got {logit_softcap}") + self.tie_embeddings = tie_embeddings + self.tied_embed_init_std = tied_embed_init_std + self.logit_softcap = logit_softcap + self.mtp_num_heads = mtp_num_heads + self.mtp_loss_weight = mtp_loss_weight + self.tok_emb = nn.Embedding(vocab_size, model_dim) + self.bigram = BigramHashEmbedding(bigram_vocab_size, bigram_dim, model_dim) if bigram_vocab_size > 0 else None + self.smear = SmearGate(model_dim) + self.num_encoder_layers = num_layers // 2 + self.num_decoder_layers = num_layers - self.num_encoder_layers + self.num_skip_weights = min(self.num_encoder_layers, self.num_decoder_layers) + self.skip_weights = nn.Parameter(torch.ones(self.num_skip_weights, model_dim, dtype=torch.float32)) + canon_start = num_layers - canon_last_n if canon_last_n > 0 else num_layers + self.blocks = nn.ModuleList( + [ + Block( + model_dim, + num_heads, + num_kv_heads, + mlp_mult, + rope_base, + qk_gain_init, + rope_dims=rope_dims, + layer_idx=i, + ln_scale=ln_scale, + value_residual=value_residual, + gated_attention=gated_attention, + canon_kernel=canon_kernel if i >= canon_start else 0, + canon_delta_gate_init=canon_delta_gate_init, + ) + for i in range(num_layers) + ] + ) + self.final_norm = RMSNorm() + self.lm_head = None if tie_embeddings else CastedLinear(model_dim, vocab_size, bias=False) + if self.lm_head is not None: + self.lm_head._zero_init = True + self.flow_refiner = FlowRefiner(model_dim, flow_latent_dim, flow_hidden_dim, flow_init_scale) if flow_enabled else None + self.native_flow = NativeFlowMatcher(model_dim, native_flow_hidden_dim, native_flow_init_scale) if native_flow_enabled else None + self.native_flow_loss_weight = native_flow_loss_weight + self.ttt_refiner = TTTLinearRefiner(model_dim, e2e_ttt_num_heads, e2e_ttt_mini_batch, e2e_ttt_base_lr) if e2e_ttt_enabled else None + self.mtp_heads = nn.ModuleList( + [CastedLinear(model_dim, vocab_size, bias=False) for _ in range(mtp_num_heads)] + ) + for head in self.mtp_heads: + head._zero_init = True + if xsa_last_n > 0: + for i in range(max(0, num_layers - xsa_last_n), num_layers): + self.blocks[i].attn.use_xsa = True + self._init_weights() + + def _init_weights(self) -> None: + if self.tie_embeddings: + nn.init.normal_(self.tok_emb.weight, mean=0.0, std=self.tied_embed_init_std) + num_layers = len(self.blocks) + for name, module in self.named_modules(): + if isinstance(module, nn.Linear): + if getattr(module, "_zero_init", False): + nn.init.zeros_(module.weight) + elif module.weight.ndim == 2 and module.weight.shape[0] >= 64 and module.weight.shape[1] >= 64: + nn.init.orthogonal_(module.weight, gain=1.0) + if ".proj." in name or name.endswith(".proj"): + with torch.no_grad(): + module.weight.mul_(1.0 / math.sqrt(2 * num_layers)) + + def forward(self, input_ids: Tensor, target_ids: Tensor) -> Tensor: + x = self.tok_emb(input_ids) + if self.bigram is not None: + x = x + self.bigram(input_ids) + x = F.rms_norm(x, (x.size(-1),)) + x = self.smear(x) + x0 = x + v0 = None + skips: list[Tensor] = [] + + for i in range(self.num_encoder_layers): + x, raw_v = self.blocks[i](x, x0, v0=v0) + if v0 is None and raw_v is not None: + v0 = raw_v + skips.append(x) + for i in range(self.num_decoder_layers): + if skips: + x = x + self.skip_weights[i].to(dtype=x.dtype)[None, None, :] * skips.pop() + x, _ = self.blocks[self.num_encoder_layers + i](x, x0, v0=v0) + + x = self.final_norm(x) + if self.ttt_refiner is not None: + x = x + self.ttt_refiner(x) + if self.flow_refiner is not None: + x = x + self.flow_refiner(x) + nflow_loss = x.new_zeros(()) + if self.native_flow is not None: + correction, nflow_loss = self.native_flow(x) + x = x + correction + x_flat = x.reshape(-1, x.size(-1)) + targets = target_ids.reshape(-1) + if self.tie_embeddings: + logits_proj = F.linear(x_flat, self.tok_emb.weight) + else: + if self.lm_head is None: + raise RuntimeError("lm_head is required when tie_embeddings=False") + logits_proj = self.lm_head(x_flat) + logits = self.logit_softcap * torch.tanh(logits_proj / self.logit_softcap) + main_loss = F.cross_entropy(logits.float(), targets, reduction="mean") + + if self.training and self.mtp_num_heads > 0 and self.mtp_loss_weight > 0.0: + _, seqlen, dim = x.shape + mtp_loss_sum = x.new_zeros(()) + mtp_loss_count = 0 + for k, mtp_head in enumerate(self.mtp_heads): + valid_t = seqlen - (k + 1) + if valid_t <= 0: + continue + mtp_hidden = x[:, :valid_t, :].reshape(-1, dim) + mtp_targets = target_ids[:, k + 1 :].reshape(-1) + mtp_logits_proj = mtp_head(mtp_hidden) + mtp_logits = self.logit_softcap * torch.tanh(mtp_logits_proj / self.logit_softcap) + mtp_loss_sum = mtp_loss_sum + F.cross_entropy(mtp_logits.float(), mtp_targets, reduction="mean") + mtp_loss_count += 1 + if mtp_loss_count > 0: + main_loss = main_loss + self.mtp_loss_weight * (mtp_loss_sum / mtp_loss_count) + + if self.native_flow is not None and self.training: + main_loss = main_loss + self.native_flow_loss_weight * nflow_loss + + return main_loss + + def forward_logits(self, input_ids: Tensor) -> Tensor: + """Return logits (bsz, seq_len, vocab) without computing loss.""" + x = self.tok_emb(input_ids) + if self.bigram is not None: + x = x + self.bigram(input_ids) + x = F.rms_norm(x, (x.size(-1),)) + x = self.smear(x) + x0 = x + v0 = None + skips: list[Tensor] = [] + for i in range(self.num_encoder_layers): + x, raw_v = self.blocks[i](x, x0, v0=v0) + if v0 is None and raw_v is not None: + v0 = raw_v + skips.append(x) + for i in range(self.num_decoder_layers): + if skips: + x = x + self.skip_weights[i].to(dtype=x.dtype)[None, None, :] * skips.pop() + x, _ = self.blocks[self.num_encoder_layers + i](x, x0, v0=v0) + x = self.final_norm(x) + if self.ttt_refiner is not None: + x = x + self.ttt_refiner(x) + if self.flow_refiner is not None: + x = x + self.flow_refiner(x) + if self.native_flow is not None: + correction, _ = self.native_flow(x) + x = x + correction + if self.tie_embeddings: + logits_proj = F.linear(x, self.tok_emb.weight) + else: + logits_proj = self.lm_head(x) + return self.logit_softcap * torch.tanh(logits_proj / self.logit_softcap) + + +# ----------------------------- +# SLIDING WINDOW EVALUATION +# ----------------------------- + +def eval_val_sliding( + args: Hyperparameters, + base_model: nn.Module, + rank: int, + world_size: int, + device: torch.device, + val_tokens: Tensor, + base_bytes_lut: Tensor, + has_leading_space_lut: Tensor, + is_boundary_token_lut: Tensor, + stride: int, + batch_seqs: int = 32, + eval_seq_len: int | None = None, +) -> tuple[float, float]: + """Sliding window evaluation: each token scored with maximum context. + Optionally uses entropy-gated 5-gram cache (NGRAM_CACHE=1).""" + seq_len = eval_seq_len or args.train_seq_len + total_tokens = val_tokens.numel() - 1 + + window_starts = [ws for ws in range(0, total_tokens, stride) + if min(ws + seq_len, total_tokens) - ws >= 1] + total_windows = len(window_starts) + + my_s = (total_windows * rank) // world_size + my_e = (total_windows * (rank + 1)) // world_size + my_windows = window_starts[my_s:my_e] + + loss_sum = torch.zeros((), device=device, dtype=torch.float64) + token_count = torch.zeros((), device=device, dtype=torch.float64) + byte_count = torch.zeros((), device=device, dtype=torch.float64) + + # N-gram eval cache with multi-order backoff + entropy-adaptive alpha (PR #702 inspired) + _ngram_default = "1" if world_size > 1 else "0" + use_ngram = bool(int(os.environ.get("NGRAM_CACHE", _ngram_default))) + ngram_alpha = float(os.environ.get("NGRAM_ALPHA", "0.40")) + ngram_min_count = int(os.environ.get("NGRAM_MIN_COUNT", "2")) + ngram_order = int(os.environ.get("NGRAM_ORDER", "7")) + ngram_min_order = int(os.environ.get("NGRAM_MIN_ORDER", "2")) + ngram_buckets = int(os.environ.get("NGRAM_BUCKETS", "4194304")) + ngram_entropy = bool(int(os.environ.get("NGRAM_ENTROPY", "1"))) + ngram_ent_base = float(os.environ.get("NGRAM_ENT_BASE", "0.05")) + ngram_ent_range = float(os.environ.get("NGRAM_ENT_RANGE", "0.55")) + ngram_ent_scale = float(os.environ.get("NGRAM_ENT_SCALE", "2.0")) + ngram_ent_thresh = float(os.environ.get("NGRAM_ENT_THRESH", "4.0")) + if use_ngram: + val_np = val_tokens.cpu().numpy() + _n_orders = ngram_order - ngram_min_order + 1 + ctx_tables = [np.zeros((ngram_buckets,), dtype=np.uint32) for _ in range(_n_orders)] + full_tables = [np.zeros((ngram_buckets,), dtype=np.uint32) for _ in range(_n_orders)] + ng_mask = np.uint64(ngram_buckets - 1) + ng_primes = np.array( + [np.uint64(36313), np.uint64(27191), np.uint64(51647), np.uint64(81929), + np.uint64(131071), np.uint64(175447), np.uint64(209591)], + dtype=np.uint64, + ) + print(f"ngram_cache:enabled orders={ngram_min_order}-{ngram_order} backoff " + f"entropy={ngram_entropy} alpha={ngram_alpha} " + f"ent_base={ngram_ent_base} ent_range={ngram_ent_range} " + f"min_count={ngram_min_count} buckets={ngram_buckets}", flush=True) + + base_model.eval() + compiled_logits = torch.compile(base_model.forward_logits, dynamic=False, fullgraph=True) + + with torch.inference_mode(): + for bi in range(0, len(my_windows), batch_seqs): + batch_ws = my_windows[bi:bi + batch_seqs] + bsz = len(batch_ws) + + x_batch = torch.zeros(bsz, seq_len, dtype=torch.int64, device=device) + y_batch = torch.zeros(bsz, seq_len, dtype=torch.int64, device=device) + wlens: list[int] = [] + + for i, ws in enumerate(batch_ws): + end = min(ws + seq_len, total_tokens) + wlen = end - ws + wlens.append(wlen) + chunk = val_tokens[ws:end + 1].to(dtype=torch.int64, device=device) + x_batch[i, :wlen] = chunk[:-1] + y_batch[i, :wlen] = chunk[1:] + + with torch.autocast(device_type="cuda", dtype=torch.bfloat16): + logits = compiled_logits(x_batch) + + nll = F.cross_entropy( + logits.reshape(-1, logits.size(-1)).float(), + y_batch.reshape(-1), + reduction="none", + ).reshape(bsz, seq_len) + + for i, ws in enumerate(batch_ws): + wlen = wlens[i] + s = 0 if ws == 0 else max(wlen - stride, 0) + seg_len = wlen - s + if seg_len <= 0: + continue + + scored_nll = nll[i, s:wlen].to(torch.float64) + + if use_ngram: + seg_nll_np = scored_nll.cpu().numpy() + seg_model_p = np.exp(-seg_nll_np) + n_seg = len(seg_nll_np) + global_j = np.arange(ws + s + 1, ws + wlen + 1, dtype=np.int64) + + # Entropy-adaptive alpha: compute from model logits (GPU) + if ngram_entropy: + with torch.no_grad(): + lp = F.log_softmax(logits[i, s:wlen].float(), dim=-1) + seg_ent = -(lp.exp() * lp).sum(dim=-1).cpu().numpy() + alpha_per_tok = ngram_ent_base + ngram_ent_range / ( + 1.0 + np.exp(-ngram_ent_scale * (seg_ent - ngram_ent_thresh))) + + # Precompute hashes for all orders + order_data = [] # (v_idx, ctx_key, full_key) per order + for oi in range(_n_orders): + ctx_w = ngram_min_order + oi - 1 + valid = global_j >= ctx_w + if not valid.any(): + order_data.append(None) + continue + v_idx = np.nonzero(valid)[0] + jv = global_j[v_idx] + ctx_hash = np.zeros(len(jv), dtype=np.uint64) + for k in range(ctx_w): + tok = val_np[jv - (ctx_w - k)].astype(np.uint64) + ctx_hash ^= tok * ng_primes[k % len(ng_primes)] + ctx_key = (ctx_hash & ng_mask).astype(np.int64) + tgt_np = val_np[jv].astype(np.uint64) + full_key = ((ctx_hash ^ (tgt_np * ng_primes[ctx_w % len(ng_primes)])) & ng_mask).astype(np.int64) + order_data.append((v_idx, ctx_key, full_key)) + + # Multi-order backoff: highest order first, fill unmatched with lower orders + best_p_ng = np.full(n_seg, -1.0) + for oi in range(_n_orders - 1, -1, -1): + if order_data[oi] is None: + continue + v_idx, ctx_key, full_key = order_data[oi] + ctx_counts = ctx_tables[oi][ctx_key].astype(np.float64) + full_counts = full_tables[oi][full_key].astype(np.float64) + has_match = ctx_counts >= float(ngram_min_count) + needs_fill = has_match & (best_p_ng[v_idx] < 0) + if needs_fill.any(): + fill_idx = v_idx[needs_fill] + p = np.minimum(full_counts[needs_fill], ctx_counts[needs_fill]) / np.maximum(ctx_counts[needs_fill], 1.0) + best_p_ng[fill_idx] = np.clip(p, 0.0, 1.0) + + # Mix model probability with n-gram + has_match = best_p_ng >= 0 + if has_match.any(): + if ngram_entropy: + alpha = alpha_per_tok[has_match] + else: + alpha = ngram_alpha + seg_model_p[has_match] = (1.0 - alpha) * seg_model_p[has_match] + alpha * best_p_ng[has_match] + seg_nll_np = -np.log(np.clip(seg_model_p, 1e-12, 1.0)) + + # Score-first: update ALL order tables AFTER scoring + for oi in range(_n_orders): + if order_data[oi] is None: + continue + v_idx, ctx_key, full_key = order_data[oi] + np.add.at(ctx_tables[oi], ctx_key, 1) + np.add.at(full_tables[oi], full_key, 1) + + scored_nll = torch.from_numpy(seg_nll_np).to(dtype=torch.float64, device=device) + + loss_sum += scored_nll.sum() + token_count += float(seg_len) + tgt = y_batch[i, s:wlen] + prev = x_batch[i, s:wlen] + tb = base_bytes_lut[tgt].to(torch.float64) + tb += (has_leading_space_lut[tgt] & ~is_boundary_token_lut[prev]).to(torch.float64) + byte_count += tb.sum() + + if dist.is_available() and dist.is_initialized(): + dist.all_reduce(loss_sum, op=dist.ReduceOp.SUM) + dist.all_reduce(token_count, op=dist.ReduceOp.SUM) + dist.all_reduce(byte_count, op=dist.ReduceOp.SUM) + + val_loss = (loss_sum / token_count).item() + bits_per_token = val_loss / math.log(2.0) + tokens_per_byte = token_count.item() / byte_count.item() + base_model.train() + return val_loss, bits_per_token * tokens_per_byte + + +# ----------------------------- +# TEST-TIME TRAINING (TTT) +# ----------------------------- + +def ttt_adapt(args: Hyperparameters, base_model: nn.Module, device: torch.device, + val_tokens: Tensor, rank: int = 0, world_size: int = 1, + log_fn=None) -> None: + """Score-first TTT: process val data in chunks, score each chunk first + (inference_mode), then train on scored tokens. Compliant with Issue #677.""" + seq_len = args.train_seq_len + total_tokens = val_tokens.numel() - 1 + chunk_tokens = args.ttt_chunk_tokens + batch_seqs = args.ttt_batch_seqs + + # Freeze early blocks + if args.ttt_freeze_blocks > 0: + for i, block in enumerate(base_model.blocks): + if i < args.ttt_freeze_blocks: + for p in block.parameters(): + p.requires_grad_(False) + + ttt_params = [p for p in base_model.parameters() if p.requires_grad] + if args.ttt_optimizer == "adamw": + optimizer = torch.optim.AdamW(ttt_params, lr=args.ttt_lr, weight_decay=0.0) + else: + optimizer = torch.optim.SGD(ttt_params, lr=args.ttt_lr, momentum=args.ttt_momentum) + + t0 = time.perf_counter() + chunk_idx = 0 + + for chunk_start in range(0, total_tokens - seq_len, chunk_tokens): + chunk_end = min(chunk_start + chunk_tokens, total_tokens) + chunk_len = chunk_end - chunk_start + n_seqs = chunk_len // seq_len + if n_seqs == 0: + break + + my_start = (n_seqs * rank) // world_size + my_end = (n_seqs * (rank + 1)) // world_size + if my_end <= my_start: + continue + + # Phase 1: Score chunk under inference_mode (forward only) + base_model.eval() + with torch.inference_mode(): + for si in range(my_start, my_end, batch_seqs): + se = min(si + batch_seqs, my_end) + raw_s = chunk_start + si * seq_len + raw_e = chunk_start + se * seq_len + 1 + local = val_tokens[raw_s:raw_e].to(device=device, dtype=torch.int64) + x = local[:-1].reshape(-1, seq_len) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16): + base_model.forward_logits(x) + + # Phase 2: Train on scored tokens (K epochs) + base_model.train() + for epoch in range(args.ttt_epochs): + for si in range(my_start, my_end, batch_seqs): + se = min(si + batch_seqs, my_end) + raw_s = chunk_start + si * seq_len + raw_e = chunk_start + se * seq_len + 1 + local = val_tokens[raw_s:raw_e].to(device=device, dtype=torch.int64) + x = local[:-1].reshape(-1, seq_len) + y = local[1:].reshape(-1, seq_len) + + optimizer.zero_grad(set_to_none=True) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16): + loss = base_model(x, y) + loss.backward() + + if world_size > 1: + for p in ttt_params: + if p.grad is not None: + dist.all_reduce(p.grad, op=dist.ReduceOp.AVG) + + torch.nn.utils.clip_grad_norm_(ttt_params, 1.0) + optimizer.step() + + chunk_idx += 1 + if log_fn and chunk_idx % 20 == 0: + log_fn(f"ttt:chunk={chunk_idx} elapsed={time.perf_counter()-t0:.1f}s") + + # Restore all params + for p in base_model.parameters(): + p.requires_grad_(True) + + if log_fn: + log_fn(f"ttt:done chunks={chunk_idx} elapsed={time.perf_counter()-t0:.1f}s") + + +# ----------------------------- +# INT6 MIXED QUANTIZATION (transplanted from working diagnostic scripts) +# ----------------------------- + +def _classify_param(name: str) -> str: + if "tok_emb" in name or "lm_head" in name: + return "embed" + if ".mlp." in name: + return "mlp" + if ".attn." in name or (".proj." in name and ".mlp." not in name): + return "attn" + return "other" + +def quantize_int6_per_row(t: Tensor, qmax: int = 31) -> tuple[Tensor, Tensor]: + t32 = t.float() + qmin = -qmax - 1 + pct = CastedLinear._quant_percentile + if t32.ndim == 2: + row_max = (torch.quantile(t32.abs(), pct, dim=1) if pct < 1.0 + else t32.abs().amax(dim=1)) + scale = (row_max / float(qmax)).clamp_min(1.0 / float(qmax)).to(torch.float16) + clipped = t32.clamp(-row_max[:, None], row_max[:, None]) + q = torch.clamp(torch.round(clipped / scale.float()[:, None]), qmin, qmax).to(torch.int8) + return q, scale + amax = t32.abs().max().item() + scale = torch.tensor(amax / float(qmax) if amax > 0 else 1.0, dtype=torch.float16) + q = torch.clamp(torch.round(t32 / scale.float()), qmin, qmax).to(torch.int8) + return q, scale + +def mixed_quantize_int6(state_dict: dict[str, Tensor], int6_cats: set[str], + int5_layers: set[int] | None = None): + num_layers_total = max( + (int(k.split(".")[1]) for k in state_dict if k.startswith("blocks.")), + default=0, + ) + 1 + late_k_layers = set(range(num_layers_total - 2, num_layers_total)) + if int5_layers is None: + int5_layers = set() + + result: dict[str, Tensor] = {} + meta: dict[str, object] = {} + for name, tensor in state_dict.items(): + t = tensor.detach().cpu().contiguous() + cat = _classify_param(name) + if not t.is_floating_point() or t.numel() <= 65536: + result[name] = t.to(torch.float16) if t.is_floating_point() else t + meta[name] = "passthrough" + continue + if any(p in name for p in CONTROL_TENSOR_NAME_PATTERNS): + result[name] = t.float() + meta[name] = "passthrough_ctrl" + continue + # Determine layer index for int5 fallback + layer_idx = -1 + if name.startswith("blocks."): + try: + layer_idx = int(name.split(".")[1]) + except (IndexError, ValueError): + pass + # tok_emb.weight falls through to int8 via "embed" category + if cat in int6_cats and t.ndim >= 1: + qmax = 15 if layer_idx in int5_layers else 31 + q, s = quantize_int6_per_row(t, qmax=qmax) + result[name + ".q"] = q + result[name + ".scale"] = s + meta[name] = {"type": "int5" if qmax == 15 else "int6"} + else: + q, s = quantize_float_tensor(t) + result[name + ".q"] = q + result[name + ".scale"] = s + meta[name] = {"type": "int8"} + return result, meta + +def dequantize_mixed_int6(result: dict[str, Tensor], meta: dict[str, object], + template_sd: dict[str, Tensor]) -> dict[str, Tensor]: + out: dict[str, Tensor] = {} + for name, orig in template_sd.items(): + info = meta.get(name) + if info is None: + continue + orig_dtype = orig.dtype + if info in ("passthrough", "passthrough_ctrl", "passthrough_fp16"): + t = result[name] + if t.dtype == torch.float16 and orig_dtype in (torch.float32, torch.bfloat16): + t = t.to(orig_dtype) + out[name] = t + continue + q, s = result[name + ".q"], result[name + ".scale"] + if s.ndim > 0: + out[name] = (q.float() * s.float().view(q.shape[0], *([1] * (q.ndim - 1)))).to(orig_dtype) + else: + out[name] = (q.float() * float(s.item())).to(orig_dtype) + return out + + +# ----------------------------- +# TRAINING +# ----------------------------- + +def main() -> None: + global zeropower_via_newtonschulz5 + + code = Path(__file__).read_text(encoding="utf-8") + args = Hyperparameters() + zeropower_via_newtonschulz5 = torch.compile(zeropower_via_newtonschulz5) + + # ----------------------------- + # DISTRIBUTED + CUDA SETUP + # ----------------------------- + + distributed = "RANK" in os.environ and "WORLD_SIZE" in os.environ and int(os.environ.get("WORLD_SIZE", "1")) > 1 + rank = int(os.environ.get("RANK", "0")) + world_size = int(os.environ.get("WORLD_SIZE", "1")) + local_rank = int(os.environ.get("LOCAL_RANK", "0")) + if world_size <= 0: + raise ValueError(f"WORLD_SIZE must be positive, got {world_size}") + if 8 % world_size != 0: + raise ValueError(f"WORLD_SIZE={world_size} must divide 8 so grad_accum_steps stays integral") + grad_accum_steps = 8 // world_size + grad_scale = 1.0 / grad_accum_steps + if not torch.cuda.is_available(): + raise RuntimeError("CUDA is required") + device = torch.device("cuda", local_rank) + torch.cuda.set_device(device) + if distributed: + dist.init_process_group(backend="nccl", device_id=device) + dist.barrier() + master_process = rank == 0 + + # Fast math knobs + torch.backends.cuda.matmul.allow_tf32 = True + torch.backends.cudnn.allow_tf32 = True + from torch.backends.cuda import enable_cudnn_sdp, enable_flash_sdp, enable_math_sdp, enable_mem_efficient_sdp + + if _USE_FA3: + enable_cudnn_sdp(False) + enable_flash_sdp(True) + enable_mem_efficient_sdp(False) + enable_math_sdp(False) + else: + enable_cudnn_sdp(False) + enable_flash_sdp(True) + enable_mem_efficient_sdp(True) + enable_math_sdp(True) + + logfile = None + if master_process: + os.makedirs("logs", exist_ok=True) + logfile = f"logs/{args.run_id}.txt" + print(logfile) + + def log0(msg: str, console: bool = True) -> None: + if not master_process: + return + if console: + print(msg) + if logfile is not None: + with open(logfile, "a", encoding="utf-8") as f: + print(msg, file=f) + + log0(code, console=False) + log0("=" * 100, console=False) + log0(f"Running Python {sys.version}", console=False) + log0(f"Running PyTorch {torch.__version__}", console=False) + log0( + subprocess.run(["nvidia-smi"], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True, check=False).stdout, + console=False, + ) + log0("=" * 100, console=False) + + # ----------------------------- + # TOKENIZER + VALIDATION METRIC SETUP + # ----------------------------- + + random.seed(args.seed) + np.random.seed(args.seed) + torch.manual_seed(args.seed) + torch.cuda.manual_seed_all(args.seed) + + if not args.tokenizer_path.endswith(".model"): + raise ValueError(f"Script only setup for SentencePiece .model file: {args.tokenizer_path}") + sp = spm.SentencePieceProcessor(model_file=args.tokenizer_path) + if int(sp.vocab_size()) != args.vocab_size: + raise ValueError( + f"VOCAB_SIZE={args.vocab_size} does not match tokenizer vocab_size={int(sp.vocab_size())}" + ) + dataset_dir = Path(args.data_path).resolve() + actual_train_files = len(list(dataset_dir.glob("fineweb_train_*.bin"))) + effective_eval_seq_len = args.eval_seq_len if args.eval_seq_len > 0 else args.train_seq_len + val_seq_len = max(args.train_seq_len, effective_eval_seq_len) + val_tokens = load_validation_tokens(args.val_files, val_seq_len) + base_bytes_lut, has_leading_space_lut, is_boundary_token_lut = build_sentencepiece_luts( + sp, args.vocab_size, device + ) + log0(f"val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path={args.tokenizer_path}") + log0(f"train_loader:dataset:{dataset_dir.name} train_shards:{actual_train_files}") + log0(f"val_loader:shards pattern={args.val_files} tokens:{val_tokens.numel() - 1}") + + # ----------------------------- + # MODEL + OPTIMIZER SETUP + # ----------------------------- + + CastedLinear._qat_enabled = args.qat_enabled + + base_model = GPT( + vocab_size=args.vocab_size, + num_layers=args.num_layers, + model_dim=args.model_dim, + num_heads=args.num_heads, + num_kv_heads=args.num_kv_heads, + mlp_mult=args.mlp_mult, + tie_embeddings=args.tie_embeddings, + tied_embed_init_std=args.tied_embed_init_std, + logit_softcap=args.logit_softcap, + rope_base=args.rope_base, + qk_gain_init=args.qk_gain_init, + mtp_num_heads=args.mtp_num_heads, + mtp_loss_weight=args.mtp_loss_weight, + bigram_vocab_size=args.bigram_vocab_size, + bigram_dim=args.bigram_dim, + xsa_last_n=args.xsa_last_n, + rope_dims=args.rope_dims, + ln_scale=args.ln_scale, + value_residual=args.value_residual, + gated_attention=args.gated_attention, + canon_last_n=args.canon_last_n, + canon_kernel=args.canon_kernel, + canon_delta_gate_init=args.canon_delta_gate_init, + flow_enabled=args.flow_enabled, + flow_latent_dim=args.flow_latent_dim, + flow_hidden_dim=args.flow_hidden_dim, + flow_init_scale=args.flow_init_scale, + native_flow_enabled=args.native_flow_enabled, + native_flow_hidden_dim=args.native_flow_hidden_dim, + native_flow_init_scale=args.native_flow_init_scale, + native_flow_loss_weight=args.native_flow_loss_weight, + e2e_ttt_enabled=args.e2e_ttt_enabled, + e2e_ttt_num_heads=args.e2e_ttt_num_heads, + e2e_ttt_mini_batch=args.e2e_ttt_mini_batch, + e2e_ttt_base_lr=args.e2e_ttt_base_lr, + ).to(device).bfloat16() + for module in base_model.modules(): + if isinstance(module, CastedLinear): + module.float() + restore_low_dim_params_to_fp32(base_model) + compiled_model = torch.compile(base_model, dynamic=False, fullgraph=True) + model: nn.Module = DDP(compiled_model, device_ids=[local_rank], broadcast_buffers=False, static_graph=True) if distributed else compiled_model + + # Optimizer split: + # - token embedding (Adam) uses EMBED_LR + # - untied lm_head (Adam) uses HEAD_LR + # - matrix params in transformer blocks use MATRIX_LR via Muon + # - vectors/scalars use SCALAR_LR via Adam + block_named_params = list(base_model.blocks.named_parameters()) + matrix_params = [ + p + for name, p in block_named_params + if p.ndim == 2 and not any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) + ] + if base_model.mtp_num_heads > 0: + matrix_params.extend([p for p in base_model.mtp_heads.parameters() if p.ndim == 2]) + scalar_params = [ + p + for name, p in block_named_params + if p.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) + ] + if base_model.skip_weights.numel() > 0: + scalar_params.append(base_model.skip_weights) + scalar_params.append(base_model.smear.gate) + if base_model.bigram is not None: + scalar_params.append(base_model.bigram.scale) + if base_model.flow_refiner is not None: + scalar_params.extend(list(base_model.flow_refiner.parameters())) + if base_model.native_flow is not None: + scalar_params.extend(list(base_model.native_flow.parameters())) + if base_model.ttt_refiner is not None: + scalar_params.extend(list(base_model.ttt_refiner.parameters())) + token_lr = args.tied_embed_lr if args.tie_embeddings else args.embed_lr + tok_params = [{"params": [base_model.tok_emb.weight], "lr": token_lr, "base_lr": token_lr}] + if base_model.bigram is not None: + tok_params.append({"params": [base_model.bigram.embed.weight], "lr": token_lr, "base_lr": token_lr}) + if base_model.bigram.proj is not None: + matrix_params.append(base_model.bigram.proj.weight) + optimizer_tok = torch.optim.AdamW( + tok_params, + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + weight_decay=args.adam_wd, + fused=True, + ) + optimizer_muon = Muon( + matrix_params, + lr=args.matrix_lr, + momentum=args.muon_momentum, + backend_steps=args.muon_backend_steps, + weight_decay=args.muon_wd, + ) + for group in optimizer_muon.param_groups: + group["base_lr"] = args.matrix_lr + optimizer_scalar = torch.optim.AdamW( + [{"params": scalar_params, "lr": args.scalar_lr, "base_lr": args.scalar_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + weight_decay=args.adam_wd, + fused=True, + ) + optimizers: list[torch.optim.Optimizer] = [optimizer_tok, optimizer_muon, optimizer_scalar] + if base_model.lm_head is not None: + optimizer_head = torch.optim.Adam( + [{"params": [base_model.lm_head.weight], "lr": args.head_lr, "base_lr": args.head_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + fused=True, + ) + optimizers.insert(1, optimizer_head) + + n_params = sum(p.numel() for p in base_model.parameters()) + mtp_params = sum(p.numel() for p in base_model.mtp_heads.parameters()) + log0(f"model_params:{n_params}") + log0(f"mtp_num_heads:{args.mtp_num_heads} mtp_loss_weight:{args.mtp_loss_weight} mtp_params:{mtp_params}") + log0(f"world_size:{world_size} grad_accum_steps:{grad_accum_steps}") + log0(f"sdp_backends:fa3={_USE_FA3} cudnn=False flash=True mem_efficient={not _USE_FA3} math={not _USE_FA3}") + log0(f"attention_mode:gqa num_heads:{args.num_heads} num_kv_heads:{args.num_kv_heads}") + log0( + f"tie_embeddings:{args.tie_embeddings} embed_lr:{token_lr} " + f"head_lr:{args.head_lr if base_model.lm_head is not None else 0.0} " + f"matrix_lr:{args.matrix_lr} scalar_lr:{args.scalar_lr}" + ) + log0( + f"train_batch_tokens:{args.train_batch_tokens} train_seq_len:{args.train_seq_len} " + f"iterations:{args.iterations} warmup_steps:{args.warmup_steps} " + f"max_wallclock_seconds:{args.max_wallclock_seconds:.3f}" + ) + log0(f"seed:{args.seed}") + + # ----------------------------- + # DATA LOADER & MODEL WARMUP + # ----------------------------- + + train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) + + def zero_grad_all() -> None: + for opt in optimizers: + opt.zero_grad(set_to_none=True) + + max_wallclock_ms = 1000.0 * args.max_wallclock_seconds if args.max_wallclock_seconds > 0 else None + + def lr_mul(step: int, elapsed_ms: float) -> float: + if args.warmdown_iters <= 0: + return 1.0 + if max_wallclock_ms is None: + warmdown_start = max(args.iterations - args.warmdown_iters, 0) + return max((args.iterations - step) / max(args.warmdown_iters, 1), 0.0) if warmdown_start <= step < args.iterations else 1.0 + step_ms = elapsed_ms / max(step, 1) + warmdown_ms = args.warmdown_iters * step_ms + remaining_ms = max(max_wallclock_ms - elapsed_ms, 0.0) + return remaining_ms / max(warmdown_ms, 1e-9) if remaining_ms <= warmdown_ms else 1.0 + + # Warmup primes the compiled forward/backward/optimizer paths, then we restore the + # initial weights/optimizer state so measured training starts from the true init. + if args.warmup_steps > 0: + initial_model_state = {name: tensor.detach().cpu().clone() for name, tensor in base_model.state_dict().items()} + initial_optimizer_states = [copy.deepcopy(opt.state_dict()) for opt in optimizers] + model.train() + for warmup_step in range(args.warmup_steps): + zero_grad_all() + for micro_step in range(grad_accum_steps): + if distributed: + model.require_backward_grad_sync = micro_step == grad_accum_steps - 1 + x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + warmup_loss = model(x, y) + (warmup_loss * grad_scale).backward() + for opt in optimizers: + opt.step() + zero_grad_all() + if args.warmup_steps <= 20 or (warmup_step + 1) % 10 == 0 or warmup_step + 1 == args.warmup_steps: + log0(f"warmup_step:{warmup_step + 1}/{args.warmup_steps}") + base_model.load_state_dict(initial_model_state, strict=True) + for opt, state in zip(optimizers, initial_optimizer_states, strict=True): + opt.load_state_dict(state) + zero_grad_all() + if distributed: + model.require_backward_grad_sync = True + train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) + + # ----------------------------- + # MAIN TRAINING LOOP + # ----------------------------- + + swa_state: dict[str, Tensor] | None = None + swa_count = 0 + + ema_state: dict[str, Tensor] | None = None + if args.ema_enabled: + ema_state = {name: t.detach().float().clone() for name, t in base_model.state_dict().items()} + + training_time_ms = 0.0 + stop_after_step: int | None = None + torch.cuda.synchronize() + t0 = time.perf_counter() + + eval_only_path = os.environ.get("EVAL_ONLY", "") + if eval_only_path: + log0(f"eval_only: loading {eval_only_path}, skipping training") + base_model.load_state_dict(torch.load(eval_only_path, map_location=device, weights_only=False), strict=False) + ema_state = None # prevent random EMA from overwriting loaded weights + swa_state = None + swa_count = 0 + args.iterations = 0 # skip training, go straight to eval + + step = 0 + while True: + last_step = step == args.iterations or (stop_after_step is not None and step >= stop_after_step) + + should_validate = last_step or (args.val_loss_every > 0 and step % args.val_loss_every == 0) + if should_validate: + torch.cuda.synchronize() + training_time_ms += 1000.0 * (time.perf_counter() - t0) + val_loss, val_bpb = eval_val( + args, + model, + rank, + world_size, + device, + grad_accum_steps, + val_tokens, + base_bytes_lut, + has_leading_space_lut, + is_boundary_token_lut, + ) + log0( + f"step:{step}/{args.iterations} val_loss:{val_loss:.4f} val_bpb:{val_bpb:.4f} " + f"train_time:{training_time_ms:.0f}ms step_avg:{training_time_ms / max(step, 1):.2f}ms" + ) + torch.cuda.synchronize() + t0 = time.perf_counter() + + if last_step: + if stop_after_step is not None and step < args.iterations: + log0( + f"stopping_early: wallclock_cap train_time:{training_time_ms:.0f}ms " + f"step:{step}/{args.iterations}" + ) + break + + elapsed_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) + scale = lr_mul(step, elapsed_ms) + qat_threshold = float(os.environ.get("QAT_THRESHOLD", "0.1")) + if args.late_qat and scale < qat_threshold and not CastedLinear._qat_enabled: + CastedLinear._qat_enabled = True + CastedLinear._soft_round = args.soft_round_qat + log0(f"late_qat:enabled step:{step} scale:{scale:.4f} soft_round:{args.soft_round_qat}") + if CastedLinear._qat_enabled and CastedLinear._soft_round: + qat_progress = max(0.0, 1.0 - (scale / qat_threshold)) + CastedLinear._soft_round_alpha = 1.0 + 15.0 * qat_progress # 1→16 + zero_grad_all() + train_loss = torch.zeros((), device=device) + for micro_step in range(grad_accum_steps): + if distributed: + model.require_backward_grad_sync = micro_step == grad_accum_steps - 1 + x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + loss = model(x, y) + train_loss += loss.detach() + (loss * grad_scale).backward() + train_loss /= grad_accum_steps + + frac = min(step / args.muon_momentum_warmup_steps, 1.0) if args.muon_momentum_warmup_steps > 0 else 1.0 + muon_momentum = (1 - frac) * args.muon_momentum_warmup_start + frac * args.muon_momentum + for group in optimizer_muon.param_groups: + group["momentum"] = muon_momentum + + for opt in optimizers: + for group in opt.param_groups: + group["lr"] = group["base_lr"] * scale + + if args.grad_clip_norm > 0: + torch.nn.utils.clip_grad_norm_(base_model.parameters(), args.grad_clip_norm) + for opt in optimizers: + opt.step() + zero_grad_all() + + step += 1 + approx_training_time_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) + + if ema_state is not None: + d = args.ema_decay + with torch.no_grad(): + for name, t in base_model.state_dict().items(): + ema_state[name].mul_(d).add_(t.detach().float(), alpha=1.0 - d) + + if args.swa_enabled and scale < 0.5 and step % args.swa_every == 0: + # Tight SWA: collect from EMA state if available, else from raw model + src = ema_state if ema_state is not None else {name: t.detach().float() for name, t in base_model.state_dict().items()} + if swa_state is None: + swa_state = {name: t.clone() for name, t in src.items()} + swa_count = 1 + log0(f"swa:start step:{step} tight={ema_state is not None}") + else: + for name in swa_state: + swa_state[name].add_(src[name]) + swa_count += 1 + + should_log_train = ( + args.train_log_every > 0 + and (step <= 10 or step % args.train_log_every == 0 or stop_after_step is not None) + ) + if should_log_train: + log0( + f"step:{step}/{args.iterations} train_loss:{train_loss.item():.4f} " + f"train_time:{approx_training_time_ms:.0f}ms step_avg:{approx_training_time_ms / step:.2f}ms" + ) + + # Needed to sync whether we've reached the wallclock cap. + reached_cap = max_wallclock_ms is not None and approx_training_time_ms >= max_wallclock_ms + if distributed and max_wallclock_ms is not None: + reached_cap_tensor = torch.tensor(int(reached_cap), device=device) + dist.all_reduce(reached_cap_tensor, op=dist.ReduceOp.MAX) + reached_cap = bool(reached_cap_tensor.item()) + if stop_after_step is None and reached_cap: + stop_after_step = step + + log0( + f"peak memory allocated: {torch.cuda.max_memory_allocated() // 1024 // 1024} MiB " + f"reserved: {torch.cuda.max_memory_reserved() // 1024 // 1024} MiB" + ) + + if args.swa_enabled and swa_state is not None and swa_count > 1: + log0(f"swa:applying Tight SWA averaged {swa_count} EMA checkpoints") + avg_state = {name: (t / swa_count).to(dtype=base_model.state_dict()[name].dtype) + for name, t in swa_state.items()} + del swa_state + if ema_state is not None: + del ema_state + base_model.load_state_dict(avg_state, strict=True) + elif ema_state is not None: + log0("ema:applying EMA weights") + avg_state = {name: t.to(dtype=base_model.state_dict()[name].dtype) + for name, t in ema_state.items()} + del ema_state + base_model.load_state_dict(avg_state, strict=True) + + # ----------------------------- + # SERIALIZATION + ROUNDTRIP VALIDATION + # ----------------------------- + + # Use RUN_ID for unique filenames when multiple jobs share CWD + _run_id = os.environ.get("RUN_ID", "") + _save_prefix = f"final_model_{_run_id}" if _run_id else "final_model" + _pt_path = f"{_save_prefix}.pt" + _ptz_path = f"{_save_prefix}.int6.ptz" + log0(f"save_paths: pt={_pt_path} ptz={_ptz_path}") + + full_state_dict = base_model.state_dict() + export_sd = {k: v for k, v in full_state_dict.items() if "mtp_heads" not in k} + excluded_mtp = sum(int(t.numel()) for k, t in full_state_dict.items() if "mtp_heads" in k) + if excluded_mtp > 0: + log0(f"export_excluding_mtp_params:{excluded_mtp}") + + if master_process: + torch.save(export_sd, _pt_path) + model_bytes = os.path.getsize(_pt_path) + code_bytes = len(code.encode("utf-8")) + log0(f"Serialized model: {model_bytes} bytes") + log0(f"Code size: {code_bytes} bytes") + + sd_cpu = {k: v.detach().cpu() for k, v in export_sd.items()} + code_bytes = len(code.encode("utf-8")) + artifact_limit = 16_000_000 - code_bytes + + # --- Auto-downgrade quantization: try int6 first, fall back to int5 middle layers --- + num_layers_total = max( + (int(k.split(".")[1]) for k in sd_cpu if k.startswith("blocks.")), + default=0, + ) + 1 + _zstd_levels = [int(os.environ.get("ZSTD_LEVEL", "16")), 1, 17, 2] + # Phase 1: pure int6 with multiple zstd levels + quant_result, quant_meta = mixed_quantize_int6(sd_cpu, {"mlp", "attn"}) + quant_buf = io.BytesIO() + torch.save({"w": quant_result, "m": quant_meta}, quant_buf) + quant_raw = quant_buf.getvalue() + quant_blob = None + chosen_level = _zstd_levels[0] + for lvl in _zstd_levels: + blob = zstandard.ZstdCompressor(level=lvl).compress(quant_raw) if _COMPRESSOR == "zstd" else zlib.compress(quant_raw, 9) + if master_process: + log0(f"quant_try int6 zstd-{lvl}: {len(blob)} bytes (limit {artifact_limit})") + if len(blob) <= artifact_limit: + quant_blob = blob + chosen_level = lvl + break + # Phase 2: progressive int5 fallback — one layer at a time from middle outward + if quant_blob is None: + mid = num_layers_total // 2 + # Expand outward from center: L5, L4, L6, L3, L7, L2, L8, ... + candidates = [] + for offset in range(num_layers_total): + for sign in [0, 1]: + layer = mid + offset if sign == 0 else mid - offset + if 0 <= layer < num_layers_total and layer not in candidates: + candidates.append(layer) + int5_layers: set[int] = set() + for layer in candidates: + int5_layers.add(layer) + if master_process: + log0(f"quant_fallback: int5 layers={sorted(int5_layers)}") + quant_result, quant_meta = mixed_quantize_int6(sd_cpu, {"mlp", "attn"}, int5_layers=int5_layers) + quant_buf = io.BytesIO() + torch.save({"w": quant_result, "m": quant_meta}, quant_buf) + quant_raw = quant_buf.getvalue() + for lvl in _zstd_levels: + blob = zstandard.ZstdCompressor(level=lvl).compress(quant_raw) if _COMPRESSOR == "zstd" else zlib.compress(quant_raw, 9) + if master_process: + log0(f"quant_try int5[{len(int5_layers)}L] zstd-{lvl}: {len(blob)} bytes (limit {artifact_limit})") + if len(blob) <= artifact_limit: + quant_blob = blob + chosen_level = lvl + break + if quant_blob is not None: + break + if quant_blob is None: + quant_blob = blob # Use last attempt even if over limit + if master_process: + log0(f"WARNING: artifact still over limit after all fallbacks") + if master_process: + with open(_ptz_path, "wb") as f: + f.write(quant_blob) + quant_file_bytes = len(quant_blob) + log0(f"Serialized model quant+{_COMPRESSOR}-{chosen_level}: {quant_file_bytes} bytes") + log0(f"Total submission size: {quant_file_bytes + code_bytes} bytes") + + # Roundtrip: decompress + dequantize into fresh model + eval + if distributed: + dist.barrier() + with open(_ptz_path, "rb") as f: + quant_blob_disk = f.read() + quant_state = torch.load( + io.BytesIO(zstandard.ZstdDecompressor().decompress(quant_blob_disk) if _COMPRESSOR == "zstd" else zlib.decompress(quant_blob_disk)), + map_location="cpu", + ) + deq_state = dequantize_mixed_int6(quant_state["w"], quant_state["m"], sd_cpu) + + eval_model = GPT( + vocab_size=args.vocab_size, num_layers=args.num_layers, model_dim=args.model_dim, + num_heads=args.num_heads, num_kv_heads=args.num_kv_heads, mlp_mult=args.mlp_mult, + tie_embeddings=args.tie_embeddings, tied_embed_init_std=args.tied_embed_init_std, + logit_softcap=args.logit_softcap, rope_base=args.rope_base, qk_gain_init=args.qk_gain_init, + mtp_num_heads=0, mtp_loss_weight=0.0, + bigram_vocab_size=args.bigram_vocab_size, bigram_dim=args.bigram_dim, + xsa_last_n=args.xsa_last_n, + rope_dims=args.rope_dims, + ln_scale=args.ln_scale, + value_residual=args.value_residual, + gated_attention=args.gated_attention, + canon_last_n=args.canon_last_n, + canon_kernel=args.canon_kernel, + canon_delta_gate_init=args.canon_delta_gate_init, + flow_enabled=args.flow_enabled, + flow_latent_dim=args.flow_latent_dim, + flow_hidden_dim=args.flow_hidden_dim, + flow_init_scale=args.flow_init_scale, + native_flow_enabled=args.native_flow_enabled, + native_flow_hidden_dim=args.native_flow_hidden_dim, + native_flow_init_scale=args.native_flow_init_scale, + native_flow_loss_weight=args.native_flow_loss_weight, + e2e_ttt_enabled=args.e2e_ttt_enabled, + e2e_ttt_num_heads=args.e2e_ttt_num_heads, + e2e_ttt_mini_batch=args.e2e_ttt_mini_batch, + e2e_ttt_base_lr=args.e2e_ttt_base_lr, + ).to(device).bfloat16() + for m in eval_model.modules(): + if isinstance(m, CastedLinear): + m.float() + restore_low_dim_params_to_fp32(eval_model) + eval_model.load_state_dict(deq_state, strict=True) + + # TTT: adapt model on validation data before eval + if args.ttt_enabled: + if distributed: + dist.barrier() + for block in eval_model.blocks: + block.attn.rotary._cos_cached = None + block.attn.rotary._sin_cached = None + block.attn.rotary._seq_len_cached = 0 + log0(f"ttt:start score-first optimizer={args.ttt_optimizer} lr={args.ttt_lr} " + f"epochs={args.ttt_epochs} freeze_blocks={args.ttt_freeze_blocks} " + f"chunk_tokens={args.ttt_chunk_tokens}") + t_ttt = time.perf_counter() + ttt_adapt(args, eval_model, device, val_tokens, + rank=rank, world_size=world_size, log_fn=log0) + log0(f"ttt:elapsed={time.perf_counter() - t_ttt:.1f}s") + if distributed: + dist.barrier() + + compiled_eval = torch.compile(eval_model, dynamic=False, fullgraph=True) + + # Standard non-overlapping eval (sanity check) + torch.cuda.synchronize() + t_qeval = time.perf_counter() + q_val_loss, q_val_bpb = eval_val( + args, compiled_eval, rank, world_size, device, grad_accum_steps, + val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, + eval_seq_len=effective_eval_seq_len, + ) + torch.cuda.synchronize() + log0( + f"final_int6_roundtrip val_loss:{q_val_loss:.4f} val_bpb:{q_val_bpb:.4f} " + f"eval_time:{1000.0 * (time.perf_counter() - t_qeval):.0f}ms" + ) + log0(f"final_int6_roundtrip_exact val_loss:{q_val_loss:.8f} val_bpb:{q_val_bpb:.8f}") + + # Sliding window eval (submission score) + sw_seq_len = effective_eval_seq_len + if args.eval_stride > 0 and args.eval_stride < sw_seq_len: + torch.cuda.synchronize() + t_slide = time.perf_counter() + sw_val_loss, sw_val_bpb = eval_val_sliding( + args, eval_model, rank, world_size, device, + val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, + stride=args.eval_stride, + eval_seq_len=sw_seq_len, + ) + torch.cuda.synchronize() + log0( + f"final_int6_sliding_window val_loss:{sw_val_loss:.4f} val_bpb:{sw_val_bpb:.4f} " + f"stride:{args.eval_stride} eval_time:{1000.0 * (time.perf_counter() - t_slide):.0f}ms" + ) + log0(f"final_int6_sliding_window_exact val_loss:{sw_val_loss:.8f} val_bpb:{sw_val_bpb:.8f}") + + # Second sliding window eval at stride=64 for submission comparison + if args.eval_stride != 64 and 64 < sw_seq_len: + torch.cuda.synchronize() + t_slide64 = time.perf_counter() + sw64_val_loss, sw64_val_bpb = eval_val_sliding( + args, eval_model, rank, world_size, device, + val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, + stride=64, + eval_seq_len=sw_seq_len, + ) + torch.cuda.synchronize() + log0( + f"final_int6_sliding_window_s64 val_loss:{sw64_val_loss:.4f} val_bpb:{sw64_val_bpb:.4f} " + f"stride:64 eval_time:{1000.0 * (time.perf_counter() - t_slide64):.0f}ms" + ) + log0(f"final_int6_sliding_window_s64_exact val_loss:{sw64_val_loss:.8f} val_bpb:{sw64_val_bpb:.8f}") + + if distributed: + dist.destroy_process_group() + + +if __name__ == "__main__": + main() + + +==================================================================================================== +Running Python 3.12.7 | packaged by conda-forge | (main, Oct 4 2024, 16:05:46) [GCC 13.3.0] +Running PyTorch 2.10.0+cu128 +Sat Mar 28 11:04:46 2026 ++-----------------------------------------------------------------------------------------+ +| NVIDIA-SMI 570.172.08 Driver Version: 570.172.08 CUDA Version: 12.8 | +|-----------------------------------------+------------------------+----------------------+ +| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC | +| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. | +| | | MIG M. | +|=========================================+========================+======================| +| 0 NVIDIA A100-PCIE-40GB On | 00000000:CA:00.0 Off | 0 | +| N/A 33C P0 48W / 250W | 423MiB / 40960MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ + ++-----------------------------------------------------------------------------------------+ +| Processes: | +| GPU GI CI PID Type Process name GPU Memory | +| ID ID Usage | +|=========================================================================================| +| 0 N/A N/A 2989968 C ...ameter_golf/.venv/bin/python3 414MiB | ++-----------------------------------------------------------------------------------------+ + +==================================================================================================== +val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path=/hpfs/scratch/gpfs/mcclec07/code/parameter_golf/repo/data/tokenizers/fineweb_1024_bpe.model +train_loader:dataset:fineweb10B_sp1024 train_shards:80 +val_loader:shards pattern=/hpfs/scratch/gpfs/mcclec07/code/parameter_golf/repo/data/datasets/fineweb10B_sp1024/fineweb_val_*.bin tokens:62021632 +model_params:27530952 +mtp_num_heads:0 mtp_loss_weight:0.2 mtp_params:0 +world_size:1 grad_accum_steps:8 +sdp_backends:fa3=False cudnn=False flash=True mem_efficient=True math=True +attention_mode:gqa num_heads:8 num_kv_heads:4 +tie_embeddings:True embed_lr:0.035 head_lr:0.0 matrix_lr:0.025 scalar_lr:0.025 +train_batch_tokens:786432 train_seq_len:2048 iterations:7000 warmup_steps:20 max_wallclock_seconds:0.000 +seed:42 +warmup_step:1/20 +warmup_step:2/20 +warmup_step:3/20 +warmup_step:4/20 +warmup_step:5/20 +warmup_step:6/20 +warmup_step:7/20 +warmup_step:8/20 +warmup_step:9/20 +warmup_step:10/20 +warmup_step:11/20 +warmup_step:12/20 +warmup_step:13/20 +warmup_step:14/20 +warmup_step:15/20 +warmup_step:16/20 +warmup_step:17/20 +warmup_step:18/20 +warmup_step:19/20 +warmup_step:20/20 +step:0/7000 val_loss:6.9319 val_bpb:4.1055 train_time:0ms step_avg:0.01ms +step:1/7000 train_loss:7.1329 train_time:1910ms step_avg:1910.05ms +step:2/7000 train_loss:9.1205 train_time:3872ms step_avg:1935.75ms +step:3/7000 train_loss:8.0686 train_time:5836ms step_avg:1945.21ms +step:4/7000 train_loss:7.3950 train_time:7800ms step_avg:1949.95ms +step:5/7000 train_loss:7.1158 train_time:9768ms step_avg:1953.57ms +step:6/7000 train_loss:7.1063 train_time:11738ms step_avg:1956.33ms +step:7/7000 train_loss:6.9423 train_time:13711ms step_avg:1958.76ms +step:8/7000 train_loss:6.7316 train_time:15684ms step_avg:1960.48ms +step:9/7000 train_loss:6.4037 train_time:17659ms step_avg:1962.13ms +step:10/7000 train_loss:6.1082 train_time:19636ms step_avg:1963.62ms +step:200/7000 train_loss:2.7242 train_time:396315ms step_avg:1981.58ms +step:400/7000 train_loss:2.5459 train_time:791709ms step_avg:1979.27ms +step:500/7000 val_loss:2.3323 val_bpb:1.3813 train_time:989535ms step_avg:1979.07ms +step:600/7000 train_loss:2.3883 train_time:1187390ms step_avg:1978.98ms +step:800/7000 train_loss:2.3764 train_time:1583181ms step_avg:1978.98ms +step:1000/7000 train_loss:2.3550 train_time:1979099ms step_avg:1979.10ms +step:1000/7000 val_loss:2.2047 val_bpb:1.3058 train_time:1979102ms step_avg:1979.10ms +step:1200/7000 train_loss:2.3096 train_time:2375020ms step_avg:1979.18ms +step:1400/7000 train_loss:2.3519 train_time:2771203ms step_avg:1979.43ms +step:1500/7000 val_loss:2.1639 val_bpb:1.2816 train_time:2969412ms step_avg:1979.61ms +step:1600/7000 train_loss:2.1972 train_time:3167603ms step_avg:1979.75ms +step:1800/7000 train_loss:2.2434 train_time:3564368ms step_avg:1980.20ms +step:2000/7000 train_loss:2.1350 train_time:3961132ms step_avg:1980.57ms +step:2000/7000 val_loss:2.1104 val_bpb:1.2499 train_time:3961136ms step_avg:1980.57ms +step:2200/7000 train_loss:2.2132 train_time:4357740ms step_avg:1980.79ms +step:2400/7000 train_loss:2.1653 train_time:4754499ms step_avg:1981.04ms +step:2500/7000 val_loss:2.0861 val_bpb:1.2355 train_time:4952992ms step_avg:1981.20ms +step:2600/7000 train_loss:2.1657 train_time:5151477ms step_avg:1981.34ms +step:2800/7000 train_loss:2.2096 train_time:5548340ms step_avg:1981.55ms +step:3000/7000 train_loss:2.1612 train_time:5945254ms step_avg:1981.75ms +step:3000/7000 val_loss:2.0739 val_bpb:1.2283 train_time:5945269ms step_avg:1981.76ms +step:3200/7000 train_loss:2.1662 train_time:6342108ms step_avg:1981.91ms +step:3400/7000 train_loss:2.1383 train_time:6738727ms step_avg:1981.98ms +step:3500/7000 val_loss:2.0657 val_bpb:1.2234 train_time:6937035ms step_avg:1982.01ms +step:3600/7000 train_loss:2.1801 train_time:7135459ms step_avg:1982.07ms +step:3800/7000 train_loss:2.1531 train_time:7532083ms step_avg:1982.13ms +step:4000/7000 train_loss:2.2310 train_time:7928770ms step_avg:1982.19ms +step:4000/7000 val_loss:2.0597 val_bpb:1.2199 train_time:7928773ms step_avg:1982.19ms +step:4200/7000 train_loss:2.1213 train_time:8325600ms step_avg:1982.29ms +step:4400/7000 train_loss:2.1049 train_time:8722496ms step_avg:1982.39ms +step:4500/7000 val_loss:2.0440 val_bpb:1.2106 train_time:8920832ms step_avg:1982.41ms +step:4600/7000 train_loss:2.0877 train_time:9119160ms step_avg:1982.43ms +step:4800/7000 train_loss:2.2189 train_time:9515876ms step_avg:1982.47ms +step:5000/7000 train_loss:2.1324 train_time:9912614ms step_avg:1982.52ms +step:5000/7000 val_loss:2.0220 val_bpb:1.1975 train_time:9912617ms step_avg:1982.52ms +step:5200/7000 train_loss:2.1037 train_time:10309364ms step_avg:1982.57ms +step:5400/7000 train_loss:2.0787 train_time:10705965ms step_avg:1982.59ms +step:5500/7000 val_loss:1.9997 val_bpb:1.1843 train_time:10904304ms step_avg:1982.60ms +step:5600/7000 train_loss:2.0623 train_time:11102510ms step_avg:1982.59ms +step:5800/7000 train_loss:2.0481 train_time:11499073ms step_avg:1982.60ms +step:6000/7000 train_loss:2.0294 train_time:11895613ms step_avg:1982.60ms +step:6000/7000 val_loss:1.9767 val_bpb:1.1707 train_time:11895616ms step_avg:1982.60ms +step:6200/7000 train_loss:2.1229 train_time:12292332ms step_avg:1982.63ms +step:6400/7000 train_loss:2.1040 train_time:12689010ms step_avg:1982.66ms +step:6500/7000 val_loss:1.9464 val_bpb:1.1527 train_time:12887389ms step_avg:1982.68ms +step:6600/7000 train_loss:2.0020 train_time:13085769ms step_avg:1982.69ms +step:6800/7000 train_loss:2.0838 train_time:13482720ms step_avg:1982.75ms +step:7000/7000 train_loss:1.9271 train_time:13879411ms step_avg:1982.77ms +step:7000/7000 val_loss:1.9215 val_bpb:1.1380 train_time:13879414ms step_avg:1982.77ms +peak memory allocated: 25832 MiB reserved: 26006 MiB +ema:applying EMA weights +save_paths: pt=final_model_pr940_nflow_55342820.pt ptz=final_model_pr940_nflow_55342820.int6.ptz +Serialized model: 107292674 bytes +Code size: 104738 bytes +quant_try int6 zstd-16: 16242245 bytes (limit 15895262) +quant_try int6 zstd-1: 16298729 bytes (limit 15895262) +quant_try int6 zstd-17: 16250339 bytes (limit 15895262) +quant_try int6 zstd-2: 16301974 bytes (limit 15895262) +quant_fallback: int5 layers=[5] +quant_try int5[1L] zstd-16: 15927696 bytes (limit 15895262) +quant_try int5[1L] zstd-1: 16001920 bytes (limit 15895262) +quant_try int5[1L] zstd-17: 16297960 bytes (limit 15895262) +quant_try int5[1L] zstd-2: 16053283 bytes (limit 15895262) +quant_fallback: int5 layers=[5, 6] +quant_try int5[2L] zstd-16: 15630744 bytes (limit 15895262) +Serialized model quant+zstd-16: 15630744 bytes +Total submission size: 15735482 bytes +final_int6_roundtrip val_loss:1.9363 val_bpb:1.1468 eval_time:67775ms +final_int6_roundtrip_exact val_loss:1.93630746 val_bpb:1.14679034 +final_int6_sliding_window val_loss:1.8963 val_bpb:1.1231 stride:64 eval_time:1713117ms +final_int6_sliding_window_exact val_loss:1.89632895 val_bpb:1.12311579 diff --git a/records/track_non_record_16mb/2026-03-31_11L_NativeFlowMatcher_LegalTTT/train_gpt.py b/records/track_non_record_16mb/2026-03-31_11L_NativeFlowMatcher_LegalTTT/train_gpt.py new file mode 100644 index 0000000000..e6e60b09c3 --- /dev/null +++ b/records/track_non_record_16mb/2026-03-31_11L_NativeFlowMatcher_LegalTTT/train_gpt.py @@ -0,0 +1,2601 @@ +""" +train_gpt_submit.py — Submission v2: wider MLP + STE int6 QAT + MTP + seq2048 + NTK RoPE + +fp16 embed + late-K passthrough + sliding window eval. +""" + +from __future__ import annotations + +import copy +import glob +import io +import math +import os +import random +import subprocess +import sys +import time +import uuid +import zlib +from pathlib import Path + +try: + import zstandard + _COMPRESSOR = "zstd" +except ImportError: + _COMPRESSOR = "zlib" + +import numpy as np +import sentencepiece as spm +import torch +import torch.distributed as dist +import torch.nn.functional as F +from torch import Tensor, nn +from torch.nn.parallel import DistributedDataParallel as DDP + +try: + from flash_attn_interface import flash_attn_func as flash_attn_3_func + _USE_FA3 = True +except ImportError: + try: + from flash_attn import flash_attn_func as flash_attn_3_func + _USE_FA3 = True + except ImportError: + _USE_FA3 = False + +# ----------------------------- +# HYPERPARAMETERS +# ----------------------------- +# Default Simple Baseline run: +# - 9 transformer blocks at width 512 +# - 8 attention heads with 4 KV heads (GQA) and 2x MLP expansion +# - vocab size 1024, sequence length 1024, tied embeddings +# - 524,288 train tokens per step for 20,000 iterations with a ~10 minute cap + +class Hyperparameters: + # Data paths are shard globs produced by the existing preprocessing pipeline. + data_path = os.environ.get("DATA_PATH", "./data/datasets/fineweb10B_sp1024") + train_files = os.path.join(data_path, "fineweb_train_*.bin") + val_files = os.path.join(data_path, "fineweb_val_*.bin") + tokenizer_path = os.environ.get("TOKENIZER_PATH", "./data/tokenizers/fineweb_1024_bpe.model") + run_id = os.environ.get("RUN_ID", str(uuid.uuid4())) + seed = int(os.environ.get("SEED", 1337)) + + # Validation cadence and batch size. Validation always uses the full fineweb_val split. + val_batch_size = int(os.environ.get("VAL_BATCH_SIZE", 524_288)) + val_loss_every = int(os.environ.get("VAL_LOSS_EVERY", 1000)) + train_log_every = int(os.environ.get("TRAIN_LOG_EVERY", 200)) + checkpoint_every = int(os.environ.get("CHECKPOINT_EVERY", 0)) + + # Training length. + iterations = int(os.environ.get("ITERATIONS", 20000)) + warmdown_iters = int(os.environ.get("WARMDOWN_ITERS", 3000)) + warmup_steps = int(os.environ.get("WARMUP_STEPS", 20)) + train_batch_tokens = int(os.environ.get("TRAIN_BATCH_TOKENS", 786_432)) + train_seq_len = int(os.environ.get("TRAIN_SEQ_LEN", 2048)) + eval_seq_len = int(os.environ.get("EVAL_SEQ_LEN", 2048)) + max_wallclock_seconds = float(os.environ.get("MAX_WALLCLOCK_SECONDS", 600.0)) + qk_gain_init = float(os.environ.get("QK_GAIN_INIT", 1.5)) + + # Model shape. + vocab_size = int(os.environ.get("VOCAB_SIZE", 1024)) + num_layers = int(os.environ.get("NUM_LAYERS", 11)) + num_kv_heads = int(os.environ.get("NUM_KV_HEADS", 4)) + model_dim = int(os.environ.get("MODEL_DIM", 512)) + num_heads = int(os.environ.get("NUM_HEADS", 8)) + mlp_mult = float(os.environ.get("MLP_MULT", 3.0)) + tie_embeddings = bool(int(os.environ.get("TIE_EMBEDDINGS", "1"))) + rope_base = float(os.environ.get("ROPE_BASE", 10000.0)) + logit_softcap = float(os.environ.get("LOGIT_SOFTCAP", 30.0)) + + # Optimizer hyperparameters. + embed_lr = float(os.environ.get("EMBED_LR", 0.6)) + head_lr = float(os.environ.get("HEAD_LR", 0.008)) + tied_embed_lr = float(os.environ.get("TIED_EMBED_LR", 0.035)) + tied_embed_init_std = float(os.environ.get("TIED_EMBED_INIT_STD", 0.005)) + matrix_lr = float(os.environ.get("MATRIX_LR", 0.025)) + scalar_lr = float(os.environ.get("SCALAR_LR", 0.025)) + muon_momentum = float(os.environ.get("MUON_MOMENTUM", 0.99)) + muon_backend_steps = int(os.environ.get("MUON_BACKEND_STEPS", 5)) + muon_momentum_warmup_start = float(os.environ.get("MUON_MOMENTUM_WARMUP_START", 0.92)) + muon_momentum_warmup_steps = int(os.environ.get("MUON_MOMENTUM_WARMUP_STEPS", 1500)) + beta1 = float(os.environ.get("BETA1", 0.9)) + beta2 = float(os.environ.get("BETA2", 0.95)) + adam_eps = float(os.environ.get("ADAM_EPS", 1e-8)) + grad_clip_norm = float(os.environ.get("GRAD_CLIP_NORM", 0.3)) + eval_stride = int(os.environ.get("EVAL_STRIDE", 64)) + mtp_num_heads = int(os.environ.get("MTP_NUM_HEADS", 0)) + mtp_loss_weight = float(os.environ.get("MTP_LOSS_WEIGHT", 0.2)) + muon_beta2 = float(os.environ.get("MUON_BETA2", 0.95)) + swa_enabled = bool(int(os.environ.get("SWA_ENABLED", "0"))) + swa_every = int(os.environ.get("SWA_EVERY", 200)) + muon_wd = float(os.environ.get("MUON_WD", 0.04)) + adam_wd = float(os.environ.get("ADAM_WD", 0.04)) + qat_enabled = bool(int(os.environ.get("QAT_ENABLED", "0"))) + xsa_last_n = int(os.environ.get("XSA_LAST_N", 11)) + ema_enabled = bool(int(os.environ.get("EMA_ENABLED", "1"))) + ema_decay = float(os.environ.get("EMA_DECAY", 0.997)) + rope_dims = int(os.environ.get("ROPE_DIMS", 16)) + ln_scale = bool(int(os.environ.get("LN_SCALE", "1"))) + late_qat = bool(int(os.environ.get("LATE_QAT", "0"))) + soft_round_qat = bool(int(os.environ.get("SOFT_ROUND_QAT", "0"))) + value_residual = bool(int(os.environ.get("VALUE_RESIDUAL", "1"))) + gated_attention = bool(int(os.environ.get("GATED_ATTENTION", "1"))) + canon_last_n = int(os.environ.get("CANON_LAST_N", 0)) + canon_kernel = int(os.environ.get("CANON_KERNEL", 4)) + canon_delta_gate_init = float(os.environ.get("CANON_DELTA_GATE_INIT", -4.0)) + bigram_vocab_size = int(os.environ.get("BIGRAM_VOCAB_SIZE", 4096)) + bigram_dim = int(os.environ.get("BIGRAM_DIM", 128)) + + # TTT (Test-Time Training) — score-first, backward-looking + ttt_enabled = bool(int(os.environ.get("TTT_ENABLED", "0"))) + ttt_optimizer = os.environ.get("TTT_OPTIMIZER", "adamw") # "sgd" or "adamw" + ttt_lr = float(os.environ.get("TTT_LR", 0.0001)) + ttt_epochs = int(os.environ.get("TTT_EPOCHS", 4)) + ttt_momentum = float(os.environ.get("TTT_MOMENTUM", 0.9)) + ttt_batch_seqs = int(os.environ.get("TTT_BATCH_SEQS", 32)) + ttt_freeze_blocks = int(os.environ.get("TTT_FREEZE_BLOCKS", 2)) + ttt_chunk_tokens = int(os.environ.get("TTT_CHUNK_TOKENS", 131072)) + + # FlowRefiner — additive latent-space refinement + flow_enabled = bool(int(os.environ.get("FLOW_ENABLED", "0"))) + flow_latent_dim = int(os.environ.get("FLOW_LATENT_DIM", "64")) + flow_hidden_dim = int(os.environ.get("FLOW_HIDDEN_DIM", "256")) + flow_init_scale = float(os.environ.get("FLOW_INIT_SCALE", "0.01")) + + # NativeFlowMatcher — time-conditioned CFM on hidden states + native_flow_enabled = bool(int(os.environ.get("NATIVE_FLOW_ENABLED", "0"))) + native_flow_hidden_dim = int(os.environ.get("NATIVE_FLOW_HIDDEN_DIM", "256")) + native_flow_init_scale = float(os.environ.get("NATIVE_FLOW_INIT_SCALE", "0.01")) + native_flow_loss_weight = float(os.environ.get("NATIVE_FLOW_LOSS_WEIGHT", "0.1")) + + # E2E TTT-Linear refiner (Sun et al., 2024) + e2e_ttt_enabled = bool(int(os.environ.get("E2E_TTT_ENABLED", "0"))) + e2e_ttt_num_heads = int(os.environ.get("E2E_TTT_NUM_HEADS", "8")) + e2e_ttt_mini_batch = int(os.environ.get("E2E_TTT_MINI_BATCH", "16")) + e2e_ttt_base_lr = float(os.environ.get("E2E_TTT_BASE_LR", "1.0")) + +# ----------------------------- +# MUON OPTIMIZER +# ----------------------------- +# +# As borrowed from modded-nanogpt +# Background on Muon: https://kellerjordan.github.io/posts/muon/ + +def zeropower_via_newtonschulz5(G: Tensor, steps: int = 10, eps: float = 1e-7) -> Tensor: + a, b, c = (3.4445, -4.7750, 2.0315) + X = G.bfloat16() + X /= X.norm() + eps + transposed = G.size(0) > G.size(1) + if transposed: + X = X.T + for _ in range(steps): + A = X @ X.T + B = b * A + c * A @ A + X = a * X + B @ X + return X.T if transposed else X + + +class Muon(torch.optim.Optimizer): + def __init__(self, params, lr: float, momentum: float, backend_steps: int, + nesterov: bool = True, weight_decay: float = 0.0): + super().__init__( + params, + dict(lr=lr, momentum=momentum, backend_steps=backend_steps, + nesterov=nesterov, weight_decay=weight_decay), + ) + + @torch.no_grad() + def step(self, closure=None): + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + + distributed = dist.is_available() and dist.is_initialized() + world_size = dist.get_world_size() if distributed else 1 + rank = dist.get_rank() if distributed else 0 + + for group in self.param_groups: + params = group["params"] + if not params: + continue + lr = group["lr"] + momentum = group["momentum"] + backend_steps = group["backend_steps"] + nesterov = group["nesterov"] + + total_params = sum(int(p.numel()) for p in params) + updates_flat = torch.zeros(total_params, device=params[0].device, dtype=torch.bfloat16) + + curr = 0 + for i, p in enumerate(params): + if i % world_size == rank and p.grad is not None: + g = p.grad + state = self.state[p] + if "momentum_buffer" not in state: + state["momentum_buffer"] = torch.zeros_like(g) + buf = state["momentum_buffer"] + buf.mul_(momentum).add_(g) + if nesterov: + g = g.add(buf, alpha=momentum) + g = zeropower_via_newtonschulz5(g, steps=backend_steps) + g *= max(1, g.size(0) / g.size(1)) ** 0.5 + updates_flat[curr : curr + p.numel()] = g.reshape(-1) + curr += p.numel() + + if distributed: + dist.all_reduce(updates_flat, op=dist.ReduceOp.SUM) + + wd = group.get("weight_decay", 0.0) + curr = 0 + for p in params: + if wd > 0.0: + p.data.mul_(1.0 - lr * wd) + g = updates_flat[curr : curr + p.numel()].view_as(p).to(dtype=p.dtype) + p.add_(g, alpha=-lr) + curr += p.numel() + + return loss + + +# ----------------------------- +# TOKENIZER-AGNOSTIC EVALUATION SETUP +# ----------------------------- +# +# It's common for small models have a large fraction of their parameters be embeddings, since the 2 * d_model * d_vocab vectors can be gigantic. +# Instead of locking the tokenizer, we let you bring your own and calculate our validation metrics on the average compression of the validation set. +# We calculate BPB (bits-per-byte) instead of validation loss, so we need methods to count the number of bits per token in the tokenizer. +# Note: Submissions that edit the tokenizer will be examined more carefully, since screwing this up might unjustly improve your score. + +def build_sentencepiece_luts( + sp: spm.SentencePieceProcessor, vocab_size: int, device: torch.device +) -> tuple[Tensor, Tensor, Tensor]: + sp_vocab_size = int(sp.vocab_size()) + table_size = max(sp_vocab_size, vocab_size) + base_bytes_np = np.zeros((table_size,), dtype=np.int16) + has_leading_space_np = np.zeros((table_size,), dtype=np.bool_) + is_boundary_token_np = np.ones((table_size,), dtype=np.bool_) + for token_id in range(sp_vocab_size): + if sp.is_control(token_id) or sp.is_unknown(token_id) or sp.is_unused(token_id): + continue + is_boundary_token_np[token_id] = False + if sp.is_byte(token_id): + base_bytes_np[token_id] = 1 + continue + piece = sp.id_to_piece(token_id) + if piece.startswith("▁"): + has_leading_space_np[token_id] = True + piece = piece[1:] + base_bytes_np[token_id] = len(piece.encode("utf-8")) + return ( + torch.tensor(base_bytes_np, dtype=torch.int16, device=device), + torch.tensor(has_leading_space_np, dtype=torch.bool, device=device), + torch.tensor(is_boundary_token_np, dtype=torch.bool, device=device), + ) + + +def load_validation_tokens(pattern: str, seq_len: int) -> Tensor: + files = [Path(p) for p in sorted(glob.glob(pattern))] + if not files: + raise FileNotFoundError(f"No files found for pattern: {pattern}") + # The export pipeline writes the fixed first-50k-doc validation set to fineweb_val_*. + tokens = torch.cat([load_data_shard(file) for file in files]).contiguous() + usable = ((tokens.numel() - 1) // seq_len) * seq_len + if usable <= 0: + raise ValueError(f"Validation split is too short for TRAIN_SEQ_LEN={seq_len}") + return tokens[: usable + 1] + + +def eval_val( + args: Hyperparameters, + model: nn.Module, + rank: int, + world_size: int, + device: torch.device, + grad_accum_steps: int, + val_tokens: Tensor, + base_bytes_lut: Tensor, + has_leading_space_lut: Tensor, + is_boundary_token_lut: Tensor, + eval_seq_len: int | None = None, +) -> tuple[float, float]: + seq_len = eval_seq_len or args.train_seq_len + local_batch_tokens = args.val_batch_size // (world_size * grad_accum_steps) + if local_batch_tokens < seq_len: + raise ValueError( + "VAL_BATCH_SIZE must provide at least one sequence per rank; " + f"got VAL_BATCH_SIZE={args.val_batch_size}, WORLD_SIZE={world_size}, " + f"GRAD_ACCUM_STEPS={grad_accum_steps}, seq_len={seq_len}" + ) + local_batch_seqs = local_batch_tokens // seq_len + total_seqs = (val_tokens.numel() - 1) // seq_len + seq_start = (total_seqs * rank) // world_size + seq_end = (total_seqs * (rank + 1)) // world_size + val_loss_sum = torch.zeros((), device=device, dtype=torch.float64) + val_token_count = torch.zeros((), device=device, dtype=torch.float64) + val_byte_count = torch.zeros((), device=device, dtype=torch.float64) + + model.eval() + with torch.inference_mode(): + for batch_seq_start in range(seq_start, seq_end, local_batch_seqs): + batch_seq_end = min(batch_seq_start + local_batch_seqs, seq_end) + raw_start = batch_seq_start * seq_len + raw_end = batch_seq_end * seq_len + 1 + local = val_tokens[raw_start:raw_end].to(device=device, dtype=torch.int64, non_blocking=True) + x = local[:-1].reshape(-1, seq_len) + y = local[1:].reshape(-1, seq_len) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + batch_loss = model(x, y).detach() + batch_token_count = float(y.numel()) + val_loss_sum += batch_loss.to(torch.float64) * batch_token_count + val_token_count += batch_token_count + prev_ids = x.reshape(-1) + tgt_ids = y.reshape(-1) + token_bytes = base_bytes_lut[tgt_ids].to(dtype=torch.int16) + token_bytes += (has_leading_space_lut[tgt_ids] & ~is_boundary_token_lut[prev_ids]).to(dtype=torch.int16) + val_byte_count += token_bytes.to(torch.float64).sum() + + if dist.is_available() and dist.is_initialized(): + dist.all_reduce(val_loss_sum, op=dist.ReduceOp.SUM) + dist.all_reduce(val_token_count, op=dist.ReduceOp.SUM) + dist.all_reduce(val_byte_count, op=dist.ReduceOp.SUM) + + val_loss = val_loss_sum / val_token_count + bits_per_token = val_loss.item() / math.log(2.0) + tokens_per_byte = val_token_count.item() / val_byte_count.item() + model.train() + return float(val_loss.item()), float(bits_per_token * tokens_per_byte) + +# ----------------------------- +# POST-TRAINING QUANTIZATION +# ----------------------------- +# +# It's silly to export our model, which is trained in bf16 and fp32, at that same precision. +# Instead, we get approximately the same model (with a small hit) by quantizing the model to int8 & zlib compressing. +# We can then decompress the model and run in higher precision for evaluation, after closing in under the size limit. + +CONTROL_TENSOR_NAME_PATTERNS = tuple( + pattern + for pattern in os.environ.get( + "CONTROL_TENSOR_NAME_PATTERNS", + "attn_scale,attn_scales,mlp_scale,mlp_scales,resid_mix,resid_mixes,q_gain,skip_weight,skip_weights,smear,vr_lambda,attn_gate,canon_a,canon_c,delta_gate", + ).split(",") + if pattern +) +INT8_KEEP_FLOAT_FP32_NAME_PATTERNS = tuple( + pattern + for pattern in os.environ.get( + "INT8_KEEP_FLOAT_FP32_NAME_PATTERNS", + ",".join(CONTROL_TENSOR_NAME_PATTERNS), + ).split(",") + if pattern +) +INT8_KEEP_FLOAT_MAX_NUMEL = 65_536 +INT8_KEEP_FLOAT_STORE_DTYPE = torch.float16 +INT8_PER_ROW_SCALE_DTYPE = torch.float16 +INT8_CLIP_PERCENTILE = 99.99984 +INT8_CLIP_Q = INT8_CLIP_PERCENTILE / 100.0 + +def tensor_nbytes(t: Tensor) -> int: + return int(t.numel()) * int(t.element_size()) + +def keep_float_tensor(name: str, t: Tensor, passthrough_orig_dtypes: dict[str, str]) -> Tensor: + if any(pattern in name for pattern in INT8_KEEP_FLOAT_FP32_NAME_PATTERNS): + return t.float().contiguous() + if t.dtype in {torch.float32, torch.bfloat16}: + passthrough_orig_dtypes[name] = str(t.dtype).removeprefix("torch.") + return t.to(dtype=INT8_KEEP_FLOAT_STORE_DTYPE).contiguous() + return t + +def quantize_float_tensor(t: Tensor, qmax: int = 127) -> tuple[Tensor, Tensor]: + """Quantize to [-qmax, qmax] range. Default int8 (qmax=127), int6 (qmax=31), int5 (qmax=15).""" + t32 = t.float() + qmin = -qmax + if t32.ndim == 2: + clip_abs = ( + torch.quantile(t32.abs(), INT8_CLIP_Q, dim=1) + if t32.numel() + else torch.empty((t32.shape[0],), dtype=torch.float32) + ) + clipped = torch.maximum(torch.minimum(t32, clip_abs[:, None]), -clip_abs[:, None]) + scale = (clip_abs / float(qmax)).clamp_min(1.0 / float(qmax)) + q = torch.clamp(torch.round(clipped / scale[:, None]), qmin, qmax).to(torch.int8).contiguous() + return q, scale.to(dtype=INT8_PER_ROW_SCALE_DTYPE).contiguous() + + clip_abs = float(torch.quantile(t32.abs().flatten(), INT8_CLIP_Q).item()) if t32.numel() else 0.0 + scale = torch.tensor(clip_abs / float(qmax) if clip_abs > 0 else 1.0, dtype=torch.float32) + q = torch.clamp(torch.round(torch.clamp(t32, -clip_abs, clip_abs) / scale), qmin, qmax).to(torch.int8).contiguous() + return q, scale + +def quantize_state_dict_int8(state_dict: dict[str, Tensor]): + # Single supported clean-script export format: + # - per-row int8 for 2D float tensors + # - per-tensor int8 for other float tensors + # - exact passthrough for non-floats + # - passthrough for small float tensors, stored as fp16 to save bytes + quantized: dict[str, Tensor] = {} + scales: dict[str, Tensor] = {} + dtypes: dict[str, str] = {} + passthrough: dict[str, Tensor] = {} + passthrough_orig_dtypes: dict[str, str] = {} + qmeta: dict[str, dict[str, object]] = {} + stats = dict.fromkeys( + ("param_count", "num_tensors", "num_float_tensors", "num_nonfloat_tensors", "baseline_tensor_bytes", "int8_payload_bytes"), + 0, + ) + + for name, tensor in state_dict.items(): + t = tensor.detach().to("cpu").contiguous() + stats["param_count"] += int(t.numel()) + stats["num_tensors"] += 1 + stats["baseline_tensor_bytes"] += tensor_nbytes(t) + + if not t.is_floating_point(): + stats["num_nonfloat_tensors"] += 1 + passthrough[name] = t + stats["int8_payload_bytes"] += tensor_nbytes(t) + continue + + # Small float tensors are cheap enough to keep directly. We still downcast + # fp32/bf16 passthrough tensors to fp16 so metadata does not dominate size. + if t.numel() <= INT8_KEEP_FLOAT_MAX_NUMEL: + kept = keep_float_tensor(name, t, passthrough_orig_dtypes) + passthrough[name] = kept + stats["int8_payload_bytes"] += tensor_nbytes(kept) + continue + + stats["num_float_tensors"] += 1 + # Mixed quantization: int6 for MLP layers 3-7 to save artifact space + int6_mlp_layers = os.environ.get("INT6_MLP_LAYERS", "") + qmax = 127 # default int8 + if int6_mlp_layers: + for li in int6_mlp_layers.split(","): + if li.strip() and f"blocks.{li.strip()}.mlp" in name and t.ndim == 2: + qmax = 31 # int6 + break + q, s = quantize_float_tensor(t, qmax=qmax) + if s.ndim > 0: + qmeta[name] = {"scheme": "per_row", "axis": 0} + quantized[name] = q + scales[name] = s + dtypes[name] = str(t.dtype).removeprefix("torch.") + stats["int8_payload_bytes"] += tensor_nbytes(q) + tensor_nbytes(s) + + obj: dict[str, object] = { + "__quant_format__": "int8_clean_per_row_v1", + "quantized": quantized, + "scales": scales, + "dtypes": dtypes, + "passthrough": passthrough, + } + if qmeta: + obj["qmeta"] = qmeta + if passthrough_orig_dtypes: + obj["passthrough_orig_dtypes"] = passthrough_orig_dtypes + return obj, stats + +def dequantize_state_dict_int8(obj: dict[str, object]) -> dict[str, Tensor]: + out: dict[str, Tensor] = {} + qmeta = obj.get("qmeta", {}) + passthrough_orig_dtypes = obj.get("passthrough_orig_dtypes", {}) + for name, q in obj["quantized"].items(): + dtype = getattr(torch, obj["dtypes"][name]) + s = obj["scales"][name] + if qmeta.get(name, {}).get("scheme") == "per_row" or s.ndim > 0: + s = s.to(dtype=torch.float32) + # Broadcast the saved row scale back across trailing dimensions. + out[name] = (q.float() * s.view(q.shape[0], *([1] * (q.ndim - 1)))).to(dtype=dtype).contiguous() + else: + scale = float(s.item()) + out[name] = (q.float() * scale).to(dtype=dtype).contiguous() + for name, t in obj["passthrough"].items(): + # Restore small tensors, undoing the temporary fp16 storage cast if needed. + out_t = t.detach().to("cpu").contiguous() + orig_dtype = passthrough_orig_dtypes.get(name) + if isinstance(orig_dtype, str): + out_t = out_t.to(dtype=getattr(torch, orig_dtype)).contiguous() + out[name] = out_t + return out + + +# ----------------------------- +# DATA LOADING +# ----------------------------- + +def load_data_shard(file: Path) -> Tensor: + header_bytes = 256 * np.dtype(" None: + self.file_idx = (self.file_idx + 1) % len(self.files) + self.tokens = load_data_shard(self.files[self.file_idx]) + self.pos = 0 + + def take(self, n: int) -> Tensor: + chunks: list[Tensor] = [] + remaining = n + while remaining > 0: + avail = self.tokens.numel() - self.pos + if avail <= 0: + self._advance_file() + continue + k = min(remaining, avail) + chunks.append(self.tokens[self.pos : self.pos + k]) + self.pos += k + remaining -= k + return chunks[0] if len(chunks) == 1 else torch.cat(chunks) + + +class DistributedTokenLoader: + # Each call consumes a contiguous chunk from the shared token stream, then slices out + # one disjoint span per rank. The extra "+1" token lets us build (x, y) by shifting. + def __init__(self, pattern: str, rank: int, world_size: int, device: torch.device): + self.rank = rank + self.world_size = world_size + self.device = device + self.stream = TokenStream(pattern) + + def next_batch(self, global_tokens: int, seq_len: int, grad_accum_steps: int) -> tuple[Tensor, Tensor]: + local_tokens = global_tokens // (self.world_size * grad_accum_steps) + per_rank_span = local_tokens + 1 + chunk = self.stream.take(per_rank_span * self.world_size) + start = self.rank * per_rank_span + local = chunk[start : start + per_rank_span].to(dtype=torch.int64) + x = local[:-1].reshape(-1, seq_len) + y = local[1:].reshape(-1, seq_len) + return x.to(self.device, non_blocking=True), y.to(self.device, non_blocking=True) + +# ----------------------------- +# TRANSFORMER MODULES +# ----------------------------- + +class RMSNorm(nn.Module): + def __init__(self, eps: float | None = None): + super().__init__() + self.eps = eps + + def forward(self, x: Tensor) -> Tensor: + return F.rms_norm(x, (x.size(-1),), eps=self.eps) + + +class CastedLinear(nn.Linear): + _qat_enabled: bool = False + _soft_round: bool = False + _soft_round_alpha: float = 1.0 + _quant_percentile: float = float(os.environ.get("QUANT_PERCENTILE", "1.0")) + + def forward(self, x: Tensor) -> Tensor: + w = self.weight.to(x.dtype) + if CastedLinear._qat_enabled and self.training and w.ndim == 2: + w32 = self.weight.float() + pct = CastedLinear._quant_percentile + row_max = (torch.quantile(w32.abs(), pct, dim=1) if pct < 1.0 + else w32.abs().amax(dim=1)).detach() + scale = (row_max / 31.0).clamp_min(1.0 / 31.0) + r = w32 / scale[:, None] + if CastedLinear._soft_round: + alpha = CastedLinear._soft_round_alpha + r_frac = r - r.detach().floor() - 0.5 + norm = torch.tanh(torch.tensor(alpha * 0.5, device=r.device, dtype=r.dtype)) + r_soft = r.detach().floor() + 0.5 + torch.tanh(alpha * r_frac) / (2.0 * norm) + w_q = (torch.clamp(r_soft, -32, 31) * scale[:, None]).to(x.dtype) + w = w_q # soft-round is differentiable, no STE needed + else: + with torch.no_grad(): + w_q = (torch.clamp(torch.round(r), -32, 31) * scale[:, None]).to(x.dtype) + w = w + (w_q - w).detach() # STE + bias = self.bias.to(x.dtype) if self.bias is not None else None + return F.linear(x, w, bias) + + +def restore_low_dim_params_to_fp32(module: nn.Module) -> None: + # Keep small/control parameters in fp32 even when the model body runs in bf16. + with torch.no_grad(): + for name, param in module.named_parameters(): + if (param.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS)) and param.dtype != torch.float32: + param.data = param.data.float() + + +class Rotary(nn.Module): + # NTK-aware RoPE: auto-scales base frequency when seq_len exceeds train_seq_len. + def __init__(self, dim: int, base: float = 10000.0, train_seq_len: int = 1024, rope_dims: int = 0): + super().__init__() + self.rope_dims = rope_dims if rope_dims > 0 else dim + self.dim = dim + self.base = base + self.train_seq_len = train_seq_len + rd = self.rope_dims + inv_freq = 1.0 / (base ** (torch.arange(0, rd, 2, dtype=torch.float32) / rd)) + self.register_buffer("inv_freq", inv_freq, persistent=False) + self._seq_len_cached = 0 + self._cos_cached: Tensor | None = None + self._sin_cached: Tensor | None = None + + def forward(self, seq_len: int, device: torch.device, dtype: torch.dtype) -> tuple[Tensor, Tensor]: + if ( + self._cos_cached is None + or self._sin_cached is None + or self._seq_len_cached != seq_len + or self._cos_cached.device != device + ): + rd = self.rope_dims + if seq_len > self.train_seq_len: + scale = seq_len / self.train_seq_len + new_base = self.base * (scale ** (rd / (rd - 2))) + inv_freq = 1.0 / (new_base ** (torch.arange(0, rd, 2, dtype=torch.float32, device=device) / rd)) + else: + inv_freq = self.inv_freq.to(device) + t = torch.arange(seq_len, device=device, dtype=inv_freq.dtype) + freqs = torch.outer(t, inv_freq) + self._cos_cached = freqs.cos()[None, :, None, :] + self._sin_cached = freqs.sin()[None, :, None, :] + self._seq_len_cached = seq_len + return self._cos_cached.to(dtype=dtype), self._sin_cached.to(dtype=dtype) + + +def apply_rotary_emb(x: Tensor, cos: Tensor, sin: Tensor) -> Tensor: + rd = cos.size(-1) * 2 + if rd < x.size(-1): + x_rope, x_pass = x[..., :rd], x[..., rd:] + half = rd // 2 + x1, x2 = x_rope[..., :half], x_rope[..., half:] + x_rot = torch.cat((x1 * cos + x2 * sin, x1 * (-sin) + x2 * cos), dim=-1) + return torch.cat((x_rot, x_pass), dim=-1) + half = x.size(-1) // 2 + x1, x2 = x[..., :half], x[..., half:] + return torch.cat((x1 * cos + x2 * sin, x1 * (-sin) + x2 * cos), dim=-1) + + +class CausalSelfAttention(nn.Module): + def __init__( + self, + dim: int, + num_heads: int, + num_kv_heads: int, + rope_base: float, + qk_gain_init: float, + rope_dims: int = 0, + value_residual: bool = False, + gated_attention: bool = False, + ): + super().__init__() + if dim % num_heads != 0: + raise ValueError("model_dim must be divisible by num_heads") + if num_heads % num_kv_heads != 0: + raise ValueError("num_heads must be divisible by num_kv_heads") + self.num_heads = num_heads + self.num_kv_heads = num_kv_heads + self.head_dim = dim // num_heads + if self.head_dim % 2 != 0: + raise ValueError("head_dim must be even for RoPE") + kv_dim = self.num_kv_heads * self.head_dim + self.c_q = CastedLinear(dim, dim, bias=False) + self.c_k = CastedLinear(dim, kv_dim, bias=False) + self.c_v = CastedLinear(dim, kv_dim, bias=False) + self.proj = CastedLinear(dim, dim, bias=False) + self.proj._zero_init = True + self.q_gain = nn.Parameter(torch.full((num_heads,), qk_gain_init, dtype=torch.float32)) + self.rope_dims = rope_dims + self.rotary = Rotary(self.head_dim, base=rope_base, train_seq_len=1024, rope_dims=rope_dims) + self.use_xsa = False + self.value_residual = value_residual + if value_residual: + self.vr_lambda = nn.Parameter(torch.tensor([0.5, 0.5], dtype=torch.float32)) + self.gated_attention = gated_attention + if gated_attention: + self.attn_gate = nn.Linear(dim, num_heads, bias=True) + nn.init.zeros_(self.attn_gate.weight) + nn.init.constant_(self.attn_gate.bias, 4.0) + + def _xsa_efficient(self, y: Tensor, v: Tensor) -> Tensor: + """Subtract self-value projection via GQA-aware reshape (no repeat_interleave).""" + B, T, H, D = y.shape + Hkv = v.size(-2) + group = H // Hkv + y_g = y.reshape(B, T, Hkv, group, D) + vn = F.normalize(v, dim=-1).unsqueeze(-2) + proj = (y_g * vn).sum(dim=-1, keepdim=True) * vn + return (y_g - proj).reshape(B, T, H, D) + + def forward(self, x: Tensor, v0: Tensor | None = None) -> tuple[Tensor, Tensor | None]: + bsz, seqlen, dim = x.shape + q = self.c_q(x).reshape(bsz, seqlen, self.num_heads, self.head_dim) + k = self.c_k(x).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim) + v = self.c_v(x).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim) + raw_v = v if self.value_residual else None + if self.value_residual and v0 is not None: + lam = self.vr_lambda.to(dtype=v.dtype) + v = lam[0] * v0 + lam[1] * v + q = F.rms_norm(q, (q.size(-1),)) + k = F.rms_norm(k, (k.size(-1),)) + cos, sin = self.rotary(seqlen, x.device, q.dtype) + q = apply_rotary_emb(q, cos, sin) + k = apply_rotary_emb(k, cos, sin) + q = q * self.q_gain.to(dtype=q.dtype)[None, None, :, None] + fa_dtype = torch.bfloat16 + if _USE_FA3: + y = flash_attn_3_func(q.to(fa_dtype), k.to(fa_dtype), v.to(fa_dtype), causal=True) + else: + # SDPA fallback: (B, T, H, D) -> (B, H, T, D), expand KV for GQA + q_t = q.to(fa_dtype).transpose(1, 2) + k_t = k.to(fa_dtype).transpose(1, 2) + v_t = v.to(fa_dtype).transpose(1, 2) + if self.num_kv_heads != self.num_heads: + rep = self.num_heads // self.num_kv_heads + k_t = k_t.repeat_interleave(rep, dim=1) + v_t = v_t.repeat_interleave(rep, dim=1) + y = F.scaled_dot_product_attention(q_t, k_t, v_t, is_causal=True) + y = y.transpose(1, 2) # (B, H, T, D) -> (B, T, H, D) + if self.use_xsa: + y = self._xsa_efficient(y, v) + if self.gated_attention: + gate = torch.sigmoid(self.attn_gate(x)) # (B, T, num_heads) + y = y * gate.unsqueeze(-1) # (B, T, H, 1) broadcast to (B, T, H, D) + y = y.reshape(bsz, seqlen, dim) + return self.proj(y), raw_v + + +class SmearGate(nn.Module): + def __init__(self, dim: int): + super().__init__() + self.gate = nn.Parameter(torch.zeros(dim, dtype=torch.float32)) + + def forward(self, x: Tensor) -> Tensor: + g = torch.sigmoid(self.gate.to(dtype=x.dtype))[None, None, :] + x_prev = torch.cat([torch.zeros_like(x[:, :1]), x[:, :-1]], dim=1) + return (1 - g) * x + g * x_prev + + +class BigramHashEmbedding(nn.Module): + def __init__(self, bigram_vocab_size: int, bigram_dim: int, model_dim: int): + super().__init__() + self.bigram_vocab_size = bigram_vocab_size + self.embed = nn.Embedding(bigram_vocab_size, bigram_dim) + nn.init.zeros_(self.embed.weight) + self.proj = CastedLinear(bigram_dim, model_dim, bias=False) if bigram_dim != model_dim else None + if self.proj is not None: + nn.init.zeros_(self.proj.weight) + self.scale = nn.Parameter(torch.tensor(0.05, dtype=torch.float32)) + + def bigram_hash(self, tokens: Tensor) -> Tensor: + t = tokens.to(torch.int32) + mod = self.bigram_vocab_size - 1 + out = torch.empty_like(t) + out[..., 0] = mod + out[..., 1:] = torch.bitwise_xor(36313 * t[..., 1:], 27191 * t[..., :-1]) % mod + return out.long() + + def forward(self, token_ids: Tensor) -> Tensor: + h = self.embed(self.bigram_hash(token_ids)) + if self.proj is not None: + h = self.proj(h) + return h * self.scale.to(dtype=h.dtype) + + +class MLP(nn.Module): + def __init__(self, dim: int, mlp_mult: int): + super().__init__() + hidden = int(mlp_mult * dim) + self.fc = CastedLinear(dim, hidden, bias=False) + self.proj = CastedLinear(hidden, dim, bias=False) + self.proj._zero_init = True + self.use_leaky = bool(int(os.environ.get("LEAKY_RELU", "1"))) + self.leaky_slope = float(os.environ.get("LEAKY_SLOPE", "0.5")) + + def forward(self, x: Tensor) -> Tensor: + x = F.leaky_relu(self.fc(x), self.leaky_slope) if self.use_leaky else torch.relu(self.fc(x)) + return self.proj(x.square()) + + +class CanonAC(nn.Module): + """Canon Autoregressive Convolution with DeltaGate. Manual shift+mul (no Conv1d).""" + def __init__(self, dim: int, kernel: int = 4, delta_gate_init: float = -4.0): + super().__init__() + self.kernel = kernel + self.weight = nn.Parameter(torch.zeros(kernel, dim)) + self.delta_gate_logit = nn.Parameter(torch.tensor(delta_gate_init)) + + def forward(self, x: Tensor) -> Tensor: + B, T, D = x.shape + K = self.kernel + w = self.weight.to(x.dtype) + x_pad = F.pad(x, (0, 0, K - 1, 0)) + y = w[0] * x_pad[:, K - 1:, :] + for k in range(1, K): + y = y + w[k] * x_pad[:, K - 1 - k : T + K - 1 - k, :] + gate = torch.sigmoid(self.delta_gate_logit.to(x.dtype)) + return x + gate * y + + +class Block(nn.Module): + def __init__( + self, + dim: int, + num_heads: int, + num_kv_heads: int, + mlp_mult: int, + rope_base: float, + qk_gain_init: float, + rope_dims: int = 0, + layer_idx: int = 0, + ln_scale: bool = False, + value_residual: bool = False, + gated_attention: bool = False, + canon_kernel: int = 0, + canon_delta_gate_init: float = -4.0, + ): + super().__init__() + self.attn_norm = RMSNorm() + self.mlp_norm = RMSNorm() + self.attn = CausalSelfAttention(dim, num_heads, num_kv_heads, rope_base, qk_gain_init, + rope_dims=rope_dims, value_residual=value_residual, + gated_attention=gated_attention) + self.mlp = MLP(dim, mlp_mult) + self.attn_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + self.mlp_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + self.resid_mix = nn.Parameter(torch.stack((torch.ones(dim), torch.zeros(dim))).float()) + self.ln_scale_factor = 1.0 / math.sqrt(layer_idx + 1) if ln_scale else 1.0 + self.canon_a = CanonAC(dim, canon_kernel, canon_delta_gate_init) if canon_kernel > 0 else None + self.canon_c = CanonAC(dim, canon_kernel, canon_delta_gate_init) if canon_kernel > 0 else None + + def forward(self, x: Tensor, x0: Tensor, v0: Tensor | None = None) -> tuple[Tensor, Tensor | None]: + mix = self.resid_mix.to(dtype=x.dtype) + x = mix[0][None, None, :] * x + mix[1][None, None, :] * x0 + s = self.ln_scale_factor + attn_in = self.attn_norm(x) * s + if self.canon_a is not None: + attn_in = self.canon_a(attn_in) + attn_out, raw_v = self.attn(attn_in, v0=v0) + x = x + self.attn_scale.to(dtype=x.dtype)[None, None, :] * attn_out + mlp_in = self.mlp_norm(x) * s + if self.canon_c is not None: + mlp_in = self.canon_c(mlp_in) + x = x + self.mlp_scale.to(dtype=x.dtype)[None, None, :] * self.mlp(mlp_in) + return x, raw_v + + +class FlowRefiner(nn.Module): + """1-step flow matching refiner in low-dim latent space. + + Projects hidden states to a low-dim latent, applies a learned velocity + field (1-step Euler), projects back. Output is additive adjustment to + hidden states. Initialized near-zero so the refiner starts as identity. + """ + def __init__(self, model_dim: int, latent_dim: int = 64, hidden_dim: int = 256, init_scale: float = 0.01): + super().__init__() + self.down_proj = nn.Linear(model_dim, latent_dim, bias=False) + self.velocity_net = nn.Sequential( + nn.Linear(latent_dim, hidden_dim, bias=True), + nn.GELU(), + nn.Linear(hidden_dim, latent_dim, bias=True), + ) + self.up_proj = nn.Linear(latent_dim, model_dim, bias=False) + self.gate = nn.Parameter(torch.tensor(-5.0)) # sigmoid(-5) ≈ 0.007 + self._init_scale = init_scale + self._init_weights() + self.velocity_net[2]._zero_init = True + + def _init_weights(self): + nn.init.normal_(self.down_proj.weight, std=self._init_scale) + nn.init.normal_(self.velocity_net[0].weight, std=self._init_scale) + nn.init.zeros_(self.velocity_net[0].bias) + nn.init.zeros_(self.velocity_net[2].weight) + nn.init.zeros_(self.velocity_net[2].bias) + nn.init.normal_(self.up_proj.weight, std=self._init_scale) + + def forward(self, x: Tensor) -> Tensor: + z = self.down_proj(x) + v = self.velocity_net(z) + z_refined = z + v + delta = self.up_proj(z_refined) + return torch.sigmoid(self.gate) * delta + + +class NativeFlowMatcher(nn.Module): + """Conditional Flow Matching refiner for hidden states. + + During training: computes auxiliary CFM loss on interpolated states. + During inference: applies a single Euler step correction at t=1. + """ + def __init__(self, model_dim: int, hidden_dim: int = 256, init_scale: float = 0.01): + super().__init__() + self.model_dim = model_dim + # Sinusoidal time embedding → projected to hidden_dim + self.time_proj = nn.Sequential( + nn.Linear(model_dim, hidden_dim, bias=True), + nn.GELU(), + ) + # Velocity network: x → hidden → x (with time conditioning via addition) + self.v_in = nn.Linear(model_dim, hidden_dim, bias=True) + self.v_act = nn.GELU() + self.v_out = nn.Linear(hidden_dim, model_dim, bias=False) + self.gate = nn.Parameter(torch.tensor(-5.0)) # sigmoid(-5) ≈ 0.007 + self._init_scale = init_scale + self._init_weights() + self.v_out._zero_init = True + + def _init_weights(self): + nn.init.normal_(self.v_in.weight, std=self._init_scale) + nn.init.zeros_(self.v_in.bias) + nn.init.normal_(self.time_proj[0].weight, std=self._init_scale) + nn.init.zeros_(self.time_proj[0].bias) + nn.init.zeros_(self.v_out.weight) + + @staticmethod + def _sinusoidal_time_emb(t: Tensor, dim: int) -> Tensor: + """Sinusoidal positional embedding for scalar time t. t: (...,) -> (..., dim).""" + half_dim = dim // 2 + emb = math.log(10000.0) / max(half_dim - 1, 1) + emb = torch.exp(torch.arange(half_dim, device=t.device, dtype=t.dtype) * -emb) + emb = t.unsqueeze(-1) * emb # (..., half_dim) + emb = torch.cat([emb.sin(), emb.cos()], dim=-1) # (..., dim) + if dim % 2 == 1: + emb = F.pad(emb, (0, 1)) + return emb + + def _velocity(self, x: Tensor, t_emb: Tensor) -> Tensor: + """Compute velocity v(x, t). x: (..., model_dim), t_emb: (..., hidden_dim).""" + h = self.v_in(x) # (..., hidden_dim) + h = h + t_emb # time conditioning via addition + h = self.v_act(h) # GELU + return self.v_out(h) # (..., model_dim) + + def forward(self, x: Tensor) -> tuple[Tensor, Tensor]: + """Returns (correction, cfm_loss). + + During training: cfm_loss is the flow matching objective (> 0). + During eval: cfm_loss is zero. + correction: gated velocity at t=1, to be added to x. + """ + B_seq = x.shape[:-1] # works for any leading dimensions + D = self.model_dim + + # Inference path: velocity at t=1 (clean input, no noise) + t_ones = torch.ones(*B_seq, device=x.device, dtype=x.dtype) + t_emb_ones = self.time_proj(self._sinusoidal_time_emb(t_ones, D)) + v_at_1 = self._velocity(x, t_emb_ones) + correction = torch.sigmoid(self.gate) * v_at_1 + + # Training path: auxiliary CFM loss + if self.training: + # Sample random t ~ U(0,1) per position + t = torch.rand(*B_seq, device=x.device, dtype=x.dtype) + z = torch.randn_like(x) # noise + # OT interpolant: x_t = (1-t)*z + t*x + t_expanded = t.unsqueeze(-1) # (..., 1) + x_t = (1.0 - t_expanded) * z + t_expanded * x.detach() + # Target velocity = x - z (OT conditional velocity field) + target_v = x.detach() - z + # Predict velocity at interpolated point + t_emb = self.time_proj(self._sinusoidal_time_emb(t, D)) + pred_v = self._velocity(x_t, t_emb) + cfm_loss = F.mse_loss(pred_v, target_v) + else: + cfm_loss = x.new_zeros(()) + + return correction, cfm_loss + + +class TTTLinearRefiner(nn.Module): + """E2E TTT-Linear refiner (Sun et al., 2024, arXiv:2407.04620). + + Applied to final hidden states [B, L, D] before lm_head. The hidden + state is a per-head linear model W updated via mini-batch gradient + descent on a learned self-supervised reconstruction task. Uses the + dual form for GPU efficiency. Trained end-to-end: the outer loop + optimizes theta_K, theta_V, theta_Q projections and the inner-loop + initialization W_0, making the model learn WHAT to compress. + """ + + def __init__(self, model_dim: int, num_heads: int = 8, + mini_batch_size: int = 16, ttt_base_lr: float = 1.0): + super().__init__() + self.model_dim = model_dim + self.num_heads = num_heads + self.head_dim = model_dim // num_heads + self.mini_batch_size = mini_batch_size + self.ttt_base_lr = ttt_base_lr + self.eta = ttt_base_lr / self.head_dim + + # Learnable projections (outer-loop) + self.theta_K = nn.Linear(model_dim, model_dim, bias=False) + self.theta_V = nn.Linear(model_dim, model_dim, bias=False) + self.theta_Q = nn.Linear(model_dim, model_dim, bias=False) + + # Inner-loop model: W_0, b_0 + self.W1 = nn.Parameter(torch.normal(0, 0.02, + size=(num_heads, self.head_dim, self.head_dim))) + self.b1 = nn.Parameter(torch.zeros(num_heads, 1, self.head_dim)) + + # Inner-loop LayerNorm (outer-loop trainable) + self.ttt_ln_w = nn.Parameter(torch.ones(num_heads, self.head_dim)) + self.ttt_ln_b = nn.Parameter(torch.zeros(num_heads, self.head_dim)) + + # Output projection + gate (starts near zero for residual safety) + self.o_proj = nn.Linear(model_dim, model_dim, bias=False) + self.post_norm = nn.LayerNorm(model_dim, eps=1e-6) + self.gate = nn.Parameter(torch.tensor(-5.0)) + nn.init.zeros_(self.o_proj.weight) + self.o_proj._zero_init = True + + @staticmethod + def _ln_fwd(x: Tensor, gamma: Tensor, beta: Tensor, eps: float = 1e-6) -> Tensor: + mu = x.mean(dim=-1, keepdim=True) + var = x.var(dim=-1, keepdim=True, unbiased=False) + x_hat = (x - mu) / torch.sqrt(var + eps) + return gamma * x_hat + beta + + @staticmethod + def _ln_fused_l2_bwd(x: Tensor, target: Tensor, gamma: Tensor, beta: Tensor, + eps: float = 1e-6) -> Tensor: + D = x.shape[-1] + mu = x.mean(dim=-1, keepdim=True) + var = x.var(dim=-1, keepdim=True, unbiased=False) + std = torch.sqrt(var + eps) + x_hat = (x - mu) / std + y = gamma * x_hat + beta + grad_output = y - target + grad_x_hat = grad_output * gamma + z = (1.0 / D) * ( + D * grad_x_hat + - grad_x_hat.sum(dim=-1, keepdim=True) + - x_hat * (grad_x_hat * x_hat).sum(dim=-1, keepdim=True) + ) / std + return z + + def forward(self, x: Tensor) -> Tensor: + """x: [B, L, D]. Returns [B, L, D] additive refinement (gated).""" + B, L, D = x.shape + NH, HD, b = self.num_heads, self.head_dim, self.mini_batch_size + eta = self.eta + + num_mb = L // b + usable_L = num_mb * b + x_proc = x[:, :usable_L] if usable_L < L else x + + # Project: [B, L, D] -> [B, NH, L, HD] + XK = self.theta_K(x_proc).reshape(B, usable_L, NH, HD).permute(0, 2, 1, 3) + XV = self.theta_V(x_proc).reshape(B, usable_L, NH, HD).permute(0, 2, 1, 3) + XQ = self.theta_Q(x_proc).reshape(B, usable_L, NH, HD).permute(0, 2, 1, 3) + + # Reshape into mini-batches: [B, NH, num_mb, b, HD] + XK = XK.reshape(B, NH, num_mb, b, HD) + XV = XV.reshape(B, NH, num_mb, b, HD) + XQ = XQ.reshape(B, NH, num_mb, b, HD) + + ln_w = self.ttt_ln_w.reshape(1, NH, 1, HD) + ln_b = self.ttt_ln_b.reshape(1, NH, 1, HD) + + # Causal mask for dual form + causal = torch.tril(torch.ones(b, b, device=x.device, dtype=x.dtype)) + + # Initialize inner-loop model + W = self.W1.unsqueeze(0).expand(B, -1, -1, -1).clone() + bias = self.b1.unsqueeze(0).expand(B, -1, -1, -1).clone() + + all_out = [] + for m in range(num_mb): + xk = XK[:, :, m] + xv = XV[:, :, m] + xq = XQ[:, :, m] + + Z1 = xk @ W + bias + target = xv - xk + + grad = self._ln_fused_l2_bwd(Z1, target, ln_w, ln_b) + + Attn = causal.unsqueeze(0).unsqueeze(0) * (xq @ xk.transpose(-2, -1)) + cumgrad = (causal.unsqueeze(0).unsqueeze(0) @ grad) + b_bar = bias - eta * cumgrad + Z_bar = xq @ W - eta * Attn @ grad + b_bar + + W = W - eta * xk.transpose(-2, -1) @ grad + bias = bias - eta * grad.sum(dim=-2, keepdim=True) + + Z_bar = self._ln_fwd(Z_bar, ln_w, ln_b) + out_mb = xq + Z_bar + all_out.append(out_mb) + + z = torch.cat(all_out, dim=2).permute(0, 2, 1, 3).reshape(B, usable_L, D) + z = self.post_norm(z) + z = self.o_proj(z) + result = torch.sigmoid(self.gate) * z + + if usable_L < L: + pad = torch.zeros(B, L - usable_L, D, device=x.device, dtype=x.dtype) + result = torch.cat([result, pad], dim=1) + + return result + + +class GPT(nn.Module): + def __init__( + self, + vocab_size: int, + num_layers: int, + model_dim: int, + num_heads: int, + num_kv_heads: int, + mlp_mult: int, + tie_embeddings: bool, + tied_embed_init_std: float, + logit_softcap: float, + rope_base: float, + qk_gain_init: float, + mtp_num_heads: int = 0, + mtp_loss_weight: float = 0.1, + bigram_vocab_size: int = 0, + bigram_dim: int = 128, + xsa_last_n: int = 0, + rope_dims: int = 0, + ln_scale: bool = False, + value_residual: bool = False, + gated_attention: bool = False, + canon_last_n: int = 0, + canon_kernel: int = 4, + canon_delta_gate_init: float = -4.0, + flow_enabled: bool = False, + flow_latent_dim: int = 64, + flow_hidden_dim: int = 256, + flow_init_scale: float = 0.01, + native_flow_enabled: bool = False, + native_flow_hidden_dim: int = 256, + native_flow_init_scale: float = 0.01, + native_flow_loss_weight: float = 0.1, + e2e_ttt_enabled: bool = False, + e2e_ttt_num_heads: int = 8, + e2e_ttt_mini_batch: int = 16, + e2e_ttt_base_lr: float = 1.0, + ): + super().__init__() + if logit_softcap <= 0.0: + raise ValueError(f"logit_softcap must be positive, got {logit_softcap}") + self.tie_embeddings = tie_embeddings + self.tied_embed_init_std = tied_embed_init_std + self.logit_softcap = logit_softcap + self.mtp_num_heads = mtp_num_heads + self.mtp_loss_weight = mtp_loss_weight + self.tok_emb = nn.Embedding(vocab_size, model_dim) + self.bigram = BigramHashEmbedding(bigram_vocab_size, bigram_dim, model_dim) if bigram_vocab_size > 0 else None + self.smear = SmearGate(model_dim) + self.num_encoder_layers = num_layers // 2 + self.num_decoder_layers = num_layers - self.num_encoder_layers + self.num_skip_weights = min(self.num_encoder_layers, self.num_decoder_layers) + self.skip_weights = nn.Parameter(torch.ones(self.num_skip_weights, model_dim, dtype=torch.float32)) + canon_start = num_layers - canon_last_n if canon_last_n > 0 else num_layers + self.blocks = nn.ModuleList( + [ + Block( + model_dim, + num_heads, + num_kv_heads, + mlp_mult, + rope_base, + qk_gain_init, + rope_dims=rope_dims, + layer_idx=i, + ln_scale=ln_scale, + value_residual=value_residual, + gated_attention=gated_attention, + canon_kernel=canon_kernel if i >= canon_start else 0, + canon_delta_gate_init=canon_delta_gate_init, + ) + for i in range(num_layers) + ] + ) + self.final_norm = RMSNorm() + self.lm_head = None if tie_embeddings else CastedLinear(model_dim, vocab_size, bias=False) + if self.lm_head is not None: + self.lm_head._zero_init = True + self.flow_refiner = FlowRefiner(model_dim, flow_latent_dim, flow_hidden_dim, flow_init_scale) if flow_enabled else None + self.native_flow = NativeFlowMatcher(model_dim, native_flow_hidden_dim, native_flow_init_scale) if native_flow_enabled else None + self.native_flow_loss_weight = native_flow_loss_weight + self.ttt_refiner = TTTLinearRefiner(model_dim, e2e_ttt_num_heads, e2e_ttt_mini_batch, e2e_ttt_base_lr) if e2e_ttt_enabled else None + self.mtp_heads = nn.ModuleList( + [CastedLinear(model_dim, vocab_size, bias=False) for _ in range(mtp_num_heads)] + ) + for head in self.mtp_heads: + head._zero_init = True + if xsa_last_n > 0: + for i in range(max(0, num_layers - xsa_last_n), num_layers): + self.blocks[i].attn.use_xsa = True + self._init_weights() + + def _init_weights(self) -> None: + if self.tie_embeddings: + nn.init.normal_(self.tok_emb.weight, mean=0.0, std=self.tied_embed_init_std) + num_layers = len(self.blocks) + for name, module in self.named_modules(): + if isinstance(module, nn.Linear): + if getattr(module, "_zero_init", False): + nn.init.zeros_(module.weight) + elif module.weight.ndim == 2 and module.weight.shape[0] >= 64 and module.weight.shape[1] >= 64: + nn.init.orthogonal_(module.weight, gain=1.0) + if ".proj." in name or name.endswith(".proj"): + with torch.no_grad(): + module.weight.mul_(1.0 / math.sqrt(2 * num_layers)) + + def forward(self, input_ids: Tensor, target_ids: Tensor) -> Tensor: + x = self.tok_emb(input_ids) + if self.bigram is not None: + x = x + self.bigram(input_ids) + x = F.rms_norm(x, (x.size(-1),)) + x = self.smear(x) + x0 = x + v0 = None + skips: list[Tensor] = [] + + for i in range(self.num_encoder_layers): + x, raw_v = self.blocks[i](x, x0, v0=v0) + if v0 is None and raw_v is not None: + v0 = raw_v + skips.append(x) + for i in range(self.num_decoder_layers): + if skips: + x = x + self.skip_weights[i].to(dtype=x.dtype)[None, None, :] * skips.pop() + x, _ = self.blocks[self.num_encoder_layers + i](x, x0, v0=v0) + + x = self.final_norm(x) + if self.ttt_refiner is not None: + x = x + self.ttt_refiner(x) + if self.flow_refiner is not None: + x = x + self.flow_refiner(x) + nflow_loss = x.new_zeros(()) + if self.native_flow is not None: + correction, nflow_loss = self.native_flow(x) + x = x + correction + x_flat = x.reshape(-1, x.size(-1)) + targets = target_ids.reshape(-1) + if self.tie_embeddings: + logits_proj = F.linear(x_flat, self.tok_emb.weight) + else: + if self.lm_head is None: + raise RuntimeError("lm_head is required when tie_embeddings=False") + logits_proj = self.lm_head(x_flat) + logits = self.logit_softcap * torch.tanh(logits_proj / self.logit_softcap) + main_loss = F.cross_entropy(logits.float(), targets, reduction="mean") + + if self.training and self.mtp_num_heads > 0 and self.mtp_loss_weight > 0.0: + _, seqlen, dim = x.shape + mtp_loss_sum = x.new_zeros(()) + mtp_loss_count = 0 + for k, mtp_head in enumerate(self.mtp_heads): + valid_t = seqlen - (k + 1) + if valid_t <= 0: + continue + mtp_hidden = x[:, :valid_t, :].reshape(-1, dim) + mtp_targets = target_ids[:, k + 1 :].reshape(-1) + mtp_logits_proj = mtp_head(mtp_hidden) + mtp_logits = self.logit_softcap * torch.tanh(mtp_logits_proj / self.logit_softcap) + mtp_loss_sum = mtp_loss_sum + F.cross_entropy(mtp_logits.float(), mtp_targets, reduction="mean") + mtp_loss_count += 1 + if mtp_loss_count > 0: + main_loss = main_loss + self.mtp_loss_weight * (mtp_loss_sum / mtp_loss_count) + + if self.native_flow is not None and self.training: + main_loss = main_loss + self.native_flow_loss_weight * nflow_loss + + return main_loss + + def forward_logits(self, input_ids: Tensor) -> Tensor: + """Return logits (bsz, seq_len, vocab) without computing loss.""" + x = self.tok_emb(input_ids) + if self.bigram is not None: + x = x + self.bigram(input_ids) + x = F.rms_norm(x, (x.size(-1),)) + x = self.smear(x) + x0 = x + v0 = None + skips: list[Tensor] = [] + for i in range(self.num_encoder_layers): + x, raw_v = self.blocks[i](x, x0, v0=v0) + if v0 is None and raw_v is not None: + v0 = raw_v + skips.append(x) + for i in range(self.num_decoder_layers): + if skips: + x = x + self.skip_weights[i].to(dtype=x.dtype)[None, None, :] * skips.pop() + x, _ = self.blocks[self.num_encoder_layers + i](x, x0, v0=v0) + x = self.final_norm(x) + if self.ttt_refiner is not None: + x = x + self.ttt_refiner(x) + if self.flow_refiner is not None: + x = x + self.flow_refiner(x) + if self.native_flow is not None: + correction, _ = self.native_flow(x) + x = x + correction + if self.tie_embeddings: + logits_proj = F.linear(x, self.tok_emb.weight) + else: + logits_proj = self.lm_head(x) + return self.logit_softcap * torch.tanh(logits_proj / self.logit_softcap) + + +# ----------------------------- +# SLIDING WINDOW EVALUATION +# ----------------------------- + +def eval_val_sliding( + args: Hyperparameters, + base_model: nn.Module, + rank: int, + world_size: int, + device: torch.device, + val_tokens: Tensor, + base_bytes_lut: Tensor, + has_leading_space_lut: Tensor, + is_boundary_token_lut: Tensor, + stride: int, + batch_seqs: int = 32, + eval_seq_len: int | None = None, +) -> tuple[float, float]: + """Sliding window evaluation: each token scored with maximum context. + Optionally uses entropy-gated 5-gram cache (NGRAM_CACHE=1).""" + seq_len = eval_seq_len or args.train_seq_len + total_tokens = val_tokens.numel() - 1 + + window_starts = [ws for ws in range(0, total_tokens, stride) + if min(ws + seq_len, total_tokens) - ws >= 1] + total_windows = len(window_starts) + + my_s = (total_windows * rank) // world_size + my_e = (total_windows * (rank + 1)) // world_size + my_windows = window_starts[my_s:my_e] + + loss_sum = torch.zeros((), device=device, dtype=torch.float64) + token_count = torch.zeros((), device=device, dtype=torch.float64) + byte_count = torch.zeros((), device=device, dtype=torch.float64) + + # N-gram eval cache with multi-order backoff + entropy-adaptive alpha (PR #702 inspired) + _ngram_default = "1" if world_size > 1 else "0" + use_ngram = bool(int(os.environ.get("NGRAM_CACHE", _ngram_default))) + ngram_alpha = float(os.environ.get("NGRAM_ALPHA", "0.40")) + ngram_min_count = int(os.environ.get("NGRAM_MIN_COUNT", "2")) + ngram_order = int(os.environ.get("NGRAM_ORDER", "7")) + ngram_min_order = int(os.environ.get("NGRAM_MIN_ORDER", "2")) + ngram_buckets = int(os.environ.get("NGRAM_BUCKETS", "4194304")) + ngram_entropy = bool(int(os.environ.get("NGRAM_ENTROPY", "1"))) + ngram_ent_base = float(os.environ.get("NGRAM_ENT_BASE", "0.05")) + ngram_ent_range = float(os.environ.get("NGRAM_ENT_RANGE", "0.55")) + ngram_ent_scale = float(os.environ.get("NGRAM_ENT_SCALE", "2.0")) + ngram_ent_thresh = float(os.environ.get("NGRAM_ENT_THRESH", "4.0")) + if use_ngram: + val_np = val_tokens.cpu().numpy() + _n_orders = ngram_order - ngram_min_order + 1 + ctx_tables = [np.zeros((ngram_buckets,), dtype=np.uint32) for _ in range(_n_orders)] + full_tables = [np.zeros((ngram_buckets,), dtype=np.uint32) for _ in range(_n_orders)] + ng_mask = np.uint64(ngram_buckets - 1) + ng_primes = np.array( + [np.uint64(36313), np.uint64(27191), np.uint64(51647), np.uint64(81929), + np.uint64(131071), np.uint64(175447), np.uint64(209591)], + dtype=np.uint64, + ) + print(f"ngram_cache:enabled orders={ngram_min_order}-{ngram_order} backoff " + f"entropy={ngram_entropy} alpha={ngram_alpha} " + f"ent_base={ngram_ent_base} ent_range={ngram_ent_range} " + f"min_count={ngram_min_count} buckets={ngram_buckets}", flush=True) + + base_model.eval() + compiled_logits = torch.compile(base_model.forward_logits, dynamic=False, fullgraph=True) + + with torch.inference_mode(): + for bi in range(0, len(my_windows), batch_seqs): + batch_ws = my_windows[bi:bi + batch_seqs] + bsz = len(batch_ws) + + x_batch = torch.zeros(bsz, seq_len, dtype=torch.int64, device=device) + y_batch = torch.zeros(bsz, seq_len, dtype=torch.int64, device=device) + wlens: list[int] = [] + + for i, ws in enumerate(batch_ws): + end = min(ws + seq_len, total_tokens) + wlen = end - ws + wlens.append(wlen) + chunk = val_tokens[ws:end + 1].to(dtype=torch.int64, device=device) + x_batch[i, :wlen] = chunk[:-1] + y_batch[i, :wlen] = chunk[1:] + + with torch.autocast(device_type="cuda", dtype=torch.bfloat16): + logits = compiled_logits(x_batch) + + nll = F.cross_entropy( + logits.reshape(-1, logits.size(-1)).float(), + y_batch.reshape(-1), + reduction="none", + ).reshape(bsz, seq_len) + + for i, ws in enumerate(batch_ws): + wlen = wlens[i] + s = 0 if ws == 0 else max(wlen - stride, 0) + seg_len = wlen - s + if seg_len <= 0: + continue + + scored_nll = nll[i, s:wlen].to(torch.float64) + + if use_ngram: + seg_nll_np = scored_nll.cpu().numpy() + seg_model_p = np.exp(-seg_nll_np) + n_seg = len(seg_nll_np) + global_j = np.arange(ws + s + 1, ws + wlen + 1, dtype=np.int64) + + # Entropy-adaptive alpha: compute from model logits (GPU) + if ngram_entropy: + with torch.no_grad(): + lp = F.log_softmax(logits[i, s:wlen].float(), dim=-1) + seg_ent = -(lp.exp() * lp).sum(dim=-1).cpu().numpy() + alpha_per_tok = ngram_ent_base + ngram_ent_range / ( + 1.0 + np.exp(-ngram_ent_scale * (seg_ent - ngram_ent_thresh))) + + # Precompute hashes for all orders + order_data = [] # (v_idx, ctx_key, full_key) per order + for oi in range(_n_orders): + ctx_w = ngram_min_order + oi - 1 + valid = global_j >= ctx_w + if not valid.any(): + order_data.append(None) + continue + v_idx = np.nonzero(valid)[0] + jv = global_j[v_idx] + ctx_hash = np.zeros(len(jv), dtype=np.uint64) + for k in range(ctx_w): + tok = val_np[jv - (ctx_w - k)].astype(np.uint64) + ctx_hash ^= tok * ng_primes[k % len(ng_primes)] + ctx_key = (ctx_hash & ng_mask).astype(np.int64) + tgt_np = val_np[jv].astype(np.uint64) + full_key = ((ctx_hash ^ (tgt_np * ng_primes[ctx_w % len(ng_primes)])) & ng_mask).astype(np.int64) + order_data.append((v_idx, ctx_key, full_key)) + + # Multi-order backoff: highest order first, fill unmatched with lower orders + best_p_ng = np.full(n_seg, -1.0) + for oi in range(_n_orders - 1, -1, -1): + if order_data[oi] is None: + continue + v_idx, ctx_key, full_key = order_data[oi] + ctx_counts = ctx_tables[oi][ctx_key].astype(np.float64) + full_counts = full_tables[oi][full_key].astype(np.float64) + has_match = ctx_counts >= float(ngram_min_count) + needs_fill = has_match & (best_p_ng[v_idx] < 0) + if needs_fill.any(): + fill_idx = v_idx[needs_fill] + p = np.minimum(full_counts[needs_fill], ctx_counts[needs_fill]) / np.maximum(ctx_counts[needs_fill], 1.0) + best_p_ng[fill_idx] = np.clip(p, 0.0, 1.0) + + # Mix model probability with n-gram + has_match = best_p_ng >= 0 + if has_match.any(): + if ngram_entropy: + alpha = alpha_per_tok[has_match] + else: + alpha = ngram_alpha + seg_model_p[has_match] = (1.0 - alpha) * seg_model_p[has_match] + alpha * best_p_ng[has_match] + seg_nll_np = -np.log(np.clip(seg_model_p, 1e-12, 1.0)) + + # Score-first: update ALL order tables AFTER scoring + for oi in range(_n_orders): + if order_data[oi] is None: + continue + v_idx, ctx_key, full_key = order_data[oi] + np.add.at(ctx_tables[oi], ctx_key, 1) + np.add.at(full_tables[oi], full_key, 1) + + scored_nll = torch.from_numpy(seg_nll_np).to(dtype=torch.float64, device=device) + + loss_sum += scored_nll.sum() + token_count += float(seg_len) + tgt = y_batch[i, s:wlen] + prev = x_batch[i, s:wlen] + tb = base_bytes_lut[tgt].to(torch.float64) + tb += (has_leading_space_lut[tgt] & ~is_boundary_token_lut[prev]).to(torch.float64) + byte_count += tb.sum() + + if dist.is_available() and dist.is_initialized(): + dist.all_reduce(loss_sum, op=dist.ReduceOp.SUM) + dist.all_reduce(token_count, op=dist.ReduceOp.SUM) + dist.all_reduce(byte_count, op=dist.ReduceOp.SUM) + + val_loss = (loss_sum / token_count).item() + bits_per_token = val_loss / math.log(2.0) + tokens_per_byte = token_count.item() / byte_count.item() + base_model.train() + return val_loss, bits_per_token * tokens_per_byte + + +# ----------------------------- +# TEST-TIME TRAINING (TTT) +# ----------------------------- + +def _clear_rotary_caches(model: nn.Module) -> None: + """Clear cached RoPE tensors to avoid 'Inference tensors cannot be saved for backward'.""" + for m in model.modules(): + if isinstance(m, Rotary): + m._cos_cached = None + m._sin_cached = None + m._seq_len_cached = 0 + + +def ttt_adapt(args: Hyperparameters, base_model: nn.Module, device: torch.device, + val_tokens: Tensor, rank: int = 0, world_size: int = 1, + log_fn=None) -> None: + """Score-first TTT: process val data in chunks, score each chunk first + (inference_mode), then train on scored tokens. Compliant with Issue #677.""" + seq_len = args.train_seq_len + total_tokens = val_tokens.numel() - 1 + chunk_tokens = args.ttt_chunk_tokens + batch_seqs = args.ttt_batch_seqs + + # Freeze early blocks + if args.ttt_freeze_blocks > 0: + for i, block in enumerate(base_model.blocks): + if i < args.ttt_freeze_blocks: + for p in block.parameters(): + p.requires_grad_(False) + + ttt_params = [p for p in base_model.parameters() if p.requires_grad] + if args.ttt_optimizer == "adamw": + optimizer = torch.optim.AdamW(ttt_params, lr=args.ttt_lr, weight_decay=0.0) + else: + optimizer = torch.optim.SGD(ttt_params, lr=args.ttt_lr, momentum=args.ttt_momentum) + + t0 = time.perf_counter() + chunk_idx = 0 + + for chunk_start in range(0, total_tokens - seq_len, chunk_tokens): + chunk_end = min(chunk_start + chunk_tokens, total_tokens) + chunk_len = chunk_end - chunk_start + n_seqs = chunk_len // seq_len + if n_seqs == 0: + break + + my_start = (n_seqs * rank) // world_size + my_end = (n_seqs * (rank + 1)) // world_size + if my_end <= my_start: + continue + + # Phase 1: Score chunk under inference_mode (forward only) + base_model.eval() + with torch.inference_mode(): + for si in range(my_start, my_end, batch_seqs): + se = min(si + batch_seqs, my_end) + raw_s = chunk_start + si * seq_len + raw_e = chunk_start + se * seq_len + 1 + local = val_tokens[raw_s:raw_e].to(device=device, dtype=torch.int64) + x = local[:-1].reshape(-1, seq_len) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16): + base_model.forward_logits(x) + + # Phase 2: Train on scored tokens (K epochs) + base_model.train() + for epoch in range(args.ttt_epochs): + for si in range(my_start, my_end, batch_seqs): + se = min(si + batch_seqs, my_end) + raw_s = chunk_start + si * seq_len + raw_e = chunk_start + se * seq_len + 1 + local = val_tokens[raw_s:raw_e].to(device=device, dtype=torch.int64) + x = local[:-1].reshape(-1, seq_len) + y = local[1:].reshape(-1, seq_len) + + optimizer.zero_grad(set_to_none=True) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16): + loss = base_model(x, y) + loss.backward() + + if world_size > 1: + for p in ttt_params: + if p.grad is not None: + dist.all_reduce(p.grad, op=dist.ReduceOp.AVG) + + torch.nn.utils.clip_grad_norm_(ttt_params, 1.0) + optimizer.step() + + chunk_idx += 1 + if log_fn and chunk_idx % 20 == 0: + log_fn(f"ttt:chunk={chunk_idx} elapsed={time.perf_counter()-t0:.1f}s") + + # Restore all params + for p in base_model.parameters(): + p.requires_grad_(True) + + if log_fn: + log_fn(f"ttt:done chunks={chunk_idx} elapsed={time.perf_counter()-t0:.1f}s") + + +def eval_val_sliding_ttt( + args: Hyperparameters, + base_model: nn.Module, + rank: int, + world_size: int, + device: torch.device, + val_tokens: Tensor, + base_bytes_lut: Tensor, + has_leading_space_lut: Tensor, + is_boundary_token_lut: Tensor, + stride: int, + batch_seqs: int = 32, + eval_seq_len: int | None = None, + log_fn=None, +) -> tuple[float, float]: + """Legal single-pass TTT: score each chunk with sliding windows, then train on it. + Tokens are always scored BEFORE any training on their chunk, so the evaluation + is never contaminated by future information.""" + seq_len = eval_seq_len or args.train_seq_len + total_tokens = val_tokens.numel() - 1 + ttt_chunk = args.ttt_chunk_tokens + + # Build window starts (same logic as eval_val_sliding) + window_starts = [ws for ws in range(0, total_tokens, stride) + if min(ws + seq_len, total_tokens) - ws >= stride or ws == 0] + + # Map each window to the chunk that contains its first scored token + num_chunks = (total_tokens + ttt_chunk - 1) // ttt_chunk + chunk_windows: list[list[int]] = [[] for _ in range(num_chunks)] + for ws in window_starts: + end = min(ws + seq_len, total_tokens) + wlen = end - ws + s = 0 if ws == 0 else max(wlen - stride, 0) + scored_start = ws + s + ci = min(scored_start // ttt_chunk, num_chunks - 1) + chunk_windows[ci].append(ws) + + if log_fn: + log_fn(f"ttt_sliding:start chunks={num_chunks} chunk_tokens={ttt_chunk} " + f"total_windows={len(window_starts)} stride={stride} " + f"ttt_lr={args.ttt_lr} ttt_epochs={args.ttt_epochs} " + f"ttt_optimizer={args.ttt_optimizer} freeze_blocks={args.ttt_freeze_blocks}") + + loss_sum = torch.zeros((), device=device, dtype=torch.float64) + token_count = torch.zeros((), device=device, dtype=torch.float64) + byte_count = torch.zeros((), device=device, dtype=torch.float64) + + # Freeze first N blocks + n_blocks = len(base_model.blocks) + frozen_block_ids = set(range(min(args.ttt_freeze_blocks, n_blocks))) + + ttt_params = [] + for name, p in base_model.named_parameters(): + freeze = False + for bi in frozen_block_ids: + if f"blocks.{bi}." in name: + freeze = True + break + if freeze: + p.requires_grad_(False) + else: + p.requires_grad_(True) + ttt_params.append(p) + + if log_fn: + log_fn(f"ttt_sliding:params unfrozen={sum(p.numel() for p in ttt_params)} " + f"frozen={sum(p.numel() for p in base_model.parameters() if not p.requires_grad)}") + + if args.ttt_optimizer == "adamw": + optimizer = torch.optim.AdamW(ttt_params, lr=args.ttt_lr, weight_decay=0.0) + else: + optimizer = torch.optim.SGD(ttt_params, lr=args.ttt_lr, momentum=args.ttt_momentum) + + t0 = time.perf_counter() + + for ci in range(num_chunks): + windows = chunk_windows[ci] + if not windows: + continue + + # --- Phase 1: SCORE this chunk's windows --- + my_s = (len(windows) * rank) // world_size + my_e = (len(windows) * (rank + 1)) // world_size + my_windows = windows[my_s:my_e] + + base_model.eval() + with torch.inference_mode(): + for bi in range(0, len(my_windows), batch_seqs): + batch_ws = my_windows[bi:bi + batch_seqs] + bsz = len(batch_ws) + x_batch = torch.zeros(bsz, seq_len, dtype=torch.int64, device=device) + y_batch = torch.zeros(bsz, seq_len, dtype=torch.int64, device=device) + wlens: list[int] = [] + for i, ws in enumerate(batch_ws): + end = min(ws + seq_len, total_tokens) + wlen = end - ws + wlens.append(wlen) + chunk_tok = val_tokens[ws:end + 1].to(dtype=torch.int64, device=device) + x_batch[i, :wlen] = chunk_tok[:-1] + y_batch[i, :wlen] = chunk_tok[1:] + with torch.autocast(device_type="cuda", dtype=torch.bfloat16): + logits = base_model.forward_logits(x_batch) + nll = F.cross_entropy( + logits.reshape(-1, logits.size(-1)).float(), + y_batch.reshape(-1), reduction="none", + ).reshape(bsz, seq_len) + for i, ws in enumerate(batch_ws): + wlen = wlens[i] + s = 0 if ws == 0 else max(wlen - stride, 0) + scored_nll = nll[i, s:wlen].to(torch.float64) + loss_sum += scored_nll.sum() + token_count += float(wlen - s) + tgt, prev = y_batch[i, s:wlen], x_batch[i, s:wlen] + tb = base_bytes_lut[tgt].to(torch.float64) + tb += (has_leading_space_lut[tgt] & ~is_boundary_token_lut[prev]).to(torch.float64) + byte_count += tb.sum() + + # --- Phase 2: TRAIN on this chunk's tokens (already scored = legal) --- + _clear_rotary_caches(base_model) + is_last_chunk = (ci == num_chunks - 1) + if not is_last_chunk and args.ttt_epochs > 0: + base_model.train() + chunk_start = ci * ttt_chunk + chunk_end = min((ci + 1) * ttt_chunk, total_tokens) + chunk_seqs = (chunk_end - chunk_start) // seq_len + if chunk_seqs > 0: + # Cosine LR schedule with 5% warmup + warmup_chunks = max(num_chunks // 20, 1) + if ci < warmup_chunks: + lr_scale = (ci + 1) / warmup_chunks + else: + progress = (ci - warmup_chunks) / max(num_chunks - 1 - warmup_chunks, 1) + lr_scale = 0.5 * (1.0 + math.cos(math.pi * progress)) + cos_lr = args.ttt_lr * lr_scale + for pg in optimizer.param_groups: + pg['lr'] = cos_lr + + my_seq_s = (chunk_seqs * rank) // world_size + my_seq_e = (chunk_seqs * (rank + 1)) // world_size + my_chunk_seqs = my_seq_e - my_seq_s + + for _ep in range(args.ttt_epochs): + for bs in range(0, my_chunk_seqs, args.ttt_batch_seqs): + be = min(bs + args.ttt_batch_seqs, my_chunk_seqs) + actual_bs = my_seq_s + bs + actual_be = my_seq_s + be + start_tok = chunk_start + actual_bs * seq_len + end_tok = chunk_start + actual_be * seq_len + 1 + if end_tok > val_tokens.numel(): + continue + local = val_tokens[start_tok:end_tok].to(device=device, dtype=torch.int64) + x = local[:-1].reshape(-1, seq_len) + y = local[1:].reshape(-1, seq_len) + optimizer.zero_grad(set_to_none=True) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16): + loss = base_model(x, y) + loss.backward() + if world_size > 1: + for p in ttt_params: + if p.grad is not None: + dist.all_reduce(p.grad, op=dist.ReduceOp.AVG) + torch.nn.utils.clip_grad_norm_(ttt_params, 1.0) + optimizer.step() + + if log_fn and (ci % 10 == 0 or ci == num_chunks - 1): + elapsed = time.perf_counter() - t0 + rl = loss_sum.item() / max(token_count.item(), 1) + rbpb = rl / math.log(2.0) * (token_count.item() / max(byte_count.item(), 1)) if token_count.item() > 0 else 0.0 + log_fn(f" ttt_chunk [{ci+1}/{num_chunks}] bpb={rbpb:.6f} time={elapsed:.1f}s") + + if dist.is_available() and dist.is_initialized(): + dist.all_reduce(loss_sum, op=dist.ReduceOp.SUM) + dist.all_reduce(token_count, op=dist.ReduceOp.SUM) + dist.all_reduce(byte_count, op=dist.ReduceOp.SUM) + + val_loss = (loss_sum / token_count).item() + val_bpb = val_loss / math.log(2.0) * (token_count.item() / byte_count.item()) + + # Restore all params and return to eval mode + for p in base_model.parameters(): + p.requires_grad_(True) + base_model.eval() + + if log_fn: + log_fn(f"ttt_sliding:done val_loss={val_loss:.6f} val_bpb={val_bpb:.6f} " + f"elapsed={time.perf_counter() - t0:.1f}s") + return val_loss, val_bpb + + +# ----------------------------- +# INT6 MIXED QUANTIZATION (transplanted from working diagnostic scripts) +# ----------------------------- + +def _classify_param(name: str) -> str: + if "tok_emb" in name or "lm_head" in name: + return "embed" + if ".mlp." in name: + return "mlp" + if ".attn." in name or (".proj." in name and ".mlp." not in name): + return "attn" + return "other" + +def quantize_int6_per_row(t: Tensor, qmax: int = 31) -> tuple[Tensor, Tensor]: + t32 = t.float() + qmin = -qmax - 1 + pct = CastedLinear._quant_percentile + if t32.ndim == 2: + row_max = (torch.quantile(t32.abs(), pct, dim=1) if pct < 1.0 + else t32.abs().amax(dim=1)) + scale = (row_max / float(qmax)).clamp_min(1.0 / float(qmax)).to(torch.float16) + clipped = t32.clamp(-row_max[:, None], row_max[:, None]) + q = torch.clamp(torch.round(clipped / scale.float()[:, None]), qmin, qmax).to(torch.int8) + return q, scale + amax = t32.abs().max().item() + scale = torch.tensor(amax / float(qmax) if amax > 0 else 1.0, dtype=torch.float16) + q = torch.clamp(torch.round(t32 / scale.float()), qmin, qmax).to(torch.int8) + return q, scale + +def mixed_quantize_int6(state_dict: dict[str, Tensor], int6_cats: set[str], + int5_layers: set[int] | None = None): + num_layers_total = max( + (int(k.split(".")[1]) for k in state_dict if k.startswith("blocks.")), + default=0, + ) + 1 + late_k_layers = set(range(num_layers_total - 2, num_layers_total)) + if int5_layers is None: + int5_layers = set() + + result: dict[str, Tensor] = {} + meta: dict[str, object] = {} + for name, tensor in state_dict.items(): + t = tensor.detach().cpu().contiguous() + cat = _classify_param(name) + if not t.is_floating_point() or t.numel() <= 65536: + result[name] = t.to(torch.float16) if t.is_floating_point() else t + meta[name] = "passthrough" + continue + if any(p in name for p in CONTROL_TENSOR_NAME_PATTERNS): + result[name] = t.float() + meta[name] = "passthrough_ctrl" + continue + # Determine layer index for int5 fallback + layer_idx = -1 + if name.startswith("blocks."): + try: + layer_idx = int(name.split(".")[1]) + except (IndexError, ValueError): + pass + # tok_emb.weight falls through to int8 via "embed" category + if cat in int6_cats and t.ndim >= 1: + qmax = 15 if layer_idx in int5_layers else 31 + q, s = quantize_int6_per_row(t, qmax=qmax) + result[name + ".q"] = q + result[name + ".scale"] = s + meta[name] = {"type": "int5" if qmax == 15 else "int6"} + else: + q, s = quantize_float_tensor(t) + result[name + ".q"] = q + result[name + ".scale"] = s + meta[name] = {"type": "int8"} + return result, meta + +def dequantize_mixed_int6(result: dict[str, Tensor], meta: dict[str, object], + template_sd: dict[str, Tensor]) -> dict[str, Tensor]: + out: dict[str, Tensor] = {} + for name, orig in template_sd.items(): + info = meta.get(name) + if info is None: + continue + orig_dtype = orig.dtype + if info in ("passthrough", "passthrough_ctrl", "passthrough_fp16"): + t = result[name] + if t.dtype == torch.float16 and orig_dtype in (torch.float32, torch.bfloat16): + t = t.to(orig_dtype) + out[name] = t + continue + q, s = result[name + ".q"], result[name + ".scale"] + if s.ndim > 0: + out[name] = (q.float() * s.float().view(q.shape[0], *([1] * (q.ndim - 1)))).to(orig_dtype) + else: + out[name] = (q.float() * float(s.item())).to(orig_dtype) + return out + + +# ----------------------------- +# TRAINING +# ----------------------------- + +def main() -> None: + global zeropower_via_newtonschulz5 + + code = Path(__file__).read_text(encoding="utf-8") + args = Hyperparameters() + zeropower_via_newtonschulz5 = torch.compile(zeropower_via_newtonschulz5) + + # ----------------------------- + # DISTRIBUTED + CUDA SETUP + # ----------------------------- + + distributed = "RANK" in os.environ and "WORLD_SIZE" in os.environ and int(os.environ.get("WORLD_SIZE", "1")) > 1 + rank = int(os.environ.get("RANK", "0")) + world_size = int(os.environ.get("WORLD_SIZE", "1")) + local_rank = int(os.environ.get("LOCAL_RANK", "0")) + if world_size <= 0: + raise ValueError(f"WORLD_SIZE must be positive, got {world_size}") + if 8 % world_size != 0: + raise ValueError(f"WORLD_SIZE={world_size} must divide 8 so grad_accum_steps stays integral") + grad_accum_steps = 8 // world_size + grad_scale = 1.0 / grad_accum_steps + if not torch.cuda.is_available(): + raise RuntimeError("CUDA is required") + device = torch.device("cuda", local_rank) + torch.cuda.set_device(device) + if distributed: + dist.init_process_group(backend="nccl", device_id=device) + dist.barrier() + master_process = rank == 0 + + # Fast math knobs + torch.backends.cuda.matmul.allow_tf32 = True + torch.backends.cudnn.allow_tf32 = True + from torch.backends.cuda import enable_cudnn_sdp, enable_flash_sdp, enable_math_sdp, enable_mem_efficient_sdp + + if _USE_FA3: + enable_cudnn_sdp(False) + enable_flash_sdp(True) + enable_mem_efficient_sdp(False) + enable_math_sdp(False) + else: + enable_cudnn_sdp(False) + enable_flash_sdp(True) + enable_mem_efficient_sdp(True) + enable_math_sdp(True) + + logfile = None + if master_process: + os.makedirs("logs", exist_ok=True) + logfile = f"logs/{args.run_id}.txt" + print(logfile) + + def log0(msg: str, console: bool = True) -> None: + if not master_process: + return + if console: + print(msg) + if logfile is not None: + with open(logfile, "a", encoding="utf-8") as f: + print(msg, file=f) + + log0(code, console=False) + log0("=" * 100, console=False) + log0(f"Running Python {sys.version}", console=False) + log0(f"Running PyTorch {torch.__version__}", console=False) + log0( + subprocess.run(["nvidia-smi"], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True, check=False).stdout, + console=False, + ) + log0("=" * 100, console=False) + + # ----------------------------- + # TOKENIZER + VALIDATION METRIC SETUP + # ----------------------------- + + random.seed(args.seed) + np.random.seed(args.seed) + torch.manual_seed(args.seed) + torch.cuda.manual_seed_all(args.seed) + + if not args.tokenizer_path.endswith(".model"): + raise ValueError(f"Script only setup for SentencePiece .model file: {args.tokenizer_path}") + sp = spm.SentencePieceProcessor(model_file=args.tokenizer_path) + if int(sp.vocab_size()) != args.vocab_size: + raise ValueError( + f"VOCAB_SIZE={args.vocab_size} does not match tokenizer vocab_size={int(sp.vocab_size())}" + ) + dataset_dir = Path(args.data_path).resolve() + actual_train_files = len(list(dataset_dir.glob("fineweb_train_*.bin"))) + effective_eval_seq_len = args.eval_seq_len if args.eval_seq_len > 0 else args.train_seq_len + val_seq_len = max(args.train_seq_len, effective_eval_seq_len) + val_tokens = load_validation_tokens(args.val_files, val_seq_len) + base_bytes_lut, has_leading_space_lut, is_boundary_token_lut = build_sentencepiece_luts( + sp, args.vocab_size, device + ) + log0(f"val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path={args.tokenizer_path}") + log0(f"train_loader:dataset:{dataset_dir.name} train_shards:{actual_train_files}") + log0(f"val_loader:shards pattern={args.val_files} tokens:{val_tokens.numel() - 1}") + + # ----------------------------- + # MODEL + OPTIMIZER SETUP + # ----------------------------- + + CastedLinear._qat_enabled = args.qat_enabled + + base_model = GPT( + vocab_size=args.vocab_size, + num_layers=args.num_layers, + model_dim=args.model_dim, + num_heads=args.num_heads, + num_kv_heads=args.num_kv_heads, + mlp_mult=args.mlp_mult, + tie_embeddings=args.tie_embeddings, + tied_embed_init_std=args.tied_embed_init_std, + logit_softcap=args.logit_softcap, + rope_base=args.rope_base, + qk_gain_init=args.qk_gain_init, + mtp_num_heads=args.mtp_num_heads, + mtp_loss_weight=args.mtp_loss_weight, + bigram_vocab_size=args.bigram_vocab_size, + bigram_dim=args.bigram_dim, + xsa_last_n=args.xsa_last_n, + rope_dims=args.rope_dims, + ln_scale=args.ln_scale, + value_residual=args.value_residual, + gated_attention=args.gated_attention, + canon_last_n=args.canon_last_n, + canon_kernel=args.canon_kernel, + canon_delta_gate_init=args.canon_delta_gate_init, + flow_enabled=args.flow_enabled, + flow_latent_dim=args.flow_latent_dim, + flow_hidden_dim=args.flow_hidden_dim, + flow_init_scale=args.flow_init_scale, + native_flow_enabled=args.native_flow_enabled, + native_flow_hidden_dim=args.native_flow_hidden_dim, + native_flow_init_scale=args.native_flow_init_scale, + native_flow_loss_weight=args.native_flow_loss_weight, + e2e_ttt_enabled=args.e2e_ttt_enabled, + e2e_ttt_num_heads=args.e2e_ttt_num_heads, + e2e_ttt_mini_batch=args.e2e_ttt_mini_batch, + e2e_ttt_base_lr=args.e2e_ttt_base_lr, + ).to(device).bfloat16() + for module in base_model.modules(): + if isinstance(module, CastedLinear): + module.float() + restore_low_dim_params_to_fp32(base_model) + compiled_model = torch.compile(base_model, dynamic=False, fullgraph=True) + model: nn.Module = DDP(compiled_model, device_ids=[local_rank], broadcast_buffers=False, static_graph=True) if distributed else compiled_model + + # Optimizer split: + # - token embedding (Adam) uses EMBED_LR + # - untied lm_head (Adam) uses HEAD_LR + # - matrix params in transformer blocks use MATRIX_LR via Muon + # - vectors/scalars use SCALAR_LR via Adam + block_named_params = list(base_model.blocks.named_parameters()) + matrix_params = [ + p + for name, p in block_named_params + if p.ndim == 2 and not any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) + ] + if base_model.mtp_num_heads > 0: + matrix_params.extend([p for p in base_model.mtp_heads.parameters() if p.ndim == 2]) + scalar_params = [ + p + for name, p in block_named_params + if p.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) + ] + if base_model.skip_weights.numel() > 0: + scalar_params.append(base_model.skip_weights) + scalar_params.append(base_model.smear.gate) + if base_model.bigram is not None: + scalar_params.append(base_model.bigram.scale) + if base_model.flow_refiner is not None: + scalar_params.extend(list(base_model.flow_refiner.parameters())) + if base_model.native_flow is not None: + scalar_params.extend(list(base_model.native_flow.parameters())) + if base_model.ttt_refiner is not None: + scalar_params.extend(list(base_model.ttt_refiner.parameters())) + token_lr = args.tied_embed_lr if args.tie_embeddings else args.embed_lr + tok_params = [{"params": [base_model.tok_emb.weight], "lr": token_lr, "base_lr": token_lr}] + if base_model.bigram is not None: + tok_params.append({"params": [base_model.bigram.embed.weight], "lr": token_lr, "base_lr": token_lr}) + if base_model.bigram.proj is not None: + matrix_params.append(base_model.bigram.proj.weight) + optimizer_tok = torch.optim.AdamW( + tok_params, + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + weight_decay=args.adam_wd, + fused=True, + ) + optimizer_muon = Muon( + matrix_params, + lr=args.matrix_lr, + momentum=args.muon_momentum, + backend_steps=args.muon_backend_steps, + weight_decay=args.muon_wd, + ) + for group in optimizer_muon.param_groups: + group["base_lr"] = args.matrix_lr + optimizer_scalar = torch.optim.AdamW( + [{"params": scalar_params, "lr": args.scalar_lr, "base_lr": args.scalar_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + weight_decay=args.adam_wd, + fused=True, + ) + optimizers: list[torch.optim.Optimizer] = [optimizer_tok, optimizer_muon, optimizer_scalar] + if base_model.lm_head is not None: + optimizer_head = torch.optim.Adam( + [{"params": [base_model.lm_head.weight], "lr": args.head_lr, "base_lr": args.head_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + fused=True, + ) + optimizers.insert(1, optimizer_head) + + n_params = sum(p.numel() for p in base_model.parameters()) + mtp_params = sum(p.numel() for p in base_model.mtp_heads.parameters()) + log0(f"model_params:{n_params}") + log0(f"mtp_num_heads:{args.mtp_num_heads} mtp_loss_weight:{args.mtp_loss_weight} mtp_params:{mtp_params}") + log0(f"world_size:{world_size} grad_accum_steps:{grad_accum_steps}") + log0(f"sdp_backends:fa3={_USE_FA3} cudnn=False flash=True mem_efficient={not _USE_FA3} math={not _USE_FA3}") + log0(f"attention_mode:gqa num_heads:{args.num_heads} num_kv_heads:{args.num_kv_heads}") + log0( + f"tie_embeddings:{args.tie_embeddings} embed_lr:{token_lr} " + f"head_lr:{args.head_lr if base_model.lm_head is not None else 0.0} " + f"matrix_lr:{args.matrix_lr} scalar_lr:{args.scalar_lr}" + ) + log0( + f"train_batch_tokens:{args.train_batch_tokens} train_seq_len:{args.train_seq_len} " + f"iterations:{args.iterations} warmup_steps:{args.warmup_steps} " + f"max_wallclock_seconds:{args.max_wallclock_seconds:.3f}" + ) + log0(f"seed:{args.seed}") + + # ----------------------------- + # DATA LOADER & MODEL WARMUP + # ----------------------------- + + train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) + + def zero_grad_all() -> None: + for opt in optimizers: + opt.zero_grad(set_to_none=True) + + max_wallclock_ms = 1000.0 * args.max_wallclock_seconds if args.max_wallclock_seconds > 0 else None + + def lr_mul(step: int, elapsed_ms: float) -> float: + if args.warmdown_iters <= 0: + return 1.0 + if max_wallclock_ms is None: + warmdown_start = max(args.iterations - args.warmdown_iters, 0) + return max((args.iterations - step) / max(args.warmdown_iters, 1), 0.0) if warmdown_start <= step < args.iterations else 1.0 + step_ms = elapsed_ms / max(step, 1) + warmdown_ms = args.warmdown_iters * step_ms + remaining_ms = max(max_wallclock_ms - elapsed_ms, 0.0) + return remaining_ms / max(warmdown_ms, 1e-9) if remaining_ms <= warmdown_ms else 1.0 + + # Warmup primes the compiled forward/backward/optimizer paths, then we restore the + # initial weights/optimizer state so measured training starts from the true init. + if args.warmup_steps > 0: + initial_model_state = {name: tensor.detach().cpu().clone() for name, tensor in base_model.state_dict().items()} + initial_optimizer_states = [copy.deepcopy(opt.state_dict()) for opt in optimizers] + model.train() + for warmup_step in range(args.warmup_steps): + zero_grad_all() + for micro_step in range(grad_accum_steps): + if distributed: + model.require_backward_grad_sync = micro_step == grad_accum_steps - 1 + x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + warmup_loss = model(x, y) + (warmup_loss * grad_scale).backward() + for opt in optimizers: + opt.step() + zero_grad_all() + if args.warmup_steps <= 20 or (warmup_step + 1) % 10 == 0 or warmup_step + 1 == args.warmup_steps: + log0(f"warmup_step:{warmup_step + 1}/{args.warmup_steps}") + base_model.load_state_dict(initial_model_state, strict=True) + for opt, state in zip(optimizers, initial_optimizer_states, strict=True): + opt.load_state_dict(state) + zero_grad_all() + if distributed: + model.require_backward_grad_sync = True + train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) + + # ----------------------------- + # MAIN TRAINING LOOP + # ----------------------------- + + swa_state: dict[str, Tensor] | None = None + swa_count = 0 + + ema_state: dict[str, Tensor] | None = None + if args.ema_enabled: + ema_state = {name: t.detach().float().clone() for name, t in base_model.state_dict().items()} + + training_time_ms = 0.0 + stop_after_step: int | None = None + torch.cuda.synchronize() + t0 = time.perf_counter() + + eval_only_path = os.environ.get("EVAL_ONLY", "") + if eval_only_path: + log0(f"eval_only: loading {eval_only_path}, skipping training") + base_model.load_state_dict(torch.load(eval_only_path, map_location=device, weights_only=False), strict=False) + ema_state = None # prevent random EMA from overwriting loaded weights + swa_state = None + swa_count = 0 + args.iterations = 0 # skip training, go straight to eval + + step = 0 + while True: + last_step = step == args.iterations or (stop_after_step is not None and step >= stop_after_step) + + should_validate = last_step or (args.val_loss_every > 0 and step % args.val_loss_every == 0) + if should_validate: + torch.cuda.synchronize() + training_time_ms += 1000.0 * (time.perf_counter() - t0) + val_loss, val_bpb = eval_val( + args, + model, + rank, + world_size, + device, + grad_accum_steps, + val_tokens, + base_bytes_lut, + has_leading_space_lut, + is_boundary_token_lut, + ) + log0( + f"step:{step}/{args.iterations} val_loss:{val_loss:.4f} val_bpb:{val_bpb:.4f} " + f"train_time:{training_time_ms:.0f}ms step_avg:{training_time_ms / max(step, 1):.2f}ms" + ) + torch.cuda.synchronize() + t0 = time.perf_counter() + + if ( + args.checkpoint_every > 0 + and step > 0 + and step % args.checkpoint_every == 0 + and not last_step + and master_process + ): + ckpt_sd = {k: v for k, v in base_model.state_dict().items() if "mtp_heads" not in k} + ckpt_path = f"checkpoint_step{step}_{args.run_id}.pt" + torch.save(ckpt_sd, ckpt_path) + log0(f"checkpoint_saved: {ckpt_path} ({os.path.getsize(ckpt_path)} bytes)") + + if last_step: + if stop_after_step is not None and step < args.iterations: + log0( + f"stopping_early: wallclock_cap train_time:{training_time_ms:.0f}ms " + f"step:{step}/{args.iterations}" + ) + break + + elapsed_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) + scale = lr_mul(step, elapsed_ms) + qat_threshold = float(os.environ.get("QAT_THRESHOLD", "0.1")) + if args.late_qat and scale < qat_threshold and not CastedLinear._qat_enabled: + CastedLinear._qat_enabled = True + CastedLinear._soft_round = args.soft_round_qat + log0(f"late_qat:enabled step:{step} scale:{scale:.4f} soft_round:{args.soft_round_qat}") + if CastedLinear._qat_enabled and CastedLinear._soft_round: + qat_progress = max(0.0, 1.0 - (scale / qat_threshold)) + CastedLinear._soft_round_alpha = 1.0 + 15.0 * qat_progress # 1→16 + zero_grad_all() + train_loss = torch.zeros((), device=device) + for micro_step in range(grad_accum_steps): + if distributed: + model.require_backward_grad_sync = micro_step == grad_accum_steps - 1 + x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + loss = model(x, y) + train_loss += loss.detach() + (loss * grad_scale).backward() + train_loss /= grad_accum_steps + + frac = min(step / args.muon_momentum_warmup_steps, 1.0) if args.muon_momentum_warmup_steps > 0 else 1.0 + muon_momentum = (1 - frac) * args.muon_momentum_warmup_start + frac * args.muon_momentum + for group in optimizer_muon.param_groups: + group["momentum"] = muon_momentum + + for opt in optimizers: + for group in opt.param_groups: + group["lr"] = group["base_lr"] * scale + + if args.grad_clip_norm > 0: + torch.nn.utils.clip_grad_norm_(base_model.parameters(), args.grad_clip_norm) + for opt in optimizers: + opt.step() + zero_grad_all() + + step += 1 + approx_training_time_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) + + if ema_state is not None: + d = args.ema_decay + with torch.no_grad(): + for name, t in base_model.state_dict().items(): + ema_state[name].mul_(d).add_(t.detach().float(), alpha=1.0 - d) + + if args.swa_enabled and scale < 0.5 and step % args.swa_every == 0: + # Tight SWA: collect from EMA state if available, else from raw model + src = ema_state if ema_state is not None else {name: t.detach().float() for name, t in base_model.state_dict().items()} + if swa_state is None: + swa_state = {name: t.clone() for name, t in src.items()} + swa_count = 1 + log0(f"swa:start step:{step} tight={ema_state is not None}") + else: + for name in swa_state: + swa_state[name].add_(src[name]) + swa_count += 1 + + should_log_train = ( + args.train_log_every > 0 + and (step <= 10 or step % args.train_log_every == 0 or stop_after_step is not None) + ) + if should_log_train: + log0( + f"step:{step}/{args.iterations} train_loss:{train_loss.item():.4f} " + f"train_time:{approx_training_time_ms:.0f}ms step_avg:{approx_training_time_ms / step:.2f}ms" + ) + + # Needed to sync whether we've reached the wallclock cap. + reached_cap = max_wallclock_ms is not None and approx_training_time_ms >= max_wallclock_ms + if distributed and max_wallclock_ms is not None: + reached_cap_tensor = torch.tensor(int(reached_cap), device=device) + dist.all_reduce(reached_cap_tensor, op=dist.ReduceOp.MAX) + reached_cap = bool(reached_cap_tensor.item()) + if stop_after_step is None and reached_cap: + stop_after_step = step + + log0( + f"peak memory allocated: {torch.cuda.max_memory_allocated() // 1024 // 1024} MiB " + f"reserved: {torch.cuda.max_memory_reserved() // 1024 // 1024} MiB" + ) + + if args.swa_enabled and swa_state is not None and swa_count > 1: + log0(f"swa:applying Tight SWA averaged {swa_count} EMA checkpoints") + avg_state = {name: (t / swa_count).to(dtype=base_model.state_dict()[name].dtype) + for name, t in swa_state.items()} + del swa_state + if ema_state is not None: + del ema_state + base_model.load_state_dict(avg_state, strict=True) + elif ema_state is not None: + log0("ema:applying EMA weights") + avg_state = {name: t.to(dtype=base_model.state_dict()[name].dtype) + for name, t in ema_state.items()} + del ema_state + base_model.load_state_dict(avg_state, strict=True) + + # ----------------------------- + # SERIALIZATION + ROUNDTRIP VALIDATION + # ----------------------------- + + # Use RUN_ID for unique filenames when multiple jobs share CWD + _run_id = os.environ.get("RUN_ID", "") + _save_prefix = f"final_model_{_run_id}" if _run_id else "final_model" + _pt_path = f"{_save_prefix}.pt" + _ptz_path = f"{_save_prefix}.int6.ptz" + log0(f"save_paths: pt={_pt_path} ptz={_ptz_path}") + + full_state_dict = base_model.state_dict() + export_sd = {k: v for k, v in full_state_dict.items() if "mtp_heads" not in k} + excluded_mtp = sum(int(t.numel()) for k, t in full_state_dict.items() if "mtp_heads" in k) + if excluded_mtp > 0: + log0(f"export_excluding_mtp_params:{excluded_mtp}") + + if master_process: + torch.save(export_sd, _pt_path) + model_bytes = os.path.getsize(_pt_path) + code_bytes = len(code.encode("utf-8")) + log0(f"Serialized model: {model_bytes} bytes") + log0(f"Code size: {code_bytes} bytes") + + sd_cpu = {k: v.detach().cpu() for k, v in export_sd.items()} + code_bytes = len(code.encode("utf-8")) + artifact_limit = 16_000_000 - code_bytes + + # --- Auto-downgrade quantization: try int6 first, fall back to int5 middle layers --- + num_layers_total = max( + (int(k.split(".")[1]) for k in sd_cpu if k.startswith("blocks.")), + default=0, + ) + 1 + _zstd_levels = [int(os.environ.get("ZSTD_LEVEL", "16")), 1, 17, 2] + # Phase 1: pure int6 with multiple zstd levels + quant_result, quant_meta = mixed_quantize_int6(sd_cpu, {"mlp", "attn"}) + quant_buf = io.BytesIO() + torch.save({"w": quant_result, "m": quant_meta}, quant_buf) + quant_raw = quant_buf.getvalue() + quant_blob = None + chosen_level = _zstd_levels[0] + for lvl in _zstd_levels: + blob = zstandard.ZstdCompressor(level=lvl).compress(quant_raw) if _COMPRESSOR == "zstd" else zlib.compress(quant_raw, 9) + if master_process: + log0(f"quant_try int6 zstd-{lvl}: {len(blob)} bytes (limit {artifact_limit})") + if len(blob) <= artifact_limit: + quant_blob = blob + chosen_level = lvl + break + # Phase 2: progressive int5 fallback — one layer at a time from middle outward + if quant_blob is None: + mid = num_layers_total // 2 + # Expand outward from center: L5, L4, L6, L3, L7, L2, L8, ... + candidates = [] + for offset in range(num_layers_total): + for sign in [0, 1]: + layer = mid + offset if sign == 0 else mid - offset + if 0 <= layer < num_layers_total and layer not in candidates: + candidates.append(layer) + int5_layers: set[int] = set() + for layer in candidates: + int5_layers.add(layer) + if master_process: + log0(f"quant_fallback: int5 layers={sorted(int5_layers)}") + quant_result, quant_meta = mixed_quantize_int6(sd_cpu, {"mlp", "attn"}, int5_layers=int5_layers) + quant_buf = io.BytesIO() + torch.save({"w": quant_result, "m": quant_meta}, quant_buf) + quant_raw = quant_buf.getvalue() + for lvl in _zstd_levels: + blob = zstandard.ZstdCompressor(level=lvl).compress(quant_raw) if _COMPRESSOR == "zstd" else zlib.compress(quant_raw, 9) + if master_process: + log0(f"quant_try int5[{len(int5_layers)}L] zstd-{lvl}: {len(blob)} bytes (limit {artifact_limit})") + if len(blob) <= artifact_limit: + quant_blob = blob + chosen_level = lvl + break + if quant_blob is not None: + break + if quant_blob is None: + quant_blob = blob # Use last attempt even if over limit + if master_process: + log0(f"WARNING: artifact still over limit after all fallbacks") + if master_process: + with open(_ptz_path, "wb") as f: + f.write(quant_blob) + quant_file_bytes = len(quant_blob) + log0(f"Serialized model quant+{_COMPRESSOR}-{chosen_level}: {quant_file_bytes} bytes") + log0(f"Total submission size: {quant_file_bytes + code_bytes} bytes") + + # Roundtrip: decompress + dequantize into fresh model + eval + if distributed: + dist.barrier() + with open(_ptz_path, "rb") as f: + quant_blob_disk = f.read() + quant_state = torch.load( + io.BytesIO(zstandard.ZstdDecompressor().decompress(quant_blob_disk) if _COMPRESSOR == "zstd" else zlib.decompress(quant_blob_disk)), + map_location="cpu", + ) + deq_state = dequantize_mixed_int6(quant_state["w"], quant_state["m"], sd_cpu) + + eval_model = GPT( + vocab_size=args.vocab_size, num_layers=args.num_layers, model_dim=args.model_dim, + num_heads=args.num_heads, num_kv_heads=args.num_kv_heads, mlp_mult=args.mlp_mult, + tie_embeddings=args.tie_embeddings, tied_embed_init_std=args.tied_embed_init_std, + logit_softcap=args.logit_softcap, rope_base=args.rope_base, qk_gain_init=args.qk_gain_init, + mtp_num_heads=0, mtp_loss_weight=0.0, + bigram_vocab_size=args.bigram_vocab_size, bigram_dim=args.bigram_dim, + xsa_last_n=args.xsa_last_n, + rope_dims=args.rope_dims, + ln_scale=args.ln_scale, + value_residual=args.value_residual, + gated_attention=args.gated_attention, + canon_last_n=args.canon_last_n, + canon_kernel=args.canon_kernel, + canon_delta_gate_init=args.canon_delta_gate_init, + flow_enabled=args.flow_enabled, + flow_latent_dim=args.flow_latent_dim, + flow_hidden_dim=args.flow_hidden_dim, + flow_init_scale=args.flow_init_scale, + native_flow_enabled=args.native_flow_enabled, + native_flow_hidden_dim=args.native_flow_hidden_dim, + native_flow_init_scale=args.native_flow_init_scale, + native_flow_loss_weight=args.native_flow_loss_weight, + e2e_ttt_enabled=args.e2e_ttt_enabled, + e2e_ttt_num_heads=args.e2e_ttt_num_heads, + e2e_ttt_mini_batch=args.e2e_ttt_mini_batch, + e2e_ttt_base_lr=args.e2e_ttt_base_lr, + ).to(device).bfloat16() + for m in eval_model.modules(): + if isinstance(m, CastedLinear): + m.float() + restore_low_dim_params_to_fp32(eval_model) + eval_model.load_state_dict(deq_state, strict=True) + + # TTT: adapt model on validation data before eval + legal_ttt = bool(int(os.environ.get("LEGAL_TTT", "0"))) + if args.ttt_enabled and not legal_ttt: + # --- Invalid two-pass TTT (adapt then eval separately) --- + if distributed: + dist.barrier() + for block in eval_model.blocks: + block.attn.rotary._cos_cached = None + block.attn.rotary._sin_cached = None + block.attn.rotary._seq_len_cached = 0 + log0(f"ttt:start score-first optimizer={args.ttt_optimizer} lr={args.ttt_lr} " + f"epochs={args.ttt_epochs} freeze_blocks={args.ttt_freeze_blocks} " + f"chunk_tokens={args.ttt_chunk_tokens}") + t_ttt = time.perf_counter() + ttt_adapt(args, eval_model, device, val_tokens, + rank=rank, world_size=world_size, log_fn=log0) + log0(f"ttt:elapsed={time.perf_counter() - t_ttt:.1f}s") + if distributed: + dist.barrier() + + compiled_eval = torch.compile(eval_model, dynamic=False, fullgraph=True) + + # Standard non-overlapping eval (sanity check) + torch.cuda.synchronize() + t_qeval = time.perf_counter() + q_val_loss, q_val_bpb = eval_val( + args, compiled_eval, rank, world_size, device, grad_accum_steps, + val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, + eval_seq_len=effective_eval_seq_len, + ) + torch.cuda.synchronize() + log0( + f"final_int6_roundtrip val_loss:{q_val_loss:.4f} val_bpb:{q_val_bpb:.4f} " + f"eval_time:{1000.0 * (time.perf_counter() - t_qeval):.0f}ms" + ) + log0(f"final_int6_roundtrip_exact val_loss:{q_val_loss:.8f} val_bpb:{q_val_bpb:.8f}") + + # Sliding window eval (submission score) + sw_seq_len = effective_eval_seq_len + if args.eval_stride > 0 and args.eval_stride < sw_seq_len: + torch.cuda.synchronize() + t_slide = time.perf_counter() + if legal_ttt and args.ttt_enabled: + # Legal single-pass TTT: score → train interleaved per chunk + log0(f"legal_ttt:start stride={args.eval_stride} " + f"optimizer={args.ttt_optimizer} lr={args.ttt_lr} " + f"epochs={args.ttt_epochs} freeze_blocks={args.ttt_freeze_blocks}") + sw_val_loss, sw_val_bpb = eval_val_sliding_ttt( + args, eval_model, rank, world_size, device, + val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, + stride=args.eval_stride, + eval_seq_len=sw_seq_len, + log_fn=log0, + ) + else: + sw_val_loss, sw_val_bpb = eval_val_sliding( + args, eval_model, rank, world_size, device, + val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, + stride=args.eval_stride, + eval_seq_len=sw_seq_len, + ) + torch.cuda.synchronize() + log0( + f"final_int6_sliding_window val_loss:{sw_val_loss:.4f} val_bpb:{sw_val_bpb:.4f} " + f"stride:{args.eval_stride} eval_time:{1000.0 * (time.perf_counter() - t_slide):.0f}ms" + ) + log0(f"final_int6_sliding_window_exact val_loss:{sw_val_loss:.8f} val_bpb:{sw_val_bpb:.8f}") + + # Second sliding window eval at stride=64 for submission comparison + if args.eval_stride != 64 and 64 < sw_seq_len: + torch.cuda.synchronize() + t_slide64 = time.perf_counter() + sw64_val_loss, sw64_val_bpb = eval_val_sliding( + args, eval_model, rank, world_size, device, + val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, + stride=64, + eval_seq_len=sw_seq_len, + ) + torch.cuda.synchronize() + log0( + f"final_int6_sliding_window_s64 val_loss:{sw64_val_loss:.4f} val_bpb:{sw64_val_bpb:.4f} " + f"stride:64 eval_time:{1000.0 * (time.perf_counter() - t_slide64):.0f}ms" + ) + log0(f"final_int6_sliding_window_s64_exact val_loss:{sw64_val_loss:.8f} val_bpb:{sw64_val_bpb:.8f}") + + if distributed: + dist.destroy_process_group() + + +if __name__ == "__main__": + main() +