From 0c39618b275b87e2ee242a66e8f0afc8bed83da0 Mon Sep 17 00:00:00 2001 From: jamesEmerson112 <36806380+jamesEmerson112@users.noreply.github.com> Date: Thu, 23 Apr 2026 18:38:58 -0700 Subject: [PATCH] Add submission: SP8192 HeadwiseGate LeakyReLU2 LegalTTT (val_bpb 1.2073) 3-seed mean 1.2073 BPB (std 0.0006) on 8xH100 SXM. SP8192 + headwise gated attention (original) + LeakyReLU(0.5)^2 + QK-Gain 5.0 + score-first TTT. MODEL_DIM=448, 16.4M params, ~15.34 MB artifact (under 16 MB budget). Co-Authored-By: Claude Opus 4.6 (1M context) --- .../README.md | 138 ++ .../submission.json | 40 + .../train_gpt.py | 1529 +++++++++++++ .../train_seed1337.log | 1929 +++++++++++++++++ .../train_seed2025.log | 1929 +++++++++++++++++ .../train_seed42.log | 1929 +++++++++++++++++ 6 files changed, 7494 insertions(+) create mode 100644 records/track_10min_16mb/2026-04-24_SP8192_HeadwiseGate_LeakyReLU2_LegalTTT/README.md create mode 100644 records/track_10min_16mb/2026-04-24_SP8192_HeadwiseGate_LeakyReLU2_LegalTTT/submission.json create mode 100644 records/track_10min_16mb/2026-04-24_SP8192_HeadwiseGate_LeakyReLU2_LegalTTT/train_gpt.py create mode 100644 records/track_10min_16mb/2026-04-24_SP8192_HeadwiseGate_LeakyReLU2_LegalTTT/train_seed1337.log create mode 100644 records/track_10min_16mb/2026-04-24_SP8192_HeadwiseGate_LeakyReLU2_LegalTTT/train_seed2025.log create mode 100644 records/track_10min_16mb/2026-04-24_SP8192_HeadwiseGate_LeakyReLU2_LegalTTT/train_seed42.log diff --git a/records/track_10min_16mb/2026-04-24_SP8192_HeadwiseGate_LeakyReLU2_LegalTTT/README.md b/records/track_10min_16mb/2026-04-24_SP8192_HeadwiseGate_LeakyReLU2_LegalTTT/README.md new file mode 100644 index 0000000000..7c50e50b4c --- /dev/null +++ b/records/track_10min_16mb/2026-04-24_SP8192_HeadwiseGate_LeakyReLU2_LegalTTT/README.md @@ -0,0 +1,138 @@ +# SP8192 + Headwise Gated Attention + LeakyReLU2 + QK-Gain 5.0 + Legal TTT + +**val_bpb: 1.2073** (3-seed mean, std 0.0006) | **~15.34 MB** | 8xH100 SXM + +## Results (8xH100 80GB SXM, PyTorch 2.7.1) + +| Seed | step_avg | steps | val_bpb (TTT) | val_bpb (int8) | Artifact Size | +|------|----------|-------|---------------|----------------|---------------| +| 1337 | 54.40ms | 11,030 | 1.20665 | 1.20807 | 15,340,947 | +| 42 | 54.41ms | 11,028 | 1.20783 | 1.21029 | 15,340,685 | +| 2025 | 54.37ms | 11,036 | 1.20746 | 1.21016 | 15,337,072 | +| **Mean** | 54.39ms | 11,031 | **1.20731 (std 0.0006)** | 1.20951 | 15,339,568 | + +## Key Techniques + +1. **SP8192** -- 8192-token SentencePiece BPE vocabulary (@kevclark, PR #1394). All top 5 submissions use SP8192. Dataset from `kevclark/parameter-golf` HuggingFace repo. + +2. **Headwise Gated Attention** -- **Original technique by James Vo.** Sigmoid gate applied per-head after scaled dot-product attention. Each head learns a scalar gate that suppresses uninformative attention patterns. Adds only ~37K parameters (~0.2% overhead). Inspired by NeurIPS 2025 Best Paper (arxiv.org/abs/2505.06708). + + ```python + # In CausalSelfAttention.forward(): + gate_logits = self.gate_proj(x) # [bsz, seqlen, num_heads] + gate = torch.sigmoid(gate_logits) + gate = gate.unsqueeze(-1) # [bsz, seqlen, num_heads, 1] + y = y * gate # applied after SDPA, before output projection + ``` + +3. **LeakyReLU(0.5)^2** -- Replaces ReLU^2 in MLP. Preserves small negative gradients instead of zeroing them. Zero extra parameters, zero speed penalty (@abaybektursun, PR #549). + +4. **QK-Gain 5.0** -- Learnable per-head scalar that scales query vectors before attention. Initialized to 5.0 instead of the default 1.5, giving sharper attention patterns (@dexhunter, PR #1413). + +5. **Score-First TTT** -- Legal test-time training. For each 32K-token chunk of the validation set: (1) SCORE under `torch.inference_mode()`, accumulating loss/bytes for BPB, (2) TRAIN on the already-scored chunk using SGD (lr=0.005, momentum=0.9, 3 epochs, gradient clip 1.0). Last chunk is score-only. Model state is saved before TTT and restored after (@dexhunter, PR #1413). + +## Architecture + +| Component | Setting | +|-----------|---------| +| Layers | 9 | +| Model dim | 448 (reduced from 512 to fit 16 MB budget) | +| Attention heads | 8 | +| KV heads | 4 (GQA) | +| MLP mult | 2x | +| Activation | LeakyReLU(0.5)^2 | +| Gated attention | Headwise (1 sigmoid gate per head) | +| QK-Gain init | 5.0 | +| Vocab size | 8192 (SentencePiece BPE) | +| Embeddings | Tied | +| Logit softcap | 30.0 | +| RoPE base | 10000.0 | +| Sequence length | 1024 | +| Batch tokens | 524,288 | +| Optimizer | Muon (matrix params) + Adam (scalars/embeddings) | +| Parameters | 16,364,616 | +| Quantization | int8 + zlib | + +## Original Contribution + +**Headwise Gated Attention** is an original technique developed for this submission. It applies a learned sigmoid gate to each attention head's output, allowing the model to dynamically suppress uninformative heads on a per-token basis. The gate is a simple linear projection from the input features to `num_heads` scalars, passed through sigmoid, then multiplied with the SDPA output. + +This is distinct from the elementwise variant (which gates every dimension independently but adds too many parameters for the 16 MB budget) and from attention head pruning (which permanently removes heads rather than dynamically gating them). + +The technique was inspired by the NeurIPS 2025 Best Paper "Gated Attention" (arxiv.org/abs/2505.06708), which demonstrated gating mechanisms for attention in vision transformers. We adapted the concept to the language modeling setting with a lightweight per-head variant. + +## Compliance + +- [x] Training under 600s (all seeds: ~600s) +- [x] Artifact under 16 MB (all seeds: ~15.34 MB, +0.54 MB headroom) +- [x] Eval under 600s (TTT eval: ~137s per seed) +- [x] No SLOT (no supervised learning on test data) +- [x] No pre-quantization TTT (TTT runs on int8+zlib quantized model) +- [x] No ETLB (no eval-time learned biases) +- [x] No n-gram cache +- [x] Score-first TTT (score chunk before training on it, last chunk score-only) +- [x] Three seeds (1337, 42, 2025) + +## Run Command + +```bash +# Download SP8192 dataset (NOT in official repo) +rm -f data/manifest.json +MATCHED_FINEWEB_REPO_ID=kevclark/parameter-golf \ + python3 data/cached_challenge_fineweb.py --variant sp8192 --train-shards 80 + +# Train (change SEED for different seeds) +GATED_ATTN=headwise \ +ACTIVATION=leaky_relu2 \ +QK_GAIN_INIT=5.0 \ +VOCAB_SIZE=8192 \ +DATA_PATH=./data/datasets/fineweb10B_sp8192/ \ +TOKENIZER_PATH=./data/tokenizers/fineweb_8192_bpe.model \ +MODEL_DIM=448 \ +NUM_HEADS=8 \ +NUM_KV_HEADS=4 \ +NUM_LAYERS=9 \ +TTT_MODE=score_first \ +TTT_LR=0.005 \ +TTT_MOMENTUM=0.9 \ +TTT_EPOCHS=3 \ +TTT_CHUNK_TOKENS=32768 \ +TTT_FREEZE_BLOCKS=0 \ +TTT_GRAD_CLIP=1.0 \ +TTT_BATCH_SEQS=32 \ +MAX_WALLCLOCK_SECONDS=600 \ +SEED=1337 \ +RUN_ID=sp8192_combo_slim \ +torchrun --standalone --nproc_per_node=8 train_gpt.py +``` + +## Ablation + +Experiments run on 2xH100 (SP1024, PyTorch 2.11) to isolate technique contributions: + +| Technique | val_bpb | Delta vs baseline | +|-----------|---------|-------------------| +| Baseline (GQA, SP1024) | 1.2649 | -- | +| + LeakyReLU2 | 1.2641 | -0.0008 | +| + Headwise gated attn | 1.2653 | -0.0 (trades speed for quality) | +| + LeakyReLU2 + headwise | 1.2642 | -0.0007 (don't stack on SP1024) | +| SP8192 combo slim + TTT (8xH100) | **1.2073** | **-0.0171** | + +The dominant factor is SP8192 + TTT. LeakyReLU2 provides a small free improvement. Headwise gated attention contributes more with larger vocabularies where attention pattern diversity matters. + +## Credits + +- SP8192 tokenizer and dataset: @kevclark (PR #1394) +- LeakyReLU(0.5)^2 activation: @abaybektursun (PR #549) +- Score-First TTT + QK-Gain 5.0: @dexhunter (PR #1413) +- Headwise Gated Attention: Original -- James Vo +- Base training infrastructure: modded-nanogpt / OpenAI Parameter Golf + +## Included Files + +- `README.md` (this file) +- `submission.json` +- `train_gpt.py` (1529 lines, 73,221 bytes) +- `train_seed1337.log` +- `train_seed42.log` +- `train_seed2025.log` diff --git a/records/track_10min_16mb/2026-04-24_SP8192_HeadwiseGate_LeakyReLU2_LegalTTT/submission.json b/records/track_10min_16mb/2026-04-24_SP8192_HeadwiseGate_LeakyReLU2_LegalTTT/submission.json new file mode 100644 index 0000000000..15efdc811a --- /dev/null +++ b/records/track_10min_16mb/2026-04-24_SP8192_HeadwiseGate_LeakyReLU2_LegalTTT/submission.json @@ -0,0 +1,40 @@ +{ + "author": "James Vo", + "github_id": "jamesEmerson112", + "name": "SP8192 + Headwise Gated Attention + LeakyReLU2 + QK-Gain 5.0 + Legal TTT", + "blurb": "SP8192 + headwise gated attention (original technique: sigmoid gate per head post-SDPA, inspired by NeurIPS 2025 Best Paper arxiv.org/abs/2505.06708) + LeakyReLU(0.5)^2 + QK-Gain 5.0 + score-first TTT. 3-seed mean 1.2073 BPB (std 0.0006). MODEL_DIM=448 to fit 16 MB budget.", + "date": "2026-04-24", + "track": "10min_16mb", + "val_bpb": 1.20731, + "val_bpb_std": 0.00060, + "seeds": [1337, 42, 2025], + "seed_results": { + "1337": {"val_bpb": 1.20665395, "val_loss": 3.11691247, "artifact_bytes": 15340947}, + "42": {"val_bpb": 1.20783103, "val_loss": 3.11995300, "artifact_bytes": 15340685}, + "2025": {"val_bpb": 1.20745783, "val_loss": 3.11898899, "artifact_bytes": 15337072} + }, + "bytes_total": 15414168, + "bytes_code": 73221, + "hardware": "8xH100 80GB SXM", + "pytorch_version": "2.7.1", + "tokenizer": "SentencePiece BPE 8192", + "architecture": "9L 448d 8h 4kv GQA headwise-gate LeakyReLU2 QK5.0 tied-embed", + "compliance": { + "train_under_600s": true, + "artifact_under_16mb": true, + "eval_under_600s": true, + "no_slot": true, + "no_pre_quant_ttt": true, + "no_etlb": true, + "no_ngram_cache": true, + "score_first_ttt": true, + "three_seeds": true + }, + "attribution": { + "SP8192": "@kevclark (PR #1394)", + "LeakyReLU2": "@abaybektursun (PR #549)", + "Score-First TTT": "@dexhunter (PR #1413)", + "QK-Gain 5.0": "@dexhunter (PR #1413)", + "Headwise Gated Attention": "Original - James Vo (inspired by arxiv.org/abs/2505.06708)" + } +} diff --git a/records/track_10min_16mb/2026-04-24_SP8192_HeadwiseGate_LeakyReLU2_LegalTTT/train_gpt.py b/records/track_10min_16mb/2026-04-24_SP8192_HeadwiseGate_LeakyReLU2_LegalTTT/train_gpt.py new file mode 100644 index 0000000000..3511937b75 --- /dev/null +++ b/records/track_10min_16mb/2026-04-24_SP8192_HeadwiseGate_LeakyReLU2_LegalTTT/train_gpt.py @@ -0,0 +1,1529 @@ +""" +The `train_gpt.py` and `train_gpt_mlx.py` scripts are intended as good launching-off points for new participants, not SOTA configs. We'll accept PRs that tune, improve, or simplify these scripts without significantly increasing complexity, but competitive submissions should stay in the `/records` folder. + +Hard stop: To keep readable for newcomers, let's make sure `train_gpt.py` and `train_gpt_mlx.py` never are longer than 1500 lines. +""" + +from __future__ import annotations + +import copy +import glob +import io +import math +import os +import random +import subprocess +import sys +import time +import uuid +import zlib +from pathlib import Path + +import numpy as np +import sentencepiece as spm +import torch +import torch.distributed as dist +import torch.nn.functional as F +from torch import Tensor, nn +from torch.nn.parallel import DistributedDataParallel as DDP + +# ----------------------------- +# HYPERPARAMETERS +# ----------------------------- +# Default Simple Baseline run: +# - 9 transformer blocks at width 512 +# - 8 attention heads with 4 KV heads (GQA) and 2x MLP expansion +# - vocab size 1024, sequence length 1024, tied embeddings +# - 524,288 train tokens per step for 20,000 iterations with a ~10 minute cap + +class Hyperparameters: + # Data paths are shard globs produced by the existing preprocessing pipeline. + data_path = os.environ.get("DATA_PATH", "./data/datasets/fineweb10B_sp1024") + train_files = os.path.join(data_path, "fineweb_train_*.bin") + val_files = os.path.join(data_path, "fineweb_val_*.bin") + tokenizer_path = os.environ.get("TOKENIZER_PATH", "./data/tokenizers/fineweb_1024_bpe.model") + run_id = os.environ.get("RUN_ID", str(uuid.uuid4())) + seed = int(os.environ.get("SEED", 1337)) + + # Validation cadence and batch size. Validation always uses the full fineweb_val split. + val_batch_size = int(os.environ.get("VAL_BATCH_SIZE", 524_288)) + val_loss_every = int(os.environ.get("VAL_LOSS_EVERY", 1000)) + train_log_every = int(os.environ.get("TRAIN_LOG_EVERY", 200)) + + # Training length. + iterations = int(os.environ.get("ITERATIONS", 20000)) + warmdown_iters = int(os.environ.get("WARMDOWN_ITERS", 1200)) + warmup_steps = int(os.environ.get("WARMUP_STEPS", 20)) + train_batch_tokens = int(os.environ.get("TRAIN_BATCH_TOKENS", 524_288)) + train_seq_len = int(os.environ.get("TRAIN_SEQ_LEN", 1024)) + max_wallclock_seconds = float(os.environ.get("MAX_WALLCLOCK_SECONDS", 600.0)) + qk_gain_init = float(os.environ.get("QK_GAIN_INIT", 1.5)) + + # Model shape. + vocab_size = int(os.environ.get("VOCAB_SIZE", 1024)) + num_layers = int(os.environ.get("NUM_LAYERS", 9)) + num_kv_heads = int(os.environ.get("NUM_KV_HEADS", 4)) + model_dim = int(os.environ.get("MODEL_DIM", 512)) + num_heads = int(os.environ.get("NUM_HEADS", 8)) + mlp_mult = int(os.environ.get("MLP_MULT", 2)) + tie_embeddings = bool(int(os.environ.get("TIE_EMBEDDINGS", "1"))) + rope_base = float(os.environ.get("ROPE_BASE", 10000.0)) + logit_softcap = float(os.environ.get("LOGIT_SOFTCAP", 30.0)) + # Gated Attention (https://arxiv.org/abs/2505.06708, NeurIPS 2025 Best Paper) + # "none" = disabled (baseline), "headwise" = 1 gate per head, "elementwise" = 1 gate per dim + gated_attn = os.environ.get("GATED_ATTN", "none") + + # Optimizer hyperparameters. + embed_lr = float(os.environ.get("EMBED_LR", 0.6)) + head_lr = float(os.environ.get("HEAD_LR", 0.008)) + tied_embed_lr = float(os.environ.get("TIED_EMBED_LR", 0.05)) + tied_embed_init_std = float(os.environ.get("TIED_EMBED_INIT_STD", 0.005)) + matrix_lr = float(os.environ.get("MATRIX_LR", 0.04)) + scalar_lr = float(os.environ.get("SCALAR_LR", 0.04)) + muon_momentum = float(os.environ.get("MUON_MOMENTUM", 0.95)) + muon_backend_steps = int(os.environ.get("MUON_BACKEND_STEPS", 5)) + muon_momentum_warmup_start = float(os.environ.get("MUON_MOMENTUM_WARMUP_START", 0.85)) + muon_momentum_warmup_steps = int(os.environ.get("MUON_MOMENTUM_WARMUP_STEPS", 500)) + beta1 = float(os.environ.get("BETA1", 0.9)) + beta2 = float(os.environ.get("BETA2", 0.95)) + adam_eps = float(os.environ.get("ADAM_EPS", 1e-8)) + grad_clip_norm = float(os.environ.get("GRAD_CLIP_NORM", 0.0)) + + # Activation: "relu2" (default), "leaky_relu2" = LeakyReLU(0.5)² (PG ranks 10-11) + activation = os.environ.get("ACTIVATION", "relu2") + + # Test-Time Training (Score-First TTT, PG ranks 1-3) + # "none" = disabled (baseline), "score_first" = legal score-before-update TTT + ttt_mode = os.environ.get("TTT_MODE", "none") + ttt_lr = float(os.environ.get("TTT_LR", "0.005")) + ttt_momentum = float(os.environ.get("TTT_MOMENTUM", "0.9")) + ttt_epochs = int(os.environ.get("TTT_EPOCHS", "3")) + ttt_chunk_tokens = int(os.environ.get("TTT_CHUNK_TOKENS", "32768")) + ttt_freeze_blocks = int(os.environ.get("TTT_FREEZE_BLOCKS", "0")) + ttt_grad_clip = float(os.environ.get("TTT_GRAD_CLIP", "1.0")) + ttt_batch_seqs = int(os.environ.get("TTT_BATCH_SEQS", "32")) + +# ----------------------------- +# MUON OPTIMIZER +# ----------------------------- +# +# As borrowed from modded-nanogpt +# Background on Muon: https://kellerjordan.github.io/posts/muon/ + +def zeropower_via_newtonschulz5(G: Tensor, steps: int = 10, eps: float = 1e-7) -> Tensor: + # Orthogonalize a 2D update matrix with a fast Newton-Schulz iteration. + # Muon uses this to normalize matrix-shaped gradients before applying them. + a, b, c = (3.4445, -4.7750, 2.0315) + X = G.bfloat16() + X /= X.norm() + eps + transposed = G.size(0) > G.size(1) + if transposed: + X = X.T + for _ in range(steps): + A = X @ X.T + B = b * A + c * A @ A + X = a * X + B @ X + return X.T if transposed else X + + +class Muon(torch.optim.Optimizer): + def __init__(self, params, lr: float, momentum: float, backend_steps: int, nesterov: bool = True): + super().__init__( + params, + dict(lr=lr, momentum=momentum, backend_steps=backend_steps, nesterov=nesterov), + ) + + @torch.no_grad() + def step(self, closure=None): + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + + distributed = dist.is_available() and dist.is_initialized() + world_size = dist.get_world_size() if distributed else 1 + rank = dist.get_rank() if distributed else 0 + + for group in self.param_groups: + params = group["params"] + if not params: + continue + lr = group["lr"] + momentum = group["momentum"] + backend_steps = group["backend_steps"] + nesterov = group["nesterov"] + + total_params = sum(int(p.numel()) for p in params) + updates_flat = torch.zeros(total_params, device=params[0].device, dtype=torch.bfloat16) + + curr = 0 + for i, p in enumerate(params): + if i % world_size == rank and p.grad is not None: + g = p.grad + state = self.state[p] + if "momentum_buffer" not in state: + state["momentum_buffer"] = torch.zeros_like(g) + buf = state["momentum_buffer"] + buf.mul_(momentum).add_(g) + if nesterov: + g = g.add(buf, alpha=momentum) + g = zeropower_via_newtonschulz5(g, steps=backend_steps) + # Scale correction from Muon reference implementations. + g *= max(1, g.size(0) / g.size(1)) ** 0.5 + updates_flat[curr : curr + p.numel()] = g.reshape(-1) + curr += p.numel() + + if distributed: + dist.all_reduce(updates_flat, op=dist.ReduceOp.SUM) + + curr = 0 + for p in params: + g = updates_flat[curr : curr + p.numel()].view_as(p).to(dtype=p.dtype) + p.add_(g, alpha=-lr) + curr += p.numel() + + return loss + + +# ----------------------------- +# TOKENIZER-AGNOSTIC EVALUATION SETUP +# ----------------------------- +# +# It's common for small models have a large fraction of their parameters be embeddings, since the 2 * d_model * d_vocab vectors can be gigantic. +# Instead of locking the tokenizer, we let you bring your own and calculate our validation metrics on the average compression of the validation set. +# We calculate BPB (bits-per-byte) instead of validation loss, so we need methods to count the number of bits per token in the tokenizer. +# Note: Submissions that edit the tokenizer will be examined more carefully, since screwing this up might unjustly improve your score. + +def build_sentencepiece_luts( + sp: spm.SentencePieceProcessor, vocab_size: int, device: torch.device +) -> tuple[Tensor, Tensor, Tensor]: + sp_vocab_size = int(sp.vocab_size()) + table_size = max(sp_vocab_size, vocab_size) + base_bytes_np = np.zeros((table_size,), dtype=np.int16) + has_leading_space_np = np.zeros((table_size,), dtype=np.bool_) + is_boundary_token_np = np.ones((table_size,), dtype=np.bool_) + for token_id in range(sp_vocab_size): + if sp.is_control(token_id) or sp.is_unknown(token_id) or sp.is_unused(token_id): + continue + is_boundary_token_np[token_id] = False + if sp.is_byte(token_id): + base_bytes_np[token_id] = 1 + continue + piece = sp.id_to_piece(token_id) + if piece.startswith("▁"): + has_leading_space_np[token_id] = True + piece = piece[1:] + base_bytes_np[token_id] = len(piece.encode("utf-8")) + return ( + torch.tensor(base_bytes_np, dtype=torch.int16, device=device), + torch.tensor(has_leading_space_np, dtype=torch.bool, device=device), + torch.tensor(is_boundary_token_np, dtype=torch.bool, device=device), + ) + + +def load_validation_tokens(pattern: str, seq_len: int) -> Tensor: + files = [Path(p) for p in sorted(glob.glob(pattern))] + if not files: + raise FileNotFoundError(f"No files found for pattern: {pattern}") + # The export pipeline writes the fixed first-50k-doc validation set to fineweb_val_*. + tokens = torch.cat([load_data_shard(file) for file in files]).contiguous() + usable = ((tokens.numel() - 1) // seq_len) * seq_len + if usable <= 0: + raise ValueError(f"Validation split is too short for TRAIN_SEQ_LEN={seq_len}") + return tokens[: usable + 1] + + +def eval_val( + args: Hyperparameters, + model: nn.Module, + rank: int, + world_size: int, + device: torch.device, + grad_accum_steps: int, + val_tokens: Tensor, + base_bytes_lut: Tensor, + has_leading_space_lut: Tensor, + is_boundary_token_lut: Tensor, +) -> tuple[float, float]: + # Validation computes two metrics: + # - val_loss: token cross-entropy (natural log) + # - val_bpb: tokenizer-agnostic compression metric used by the challenge + local_batch_tokens = args.val_batch_size // (world_size * grad_accum_steps) + if local_batch_tokens < args.train_seq_len: + raise ValueError( + "VAL_BATCH_SIZE must provide at least one sequence per rank; " + f"got VAL_BATCH_SIZE={args.val_batch_size}, WORLD_SIZE={world_size}, " + f"GRAD_ACCUM_STEPS={grad_accum_steps}, TRAIN_SEQ_LEN={args.train_seq_len}" + ) + local_batch_seqs = local_batch_tokens // args.train_seq_len + total_seqs = (val_tokens.numel() - 1) // args.train_seq_len + seq_start = (total_seqs * rank) // world_size + seq_end = (total_seqs * (rank + 1)) // world_size + val_loss_sum = torch.zeros((), device=device, dtype=torch.float64) + val_token_count = torch.zeros((), device=device, dtype=torch.float64) + val_byte_count = torch.zeros((), device=device, dtype=torch.float64) + + model.eval() + with torch.inference_mode(): + for batch_seq_start in range(seq_start, seq_end, local_batch_seqs): + batch_seq_end = min(batch_seq_start + local_batch_seqs, seq_end) + raw_start = batch_seq_start * args.train_seq_len + raw_end = batch_seq_end * args.train_seq_len + 1 + local = val_tokens[raw_start:raw_end].to(device=device, dtype=torch.int64, non_blocking=True) + x = local[:-1].reshape(-1, args.train_seq_len) + y = local[1:].reshape(-1, args.train_seq_len) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + batch_loss = model(x, y).detach() + batch_token_count = float(y.numel()) + val_loss_sum += batch_loss.to(torch.float64) * batch_token_count + val_token_count += batch_token_count + prev_ids = x.reshape(-1) + tgt_ids = y.reshape(-1) + token_bytes = base_bytes_lut[tgt_ids].to(dtype=torch.int16) + token_bytes += (has_leading_space_lut[tgt_ids] & ~is_boundary_token_lut[prev_ids]).to(dtype=torch.int16) + val_byte_count += token_bytes.to(torch.float64).sum() + + if dist.is_available() and dist.is_initialized(): + dist.all_reduce(val_loss_sum, op=dist.ReduceOp.SUM) + dist.all_reduce(val_token_count, op=dist.ReduceOp.SUM) + dist.all_reduce(val_byte_count, op=dist.ReduceOp.SUM) + + val_loss = val_loss_sum / val_token_count + bits_per_token = val_loss.item() / math.log(2.0) + tokens_per_byte = val_token_count.item() / val_byte_count.item() + model.train() + return float(val_loss.item()), float(bits_per_token * tokens_per_byte) + + +# ----------------------------- +# SCORE-FIRST TTT (Legal Test-Time Training) +# ----------------------------- +# Protocol (PR #461 recipe, PG ranks 1-3): +# 1. Chunk validation data into ttt_chunk_tokens windows +# 2. For each chunk: SCORE all tokens under inference_mode(), then TRAIN on scored data +# 3. Last chunk: score only (never train on un-scored data) +# Guarantees: every token scored BEFORE any parameter update that could use it. + +def eval_val_score_first_ttt( + args: Hyperparameters, + base_model: nn.Module, + rank: int, + world_size: int, + device: torch.device, + val_tokens: Tensor, + base_bytes_lut: Tensor, + has_leading_space_lut: Tensor, + is_boundary_token_lut: Tensor, + log0=print, +) -> tuple[float, float]: + seq_len = args.train_seq_len + total_tokens = val_tokens.numel() - 1 + stride = seq_len # non-overlapping windows for simplicity + ttt_chunk = args.ttt_chunk_tokens + + # Pre-compute window starts + window_starts = [ws for ws in range(0, total_tokens, stride) + if min(ws + seq_len, total_tokens) - ws >= 1] + + # Assign windows to chunks + num_chunks = (total_tokens + ttt_chunk - 1) // ttt_chunk + chunk_windows: list[list[int]] = [[] for _ in range(num_chunks)] + for ws in window_starts: + end = min(ws + seq_len, total_tokens) + ci = min(ws // ttt_chunk, num_chunks - 1) + chunk_windows[ci].append(ws) + + log0(f"ttt:start chunks={num_chunks} chunk_tokens={ttt_chunk} " + f"windows={len(window_starts)} ttt_lr={args.ttt_lr} " + f"ttt_epochs={args.ttt_epochs} freeze_blocks={args.ttt_freeze_blocks}") + + loss_sum = torch.zeros((), device=device, dtype=torch.float64) + token_count = torch.zeros((), device=device, dtype=torch.float64) + byte_count = torch.zeros((), device=device, dtype=torch.float64) + + # Freeze first N blocks + frozen_block_ids = set(range(min(args.ttt_freeze_blocks, len(base_model.blocks)))) + ttt_params = [] + for name, p in base_model.named_parameters(): + freeze = any(f"blocks.{bi}." in name for bi in frozen_block_ids) + if freeze: + p.requires_grad_(False) + else: + p.requires_grad_(True) + ttt_params.append(p) + + log0(f"ttt:params unfrozen={sum(p.numel() for p in ttt_params)} " + f"frozen={sum(p.numel() for p in base_model.parameters() if not p.requires_grad)}") + + optimizer = torch.optim.SGD(ttt_params, lr=args.ttt_lr, momentum=args.ttt_momentum) + t0 = time.perf_counter() + + for ci in range(num_chunks): + windows = chunk_windows[ci] + if not windows: + continue + + # --- PHASE 1: SCORE (no gradients) --- + my_s = (len(windows) * rank) // world_size + my_e = (len(windows) * (rank + 1)) // world_size + my_windows = windows[my_s:my_e] + + base_model.eval() + with torch.inference_mode(): + for bi in range(0, len(my_windows), args.ttt_batch_seqs): + batch_ws = my_windows[bi:bi + args.ttt_batch_seqs] + bsz = len(batch_ws) + x_batch = torch.zeros(bsz, seq_len, dtype=torch.int64, device=device) + y_batch = torch.zeros(bsz, seq_len, dtype=torch.int64, device=device) + wlens: list[int] = [] + + for i, ws in enumerate(batch_ws): + end = min(ws + seq_len, total_tokens) + wlen = end - ws + wlens.append(wlen) + chunk_tok = val_tokens[ws:end + 1].to(dtype=torch.int64, device=device) + x_batch[i, :wlen] = chunk_tok[:-1] + y_batch[i, :wlen] = chunk_tok[1:] + + with torch.autocast(device_type="cuda", dtype=torch.bfloat16): + logits = base_model.forward_logits(x_batch) + + nll = F.cross_entropy( + logits.reshape(-1, logits.size(-1)).float(), + y_batch.reshape(-1), reduction="none", + ).reshape(bsz, seq_len) + + for i in range(bsz): + wlen = wlens[i] + scored_nll = nll[i, :wlen].to(torch.float64) + loss_sum += scored_nll.sum() + token_count += float(wlen) + tgt = y_batch[i, :wlen] + prev = x_batch[i, :wlen] + tb = base_bytes_lut[tgt].to(torch.float64) + tb += (has_leading_space_lut[tgt] & ~is_boundary_token_lut[prev]).to(torch.float64) + byte_count += tb.sum() + + # --- PHASE 2: TRAIN on scored chunk (skip last chunk) --- + is_last_chunk = (ci == num_chunks - 1) + if not is_last_chunk and args.ttt_epochs > 0: + base_model.train() + chunk_start = ci * ttt_chunk + chunk_end = min((ci + 1) * ttt_chunk, total_tokens) + chunk_seqs = (chunk_end - chunk_start) // seq_len + if chunk_seqs > 0: + # Cosine LR decay across chunks + cos_lr = args.ttt_lr * 0.5 * (1.0 + math.cos(math.pi * ci / max(num_chunks - 1, 1))) + for pg in optimizer.param_groups: + pg['lr'] = cos_lr + + my_seq_s = (chunk_seqs * rank) // world_size + my_seq_e = (chunk_seqs * (rank + 1)) // world_size + + for _ep in range(args.ttt_epochs): + for bs in range(my_seq_s, my_seq_e, args.ttt_batch_seqs): + be = min(bs + args.ttt_batch_seqs, my_seq_e) + start_tok = chunk_start + bs * seq_len + end_tok = chunk_start + be * seq_len + 1 + if end_tok > val_tokens.numel(): + continue + local = val_tokens[start_tok:end_tok].to(device=device, dtype=torch.int64) + x = local[:-1].reshape(-1, seq_len) + y = local[1:].reshape(-1, seq_len) + optimizer.zero_grad(set_to_none=True) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16): + loss = base_model(x, y) + loss.backward() + if world_size > 1: + for p in ttt_params: + if p.grad is not None: + dist.all_reduce(p.grad, op=dist.ReduceOp.AVG) + torch.nn.utils.clip_grad_norm_(ttt_params, args.ttt_grad_clip) + optimizer.step() + + if rank == 0 and (ci % 10 == 0 or ci == num_chunks - 1): + elapsed = time.perf_counter() - t0 + rl = loss_sum.item() / max(token_count.item(), 1) + rbpb = rl / math.log(2.0) * (token_count.item() / max(byte_count.item(), 1)) if token_count.item() > 0 else 0.0 + log0(f" ttt_chunk [{ci+1}/{num_chunks}] bpb={rbpb:.6f} time={elapsed:.1f}s") + + # AllReduce across ranks + if dist.is_available() and dist.is_initialized(): + dist.all_reduce(loss_sum, op=dist.ReduceOp.SUM) + dist.all_reduce(token_count, op=dist.ReduceOp.SUM) + dist.all_reduce(byte_count, op=dist.ReduceOp.SUM) + + val_loss = (loss_sum / token_count).item() + val_bpb = val_loss / math.log(2.0) * (token_count.item() / byte_count.item()) + + # Restore grad for all params + for p in base_model.parameters(): + p.requires_grad_(True) + base_model.eval() + + log0(f"ttt:done val_loss={val_loss:.6f} val_bpb={val_bpb:.6f} " + f"elapsed={time.perf_counter() - t0:.1f}s") + return val_loss, val_bpb + + +# ----------------------------- +# POST-TRAINING QUANTIZATION +# ----------------------------- +# +# It's silly to export our model, which is trained in bf16 and fp32, at that same precision. +# Instead, we get approximately the same model (with a small hit) by quantizing the model to int8 & zlib compressing. +# We can then decompress the model and run in higher precision for evaluation, after closing in under the size limit. + +CONTROL_TENSOR_NAME_PATTERNS = tuple( + pattern + for pattern in os.environ.get( + "CONTROL_TENSOR_NAME_PATTERNS", + "attn_scale,attn_scales,mlp_scale,mlp_scales,resid_mix,resid_mixes,q_gain,skip_weight,skip_weights", + ).split(",") + if pattern +) +INT8_KEEP_FLOAT_FP32_NAME_PATTERNS = tuple( + pattern + for pattern in os.environ.get( + "INT8_KEEP_FLOAT_FP32_NAME_PATTERNS", + ",".join(CONTROL_TENSOR_NAME_PATTERNS), + ).split(",") + if pattern +) +INT8_KEEP_FLOAT_MAX_NUMEL = 65_536 +INT8_KEEP_FLOAT_STORE_DTYPE = torch.float16 +INT8_PER_ROW_SCALE_DTYPE = torch.float16 +INT8_CLIP_PERCENTILE = 99.99984 +INT8_CLIP_Q = INT8_CLIP_PERCENTILE / 100.0 + +def tensor_nbytes(t: Tensor) -> int: + return int(t.numel()) * int(t.element_size()) + +def keep_float_tensor(name: str, t: Tensor, passthrough_orig_dtypes: dict[str, str]) -> Tensor: + if any(pattern in name for pattern in INT8_KEEP_FLOAT_FP32_NAME_PATTERNS): + return t.float().contiguous() + if t.dtype in {torch.float32, torch.bfloat16}: + passthrough_orig_dtypes[name] = str(t.dtype).removeprefix("torch.") + return t.to(dtype=INT8_KEEP_FLOAT_STORE_DTYPE).contiguous() + return t + +def quantize_float_tensor(t: Tensor) -> tuple[Tensor, Tensor]: + t32 = t.float() + if t32.ndim == 2: + # Matrices get one scale per row, which usually tracks output-channel + # ranges much better than a single tensor-wide scale. + clip_abs = ( + torch.quantile(t32.abs(), INT8_CLIP_Q, dim=1) + if t32.numel() + else torch.empty((t32.shape[0],), dtype=torch.float32) + ) + clipped = torch.maximum(torch.minimum(t32, clip_abs[:, None]), -clip_abs[:, None]) + scale = (clip_abs / 127.0).clamp_min(1.0 / 127.0) + q = torch.clamp(torch.round(clipped / scale[:, None]), -127, 127).to(torch.int8).contiguous() + return q, scale.to(dtype=INT8_PER_ROW_SCALE_DTYPE).contiguous() + + # Vectors / scalars use a simpler per-tensor scale. + clip_abs = float(torch.quantile(t32.abs().flatten(), INT8_CLIP_Q).item()) if t32.numel() else 0.0 + scale = torch.tensor(clip_abs / 127.0 if clip_abs > 0 else 1.0, dtype=torch.float32) + q = torch.clamp(torch.round(torch.clamp(t32, -clip_abs, clip_abs) / scale), -127, 127).to(torch.int8).contiguous() + return q, scale + +def quantize_state_dict_int8(state_dict: dict[str, Tensor]): + # Single supported clean-script export format: + # - per-row int8 for 2D float tensors + # - per-tensor int8 for other float tensors + # - exact passthrough for non-floats + # - passthrough for small float tensors, stored as fp16 to save bytes + quantized: dict[str, Tensor] = {} + scales: dict[str, Tensor] = {} + dtypes: dict[str, str] = {} + passthrough: dict[str, Tensor] = {} + passthrough_orig_dtypes: dict[str, str] = {} + qmeta: dict[str, dict[str, object]] = {} + stats = dict.fromkeys( + ("param_count", "num_tensors", "num_float_tensors", "num_nonfloat_tensors", "baseline_tensor_bytes", "int8_payload_bytes"), + 0, + ) + + for name, tensor in state_dict.items(): + t = tensor.detach().to("cpu").contiguous() + stats["param_count"] += int(t.numel()) + stats["num_tensors"] += 1 + stats["baseline_tensor_bytes"] += tensor_nbytes(t) + + if not t.is_floating_point(): + stats["num_nonfloat_tensors"] += 1 + passthrough[name] = t + stats["int8_payload_bytes"] += tensor_nbytes(t) + continue + + # Small float tensors are cheap enough to keep directly. We still downcast + # fp32/bf16 passthrough tensors to fp16 so metadata does not dominate size. + if t.numel() <= INT8_KEEP_FLOAT_MAX_NUMEL: + kept = keep_float_tensor(name, t, passthrough_orig_dtypes) + passthrough[name] = kept + stats["int8_payload_bytes"] += tensor_nbytes(kept) + continue + + stats["num_float_tensors"] += 1 + q, s = quantize_float_tensor(t) + if s.ndim > 0: + qmeta[name] = {"scheme": "per_row", "axis": 0} + quantized[name] = q + scales[name] = s + dtypes[name] = str(t.dtype).removeprefix("torch.") + stats["int8_payload_bytes"] += tensor_nbytes(q) + tensor_nbytes(s) + + obj: dict[str, object] = { + "__quant_format__": "int8_clean_per_row_v1", + "quantized": quantized, + "scales": scales, + "dtypes": dtypes, + "passthrough": passthrough, + } + if qmeta: + obj["qmeta"] = qmeta + if passthrough_orig_dtypes: + obj["passthrough_orig_dtypes"] = passthrough_orig_dtypes + return obj, stats + +def dequantize_state_dict_int8(obj: dict[str, object]) -> dict[str, Tensor]: + out: dict[str, Tensor] = {} + qmeta = obj.get("qmeta", {}) + passthrough_orig_dtypes = obj.get("passthrough_orig_dtypes", {}) + for name, q in obj["quantized"].items(): + dtype = getattr(torch, obj["dtypes"][name]) + s = obj["scales"][name] + if qmeta.get(name, {}).get("scheme") == "per_row" or s.ndim > 0: + s = s.to(dtype=torch.float32) + # Broadcast the saved row scale back across trailing dimensions. + out[name] = (q.float() * s.view(q.shape[0], *([1] * (q.ndim - 1)))).to(dtype=dtype).contiguous() + else: + scale = float(s.item()) + out[name] = (q.float() * scale).to(dtype=dtype).contiguous() + for name, t in obj["passthrough"].items(): + # Restore small tensors, undoing the temporary fp16 storage cast if needed. + out_t = t.detach().to("cpu").contiguous() + orig_dtype = passthrough_orig_dtypes.get(name) + if isinstance(orig_dtype, str): + out_t = out_t.to(dtype=getattr(torch, orig_dtype)).contiguous() + out[name] = out_t + return out + + +# ----------------------------- +# DATA LOADING +# ----------------------------- + +def load_data_shard(file: Path) -> Tensor: + header_bytes = 256 * np.dtype(" None: + self.file_idx = (self.file_idx + 1) % len(self.files) + self.tokens = load_data_shard(self.files[self.file_idx]) + self.pos = 0 + + def take(self, n: int) -> Tensor: + chunks: list[Tensor] = [] + remaining = n + while remaining > 0: + avail = self.tokens.numel() - self.pos + if avail <= 0: + self._advance_file() + continue + k = min(remaining, avail) + chunks.append(self.tokens[self.pos : self.pos + k]) + self.pos += k + remaining -= k + return chunks[0] if len(chunks) == 1 else torch.cat(chunks) + + +class DistributedTokenLoader: + # Each call consumes a contiguous chunk from the shared token stream, then slices out + # one disjoint span per rank. The extra "+1" token lets us build (x, y) by shifting. + def __init__(self, pattern: str, rank: int, world_size: int, device: torch.device): + self.rank = rank + self.world_size = world_size + self.device = device + self.stream = TokenStream(pattern) + + def next_batch(self, global_tokens: int, seq_len: int, grad_accum_steps: int) -> tuple[Tensor, Tensor]: + local_tokens = global_tokens // (self.world_size * grad_accum_steps) + per_rank_span = local_tokens + 1 + chunk = self.stream.take(per_rank_span * self.world_size) + start = self.rank * per_rank_span + local = chunk[start : start + per_rank_span].to(dtype=torch.int64) + x = local[:-1].reshape(-1, seq_len) + y = local[1:].reshape(-1, seq_len) + return x.to(self.device, non_blocking=True), y.to(self.device, non_blocking=True) + +# ----------------------------- +# TRANSFORMER MODULES +# ----------------------------- + +class RMSNorm(nn.Module): + def __init__(self, eps: float | None = None): + super().__init__() + self.eps = eps + + def forward(self, x: Tensor) -> Tensor: + return F.rms_norm(x, (x.size(-1),), eps=self.eps) + + +class CastedLinear(nn.Linear): + # Keep weights in fp32 for optimizer/state quality, cast at matmul time for bf16 compute. + def forward(self, x: Tensor) -> Tensor: + bias = self.bias.to(x.dtype) if self.bias is not None else None + return F.linear(x, self.weight.to(x.dtype), bias) + + +def restore_low_dim_params_to_fp32(module: nn.Module) -> None: + # Keep small/control parameters in fp32 even when the model body runs in bf16. + with torch.no_grad(): + for name, param in module.named_parameters(): + if (param.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS)) and param.dtype != torch.float32: + param.data = param.data.float() + + +class Rotary(nn.Module): + # Caches cos/sin tables per sequence length on the current device. + def __init__(self, dim: int, base: float = 10000.0): + super().__init__() + inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.float32) / dim)) + self.register_buffer("inv_freq", inv_freq, persistent=False) + self._seq_len_cached = 0 + self._cos_cached: Tensor | None = None + self._sin_cached: Tensor | None = None + + def forward(self, seq_len: int, device: torch.device, dtype: torch.dtype) -> tuple[Tensor, Tensor]: + if ( + self._cos_cached is None + or self._sin_cached is None + or self._seq_len_cached != seq_len + or self._cos_cached.device != device + ): + t = torch.arange(seq_len, device=device, dtype=self.inv_freq.dtype) + freqs = torch.outer(t, self.inv_freq.to(device)) + self._cos_cached = freqs.cos()[None, None, :, :] + self._sin_cached = freqs.sin()[None, None, :, :] + self._seq_len_cached = seq_len + return self._cos_cached.to(dtype=dtype), self._sin_cached.to(dtype=dtype) + + +def apply_rotary_emb(x: Tensor, cos: Tensor, sin: Tensor) -> Tensor: + half = x.size(-1) // 2 + x1, x2 = x[..., :half], x[..., half:] + return torch.cat((x1 * cos + x2 * sin, x1 * (-sin) + x2 * cos), dim=-1) + + +# ╔═══════════════════════════════════════════════════════════════════════════════════╗ +# ║ CAUSAL SELF-ATTENTION (with GQA) ║ +# ║ ║ +# ║ Plain English: Each token asks "which earlier tokens should I pay attention to?" ║ +# ║ It computes a compatibility score (Q·K) with every previous token, then takes ║ +# ║ a weighted average of their values (V). "Causal" = can only look backward. ║ +# ║ ║ +# ║ ┌──────────────── Input x: [batch, seq_len, 512] ────────────────┐ ║ +# ║ │ │ │ ║ +# ║ │ ┌──────────────────┼──────────────────┐ │ ║ +# ║ │ │ │ │ │ ║ +# ║ │ ▼ ▼ ▼ │ ║ +# ║ │ c_q (512→512) c_k (512→256) c_v (512→256) │ ║ +# ║ │ │ │ │ │ ║ +# ║ │ ▼ ▼ ▼ │ ║ +# ║ │ Q: 8 heads × 64d K: 4 heads × 64d V: 4 heads × 64d │ ║ +# ║ │ │ │ │ ║ +# ║ │ ▼ ▼ │ ║ +# ║ │ RMSNorm(Q) RMSNorm(K) ← stabilize magnitudes │ ║ +# ║ │ │ │ │ ║ +# ║ │ ▼ ▼ │ ║ +# ║ │ RoPE(Q) RoPE(K) ← encode positions │ ║ +# ║ │ │ │ ║ +# ║ │ ▼ │ ║ +# ║ │ Q × q_gain ← learnable sharpness │ ║ +# ║ │ │ │ │ │ ║ +# ║ │ └──────────────────┼──────────────────┘ │ ║ +# ║ │ ▼ │ ║ +# ║ │ scaled_dot_product_attention │ ║ +# ║ │ softmax(Q·Kᵀ / √64) · V │ ║ +# ║ │ (causal mask + FlashAttention + GQA) │ ║ +# ║ │ │ │ ║ +# ║ │ ▼ │ ║ +# ║ │ proj (512→512) │ ║ +# ║ │ │ │ ║ +# ║ │ ▼ │ ║ +# ║ │ Output: [batch, seq_len, 512] │ ║ +# ║ └─────────────────────────────────────────────────────────────────┘ ║ +# ║ ║ +# ║ GROUPED QUERY ATTENTION (GQA) — why K,V have 4 heads but Q has 8: ║ +# ║ ║ +# ║ Standard (MHA, 8×8): This model (GQA, 8×4): ║ +# ║ Q₀ → K₀ Q₄ → K₄ Q₀ ─┐ ║ +# ║ Q₁ → K₁ Q₅ → K₅ Q₁ ─┤→ K₀,V₀ ║ +# ║ Q₂ → K₂ Q₆ → K₆ Q₂ ─┐ ║ +# ║ Q₃ → K₃ Q₇ → K₇ Q₃ ─┤→ K₁,V₁ ║ +# ║ (8 KV pairs = full cost) Q₄ ─┐ ║ +# ║ Q₅ ─┤→ K₂,V₂ ║ +# ║ Q₆ ─┐ ║ +# ║ Q₇ ─┤→ K₃,V₃ ║ +# ║ (4 KV pairs = 50% memory) ║ +# ║ ║ +# ║ CAUSAL MASK (is_causal=True) — what makes this a language model: ║ +# ║ ║ +# ║ attends to→ t₀ t₁ t₂ t₃ t₄ ║ +# ║ t₀ ✓ ✗ ✗ ✗ ✗ "The" can only see itself ║ +# ║ t₁ ✓ ✓ ✗ ✗ ✗ "cat" sees "The cat" ║ +# ║ t₂ ✓ ✓ ✓ ✗ ✗ "sat" sees "The cat sat" ║ +# ║ t₃ ✓ ✓ ✓ ✓ ✗ "on" sees "The cat sat on" ║ +# ║ t₄ ✓ ✓ ✓ ✓ ✓ "the" sees everything before it ║ +# ║ ║ +# ║ Each token predicts the NEXT token without peeking at the answer. ║ +# ║ ║ +# ║ WHY EACH STEP EXISTS: ║ +# ║ • RMSNorm on Q,K → prevents attention scores from exploding to ±inf ║ +# ║ • RoPE on Q,K → encodes WHERE each token is (position 0, 1, 2...) ║ +# ║ • q_gain (4.0-5.25) → learnable per-head sharpness (higher = more focused) ║ +# ║ • √64 denominator → scales dot products so softmax doesn't saturate ║ +# ║ • proj → mixes info from all 8 heads back to 512 dimensions ║ +# ╚═══════════════════════════════════════════════════════════════════════════════════╝ +class CausalSelfAttention(nn.Module): + def __init__( + self, + dim: int, + num_heads: int, + num_kv_heads: int, + rope_base: float, + qk_gain_init: float, + gated_attn: str = "none", + ): + super().__init__() + if dim % num_heads != 0: + raise ValueError("model_dim must be divisible by num_heads") + if num_heads % num_kv_heads != 0: + raise ValueError("num_heads must be divisible by num_kv_heads") + self.num_heads = num_heads + self.num_kv_heads = num_kv_heads + self.head_dim = dim // num_heads # 512 / 8 = 64 dimensions per head + self.gated_attn = gated_attn + if self.head_dim % 2 != 0: + raise ValueError("head_dim must be even for RoPE") + kv_dim = self.num_kv_heads * self.head_dim # 4 × 64 = 256 (half of Q's 512, GQA savings) + # Gated Attention (arxiv.org/abs/2505.06708, NeurIPS 2025 Best Paper): + # Widen Q projection to also produce gate logits from the same input. + # "none" → c_q: 512→512, no gate + # "headwise" → c_q: 512→520, +8 dims (1 gate scalar per head, ~9K extra params/layer) + # "elementwise" → c_q: 512→1024, +512 dims (1 gate per dimension, doubles Q projection) + if gated_attn == "headwise": + self.gate_dim = num_heads # 8 extra outputs + elif gated_attn == "elementwise": + self.gate_dim = dim # 512 extra outputs + else: + self.gate_dim = 0 + self.c_q = CastedLinear(dim, dim + self.gate_dim, bias=False) # Query + gate logits + self.c_k = CastedLinear(dim, kv_dim, bias=False) # Key: 512→256 (only 4 KV heads) + self.c_v = CastedLinear(dim, kv_dim, bias=False) # Value: 512→256 (only 4 KV heads) + self.proj = CastedLinear(dim, dim, bias=False) # Output projection: 512→512 + self.proj._zero_init = True # Start at zero so skip connections dominate early + self.q_gain = nn.Parameter(torch.full((num_heads,), qk_gain_init, dtype=torch.float32)) + self.rotary = Rotary(self.head_dim, base=rope_base) + + def forward(self, x: Tensor) -> Tensor: + bsz, seqlen, dim = x.shape + # Step 1: Project input → Q (+ gate logits if gated attention is enabled) + q_out = self.c_q(x) + if self.gate_dim > 0: + q_raw, gate_logits = q_out.split([dim, self.gate_dim], dim=-1) + else: + q_raw = q_out + # Reshape into multi-head format: [batch, heads, seq, head_dim] + q = q_raw.reshape(bsz, seqlen, self.num_heads, self.head_dim).transpose(1, 2) + k = self.c_k(x).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim).transpose(1, 2) + v = self.c_v(x).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim).transpose(1, 2) + # Step 2: Normalize Q and K so dot products don't explode + q = F.rms_norm(q, (q.size(-1),)) + k = F.rms_norm(k, (k.size(-1),)) + # Step 3: Apply rotary position embeddings — encode WHERE each token sits + cos, sin = self.rotary(seqlen, x.device, q.dtype) + q = apply_rotary_emb(q, cos, sin) + k = apply_rotary_emb(k, cos, sin) + # Step 4: Scale Q by learnable per-head gain — controls attention sharpness + q = q * self.q_gain.to(dtype=q.dtype)[None, :, None, None] + # Step 5: The actual attention — softmax(Q·Kᵀ/√d) · V with causal mask + # Uses FlashAttention kernel under the hood (memory-efficient, no full NxN matrix) + # GQA: repeat K,V heads to match Q heads (e.g. 4 KV → 8 Q, repeat 2×) + if self.num_kv_heads != self.num_heads: + rep = self.num_heads // self.num_kv_heads + k = k.repeat_interleave(rep, dim=1) + v = v.repeat_interleave(rep, dim=1) + y = F.scaled_dot_product_attention( + q, + k, + v, + attn_mask=None, + is_causal=True, + ) + # Step 6 (Gated Attention): sigmoid gate AFTER attention, BEFORE output projection + # Gate logits come from the Q projection (query-dependent, input-dependent). + # sigmoid(gate) ∈ [0,1] lets the model suppress uninformative heads/dims per token. + # Headwise: [bsz, seq, 8] → [bsz, 8, seq, 1] (one scalar per head) + # Elementwise: [bsz, seq, 512] → [bsz, 8, seq, 64] (one per dimension) + if self.gated_attn == "headwise": + gate = torch.sigmoid(gate_logits) # [bsz, seqlen, num_heads] + gate = gate.transpose(1, 2).unsqueeze(-1) # [bsz, num_heads, seqlen, 1] + y = y * gate + elif self.gated_attn == "elementwise": + gate = torch.sigmoid(gate_logits) # [bsz, seqlen, dim] + gate = gate.reshape(bsz, seqlen, self.num_heads, self.head_dim).transpose(1, 2) + y = y * gate + # Step 7: Reshape from multi-head back to [batch, seq, 512] and project + y = y.transpose(1, 2).contiguous().reshape(bsz, seqlen, dim) + return self.proj(y) + + +class MLP(nn.Module): + # relu^2 MLP from the original modded-nanogpt setup + def __init__(self, dim: int, mlp_mult: int, activation: str = "relu2"): + super().__init__() + self.activation = activation + hidden = mlp_mult * dim + self.fc = CastedLinear(dim, hidden, bias=False) + self.proj = CastedLinear(hidden, dim, bias=False) + self.proj._zero_init = True + + def forward(self, x: Tensor) -> Tensor: + x = self.fc(x) + if self.activation == "leaky_relu2": + x = F.leaky_relu(x, negative_slope=0.5) + else: + x = torch.relu(x) + return self.proj(x.square()) + + +class Block(nn.Module): + def __init__( + self, + dim: int, + num_heads: int, + num_kv_heads: int, + mlp_mult: int, + rope_base: float, + qk_gain_init: float, + gated_attn: str = "none", + activation: str = "relu2", + ): + super().__init__() + self.attn_norm = RMSNorm() + self.mlp_norm = RMSNorm() + self.attn = CausalSelfAttention(dim, num_heads, num_kv_heads, rope_base, qk_gain_init, gated_attn) + self.mlp = MLP(dim, mlp_mult, activation) + self.attn_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + self.mlp_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + self.resid_mix = nn.Parameter(torch.stack((torch.ones(dim), torch.zeros(dim))).float()) + + def forward(self, x: Tensor, x0: Tensor) -> Tensor: + mix = self.resid_mix.to(dtype=x.dtype) + x = mix[0][None, None, :] * x + mix[1][None, None, :] * x0 + attn_out = self.attn(self.attn_norm(x)) + x = x + self.attn_scale.to(dtype=x.dtype)[None, None, :] * attn_out + x = x + self.mlp_scale.to(dtype=x.dtype)[None, None, :] * self.mlp(self.mlp_norm(x)) + return x + + +class GPT(nn.Module): + def __init__( + self, + vocab_size: int, + num_layers: int, + model_dim: int, + num_heads: int, + num_kv_heads: int, + mlp_mult: int, + tie_embeddings: bool, + tied_embed_init_std: float, + logit_softcap: float, + rope_base: float, + qk_gain_init: float, + gated_attn: str = "none", + activation: str = "relu2", + ): + super().__init__() + if logit_softcap <= 0.0: + raise ValueError(f"logit_softcap must be positive, got {logit_softcap}") + self.tie_embeddings = tie_embeddings + self.tied_embed_init_std = tied_embed_init_std + self.logit_softcap = logit_softcap + self.tok_emb = nn.Embedding(vocab_size, model_dim) + self.num_encoder_layers = num_layers // 2 + self.num_decoder_layers = num_layers - self.num_encoder_layers + self.num_skip_weights = min(self.num_encoder_layers, self.num_decoder_layers) + self.skip_weights = nn.Parameter(torch.ones(self.num_skip_weights, model_dim, dtype=torch.float32)) + self.blocks = nn.ModuleList( + [ + Block( + model_dim, + num_heads, + num_kv_heads, + mlp_mult, + rope_base, + qk_gain_init, + gated_attn, + activation, + ) + for i in range(num_layers) + ] + ) + self.final_norm = RMSNorm() + self.lm_head = None if tie_embeddings else CastedLinear(model_dim, vocab_size, bias=False) + if self.lm_head is not None: + self.lm_head._zero_init = True + self._init_weights() + + def _init_weights(self) -> None: + if self.tie_embeddings: + nn.init.normal_(self.tok_emb.weight, mean=0.0, std=self.tied_embed_init_std) + for module in self.modules(): + if isinstance(module, nn.Linear) and getattr(module, "_zero_init", False): + nn.init.zeros_(module.weight) + + def _run_backbone(self, input_ids: Tensor) -> Tensor: + x = self.tok_emb(input_ids) + x = F.rms_norm(x, (x.size(-1),)) + x0 = x + skips: list[Tensor] = [] + for i in range(self.num_encoder_layers): + x = self.blocks[i](x, x0) + skips.append(x) + for i in range(self.num_decoder_layers): + if skips: + x = x + self.skip_weights[i].to(dtype=x.dtype)[None, None, :] * skips.pop() + x = self.blocks[self.num_encoder_layers + i](x, x0) + x = self.final_norm(x) + if self.tie_embeddings: + logits_proj = F.linear(x, self.tok_emb.weight) + else: + if self.lm_head is None: + raise RuntimeError("lm_head is required when tie_embeddings=False") + logits_proj = self.lm_head(x) + return self.logit_softcap * torch.tanh(logits_proj / self.logit_softcap) + + def forward_logits(self, input_ids: Tensor) -> Tensor: + """Return logits [batch, seq, vocab] without computing loss.""" + return self._run_backbone(input_ids) + + def forward(self, input_ids: Tensor, target_ids: Tensor) -> Tensor: + logits = self._run_backbone(input_ids) + logits = logits.reshape(-1, logits.size(-1)) + targets = target_ids.reshape(-1) + return F.cross_entropy(logits.float(), targets, reduction="mean") + + +# ----------------------------- +# TRAINING +# ----------------------------- + +def main() -> None: + global zeropower_via_newtonschulz5 + + code = Path(__file__).read_text(encoding="utf-8") + args = Hyperparameters() + zeropower_via_newtonschulz5 = torch.compile(zeropower_via_newtonschulz5) + + # ----------------------------- + # DISTRIBUTED + CUDA SETUP + # ----------------------------- + + distributed = "RANK" in os.environ and "WORLD_SIZE" in os.environ + rank = int(os.environ.get("RANK", "0")) + world_size = int(os.environ.get("WORLD_SIZE", "1")) + local_rank = int(os.environ.get("LOCAL_RANK", "0")) + if world_size <= 0: + raise ValueError(f"WORLD_SIZE must be positive, got {world_size}") + if 8 % world_size != 0: + raise ValueError(f"WORLD_SIZE={world_size} must divide 8 so grad_accum_steps stays integral") + grad_accum_steps = 8 // world_size + grad_scale = 1.0 / grad_accum_steps + if not torch.cuda.is_available(): + raise RuntimeError("CUDA is required") + device = torch.device("cuda", local_rank) + torch.cuda.set_device(device) + if distributed: + dist.init_process_group(backend="nccl", device_id=device) + dist.barrier() + master_process = rank == 0 + + # Fast math knobs + torch.backends.cuda.matmul.allow_tf32 = True + torch.backends.cudnn.allow_tf32 = True + from torch.backends.cuda import enable_cudnn_sdp, enable_flash_sdp, enable_math_sdp, enable_mem_efficient_sdp + + enable_cudnn_sdp(False) + enable_flash_sdp(True) + enable_mem_efficient_sdp(False) + enable_math_sdp(False) + + logfile = None + if master_process: + os.makedirs("logs", exist_ok=True) + logfile = f"logs/{args.run_id}.txt" + print(logfile) + + def log0(msg: str, console: bool = True) -> None: + if not master_process: + return + if console: + print(msg) + if logfile is not None: + with open(logfile, "a", encoding="utf-8") as f: + print(msg, file=f) + + log0(code, console=False) + log0("=" * 100, console=False) + log0(f"Running Python {sys.version}", console=False) + log0(f"Running PyTorch {torch.__version__}", console=False) + log0( + subprocess.run(["nvidia-smi"], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True, check=False).stdout, + console=False, + ) + log0("=" * 100, console=False) + + # ----------------------------- + # TOKENIZER + VALIDATION METRIC SETUP + # ----------------------------- + + random.seed(args.seed) + np.random.seed(args.seed) + torch.manual_seed(args.seed) + torch.cuda.manual_seed_all(args.seed) + + if not args.tokenizer_path.endswith(".model"): + raise ValueError(f"Script only setup for SentencePiece .model file: {args.tokenizer_path}") + sp = spm.SentencePieceProcessor(model_file=args.tokenizer_path) + if int(sp.vocab_size()) != args.vocab_size: + raise ValueError( + f"VOCAB_SIZE={args.vocab_size} does not match tokenizer vocab_size={int(sp.vocab_size())}" + ) + dataset_dir = Path(args.data_path).resolve() + actual_train_files = len(list(dataset_dir.glob("fineweb_train_*.bin"))) + val_tokens = load_validation_tokens(args.val_files, args.train_seq_len) + base_bytes_lut, has_leading_space_lut, is_boundary_token_lut = build_sentencepiece_luts( + sp, args.vocab_size, device + ) + log0(f"val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path={args.tokenizer_path}") + log0(f"train_loader:dataset:{dataset_dir.name} train_shards:{actual_train_files}") + log0(f"val_loader:shards pattern={args.val_files} tokens:{val_tokens.numel() - 1}") + + # ----------------------------- + # MODEL + OPTIMIZER SETUP + # ----------------------------- + + base_model = GPT( + vocab_size=args.vocab_size, + num_layers=args.num_layers, + model_dim=args.model_dim, + num_heads=args.num_heads, + num_kv_heads=args.num_kv_heads, + mlp_mult=args.mlp_mult, + tie_embeddings=args.tie_embeddings, + tied_embed_init_std=args.tied_embed_init_std, + logit_softcap=args.logit_softcap, + rope_base=args.rope_base, + qk_gain_init=args.qk_gain_init, + gated_attn=args.gated_attn, + activation=args.activation, + ).to(device).bfloat16() + for module in base_model.modules(): + if isinstance(module, CastedLinear): + module.float() + restore_low_dim_params_to_fp32(base_model) + compiled_model = torch.compile(base_model, dynamic=False, fullgraph=True) + model: nn.Module = DDP(compiled_model, device_ids=[local_rank], broadcast_buffers=False) if distributed else compiled_model + + # ===================================================================== + # OPTIMIZER SPLIT — Different param types get different optimizers & LRs + # ===================================================================== + # + # The model's parameters are sorted into 3-4 groups, each with its own + # optimizer and learning rate. Think of it like a team where each role + # needs a different management style: + # + # ┌─────────────────────────────────────────────────────────────────┐ + # │ Parameter Group Optimizer LR % of Params │ + # │ ───────────────────── ───────── ────── ─────────── │ + # │ 1. Embedding table Adam 0.05 ~3% (SP1024) │ + # │ (tok_emb.weight) ~12% (SP8192) │ + # │ "the dictionary" │ + # │ │ + # │ 2. Matrix weights MUON 0.04 ~95% of blocks │ + # │ (Q, K, V, proj, The heavy lifters │ + # │ fc, mlp.proj) — 2D matrices │ + # │ "attention + MLP that do the real │ + # │ computation" learning │ + # │ │ + # │ 3. Scalar params Adam 0.04 <1% │ + # │ (attn_scale, Tiny knobs that │ + # │ mlp_scale, fine-tune how │ + # │ resid_mix, blocks blend │ + # │ q_gain, their outputs │ + # │ skip_weights) │ + # │ │ + # │ 4. Output head Adam 0.008 Only if untied │ + # │ (lm_head.weight) embeddings │ + # │ "vector → token (not used in │ + # │ probability" baseline) │ + # └─────────────────────────────────────────────────────────────────┘ + # + # WHY MUON FOR MATRICES? + # Muon orthogonalizes gradient updates via Newton-Schulz iterations. + # This only works on 2D matrices (needs rows & columns to orthogonalize). + # 1D vectors and scalars are too small — they use standard Adam instead. + # + # WHY DIFFERENT LEARNING RATES? + # - Embeddings are a lookup table — aggressive updates cause instability + # - Matrix weights are the core compute — Muon handles the LR scaling + # - Scalars are sensitive knobs — moderate LR with Adam's adaptive step + # - Output head (if untied) maps back to vocab — needs gentler updates + # + block_named_params = list(base_model.blocks.named_parameters()) + + # --- Group 2: Matrix params (2D tensors in transformer blocks) → Muon --- + # Collects: c_q.weight, c_k.weight, c_v.weight, proj.weight, fc.weight, mlp.proj.weight + # Excludes: control tensors like attn_scale, resid_mix (those are 2D but act as scalars) + matrix_params = [ + p + for name, p in block_named_params + if p.ndim == 2 and not any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) + ] + + # --- Group 3: Scalar params (1D vectors + control tensors) → Adam --- + # Collects: attn_scale, mlp_scale, resid_mix, q_gain, skip_weights + # These are the small "knobs" that tune how blocks combine their outputs + scalar_params = [ + p + for name, p in block_named_params + if p.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) + ] + if base_model.skip_weights.numel() > 0: + scalar_params.append(base_model.skip_weights) # U-Net skip connection weights + + # --- Group 1: Embedding table → Adam (gentle LR) --- + # The "dictionary" that maps token IDs to 512-dim vectors + # If tied: same weight used for input embedding AND output prediction (LR=0.05) + # If untied: separate input embedding (LR=0.6) and output head (LR=0.008) + token_lr = args.tied_embed_lr if args.tie_embeddings else args.embed_lr + optimizer_tok = torch.optim.Adam( + [{"params": [base_model.tok_emb.weight], "lr": token_lr, "base_lr": token_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + fused=True, # fused=True: single GPU kernel for speed + ) + + # --- Group 2 optimizer: Muon for the heavy 2D matrices --- + # Uses Newton-Schulz orthogonalization (5 steps) to normalize gradients + # before applying them. This is the key innovation from modded-nanogpt. + optimizer_muon = Muon( + matrix_params, + lr=args.matrix_lr, + momentum=args.muon_momentum, + backend_steps=args.muon_backend_steps, + ) + for group in optimizer_muon.param_groups: + group["base_lr"] = args.matrix_lr # Stash base LR for warmdown scheduling + + # --- Group 3 optimizer: Adam for small scalar/vector params --- + optimizer_scalar = torch.optim.Adam( + [{"params": scalar_params, "lr": args.scalar_lr, "base_lr": args.scalar_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + fused=True, + ) + + # Collect all optimizers — training loop will step() all of them each iteration + optimizers: list[torch.optim.Optimizer] = [optimizer_tok, optimizer_muon, optimizer_scalar] + + # --- Group 4 (optional): Output head → Adam (only when embeddings are untied) --- + # In baseline, tie_embeddings=True so lm_head is None and this is skipped. + # When untied, the output head gets its own gentle LR (0.008) + if base_model.lm_head is not None: + optimizer_head = torch.optim.Adam( + [{"params": [base_model.lm_head.weight], "lr": args.head_lr, "base_lr": args.head_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + fused=True, + ) + optimizers.insert(1, optimizer_head) + + n_params = sum(p.numel() for p in base_model.parameters()) + log0(f"model_params:{n_params}") + log0(f"world_size:{world_size} grad_accum_steps:{grad_accum_steps}") + log0("sdp_backends:cudnn=False flash=True mem_efficient=False math=False") + log0(f"attention_mode:gqa num_heads:{args.num_heads} num_kv_heads:{args.num_kv_heads}") + log0( + f"tie_embeddings:{args.tie_embeddings} embed_lr:{token_lr} " + f"head_lr:{args.head_lr if base_model.lm_head is not None else 0.0} " + f"matrix_lr:{args.matrix_lr} scalar_lr:{args.scalar_lr}" + ) + log0( + f"train_batch_tokens:{args.train_batch_tokens} train_seq_len:{args.train_seq_len} " + f"iterations:{args.iterations} warmup_steps:{args.warmup_steps} " + f"max_wallclock_seconds:{args.max_wallclock_seconds:.3f}" + ) + log0(f"seed:{args.seed}") + + # ----------------------------- + # DATA LOADER & MODEL WARMUP + # ----------------------------- + + train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) + + def zero_grad_all() -> None: + for opt in optimizers: + opt.zero_grad(set_to_none=True) + + max_wallclock_ms = 1000.0 * args.max_wallclock_seconds if args.max_wallclock_seconds > 0 else None + + def lr_mul(step: int, elapsed_ms: float) -> float: + if args.warmdown_iters <= 0: + return 1.0 + if max_wallclock_ms is None: + warmdown_start = max(args.iterations - args.warmdown_iters, 0) + return max((args.iterations - step) / max(args.warmdown_iters, 1), 0.0) if warmdown_start <= step < args.iterations else 1.0 + step_ms = elapsed_ms / max(step, 1) + warmdown_ms = args.warmdown_iters * step_ms + remaining_ms = max(max_wallclock_ms - elapsed_ms, 0.0) + return remaining_ms / max(warmdown_ms, 1e-9) if remaining_ms <= warmdown_ms else 1.0 + + # Warmup primes the compiled forward/backward/optimizer paths, then we restore the + # initial weights/optimizer state so measured training starts from the true init. + if args.warmup_steps > 0: + initial_model_state = {name: tensor.detach().cpu().clone() for name, tensor in base_model.state_dict().items()} + initial_optimizer_states = [copy.deepcopy(opt.state_dict()) for opt in optimizers] + model.train() + for warmup_step in range(args.warmup_steps): + zero_grad_all() + for micro_step in range(grad_accum_steps): + if distributed: + model.require_backward_grad_sync = micro_step == grad_accum_steps - 1 + x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + warmup_loss = model(x, y) + (warmup_loss * grad_scale).backward() + for opt in optimizers: + opt.step() + zero_grad_all() + if args.warmup_steps <= 20 or (warmup_step + 1) % 10 == 0 or warmup_step + 1 == args.warmup_steps: + log0(f"warmup_step:{warmup_step + 1}/{args.warmup_steps}") + base_model.load_state_dict(initial_model_state, strict=True) + for opt, state in zip(optimizers, initial_optimizer_states, strict=True): + opt.load_state_dict(state) + zero_grad_all() + if distributed: + model.require_backward_grad_sync = True + train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) + + # ----------------------------- + # MAIN TRAINING LOOP + # ----------------------------- + + training_time_ms = 0.0 + stop_after_step: int | None = None + torch.cuda.synchronize() + t0 = time.perf_counter() + + step = 0 + while True: + last_step = step == args.iterations or (stop_after_step is not None and step >= stop_after_step) + + should_validate = last_step or (args.val_loss_every > 0 and step % args.val_loss_every == 0) + if should_validate: + torch.cuda.synchronize() + training_time_ms += 1000.0 * (time.perf_counter() - t0) + val_loss, val_bpb = eval_val( + args, + model, + rank, + world_size, + device, + grad_accum_steps, + val_tokens, + base_bytes_lut, + has_leading_space_lut, + is_boundary_token_lut, + ) + log0( + f"step:{step}/{args.iterations} val_loss:{val_loss:.4f} val_bpb:{val_bpb:.4f} " + f"train_time:{training_time_ms:.0f}ms step_avg:{training_time_ms / max(step, 1):.2f}ms" + ) + torch.cuda.synchronize() + t0 = time.perf_counter() + + if last_step: + if stop_after_step is not None and step < args.iterations: + log0( + f"stopping_early: wallclock_cap train_time:{training_time_ms:.0f}ms " + f"step:{step}/{args.iterations}" + ) + break + + elapsed_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) + scale = lr_mul(step, elapsed_ms) + zero_grad_all() + train_loss = torch.zeros((), device=device) + for micro_step in range(grad_accum_steps): + if distributed: + model.require_backward_grad_sync = micro_step == grad_accum_steps - 1 + x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + loss = model(x, y) + train_loss += loss.detach() + (loss * grad_scale).backward() + train_loss /= grad_accum_steps + + frac = min(step / args.muon_momentum_warmup_steps, 1.0) if args.muon_momentum_warmup_steps > 0 else 1.0 + muon_momentum = (1 - frac) * args.muon_momentum_warmup_start + frac * args.muon_momentum + for group in optimizer_muon.param_groups: + group["momentum"] = muon_momentum + + for opt in optimizers: + for group in opt.param_groups: + group["lr"] = group["base_lr"] * scale + + if args.grad_clip_norm > 0: + torch.nn.utils.clip_grad_norm_(base_model.parameters(), args.grad_clip_norm) + for opt in optimizers: + opt.step() + zero_grad_all() + + step += 1 + approx_training_time_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) + should_log_train = ( + args.train_log_every > 0 + and (step <= 10 or step % args.train_log_every == 0 or stop_after_step is not None) + ) + if should_log_train: + log0( + f"step:{step}/{args.iterations} train_loss:{train_loss.item():.4f} " + f"train_time:{approx_training_time_ms:.0f}ms step_avg:{approx_training_time_ms / step:.2f}ms" + ) + + # Needed to sync whether we've reached the wallclock cap. + reached_cap = max_wallclock_ms is not None and approx_training_time_ms >= max_wallclock_ms + if distributed and max_wallclock_ms is not None: + reached_cap_tensor = torch.tensor(int(reached_cap), device=device) + dist.all_reduce(reached_cap_tensor, op=dist.ReduceOp.MAX) + reached_cap = bool(reached_cap_tensor.item()) + if stop_after_step is None and reached_cap: + stop_after_step = step + + log0( + f"peak memory allocated: {torch.cuda.max_memory_allocated() // 1024 // 1024} MiB " + f"reserved: {torch.cuda.max_memory_reserved() // 1024 // 1024} MiB" + ) + + # ----------------------------- + # SERIALIZATION + ROUNDTRIP VALIDATION + # ----------------------------- + # Save the raw state (useful for debugging/loading in PyTorch directly), then always produce + # the compressed int8+zlib artifact and validate the round-tripped weights. + + if master_process: + torch.save(base_model.state_dict(), "final_model.pt") + model_bytes = os.path.getsize("final_model.pt") + code_bytes = len(code.encode("utf-8")) + log0(f"Serialized model: {model_bytes} bytes") + log0(f"Code size: {code_bytes} bytes") + log0(f"Total submission size: {model_bytes + code_bytes} bytes") + + quant_obj, quant_stats = quantize_state_dict_int8(base_model.state_dict()) + quant_buf = io.BytesIO() + torch.save(quant_obj, quant_buf) + quant_raw = quant_buf.getvalue() + quant_blob = zlib.compress(quant_raw, level=9) + quant_raw_bytes = len(quant_raw) + if master_process: + with open("final_model.int8.ptz", "wb") as f: + f.write(quant_blob) + quant_file_bytes = os.path.getsize("final_model.int8.ptz") + code_bytes = len(code.encode("utf-8")) + ratio = quant_stats["baseline_tensor_bytes"] / max(quant_stats["int8_payload_bytes"], 1) + log0( + f"Serialized model int8+zlib: {quant_file_bytes} bytes " + f"(payload:{quant_stats['int8_payload_bytes']} raw_torch:{quant_raw_bytes} payload_ratio:{ratio:.2f}x)" + ) + log0(f"Total submission size int8+zlib: {quant_file_bytes + code_bytes} bytes") + + if distributed: + dist.barrier() + with open("final_model.int8.ptz", "rb") as f: + quant_blob_disk = f.read() + quant_state = torch.load(io.BytesIO(zlib.decompress(quant_blob_disk)), map_location="cpu") + base_model.load_state_dict(dequantize_state_dict_int8(quant_state), strict=True) + torch.cuda.synchronize() + t_qeval = time.perf_counter() + q_val_loss, q_val_bpb = eval_val( + args, + model, + rank, + world_size, + device, + grad_accum_steps, + val_tokens, + base_bytes_lut, + has_leading_space_lut, + is_boundary_token_lut, + ) + torch.cuda.synchronize() + log0( + f"final_int8_zlib_roundtrip val_loss:{q_val_loss:.4f} val_bpb:{q_val_bpb:.4f} " + f"eval_time:{1000.0 * (time.perf_counter() - t_qeval):.0f}ms" + ) + log0(f"final_int8_zlib_roundtrip_exact val_loss:{q_val_loss:.8f} val_bpb:{q_val_bpb:.8f}") + + # Score-First TTT on quantized model (if enabled) + if args.ttt_mode == "score_first": + torch.cuda.synchronize() + t_ttt = time.perf_counter() + # Save quantized weights, TTT will modify them during adaptation + saved_state = copy.deepcopy(base_model.state_dict()) + ttt_loss, ttt_bpb = eval_val_score_first_ttt( + args, base_model, rank, world_size, device, + val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, + log0=log0, + ) + # Restore quantized weights (TTT modifies them in-place) + base_model.load_state_dict(saved_state) + torch.cuda.synchronize() + log0(f"final_int8_ttt val_loss:{ttt_loss:.4f} val_bpb:{ttt_bpb:.4f} " + f"eval_time:{1000.0 * (time.perf_counter() - t_ttt):.0f}ms") + log0(f"final_int8_ttt_exact val_loss:{ttt_loss:.8f} val_bpb:{ttt_bpb:.8f}") + + if distributed: + dist.destroy_process_group() + + +if __name__ == "__main__": + main() diff --git a/records/track_10min_16mb/2026-04-24_SP8192_HeadwiseGate_LeakyReLU2_LegalTTT/train_seed1337.log b/records/track_10min_16mb/2026-04-24_SP8192_HeadwiseGate_LeakyReLU2_LegalTTT/train_seed1337.log new file mode 100644 index 0000000000..99a711ae79 --- /dev/null +++ b/records/track_10min_16mb/2026-04-24_SP8192_HeadwiseGate_LeakyReLU2_LegalTTT/train_seed1337.log @@ -0,0 +1,1929 @@ +""" +The `train_gpt.py` and `train_gpt_mlx.py` scripts are intended as good launching-off points for new participants, not SOTA configs. We'll accept PRs that tune, improve, or simplify these scripts without significantly increasing complexity, but competitive submissions should stay in the `/records` folder. + +Hard stop: To keep readable for newcomers, let's make sure `train_gpt.py` and `train_gpt_mlx.py` never are longer than 1500 lines. +""" + +from __future__ import annotations + +import copy +import glob +import io +import math +import os +import random +import subprocess +import sys +import time +import uuid +import zlib +from pathlib import Path + +import numpy as np +import sentencepiece as spm +import torch +import torch.distributed as dist +import torch.nn.functional as F +from torch import Tensor, nn +from torch.nn.parallel import DistributedDataParallel as DDP + +# ----------------------------- +# HYPERPARAMETERS +# ----------------------------- +# Default Simple Baseline run: +# - 9 transformer blocks at width 512 +# - 8 attention heads with 4 KV heads (GQA) and 2x MLP expansion +# - vocab size 1024, sequence length 1024, tied embeddings +# - 524,288 train tokens per step for 20,000 iterations with a ~10 minute cap + +class Hyperparameters: + # Data paths are shard globs produced by the existing preprocessing pipeline. + data_path = os.environ.get("DATA_PATH", "./data/datasets/fineweb10B_sp1024") + train_files = os.path.join(data_path, "fineweb_train_*.bin") + val_files = os.path.join(data_path, "fineweb_val_*.bin") + tokenizer_path = os.environ.get("TOKENIZER_PATH", "./data/tokenizers/fineweb_1024_bpe.model") + run_id = os.environ.get("RUN_ID", str(uuid.uuid4())) + seed = int(os.environ.get("SEED", 1337)) + + # Validation cadence and batch size. Validation always uses the full fineweb_val split. + val_batch_size = int(os.environ.get("VAL_BATCH_SIZE", 524_288)) + val_loss_every = int(os.environ.get("VAL_LOSS_EVERY", 1000)) + train_log_every = int(os.environ.get("TRAIN_LOG_EVERY", 200)) + + # Training length. + iterations = int(os.environ.get("ITERATIONS", 20000)) + warmdown_iters = int(os.environ.get("WARMDOWN_ITERS", 1200)) + warmup_steps = int(os.environ.get("WARMUP_STEPS", 20)) + train_batch_tokens = int(os.environ.get("TRAIN_BATCH_TOKENS", 524_288)) + train_seq_len = int(os.environ.get("TRAIN_SEQ_LEN", 1024)) + max_wallclock_seconds = float(os.environ.get("MAX_WALLCLOCK_SECONDS", 600.0)) + qk_gain_init = float(os.environ.get("QK_GAIN_INIT", 1.5)) + + # Model shape. + vocab_size = int(os.environ.get("VOCAB_SIZE", 1024)) + num_layers = int(os.environ.get("NUM_LAYERS", 9)) + num_kv_heads = int(os.environ.get("NUM_KV_HEADS", 4)) + model_dim = int(os.environ.get("MODEL_DIM", 512)) + num_heads = int(os.environ.get("NUM_HEADS", 8)) + mlp_mult = int(os.environ.get("MLP_MULT", 2)) + tie_embeddings = bool(int(os.environ.get("TIE_EMBEDDINGS", "1"))) + rope_base = float(os.environ.get("ROPE_BASE", 10000.0)) + logit_softcap = float(os.environ.get("LOGIT_SOFTCAP", 30.0)) + # Gated Attention (https://arxiv.org/abs/2505.06708, NeurIPS 2025 Best Paper) + # "none" = disabled (baseline), "headwise" = 1 gate per head, "elementwise" = 1 gate per dim + gated_attn = os.environ.get("GATED_ATTN", "none") + + # Optimizer hyperparameters. + embed_lr = float(os.environ.get("EMBED_LR", 0.6)) + head_lr = float(os.environ.get("HEAD_LR", 0.008)) + tied_embed_lr = float(os.environ.get("TIED_EMBED_LR", 0.05)) + tied_embed_init_std = float(os.environ.get("TIED_EMBED_INIT_STD", 0.005)) + matrix_lr = float(os.environ.get("MATRIX_LR", 0.04)) + scalar_lr = float(os.environ.get("SCALAR_LR", 0.04)) + muon_momentum = float(os.environ.get("MUON_MOMENTUM", 0.95)) + muon_backend_steps = int(os.environ.get("MUON_BACKEND_STEPS", 5)) + muon_momentum_warmup_start = float(os.environ.get("MUON_MOMENTUM_WARMUP_START", 0.85)) + muon_momentum_warmup_steps = int(os.environ.get("MUON_MOMENTUM_WARMUP_STEPS", 500)) + beta1 = float(os.environ.get("BETA1", 0.9)) + beta2 = float(os.environ.get("BETA2", 0.95)) + adam_eps = float(os.environ.get("ADAM_EPS", 1e-8)) + grad_clip_norm = float(os.environ.get("GRAD_CLIP_NORM", 0.0)) + + # Activation: "relu2" (default), "leaky_relu2" = LeakyReLU(0.5)² (PG ranks 10-11) + activation = os.environ.get("ACTIVATION", "relu2") + + # Test-Time Training (Score-First TTT, PG ranks 1-3) + # "none" = disabled (baseline), "score_first" = legal score-before-update TTT + ttt_mode = os.environ.get("TTT_MODE", "none") + ttt_lr = float(os.environ.get("TTT_LR", "0.005")) + ttt_momentum = float(os.environ.get("TTT_MOMENTUM", "0.9")) + ttt_epochs = int(os.environ.get("TTT_EPOCHS", "3")) + ttt_chunk_tokens = int(os.environ.get("TTT_CHUNK_TOKENS", "32768")) + ttt_freeze_blocks = int(os.environ.get("TTT_FREEZE_BLOCKS", "0")) + ttt_grad_clip = float(os.environ.get("TTT_GRAD_CLIP", "1.0")) + ttt_batch_seqs = int(os.environ.get("TTT_BATCH_SEQS", "32")) + +# ----------------------------- +# MUON OPTIMIZER +# ----------------------------- +# +# As borrowed from modded-nanogpt +# Background on Muon: https://kellerjordan.github.io/posts/muon/ + +def zeropower_via_newtonschulz5(G: Tensor, steps: int = 10, eps: float = 1e-7) -> Tensor: + # Orthogonalize a 2D update matrix with a fast Newton-Schulz iteration. + # Muon uses this to normalize matrix-shaped gradients before applying them. + a, b, c = (3.4445, -4.7750, 2.0315) + X = G.bfloat16() + X /= X.norm() + eps + transposed = G.size(0) > G.size(1) + if transposed: + X = X.T + for _ in range(steps): + A = X @ X.T + B = b * A + c * A @ A + X = a * X + B @ X + return X.T if transposed else X + + +class Muon(torch.optim.Optimizer): + def __init__(self, params, lr: float, momentum: float, backend_steps: int, nesterov: bool = True): + super().__init__( + params, + dict(lr=lr, momentum=momentum, backend_steps=backend_steps, nesterov=nesterov), + ) + + @torch.no_grad() + def step(self, closure=None): + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + + distributed = dist.is_available() and dist.is_initialized() + world_size = dist.get_world_size() if distributed else 1 + rank = dist.get_rank() if distributed else 0 + + for group in self.param_groups: + params = group["params"] + if not params: + continue + lr = group["lr"] + momentum = group["momentum"] + backend_steps = group["backend_steps"] + nesterov = group["nesterov"] + + total_params = sum(int(p.numel()) for p in params) + updates_flat = torch.zeros(total_params, device=params[0].device, dtype=torch.bfloat16) + + curr = 0 + for i, p in enumerate(params): + if i % world_size == rank and p.grad is not None: + g = p.grad + state = self.state[p] + if "momentum_buffer" not in state: + state["momentum_buffer"] = torch.zeros_like(g) + buf = state["momentum_buffer"] + buf.mul_(momentum).add_(g) + if nesterov: + g = g.add(buf, alpha=momentum) + g = zeropower_via_newtonschulz5(g, steps=backend_steps) + # Scale correction from Muon reference implementations. + g *= max(1, g.size(0) / g.size(1)) ** 0.5 + updates_flat[curr : curr + p.numel()] = g.reshape(-1) + curr += p.numel() + + if distributed: + dist.all_reduce(updates_flat, op=dist.ReduceOp.SUM) + + curr = 0 + for p in params: + g = updates_flat[curr : curr + p.numel()].view_as(p).to(dtype=p.dtype) + p.add_(g, alpha=-lr) + curr += p.numel() + + return loss + + +# ----------------------------- +# TOKENIZER-AGNOSTIC EVALUATION SETUP +# ----------------------------- +# +# It's common for small models have a large fraction of their parameters be embeddings, since the 2 * d_model * d_vocab vectors can be gigantic. +# Instead of locking the tokenizer, we let you bring your own and calculate our validation metrics on the average compression of the validation set. +# We calculate BPB (bits-per-byte) instead of validation loss, so we need methods to count the number of bits per token in the tokenizer. +# Note: Submissions that edit the tokenizer will be examined more carefully, since screwing this up might unjustly improve your score. + +def build_sentencepiece_luts( + sp: spm.SentencePieceProcessor, vocab_size: int, device: torch.device +) -> tuple[Tensor, Tensor, Tensor]: + sp_vocab_size = int(sp.vocab_size()) + table_size = max(sp_vocab_size, vocab_size) + base_bytes_np = np.zeros((table_size,), dtype=np.int16) + has_leading_space_np = np.zeros((table_size,), dtype=np.bool_) + is_boundary_token_np = np.ones((table_size,), dtype=np.bool_) + for token_id in range(sp_vocab_size): + if sp.is_control(token_id) or sp.is_unknown(token_id) or sp.is_unused(token_id): + continue + is_boundary_token_np[token_id] = False + if sp.is_byte(token_id): + base_bytes_np[token_id] = 1 + continue + piece = sp.id_to_piece(token_id) + if piece.startswith("▁"): + has_leading_space_np[token_id] = True + piece = piece[1:] + base_bytes_np[token_id] = len(piece.encode("utf-8")) + return ( + torch.tensor(base_bytes_np, dtype=torch.int16, device=device), + torch.tensor(has_leading_space_np, dtype=torch.bool, device=device), + torch.tensor(is_boundary_token_np, dtype=torch.bool, device=device), + ) + + +def load_validation_tokens(pattern: str, seq_len: int) -> Tensor: + files = [Path(p) for p in sorted(glob.glob(pattern))] + if not files: + raise FileNotFoundError(f"No files found for pattern: {pattern}") + # The export pipeline writes the fixed first-50k-doc validation set to fineweb_val_*. + tokens = torch.cat([load_data_shard(file) for file in files]).contiguous() + usable = ((tokens.numel() - 1) // seq_len) * seq_len + if usable <= 0: + raise ValueError(f"Validation split is too short for TRAIN_SEQ_LEN={seq_len}") + return tokens[: usable + 1] + + +def eval_val( + args: Hyperparameters, + model: nn.Module, + rank: int, + world_size: int, + device: torch.device, + grad_accum_steps: int, + val_tokens: Tensor, + base_bytes_lut: Tensor, + has_leading_space_lut: Tensor, + is_boundary_token_lut: Tensor, +) -> tuple[float, float]: + # Validation computes two metrics: + # - val_loss: token cross-entropy (natural log) + # - val_bpb: tokenizer-agnostic compression metric used by the challenge + local_batch_tokens = args.val_batch_size // (world_size * grad_accum_steps) + if local_batch_tokens < args.train_seq_len: + raise ValueError( + "VAL_BATCH_SIZE must provide at least one sequence per rank; " + f"got VAL_BATCH_SIZE={args.val_batch_size}, WORLD_SIZE={world_size}, " + f"GRAD_ACCUM_STEPS={grad_accum_steps}, TRAIN_SEQ_LEN={args.train_seq_len}" + ) + local_batch_seqs = local_batch_tokens // args.train_seq_len + total_seqs = (val_tokens.numel() - 1) // args.train_seq_len + seq_start = (total_seqs * rank) // world_size + seq_end = (total_seqs * (rank + 1)) // world_size + val_loss_sum = torch.zeros((), device=device, dtype=torch.float64) + val_token_count = torch.zeros((), device=device, dtype=torch.float64) + val_byte_count = torch.zeros((), device=device, dtype=torch.float64) + + model.eval() + with torch.inference_mode(): + for batch_seq_start in range(seq_start, seq_end, local_batch_seqs): + batch_seq_end = min(batch_seq_start + local_batch_seqs, seq_end) + raw_start = batch_seq_start * args.train_seq_len + raw_end = batch_seq_end * args.train_seq_len + 1 + local = val_tokens[raw_start:raw_end].to(device=device, dtype=torch.int64, non_blocking=True) + x = local[:-1].reshape(-1, args.train_seq_len) + y = local[1:].reshape(-1, args.train_seq_len) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + batch_loss = model(x, y).detach() + batch_token_count = float(y.numel()) + val_loss_sum += batch_loss.to(torch.float64) * batch_token_count + val_token_count += batch_token_count + prev_ids = x.reshape(-1) + tgt_ids = y.reshape(-1) + token_bytes = base_bytes_lut[tgt_ids].to(dtype=torch.int16) + token_bytes += (has_leading_space_lut[tgt_ids] & ~is_boundary_token_lut[prev_ids]).to(dtype=torch.int16) + val_byte_count += token_bytes.to(torch.float64).sum() + + if dist.is_available() and dist.is_initialized(): + dist.all_reduce(val_loss_sum, op=dist.ReduceOp.SUM) + dist.all_reduce(val_token_count, op=dist.ReduceOp.SUM) + dist.all_reduce(val_byte_count, op=dist.ReduceOp.SUM) + + val_loss = val_loss_sum / val_token_count + bits_per_token = val_loss.item() / math.log(2.0) + tokens_per_byte = val_token_count.item() / val_byte_count.item() + model.train() + return float(val_loss.item()), float(bits_per_token * tokens_per_byte) + + +# ----------------------------- +# SCORE-FIRST TTT (Legal Test-Time Training) +# ----------------------------- +# Protocol (PR #461 recipe, PG ranks 1-3): +# 1. Chunk validation data into ttt_chunk_tokens windows +# 2. For each chunk: SCORE all tokens under inference_mode(), then TRAIN on scored data +# 3. Last chunk: score only (never train on un-scored data) +# Guarantees: every token scored BEFORE any parameter update that could use it. + +def eval_val_score_first_ttt( + args: Hyperparameters, + base_model: nn.Module, + rank: int, + world_size: int, + device: torch.device, + val_tokens: Tensor, + base_bytes_lut: Tensor, + has_leading_space_lut: Tensor, + is_boundary_token_lut: Tensor, + log0=print, +) -> tuple[float, float]: + seq_len = args.train_seq_len + total_tokens = val_tokens.numel() - 1 + stride = seq_len # non-overlapping windows for simplicity + ttt_chunk = args.ttt_chunk_tokens + + # Pre-compute window starts + window_starts = [ws for ws in range(0, total_tokens, stride) + if min(ws + seq_len, total_tokens) - ws >= 1] + + # Assign windows to chunks + num_chunks = (total_tokens + ttt_chunk - 1) // ttt_chunk + chunk_windows: list[list[int]] = [[] for _ in range(num_chunks)] + for ws in window_starts: + end = min(ws + seq_len, total_tokens) + ci = min(ws // ttt_chunk, num_chunks - 1) + chunk_windows[ci].append(ws) + + log0(f"ttt:start chunks={num_chunks} chunk_tokens={ttt_chunk} " + f"windows={len(window_starts)} ttt_lr={args.ttt_lr} " + f"ttt_epochs={args.ttt_epochs} freeze_blocks={args.ttt_freeze_blocks}") + + loss_sum = torch.zeros((), device=device, dtype=torch.float64) + token_count = torch.zeros((), device=device, dtype=torch.float64) + byte_count = torch.zeros((), device=device, dtype=torch.float64) + + # Freeze first N blocks + frozen_block_ids = set(range(min(args.ttt_freeze_blocks, len(base_model.blocks)))) + ttt_params = [] + for name, p in base_model.named_parameters(): + freeze = any(f"blocks.{bi}." in name for bi in frozen_block_ids) + if freeze: + p.requires_grad_(False) + else: + p.requires_grad_(True) + ttt_params.append(p) + + log0(f"ttt:params unfrozen={sum(p.numel() for p in ttt_params)} " + f"frozen={sum(p.numel() for p in base_model.parameters() if not p.requires_grad)}") + + optimizer = torch.optim.SGD(ttt_params, lr=args.ttt_lr, momentum=args.ttt_momentum) + t0 = time.perf_counter() + + for ci in range(num_chunks): + windows = chunk_windows[ci] + if not windows: + continue + + # --- PHASE 1: SCORE (no gradients) --- + my_s = (len(windows) * rank) // world_size + my_e = (len(windows) * (rank + 1)) // world_size + my_windows = windows[my_s:my_e] + + base_model.eval() + with torch.inference_mode(): + for bi in range(0, len(my_windows), args.ttt_batch_seqs): + batch_ws = my_windows[bi:bi + args.ttt_batch_seqs] + bsz = len(batch_ws) + x_batch = torch.zeros(bsz, seq_len, dtype=torch.int64, device=device) + y_batch = torch.zeros(bsz, seq_len, dtype=torch.int64, device=device) + wlens: list[int] = [] + + for i, ws in enumerate(batch_ws): + end = min(ws + seq_len, total_tokens) + wlen = end - ws + wlens.append(wlen) + chunk_tok = val_tokens[ws:end + 1].to(dtype=torch.int64, device=device) + x_batch[i, :wlen] = chunk_tok[:-1] + y_batch[i, :wlen] = chunk_tok[1:] + + with torch.autocast(device_type="cuda", dtype=torch.bfloat16): + logits = base_model.forward_logits(x_batch) + + nll = F.cross_entropy( + logits.reshape(-1, logits.size(-1)).float(), + y_batch.reshape(-1), reduction="none", + ).reshape(bsz, seq_len) + + for i in range(bsz): + wlen = wlens[i] + scored_nll = nll[i, :wlen].to(torch.float64) + loss_sum += scored_nll.sum() + token_count += float(wlen) + tgt = y_batch[i, :wlen] + prev = x_batch[i, :wlen] + tb = base_bytes_lut[tgt].to(torch.float64) + tb += (has_leading_space_lut[tgt] & ~is_boundary_token_lut[prev]).to(torch.float64) + byte_count += tb.sum() + + # --- PHASE 2: TRAIN on scored chunk (skip last chunk) --- + is_last_chunk = (ci == num_chunks - 1) + if not is_last_chunk and args.ttt_epochs > 0: + base_model.train() + chunk_start = ci * ttt_chunk + chunk_end = min((ci + 1) * ttt_chunk, total_tokens) + chunk_seqs = (chunk_end - chunk_start) // seq_len + if chunk_seqs > 0: + # Cosine LR decay across chunks + cos_lr = args.ttt_lr * 0.5 * (1.0 + math.cos(math.pi * ci / max(num_chunks - 1, 1))) + for pg in optimizer.param_groups: + pg['lr'] = cos_lr + + my_seq_s = (chunk_seqs * rank) // world_size + my_seq_e = (chunk_seqs * (rank + 1)) // world_size + + for _ep in range(args.ttt_epochs): + for bs in range(my_seq_s, my_seq_e, args.ttt_batch_seqs): + be = min(bs + args.ttt_batch_seqs, my_seq_e) + start_tok = chunk_start + bs * seq_len + end_tok = chunk_start + be * seq_len + 1 + if end_tok > val_tokens.numel(): + continue + local = val_tokens[start_tok:end_tok].to(device=device, dtype=torch.int64) + x = local[:-1].reshape(-1, seq_len) + y = local[1:].reshape(-1, seq_len) + optimizer.zero_grad(set_to_none=True) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16): + loss = base_model(x, y) + loss.backward() + if world_size > 1: + for p in ttt_params: + if p.grad is not None: + dist.all_reduce(p.grad, op=dist.ReduceOp.AVG) + torch.nn.utils.clip_grad_norm_(ttt_params, args.ttt_grad_clip) + optimizer.step() + + if rank == 0 and (ci % 10 == 0 or ci == num_chunks - 1): + elapsed = time.perf_counter() - t0 + rl = loss_sum.item() / max(token_count.item(), 1) + rbpb = rl / math.log(2.0) * (token_count.item() / max(byte_count.item(), 1)) if token_count.item() > 0 else 0.0 + log0(f" ttt_chunk [{ci+1}/{num_chunks}] bpb={rbpb:.6f} time={elapsed:.1f}s") + + # AllReduce across ranks + if dist.is_available() and dist.is_initialized(): + dist.all_reduce(loss_sum, op=dist.ReduceOp.SUM) + dist.all_reduce(token_count, op=dist.ReduceOp.SUM) + dist.all_reduce(byte_count, op=dist.ReduceOp.SUM) + + val_loss = (loss_sum / token_count).item() + val_bpb = val_loss / math.log(2.0) * (token_count.item() / byte_count.item()) + + # Restore grad for all params + for p in base_model.parameters(): + p.requires_grad_(True) + base_model.eval() + + log0(f"ttt:done val_loss={val_loss:.6f} val_bpb={val_bpb:.6f} " + f"elapsed={time.perf_counter() - t0:.1f}s") + return val_loss, val_bpb + + +# ----------------------------- +# POST-TRAINING QUANTIZATION +# ----------------------------- +# +# It's silly to export our model, which is trained in bf16 and fp32, at that same precision. +# Instead, we get approximately the same model (with a small hit) by quantizing the model to int8 & zlib compressing. +# We can then decompress the model and run in higher precision for evaluation, after closing in under the size limit. + +CONTROL_TENSOR_NAME_PATTERNS = tuple( + pattern + for pattern in os.environ.get( + "CONTROL_TENSOR_NAME_PATTERNS", + "attn_scale,attn_scales,mlp_scale,mlp_scales,resid_mix,resid_mixes,q_gain,skip_weight,skip_weights", + ).split(",") + if pattern +) +INT8_KEEP_FLOAT_FP32_NAME_PATTERNS = tuple( + pattern + for pattern in os.environ.get( + "INT8_KEEP_FLOAT_FP32_NAME_PATTERNS", + ",".join(CONTROL_TENSOR_NAME_PATTERNS), + ).split(",") + if pattern +) +INT8_KEEP_FLOAT_MAX_NUMEL = 65_536 +INT8_KEEP_FLOAT_STORE_DTYPE = torch.float16 +INT8_PER_ROW_SCALE_DTYPE = torch.float16 +INT8_CLIP_PERCENTILE = 99.99984 +INT8_CLIP_Q = INT8_CLIP_PERCENTILE / 100.0 + +def tensor_nbytes(t: Tensor) -> int: + return int(t.numel()) * int(t.element_size()) + +def keep_float_tensor(name: str, t: Tensor, passthrough_orig_dtypes: dict[str, str]) -> Tensor: + if any(pattern in name for pattern in INT8_KEEP_FLOAT_FP32_NAME_PATTERNS): + return t.float().contiguous() + if t.dtype in {torch.float32, torch.bfloat16}: + passthrough_orig_dtypes[name] = str(t.dtype).removeprefix("torch.") + return t.to(dtype=INT8_KEEP_FLOAT_STORE_DTYPE).contiguous() + return t + +def quantize_float_tensor(t: Tensor) -> tuple[Tensor, Tensor]: + t32 = t.float() + if t32.ndim == 2: + # Matrices get one scale per row, which usually tracks output-channel + # ranges much better than a single tensor-wide scale. + clip_abs = ( + torch.quantile(t32.abs(), INT8_CLIP_Q, dim=1) + if t32.numel() + else torch.empty((t32.shape[0],), dtype=torch.float32) + ) + clipped = torch.maximum(torch.minimum(t32, clip_abs[:, None]), -clip_abs[:, None]) + scale = (clip_abs / 127.0).clamp_min(1.0 / 127.0) + q = torch.clamp(torch.round(clipped / scale[:, None]), -127, 127).to(torch.int8).contiguous() + return q, scale.to(dtype=INT8_PER_ROW_SCALE_DTYPE).contiguous() + + # Vectors / scalars use a simpler per-tensor scale. + clip_abs = float(torch.quantile(t32.abs().flatten(), INT8_CLIP_Q).item()) if t32.numel() else 0.0 + scale = torch.tensor(clip_abs / 127.0 if clip_abs > 0 else 1.0, dtype=torch.float32) + q = torch.clamp(torch.round(torch.clamp(t32, -clip_abs, clip_abs) / scale), -127, 127).to(torch.int8).contiguous() + return q, scale + +def quantize_state_dict_int8(state_dict: dict[str, Tensor]): + # Single supported clean-script export format: + # - per-row int8 for 2D float tensors + # - per-tensor int8 for other float tensors + # - exact passthrough for non-floats + # - passthrough for small float tensors, stored as fp16 to save bytes + quantized: dict[str, Tensor] = {} + scales: dict[str, Tensor] = {} + dtypes: dict[str, str] = {} + passthrough: dict[str, Tensor] = {} + passthrough_orig_dtypes: dict[str, str] = {} + qmeta: dict[str, dict[str, object]] = {} + stats = dict.fromkeys( + ("param_count", "num_tensors", "num_float_tensors", "num_nonfloat_tensors", "baseline_tensor_bytes", "int8_payload_bytes"), + 0, + ) + + for name, tensor in state_dict.items(): + t = tensor.detach().to("cpu").contiguous() + stats["param_count"] += int(t.numel()) + stats["num_tensors"] += 1 + stats["baseline_tensor_bytes"] += tensor_nbytes(t) + + if not t.is_floating_point(): + stats["num_nonfloat_tensors"] += 1 + passthrough[name] = t + stats["int8_payload_bytes"] += tensor_nbytes(t) + continue + + # Small float tensors are cheap enough to keep directly. We still downcast + # fp32/bf16 passthrough tensors to fp16 so metadata does not dominate size. + if t.numel() <= INT8_KEEP_FLOAT_MAX_NUMEL: + kept = keep_float_tensor(name, t, passthrough_orig_dtypes) + passthrough[name] = kept + stats["int8_payload_bytes"] += tensor_nbytes(kept) + continue + + stats["num_float_tensors"] += 1 + q, s = quantize_float_tensor(t) + if s.ndim > 0: + qmeta[name] = {"scheme": "per_row", "axis": 0} + quantized[name] = q + scales[name] = s + dtypes[name] = str(t.dtype).removeprefix("torch.") + stats["int8_payload_bytes"] += tensor_nbytes(q) + tensor_nbytes(s) + + obj: dict[str, object] = { + "__quant_format__": "int8_clean_per_row_v1", + "quantized": quantized, + "scales": scales, + "dtypes": dtypes, + "passthrough": passthrough, + } + if qmeta: + obj["qmeta"] = qmeta + if passthrough_orig_dtypes: + obj["passthrough_orig_dtypes"] = passthrough_orig_dtypes + return obj, stats + +def dequantize_state_dict_int8(obj: dict[str, object]) -> dict[str, Tensor]: + out: dict[str, Tensor] = {} + qmeta = obj.get("qmeta", {}) + passthrough_orig_dtypes = obj.get("passthrough_orig_dtypes", {}) + for name, q in obj["quantized"].items(): + dtype = getattr(torch, obj["dtypes"][name]) + s = obj["scales"][name] + if qmeta.get(name, {}).get("scheme") == "per_row" or s.ndim > 0: + s = s.to(dtype=torch.float32) + # Broadcast the saved row scale back across trailing dimensions. + out[name] = (q.float() * s.view(q.shape[0], *([1] * (q.ndim - 1)))).to(dtype=dtype).contiguous() + else: + scale = float(s.item()) + out[name] = (q.float() * scale).to(dtype=dtype).contiguous() + for name, t in obj["passthrough"].items(): + # Restore small tensors, undoing the temporary fp16 storage cast if needed. + out_t = t.detach().to("cpu").contiguous() + orig_dtype = passthrough_orig_dtypes.get(name) + if isinstance(orig_dtype, str): + out_t = out_t.to(dtype=getattr(torch, orig_dtype)).contiguous() + out[name] = out_t + return out + + +# ----------------------------- +# DATA LOADING +# ----------------------------- + +def load_data_shard(file: Path) -> Tensor: + header_bytes = 256 * np.dtype(" None: + self.file_idx = (self.file_idx + 1) % len(self.files) + self.tokens = load_data_shard(self.files[self.file_idx]) + self.pos = 0 + + def take(self, n: int) -> Tensor: + chunks: list[Tensor] = [] + remaining = n + while remaining > 0: + avail = self.tokens.numel() - self.pos + if avail <= 0: + self._advance_file() + continue + k = min(remaining, avail) + chunks.append(self.tokens[self.pos : self.pos + k]) + self.pos += k + remaining -= k + return chunks[0] if len(chunks) == 1 else torch.cat(chunks) + + +class DistributedTokenLoader: + # Each call consumes a contiguous chunk from the shared token stream, then slices out + # one disjoint span per rank. The extra "+1" token lets us build (x, y) by shifting. + def __init__(self, pattern: str, rank: int, world_size: int, device: torch.device): + self.rank = rank + self.world_size = world_size + self.device = device + self.stream = TokenStream(pattern) + + def next_batch(self, global_tokens: int, seq_len: int, grad_accum_steps: int) -> tuple[Tensor, Tensor]: + local_tokens = global_tokens // (self.world_size * grad_accum_steps) + per_rank_span = local_tokens + 1 + chunk = self.stream.take(per_rank_span * self.world_size) + start = self.rank * per_rank_span + local = chunk[start : start + per_rank_span].to(dtype=torch.int64) + x = local[:-1].reshape(-1, seq_len) + y = local[1:].reshape(-1, seq_len) + return x.to(self.device, non_blocking=True), y.to(self.device, non_blocking=True) + +# ----------------------------- +# TRANSFORMER MODULES +# ----------------------------- + +class RMSNorm(nn.Module): + def __init__(self, eps: float | None = None): + super().__init__() + self.eps = eps + + def forward(self, x: Tensor) -> Tensor: + return F.rms_norm(x, (x.size(-1),), eps=self.eps) + + +class CastedLinear(nn.Linear): + # Keep weights in fp32 for optimizer/state quality, cast at matmul time for bf16 compute. + def forward(self, x: Tensor) -> Tensor: + bias = self.bias.to(x.dtype) if self.bias is not None else None + return F.linear(x, self.weight.to(x.dtype), bias) + + +def restore_low_dim_params_to_fp32(module: nn.Module) -> None: + # Keep small/control parameters in fp32 even when the model body runs in bf16. + with torch.no_grad(): + for name, param in module.named_parameters(): + if (param.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS)) and param.dtype != torch.float32: + param.data = param.data.float() + + +class Rotary(nn.Module): + # Caches cos/sin tables per sequence length on the current device. + def __init__(self, dim: int, base: float = 10000.0): + super().__init__() + inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.float32) / dim)) + self.register_buffer("inv_freq", inv_freq, persistent=False) + self._seq_len_cached = 0 + self._cos_cached: Tensor | None = None + self._sin_cached: Tensor | None = None + + def forward(self, seq_len: int, device: torch.device, dtype: torch.dtype) -> tuple[Tensor, Tensor]: + if ( + self._cos_cached is None + or self._sin_cached is None + or self._seq_len_cached != seq_len + or self._cos_cached.device != device + ): + t = torch.arange(seq_len, device=device, dtype=self.inv_freq.dtype) + freqs = torch.outer(t, self.inv_freq.to(device)) + self._cos_cached = freqs.cos()[None, None, :, :] + self._sin_cached = freqs.sin()[None, None, :, :] + self._seq_len_cached = seq_len + return self._cos_cached.to(dtype=dtype), self._sin_cached.to(dtype=dtype) + + +def apply_rotary_emb(x: Tensor, cos: Tensor, sin: Tensor) -> Tensor: + half = x.size(-1) // 2 + x1, x2 = x[..., :half], x[..., half:] + return torch.cat((x1 * cos + x2 * sin, x1 * (-sin) + x2 * cos), dim=-1) + + +# ╔═══════════════════════════════════════════════════════════════════════════════════╗ +# ║ CAUSAL SELF-ATTENTION (with GQA) ║ +# ║ ║ +# ║ Plain English: Each token asks "which earlier tokens should I pay attention to?" ║ +# ║ It computes a compatibility score (Q·K) with every previous token, then takes ║ +# ║ a weighted average of their values (V). "Causal" = can only look backward. ║ +# ║ ║ +# ║ ┌──────────────── Input x: [batch, seq_len, 512] ────────────────┐ ║ +# ║ │ │ │ ║ +# ║ │ ┌──────────────────┼──────────────────┐ │ ║ +# ║ │ │ │ │ │ ║ +# ║ │ ▼ ▼ ▼ │ ║ +# ║ │ c_q (512→512) c_k (512→256) c_v (512→256) │ ║ +# ║ │ │ │ │ │ ║ +# ║ │ ▼ ▼ ▼ │ ║ +# ║ │ Q: 8 heads × 64d K: 4 heads × 64d V: 4 heads × 64d │ ║ +# ║ │ │ │ │ ║ +# ║ │ ▼ ▼ │ ║ +# ║ │ RMSNorm(Q) RMSNorm(K) ← stabilize magnitudes │ ║ +# ║ │ │ │ │ ║ +# ║ │ ▼ ▼ │ ║ +# ║ │ RoPE(Q) RoPE(K) ← encode positions │ ║ +# ║ │ │ │ ║ +# ║ │ ▼ │ ║ +# ║ │ Q × q_gain ← learnable sharpness │ ║ +# ║ │ │ │ │ │ ║ +# ║ │ └──────────────────┼──────────────────┘ │ ║ +# ║ │ ▼ │ ║ +# ║ │ scaled_dot_product_attention │ ║ +# ║ │ softmax(Q·Kᵀ / √64) · V │ ║ +# ║ │ (causal mask + FlashAttention + GQA) │ ║ +# ║ │ │ │ ║ +# ║ │ ▼ │ ║ +# ║ │ proj (512→512) │ ║ +# ║ │ │ │ ║ +# ║ │ ▼ │ ║ +# ║ │ Output: [batch, seq_len, 512] │ ║ +# ║ └─────────────────────────────────────────────────────────────────┘ ║ +# ║ ║ +# ║ GROUPED QUERY ATTENTION (GQA) — why K,V have 4 heads but Q has 8: ║ +# ║ ║ +# ║ Standard (MHA, 8×8): This model (GQA, 8×4): ║ +# ║ Q₀ → K₀ Q₄ → K₄ Q₀ ─┐ ║ +# ║ Q₁ → K₁ Q₅ → K₅ Q₁ ─┤→ K₀,V₀ ║ +# ║ Q₂ → K₂ Q₆ → K₆ Q₂ ─┐ ║ +# ║ Q₃ → K₃ Q₇ → K₇ Q₃ ─┤→ K₁,V₁ ║ +# ║ (8 KV pairs = full cost) Q₄ ─┐ ║ +# ║ Q₅ ─┤→ K₂,V₂ ║ +# ║ Q₆ ─┐ ║ +# ║ Q₇ ─┤→ K₃,V₃ ║ +# ║ (4 KV pairs = 50% memory) ║ +# ║ ║ +# ║ CAUSAL MASK (is_causal=True) — what makes this a language model: ║ +# ║ ║ +# ║ attends to→ t₀ t₁ t₂ t₃ t₄ ║ +# ║ t₀ ✓ ✗ ✗ ✗ ✗ "The" can only see itself ║ +# ║ t₁ ✓ ✓ ✗ ✗ ✗ "cat" sees "The cat" ║ +# ║ t₂ ✓ ✓ ✓ ✗ ✗ "sat" sees "The cat sat" ║ +# ║ t₃ ✓ ✓ ✓ ✓ ✗ "on" sees "The cat sat on" ║ +# ║ t₄ ✓ ✓ ✓ ✓ ✓ "the" sees everything before it ║ +# ║ ║ +# ║ Each token predicts the NEXT token without peeking at the answer. ║ +# ║ ║ +# ║ WHY EACH STEP EXISTS: ║ +# ║ • RMSNorm on Q,K → prevents attention scores from exploding to ±inf ║ +# ║ • RoPE on Q,K → encodes WHERE each token is (position 0, 1, 2...) ║ +# ║ • q_gain (4.0-5.25) → learnable per-head sharpness (higher = more focused) ║ +# ║ • √64 denominator → scales dot products so softmax doesn't saturate ║ +# ║ • proj → mixes info from all 8 heads back to 512 dimensions ║ +# ╚═══════════════════════════════════════════════════════════════════════════════════╝ +class CausalSelfAttention(nn.Module): + def __init__( + self, + dim: int, + num_heads: int, + num_kv_heads: int, + rope_base: float, + qk_gain_init: float, + gated_attn: str = "none", + ): + super().__init__() + if dim % num_heads != 0: + raise ValueError("model_dim must be divisible by num_heads") + if num_heads % num_kv_heads != 0: + raise ValueError("num_heads must be divisible by num_kv_heads") + self.num_heads = num_heads + self.num_kv_heads = num_kv_heads + self.head_dim = dim // num_heads # 512 / 8 = 64 dimensions per head + self.gated_attn = gated_attn + if self.head_dim % 2 != 0: + raise ValueError("head_dim must be even for RoPE") + kv_dim = self.num_kv_heads * self.head_dim # 4 × 64 = 256 (half of Q's 512, GQA savings) + # Gated Attention (arxiv.org/abs/2505.06708, NeurIPS 2025 Best Paper): + # Widen Q projection to also produce gate logits from the same input. + # "none" → c_q: 512→512, no gate + # "headwise" → c_q: 512→520, +8 dims (1 gate scalar per head, ~9K extra params/layer) + # "elementwise" → c_q: 512→1024, +512 dims (1 gate per dimension, doubles Q projection) + if gated_attn == "headwise": + self.gate_dim = num_heads # 8 extra outputs + elif gated_attn == "elementwise": + self.gate_dim = dim # 512 extra outputs + else: + self.gate_dim = 0 + self.c_q = CastedLinear(dim, dim + self.gate_dim, bias=False) # Query + gate logits + self.c_k = CastedLinear(dim, kv_dim, bias=False) # Key: 512→256 (only 4 KV heads) + self.c_v = CastedLinear(dim, kv_dim, bias=False) # Value: 512→256 (only 4 KV heads) + self.proj = CastedLinear(dim, dim, bias=False) # Output projection: 512→512 + self.proj._zero_init = True # Start at zero so skip connections dominate early + self.q_gain = nn.Parameter(torch.full((num_heads,), qk_gain_init, dtype=torch.float32)) + self.rotary = Rotary(self.head_dim, base=rope_base) + + def forward(self, x: Tensor) -> Tensor: + bsz, seqlen, dim = x.shape + # Step 1: Project input → Q (+ gate logits if gated attention is enabled) + q_out = self.c_q(x) + if self.gate_dim > 0: + q_raw, gate_logits = q_out.split([dim, self.gate_dim], dim=-1) + else: + q_raw = q_out + # Reshape into multi-head format: [batch, heads, seq, head_dim] + q = q_raw.reshape(bsz, seqlen, self.num_heads, self.head_dim).transpose(1, 2) + k = self.c_k(x).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim).transpose(1, 2) + v = self.c_v(x).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim).transpose(1, 2) + # Step 2: Normalize Q and K so dot products don't explode + q = F.rms_norm(q, (q.size(-1),)) + k = F.rms_norm(k, (k.size(-1),)) + # Step 3: Apply rotary position embeddings — encode WHERE each token sits + cos, sin = self.rotary(seqlen, x.device, q.dtype) + q = apply_rotary_emb(q, cos, sin) + k = apply_rotary_emb(k, cos, sin) + # Step 4: Scale Q by learnable per-head gain — controls attention sharpness + q = q * self.q_gain.to(dtype=q.dtype)[None, :, None, None] + # Step 5: The actual attention — softmax(Q·Kᵀ/√d) · V with causal mask + # Uses FlashAttention kernel under the hood (memory-efficient, no full NxN matrix) + # GQA: repeat K,V heads to match Q heads (e.g. 4 KV → 8 Q, repeat 2×) + if self.num_kv_heads != self.num_heads: + rep = self.num_heads // self.num_kv_heads + k = k.repeat_interleave(rep, dim=1) + v = v.repeat_interleave(rep, dim=1) + y = F.scaled_dot_product_attention( + q, + k, + v, + attn_mask=None, + is_causal=True, + ) + # Step 6 (Gated Attention): sigmoid gate AFTER attention, BEFORE output projection + # Gate logits come from the Q projection (query-dependent, input-dependent). + # sigmoid(gate) ∈ [0,1] lets the model suppress uninformative heads/dims per token. + # Headwise: [bsz, seq, 8] → [bsz, 8, seq, 1] (one scalar per head) + # Elementwise: [bsz, seq, 512] → [bsz, 8, seq, 64] (one per dimension) + if self.gated_attn == "headwise": + gate = torch.sigmoid(gate_logits) # [bsz, seqlen, num_heads] + gate = gate.transpose(1, 2).unsqueeze(-1) # [bsz, num_heads, seqlen, 1] + y = y * gate + elif self.gated_attn == "elementwise": + gate = torch.sigmoid(gate_logits) # [bsz, seqlen, dim] + gate = gate.reshape(bsz, seqlen, self.num_heads, self.head_dim).transpose(1, 2) + y = y * gate + # Step 7: Reshape from multi-head back to [batch, seq, 512] and project + y = y.transpose(1, 2).contiguous().reshape(bsz, seqlen, dim) + return self.proj(y) + + +class MLP(nn.Module): + # relu^2 MLP from the original modded-nanogpt setup + def __init__(self, dim: int, mlp_mult: int, activation: str = "relu2"): + super().__init__() + self.activation = activation + hidden = mlp_mult * dim + self.fc = CastedLinear(dim, hidden, bias=False) + self.proj = CastedLinear(hidden, dim, bias=False) + self.proj._zero_init = True + + def forward(self, x: Tensor) -> Tensor: + x = self.fc(x) + if self.activation == "leaky_relu2": + x = F.leaky_relu(x, negative_slope=0.5) + else: + x = torch.relu(x) + return self.proj(x.square()) + + +class Block(nn.Module): + def __init__( + self, + dim: int, + num_heads: int, + num_kv_heads: int, + mlp_mult: int, + rope_base: float, + qk_gain_init: float, + gated_attn: str = "none", + activation: str = "relu2", + ): + super().__init__() + self.attn_norm = RMSNorm() + self.mlp_norm = RMSNorm() + self.attn = CausalSelfAttention(dim, num_heads, num_kv_heads, rope_base, qk_gain_init, gated_attn) + self.mlp = MLP(dim, mlp_mult, activation) + self.attn_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + self.mlp_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + self.resid_mix = nn.Parameter(torch.stack((torch.ones(dim), torch.zeros(dim))).float()) + + def forward(self, x: Tensor, x0: Tensor) -> Tensor: + mix = self.resid_mix.to(dtype=x.dtype) + x = mix[0][None, None, :] * x + mix[1][None, None, :] * x0 + attn_out = self.attn(self.attn_norm(x)) + x = x + self.attn_scale.to(dtype=x.dtype)[None, None, :] * attn_out + x = x + self.mlp_scale.to(dtype=x.dtype)[None, None, :] * self.mlp(self.mlp_norm(x)) + return x + + +class GPT(nn.Module): + def __init__( + self, + vocab_size: int, + num_layers: int, + model_dim: int, + num_heads: int, + num_kv_heads: int, + mlp_mult: int, + tie_embeddings: bool, + tied_embed_init_std: float, + logit_softcap: float, + rope_base: float, + qk_gain_init: float, + gated_attn: str = "none", + activation: str = "relu2", + ): + super().__init__() + if logit_softcap <= 0.0: + raise ValueError(f"logit_softcap must be positive, got {logit_softcap}") + self.tie_embeddings = tie_embeddings + self.tied_embed_init_std = tied_embed_init_std + self.logit_softcap = logit_softcap + self.tok_emb = nn.Embedding(vocab_size, model_dim) + self.num_encoder_layers = num_layers // 2 + self.num_decoder_layers = num_layers - self.num_encoder_layers + self.num_skip_weights = min(self.num_encoder_layers, self.num_decoder_layers) + self.skip_weights = nn.Parameter(torch.ones(self.num_skip_weights, model_dim, dtype=torch.float32)) + self.blocks = nn.ModuleList( + [ + Block( + model_dim, + num_heads, + num_kv_heads, + mlp_mult, + rope_base, + qk_gain_init, + gated_attn, + activation, + ) + for i in range(num_layers) + ] + ) + self.final_norm = RMSNorm() + self.lm_head = None if tie_embeddings else CastedLinear(model_dim, vocab_size, bias=False) + if self.lm_head is not None: + self.lm_head._zero_init = True + self._init_weights() + + def _init_weights(self) -> None: + if self.tie_embeddings: + nn.init.normal_(self.tok_emb.weight, mean=0.0, std=self.tied_embed_init_std) + for module in self.modules(): + if isinstance(module, nn.Linear) and getattr(module, "_zero_init", False): + nn.init.zeros_(module.weight) + + def _run_backbone(self, input_ids: Tensor) -> Tensor: + x = self.tok_emb(input_ids) + x = F.rms_norm(x, (x.size(-1),)) + x0 = x + skips: list[Tensor] = [] + for i in range(self.num_encoder_layers): + x = self.blocks[i](x, x0) + skips.append(x) + for i in range(self.num_decoder_layers): + if skips: + x = x + self.skip_weights[i].to(dtype=x.dtype)[None, None, :] * skips.pop() + x = self.blocks[self.num_encoder_layers + i](x, x0) + x = self.final_norm(x) + if self.tie_embeddings: + logits_proj = F.linear(x, self.tok_emb.weight) + else: + if self.lm_head is None: + raise RuntimeError("lm_head is required when tie_embeddings=False") + logits_proj = self.lm_head(x) + return self.logit_softcap * torch.tanh(logits_proj / self.logit_softcap) + + def forward_logits(self, input_ids: Tensor) -> Tensor: + """Return logits [batch, seq, vocab] without computing loss.""" + return self._run_backbone(input_ids) + + def forward(self, input_ids: Tensor, target_ids: Tensor) -> Tensor: + logits = self._run_backbone(input_ids) + logits = logits.reshape(-1, logits.size(-1)) + targets = target_ids.reshape(-1) + return F.cross_entropy(logits.float(), targets, reduction="mean") + + +# ----------------------------- +# TRAINING +# ----------------------------- + +def main() -> None: + global zeropower_via_newtonschulz5 + + code = Path(__file__).read_text(encoding="utf-8") + args = Hyperparameters() + zeropower_via_newtonschulz5 = torch.compile(zeropower_via_newtonschulz5) + + # ----------------------------- + # DISTRIBUTED + CUDA SETUP + # ----------------------------- + + distributed = "RANK" in os.environ and "WORLD_SIZE" in os.environ + rank = int(os.environ.get("RANK", "0")) + world_size = int(os.environ.get("WORLD_SIZE", "1")) + local_rank = int(os.environ.get("LOCAL_RANK", "0")) + if world_size <= 0: + raise ValueError(f"WORLD_SIZE must be positive, got {world_size}") + if 8 % world_size != 0: + raise ValueError(f"WORLD_SIZE={world_size} must divide 8 so grad_accum_steps stays integral") + grad_accum_steps = 8 // world_size + grad_scale = 1.0 / grad_accum_steps + if not torch.cuda.is_available(): + raise RuntimeError("CUDA is required") + device = torch.device("cuda", local_rank) + torch.cuda.set_device(device) + if distributed: + dist.init_process_group(backend="nccl", device_id=device) + dist.barrier() + master_process = rank == 0 + + # Fast math knobs + torch.backends.cuda.matmul.allow_tf32 = True + torch.backends.cudnn.allow_tf32 = True + from torch.backends.cuda import enable_cudnn_sdp, enable_flash_sdp, enable_math_sdp, enable_mem_efficient_sdp + + enable_cudnn_sdp(False) + enable_flash_sdp(True) + enable_mem_efficient_sdp(False) + enable_math_sdp(False) + + logfile = None + if master_process: + os.makedirs("logs", exist_ok=True) + logfile = f"logs/{args.run_id}.txt" + print(logfile) + + def log0(msg: str, console: bool = True) -> None: + if not master_process: + return + if console: + print(msg) + if logfile is not None: + with open(logfile, "a", encoding="utf-8") as f: + print(msg, file=f) + + log0(code, console=False) + log0("=" * 100, console=False) + log0(f"Running Python {sys.version}", console=False) + log0(f"Running PyTorch {torch.__version__}", console=False) + log0( + subprocess.run(["nvidia-smi"], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True, check=False).stdout, + console=False, + ) + log0("=" * 100, console=False) + + # ----------------------------- + # TOKENIZER + VALIDATION METRIC SETUP + # ----------------------------- + + random.seed(args.seed) + np.random.seed(args.seed) + torch.manual_seed(args.seed) + torch.cuda.manual_seed_all(args.seed) + + if not args.tokenizer_path.endswith(".model"): + raise ValueError(f"Script only setup for SentencePiece .model file: {args.tokenizer_path}") + sp = spm.SentencePieceProcessor(model_file=args.tokenizer_path) + if int(sp.vocab_size()) != args.vocab_size: + raise ValueError( + f"VOCAB_SIZE={args.vocab_size} does not match tokenizer vocab_size={int(sp.vocab_size())}" + ) + dataset_dir = Path(args.data_path).resolve() + actual_train_files = len(list(dataset_dir.glob("fineweb_train_*.bin"))) + val_tokens = load_validation_tokens(args.val_files, args.train_seq_len) + base_bytes_lut, has_leading_space_lut, is_boundary_token_lut = build_sentencepiece_luts( + sp, args.vocab_size, device + ) + log0(f"val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path={args.tokenizer_path}") + log0(f"train_loader:dataset:{dataset_dir.name} train_shards:{actual_train_files}") + log0(f"val_loader:shards pattern={args.val_files} tokens:{val_tokens.numel() - 1}") + + # ----------------------------- + # MODEL + OPTIMIZER SETUP + # ----------------------------- + + base_model = GPT( + vocab_size=args.vocab_size, + num_layers=args.num_layers, + model_dim=args.model_dim, + num_heads=args.num_heads, + num_kv_heads=args.num_kv_heads, + mlp_mult=args.mlp_mult, + tie_embeddings=args.tie_embeddings, + tied_embed_init_std=args.tied_embed_init_std, + logit_softcap=args.logit_softcap, + rope_base=args.rope_base, + qk_gain_init=args.qk_gain_init, + gated_attn=args.gated_attn, + activation=args.activation, + ).to(device).bfloat16() + for module in base_model.modules(): + if isinstance(module, CastedLinear): + module.float() + restore_low_dim_params_to_fp32(base_model) + compiled_model = torch.compile(base_model, dynamic=False, fullgraph=True) + model: nn.Module = DDP(compiled_model, device_ids=[local_rank], broadcast_buffers=False) if distributed else compiled_model + + # ===================================================================== + # OPTIMIZER SPLIT — Different param types get different optimizers & LRs + # ===================================================================== + # + # The model's parameters are sorted into 3-4 groups, each with its own + # optimizer and learning rate. Think of it like a team where each role + # needs a different management style: + # + # ┌─────────────────────────────────────────────────────────────────┐ + # │ Parameter Group Optimizer LR % of Params │ + # │ ───────────────────── ───────── ────── ─────────── │ + # │ 1. Embedding table Adam 0.05 ~3% (SP1024) │ + # │ (tok_emb.weight) ~12% (SP8192) │ + # │ "the dictionary" │ + # │ │ + # │ 2. Matrix weights MUON 0.04 ~95% of blocks │ + # │ (Q, K, V, proj, The heavy lifters │ + # │ fc, mlp.proj) — 2D matrices │ + # │ "attention + MLP that do the real │ + # │ computation" learning │ + # │ │ + # │ 3. Scalar params Adam 0.04 <1% │ + # │ (attn_scale, Tiny knobs that │ + # │ mlp_scale, fine-tune how │ + # │ resid_mix, blocks blend │ + # │ q_gain, their outputs │ + # │ skip_weights) │ + # │ │ + # │ 4. Output head Adam 0.008 Only if untied │ + # │ (lm_head.weight) embeddings │ + # │ "vector → token (not used in │ + # │ probability" baseline) │ + # └─────────────────────────────────────────────────────────────────┘ + # + # WHY MUON FOR MATRICES? + # Muon orthogonalizes gradient updates via Newton-Schulz iterations. + # This only works on 2D matrices (needs rows & columns to orthogonalize). + # 1D vectors and scalars are too small — they use standard Adam instead. + # + # WHY DIFFERENT LEARNING RATES? + # - Embeddings are a lookup table — aggressive updates cause instability + # - Matrix weights are the core compute — Muon handles the LR scaling + # - Scalars are sensitive knobs — moderate LR with Adam's adaptive step + # - Output head (if untied) maps back to vocab — needs gentler updates + # + block_named_params = list(base_model.blocks.named_parameters()) + + # --- Group 2: Matrix params (2D tensors in transformer blocks) → Muon --- + # Collects: c_q.weight, c_k.weight, c_v.weight, proj.weight, fc.weight, mlp.proj.weight + # Excludes: control tensors like attn_scale, resid_mix (those are 2D but act as scalars) + matrix_params = [ + p + for name, p in block_named_params + if p.ndim == 2 and not any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) + ] + + # --- Group 3: Scalar params (1D vectors + control tensors) → Adam --- + # Collects: attn_scale, mlp_scale, resid_mix, q_gain, skip_weights + # These are the small "knobs" that tune how blocks combine their outputs + scalar_params = [ + p + for name, p in block_named_params + if p.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) + ] + if base_model.skip_weights.numel() > 0: + scalar_params.append(base_model.skip_weights) # U-Net skip connection weights + + # --- Group 1: Embedding table → Adam (gentle LR) --- + # The "dictionary" that maps token IDs to 512-dim vectors + # If tied: same weight used for input embedding AND output prediction (LR=0.05) + # If untied: separate input embedding (LR=0.6) and output head (LR=0.008) + token_lr = args.tied_embed_lr if args.tie_embeddings else args.embed_lr + optimizer_tok = torch.optim.Adam( + [{"params": [base_model.tok_emb.weight], "lr": token_lr, "base_lr": token_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + fused=True, # fused=True: single GPU kernel for speed + ) + + # --- Group 2 optimizer: Muon for the heavy 2D matrices --- + # Uses Newton-Schulz orthogonalization (5 steps) to normalize gradients + # before applying them. This is the key innovation from modded-nanogpt. + optimizer_muon = Muon( + matrix_params, + lr=args.matrix_lr, + momentum=args.muon_momentum, + backend_steps=args.muon_backend_steps, + ) + for group in optimizer_muon.param_groups: + group["base_lr"] = args.matrix_lr # Stash base LR for warmdown scheduling + + # --- Group 3 optimizer: Adam for small scalar/vector params --- + optimizer_scalar = torch.optim.Adam( + [{"params": scalar_params, "lr": args.scalar_lr, "base_lr": args.scalar_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + fused=True, + ) + + # Collect all optimizers — training loop will step() all of them each iteration + optimizers: list[torch.optim.Optimizer] = [optimizer_tok, optimizer_muon, optimizer_scalar] + + # --- Group 4 (optional): Output head → Adam (only when embeddings are untied) --- + # In baseline, tie_embeddings=True so lm_head is None and this is skipped. + # When untied, the output head gets its own gentle LR (0.008) + if base_model.lm_head is not None: + optimizer_head = torch.optim.Adam( + [{"params": [base_model.lm_head.weight], "lr": args.head_lr, "base_lr": args.head_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + fused=True, + ) + optimizers.insert(1, optimizer_head) + + n_params = sum(p.numel() for p in base_model.parameters()) + log0(f"model_params:{n_params}") + log0(f"world_size:{world_size} grad_accum_steps:{grad_accum_steps}") + log0("sdp_backends:cudnn=False flash=True mem_efficient=False math=False") + log0(f"attention_mode:gqa num_heads:{args.num_heads} num_kv_heads:{args.num_kv_heads}") + log0( + f"tie_embeddings:{args.tie_embeddings} embed_lr:{token_lr} " + f"head_lr:{args.head_lr if base_model.lm_head is not None else 0.0} " + f"matrix_lr:{args.matrix_lr} scalar_lr:{args.scalar_lr}" + ) + log0( + f"train_batch_tokens:{args.train_batch_tokens} train_seq_len:{args.train_seq_len} " + f"iterations:{args.iterations} warmup_steps:{args.warmup_steps} " + f"max_wallclock_seconds:{args.max_wallclock_seconds:.3f}" + ) + log0(f"seed:{args.seed}") + + # ----------------------------- + # DATA LOADER & MODEL WARMUP + # ----------------------------- + + train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) + + def zero_grad_all() -> None: + for opt in optimizers: + opt.zero_grad(set_to_none=True) + + max_wallclock_ms = 1000.0 * args.max_wallclock_seconds if args.max_wallclock_seconds > 0 else None + + def lr_mul(step: int, elapsed_ms: float) -> float: + if args.warmdown_iters <= 0: + return 1.0 + if max_wallclock_ms is None: + warmdown_start = max(args.iterations - args.warmdown_iters, 0) + return max((args.iterations - step) / max(args.warmdown_iters, 1), 0.0) if warmdown_start <= step < args.iterations else 1.0 + step_ms = elapsed_ms / max(step, 1) + warmdown_ms = args.warmdown_iters * step_ms + remaining_ms = max(max_wallclock_ms - elapsed_ms, 0.0) + return remaining_ms / max(warmdown_ms, 1e-9) if remaining_ms <= warmdown_ms else 1.0 + + # Warmup primes the compiled forward/backward/optimizer paths, then we restore the + # initial weights/optimizer state so measured training starts from the true init. + if args.warmup_steps > 0: + initial_model_state = {name: tensor.detach().cpu().clone() for name, tensor in base_model.state_dict().items()} + initial_optimizer_states = [copy.deepcopy(opt.state_dict()) for opt in optimizers] + model.train() + for warmup_step in range(args.warmup_steps): + zero_grad_all() + for micro_step in range(grad_accum_steps): + if distributed: + model.require_backward_grad_sync = micro_step == grad_accum_steps - 1 + x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + warmup_loss = model(x, y) + (warmup_loss * grad_scale).backward() + for opt in optimizers: + opt.step() + zero_grad_all() + if args.warmup_steps <= 20 or (warmup_step + 1) % 10 == 0 or warmup_step + 1 == args.warmup_steps: + log0(f"warmup_step:{warmup_step + 1}/{args.warmup_steps}") + base_model.load_state_dict(initial_model_state, strict=True) + for opt, state in zip(optimizers, initial_optimizer_states, strict=True): + opt.load_state_dict(state) + zero_grad_all() + if distributed: + model.require_backward_grad_sync = True + train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) + + # ----------------------------- + # MAIN TRAINING LOOP + # ----------------------------- + + training_time_ms = 0.0 + stop_after_step: int | None = None + torch.cuda.synchronize() + t0 = time.perf_counter() + + step = 0 + while True: + last_step = step == args.iterations or (stop_after_step is not None and step >= stop_after_step) + + should_validate = last_step or (args.val_loss_every > 0 and step % args.val_loss_every == 0) + if should_validate: + torch.cuda.synchronize() + training_time_ms += 1000.0 * (time.perf_counter() - t0) + val_loss, val_bpb = eval_val( + args, + model, + rank, + world_size, + device, + grad_accum_steps, + val_tokens, + base_bytes_lut, + has_leading_space_lut, + is_boundary_token_lut, + ) + log0( + f"step:{step}/{args.iterations} val_loss:{val_loss:.4f} val_bpb:{val_bpb:.4f} " + f"train_time:{training_time_ms:.0f}ms step_avg:{training_time_ms / max(step, 1):.2f}ms" + ) + torch.cuda.synchronize() + t0 = time.perf_counter() + + if last_step: + if stop_after_step is not None and step < args.iterations: + log0( + f"stopping_early: wallclock_cap train_time:{training_time_ms:.0f}ms " + f"step:{step}/{args.iterations}" + ) + break + + elapsed_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) + scale = lr_mul(step, elapsed_ms) + zero_grad_all() + train_loss = torch.zeros((), device=device) + for micro_step in range(grad_accum_steps): + if distributed: + model.require_backward_grad_sync = micro_step == grad_accum_steps - 1 + x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + loss = model(x, y) + train_loss += loss.detach() + (loss * grad_scale).backward() + train_loss /= grad_accum_steps + + frac = min(step / args.muon_momentum_warmup_steps, 1.0) if args.muon_momentum_warmup_steps > 0 else 1.0 + muon_momentum = (1 - frac) * args.muon_momentum_warmup_start + frac * args.muon_momentum + for group in optimizer_muon.param_groups: + group["momentum"] = muon_momentum + + for opt in optimizers: + for group in opt.param_groups: + group["lr"] = group["base_lr"] * scale + + if args.grad_clip_norm > 0: + torch.nn.utils.clip_grad_norm_(base_model.parameters(), args.grad_clip_norm) + for opt in optimizers: + opt.step() + zero_grad_all() + + step += 1 + approx_training_time_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) + should_log_train = ( + args.train_log_every > 0 + and (step <= 10 or step % args.train_log_every == 0 or stop_after_step is not None) + ) + if should_log_train: + log0( + f"step:{step}/{args.iterations} train_loss:{train_loss.item():.4f} " + f"train_time:{approx_training_time_ms:.0f}ms step_avg:{approx_training_time_ms / step:.2f}ms" + ) + + # Needed to sync whether we've reached the wallclock cap. + reached_cap = max_wallclock_ms is not None and approx_training_time_ms >= max_wallclock_ms + if distributed and max_wallclock_ms is not None: + reached_cap_tensor = torch.tensor(int(reached_cap), device=device) + dist.all_reduce(reached_cap_tensor, op=dist.ReduceOp.MAX) + reached_cap = bool(reached_cap_tensor.item()) + if stop_after_step is None and reached_cap: + stop_after_step = step + + log0( + f"peak memory allocated: {torch.cuda.max_memory_allocated() // 1024 // 1024} MiB " + f"reserved: {torch.cuda.max_memory_reserved() // 1024 // 1024} MiB" + ) + + # ----------------------------- + # SERIALIZATION + ROUNDTRIP VALIDATION + # ----------------------------- + # Save the raw state (useful for debugging/loading in PyTorch directly), then always produce + # the compressed int8+zlib artifact and validate the round-tripped weights. + + if master_process: + torch.save(base_model.state_dict(), "final_model.pt") + model_bytes = os.path.getsize("final_model.pt") + code_bytes = len(code.encode("utf-8")) + log0(f"Serialized model: {model_bytes} bytes") + log0(f"Code size: {code_bytes} bytes") + log0(f"Total submission size: {model_bytes + code_bytes} bytes") + + quant_obj, quant_stats = quantize_state_dict_int8(base_model.state_dict()) + quant_buf = io.BytesIO() + torch.save(quant_obj, quant_buf) + quant_raw = quant_buf.getvalue() + quant_blob = zlib.compress(quant_raw, level=9) + quant_raw_bytes = len(quant_raw) + if master_process: + with open("final_model.int8.ptz", "wb") as f: + f.write(quant_blob) + quant_file_bytes = os.path.getsize("final_model.int8.ptz") + code_bytes = len(code.encode("utf-8")) + ratio = quant_stats["baseline_tensor_bytes"] / max(quant_stats["int8_payload_bytes"], 1) + log0( + f"Serialized model int8+zlib: {quant_file_bytes} bytes " + f"(payload:{quant_stats['int8_payload_bytes']} raw_torch:{quant_raw_bytes} payload_ratio:{ratio:.2f}x)" + ) + log0(f"Total submission size int8+zlib: {quant_file_bytes + code_bytes} bytes") + + if distributed: + dist.barrier() + with open("final_model.int8.ptz", "rb") as f: + quant_blob_disk = f.read() + quant_state = torch.load(io.BytesIO(zlib.decompress(quant_blob_disk)), map_location="cpu") + base_model.load_state_dict(dequantize_state_dict_int8(quant_state), strict=True) + torch.cuda.synchronize() + t_qeval = time.perf_counter() + q_val_loss, q_val_bpb = eval_val( + args, + model, + rank, + world_size, + device, + grad_accum_steps, + val_tokens, + base_bytes_lut, + has_leading_space_lut, + is_boundary_token_lut, + ) + torch.cuda.synchronize() + log0( + f"final_int8_zlib_roundtrip val_loss:{q_val_loss:.4f} val_bpb:{q_val_bpb:.4f} " + f"eval_time:{1000.0 * (time.perf_counter() - t_qeval):.0f}ms" + ) + log0(f"final_int8_zlib_roundtrip_exact val_loss:{q_val_loss:.8f} val_bpb:{q_val_bpb:.8f}") + + # Score-First TTT on quantized model (if enabled) + if args.ttt_mode == "score_first": + torch.cuda.synchronize() + t_ttt = time.perf_counter() + # Save quantized weights, TTT will modify them during adaptation + saved_state = copy.deepcopy(base_model.state_dict()) + ttt_loss, ttt_bpb = eval_val_score_first_ttt( + args, base_model, rank, world_size, device, + val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, + log0=log0, + ) + # Restore quantized weights (TTT modifies them in-place) + base_model.load_state_dict(saved_state) + torch.cuda.synchronize() + log0(f"final_int8_ttt val_loss:{ttt_loss:.4f} val_bpb:{ttt_bpb:.4f} " + f"eval_time:{1000.0 * (time.perf_counter() - t_ttt):.0f}ms") + log0(f"final_int8_ttt_exact val_loss:{ttt_loss:.8f} val_bpb:{ttt_bpb:.8f}") + + if distributed: + dist.destroy_process_group() + + +if __name__ == "__main__": + main() + +==================================================================================================== +Running Python 3.11.10 (main, Sep 7 2024, 18:35:41) [GCC 11.4.0] +Running PyTorch 2.11.0+cu130 +Thu Apr 23 23:58:45 2026 ++-----------------------------------------------------------------------------------------+ +| NVIDIA-SMI 580.126.09 Driver Version: 580.126.09 CUDA Version: 13.0 | ++-----------------------------------------+------------------------+----------------------+ +| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC | +| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. | +| | | MIG M. | +|=========================================+========================+======================| +| 0 NVIDIA H100 80GB HBM3 On | 00000000:19:00.0 Off | 0 | +| N/A 33C P0 114W / 700W | 1505MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 1 NVIDIA H100 80GB HBM3 On | 00000000:3B:00.0 Off | 0 | +| N/A 32C P0 119W / 700W | 1505MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 2 NVIDIA H100 80GB HBM3 On | 00000000:4C:00.0 Off | 0 | +| N/A 30C P0 116W / 700W | 1505MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 3 NVIDIA H100 80GB HBM3 On | 00000000:5D:00.0 Off | 0 | +| N/A 33C P0 120W / 700W | 1505MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 4 NVIDIA H100 80GB HBM3 On | 00000000:9B:00.0 Off | 0 | +| N/A 35C P0 114W / 700W | 1505MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 5 NVIDIA H100 80GB HBM3 On | 00000000:BB:00.0 Off | 0 | +| N/A 33C P0 116W / 700W | 1505MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 6 NVIDIA H100 80GB HBM3 On | 00000000:CB:00.0 Off | 0 | +| N/A 34C P0 116W / 700W | 1505MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 7 NVIDIA H100 80GB HBM3 On | 00000000:DB:00.0 Off | 0 | +| N/A 31C P0 114W / 700W | 1505MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ + ++-----------------------------------------------------------------------------------------+ +| Processes: | +| GPU GI CI PID Type Process name GPU Memory | +| ID ID Usage | +|=========================================================================================| +| No running processes found | ++-----------------------------------------------------------------------------------------+ + +==================================================================================================== +val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path=./data/tokenizers/fineweb_8192_bpe.model +train_loader:dataset:fineweb10B_sp8192 train_shards:80 +val_loader:shards pattern=./data/datasets/fineweb10B_sp8192/fineweb_val_*.bin tokens:40540160 +model_params:16364616 +world_size:8 grad_accum_steps:1 +sdp_backends:cudnn=False flash=True mem_efficient=False math=False +attention_mode:gqa num_heads:8 num_kv_heads:4 +tie_embeddings:True embed_lr:0.05 head_lr:0.0 matrix_lr:0.04 scalar_lr:0.04 +train_batch_tokens:524288 train_seq_len:1024 iterations:20000 warmup_steps:20 max_wallclock_seconds:600.000 +seed:1337 +warmup_step:1/20 +warmup_step:2/20 +warmup_step:3/20 +warmup_step:4/20 +warmup_step:5/20 +warmup_step:6/20 +warmup_step:7/20 +warmup_step:8/20 +warmup_step:9/20 +warmup_step:10/20 +warmup_step:11/20 +warmup_step:12/20 +warmup_step:13/20 +warmup_step:14/20 +warmup_step:15/20 +warmup_step:16/20 +warmup_step:17/20 +warmup_step:18/20 +warmup_step:19/20 +warmup_step:20/20 +step:0/20000 val_loss:9.0078 val_bpb:3.4872 train_time:0ms step_avg:0.02ms +step:1/20000 train_loss:9.0088 train_time:35ms step_avg:34.63ms +step:2/20000 train_loss:15.5891 train_time:84ms step_avg:42.10ms +step:3/20000 train_loss:11.2669 train_time:138ms step_avg:45.97ms +step:4/20000 train_loss:9.0208 train_time:192ms step_avg:48.07ms +step:5/20000 train_loss:8.5225 train_time:247ms step_avg:49.30ms +step:6/20000 train_loss:8.4845 train_time:301ms step_avg:50.15ms +step:7/20000 train_loss:8.2421 train_time:355ms step_avg:50.74ms +step:8/20000 train_loss:8.0542 train_time:410ms step_avg:51.19ms +step:9/20000 train_loss:7.7000 train_time:464ms step_avg:51.53ms +step:10/20000 train_loss:7.5769 train_time:518ms step_avg:51.81ms +step:100/20000 train_loss:4.6086 train_time:5518ms step_avg:55.18ms +step:200/20000 train_loss:3.9402 train_time:11005ms step_avg:55.02ms +step:200/20000 val_loss:3.9500 val_bpb:1.5292 train_time:11026ms step_avg:55.13ms +step:300/20000 train_loss:3.8180 train_time:16412ms step_avg:54.71ms +step:400/20000 train_loss:3.6614 train_time:21882ms step_avg:54.71ms +step:400/20000 val_loss:3.6603 val_bpb:1.4170 train_time:21904ms step_avg:54.76ms +step:500/20000 train_loss:3.6033 train_time:27282ms step_avg:54.56ms +step:600/20000 train_loss:3.5454 train_time:32739ms step_avg:54.56ms +step:600/20000 val_loss:3.5425 val_bpb:1.3714 train_time:32759ms step_avg:54.60ms +step:700/20000 train_loss:3.3691 train_time:38131ms step_avg:54.47ms +step:800/20000 train_loss:3.4430 train_time:43586ms step_avg:54.48ms +step:800/20000 val_loss:3.4659 val_bpb:1.3417 train_time:43607ms step_avg:54.51ms +step:900/20000 train_loss:3.5918 train_time:48981ms step_avg:54.42ms +step:1000/20000 train_loss:3.1732 train_time:54437ms step_avg:54.44ms +step:1000/20000 val_loss:3.4180 val_bpb:1.3232 train_time:54459ms step_avg:54.46ms +step:1100/20000 train_loss:3.4172 train_time:59845ms step_avg:54.40ms +step:1200/20000 train_loss:3.4110 train_time:65317ms step_avg:54.43ms +step:1200/20000 val_loss:3.3832 val_bpb:1.3098 train_time:65339ms step_avg:54.45ms +step:1300/20000 train_loss:3.3989 train_time:70724ms step_avg:54.40ms +step:1400/20000 train_loss:3.4809 train_time:76189ms step_avg:54.42ms +step:1400/20000 val_loss:3.3613 val_bpb:1.3013 train_time:76211ms step_avg:54.44ms +step:1500/20000 train_loss:3.5421 train_time:81606ms step_avg:54.40ms +step:1600/20000 train_loss:3.5551 train_time:87082ms step_avg:54.43ms +step:1600/20000 val_loss:3.3432 val_bpb:1.2943 train_time:87103ms step_avg:54.44ms +step:1700/20000 train_loss:3.4579 train_time:92490ms step_avg:54.41ms +step:1800/20000 train_loss:3.2671 train_time:97954ms step_avg:54.42ms +step:1800/20000 val_loss:3.3261 val_bpb:1.2876 train_time:97976ms step_avg:54.43ms +step:1900/20000 train_loss:3.4450 train_time:103358ms step_avg:54.40ms +step:2000/20000 train_loss:3.2871 train_time:108826ms step_avg:54.41ms +step:2000/20000 val_loss:3.3109 val_bpb:1.2817 train_time:108847ms step_avg:54.42ms +step:2100/20000 train_loss:3.1805 train_time:114289ms step_avg:54.42ms +step:2200/20000 train_loss:3.3250 train_time:119693ms step_avg:54.41ms +step:2200/20000 val_loss:3.3012 val_bpb:1.2780 train_time:119715ms step_avg:54.42ms +step:2300/20000 train_loss:3.4275 train_time:125170ms step_avg:54.42ms +step:2400/20000 train_loss:3.3016 train_time:130573ms step_avg:54.41ms +step:2400/20000 val_loss:3.2880 val_bpb:1.2729 train_time:130594ms step_avg:54.41ms +step:2500/20000 train_loss:3.1703 train_time:136034ms step_avg:54.41ms +step:2600/20000 train_loss:3.2871 train_time:141436ms step_avg:54.40ms +step:2600/20000 val_loss:3.2780 val_bpb:1.2690 train_time:141457ms step_avg:54.41ms +step:2700/20000 train_loss:3.4050 train_time:146908ms step_avg:54.41ms +step:2800/20000 train_loss:3.4570 train_time:152308ms step_avg:54.40ms +step:2800/20000 val_loss:3.2724 val_bpb:1.2669 train_time:152329ms step_avg:54.40ms +step:2900/20000 train_loss:3.2258 train_time:157779ms step_avg:54.41ms +step:3000/20000 train_loss:3.3814 train_time:163182ms step_avg:54.39ms +step:3000/20000 val_loss:3.2649 val_bpb:1.2640 train_time:163204ms step_avg:54.40ms +step:3100/20000 train_loss:3.1656 train_time:168656ms step_avg:54.41ms +step:3200/20000 train_loss:3.4771 train_time:174064ms step_avg:54.39ms +step:3200/20000 val_loss:3.2587 val_bpb:1.2615 train_time:174085ms step_avg:54.40ms +step:3300/20000 train_loss:3.2926 train_time:179530ms step_avg:54.40ms +step:3400/20000 train_loss:3.3849 train_time:184936ms step_avg:54.39ms +step:3400/20000 val_loss:3.2535 val_bpb:1.2595 train_time:184957ms step_avg:54.40ms +step:3500/20000 train_loss:3.5824 train_time:190413ms step_avg:54.40ms +step:3600/20000 train_loss:3.1954 train_time:195818ms step_avg:54.39ms +step:3600/20000 val_loss:3.2462 val_bpb:1.2567 train_time:195839ms step_avg:54.40ms +step:3700/20000 train_loss:3.2165 train_time:201283ms step_avg:54.40ms +step:3800/20000 train_loss:3.1924 train_time:206688ms step_avg:54.39ms +step:3800/20000 val_loss:3.2427 val_bpb:1.2553 train_time:206710ms step_avg:54.40ms +step:3900/20000 train_loss:3.1956 train_time:212163ms step_avg:54.40ms +step:4000/20000 train_loss:3.3084 train_time:217569ms step_avg:54.39ms +step:4000/20000 val_loss:3.2373 val_bpb:1.2532 train_time:217590ms step_avg:54.40ms +step:4100/20000 train_loss:2.6838 train_time:223033ms step_avg:54.40ms +step:4200/20000 train_loss:3.1755 train_time:228507ms step_avg:54.41ms +step:4200/20000 val_loss:3.2306 val_bpb:1.2507 train_time:228528ms step_avg:54.41ms +step:4300/20000 train_loss:3.2618 train_time:233911ms step_avg:54.40ms +step:4400/20000 train_loss:3.1398 train_time:239379ms step_avg:54.40ms +step:4400/20000 val_loss:3.2294 val_bpb:1.2502 train_time:239401ms step_avg:54.41ms +step:4500/20000 train_loss:3.1948 train_time:244786ms step_avg:54.40ms +step:4600/20000 train_loss:3.1681 train_time:250260ms step_avg:54.40ms +step:4600/20000 val_loss:3.2239 val_bpb:1.2481 train_time:250281ms step_avg:54.41ms +step:4700/20000 train_loss:3.1448 train_time:255665ms step_avg:54.40ms +step:4800/20000 train_loss:3.2815 train_time:261130ms step_avg:54.40ms +step:4800/20000 val_loss:3.2206 val_bpb:1.2468 train_time:261152ms step_avg:54.41ms +step:4900/20000 train_loss:3.1225 train_time:266539ms step_avg:54.40ms +step:5000/20000 train_loss:3.2987 train_time:272015ms step_avg:54.40ms +step:5000/20000 val_loss:3.2177 val_bpb:1.2457 train_time:272037ms step_avg:54.41ms +step:5100/20000 train_loss:3.3241 train_time:277423ms step_avg:54.40ms +step:5200/20000 train_loss:3.2819 train_time:282888ms step_avg:54.40ms +step:5200/20000 val_loss:3.2155 val_bpb:1.2448 train_time:282910ms step_avg:54.41ms +step:5300/20000 train_loss:3.2308 train_time:288301ms step_avg:54.40ms +step:5400/20000 train_loss:3.1886 train_time:293774ms step_avg:54.40ms +step:5400/20000 val_loss:3.2108 val_bpb:1.2430 train_time:293795ms step_avg:54.41ms +step:5500/20000 train_loss:3.1998 train_time:299180ms step_avg:54.40ms +step:5600/20000 train_loss:3.2984 train_time:304645ms step_avg:54.40ms +step:5600/20000 val_loss:3.2073 val_bpb:1.2416 train_time:304666ms step_avg:54.40ms +step:5700/20000 train_loss:3.0328 train_time:310053ms step_avg:54.40ms +step:5800/20000 train_loss:3.2188 train_time:315528ms step_avg:54.40ms +step:5800/20000 val_loss:3.2067 val_bpb:1.2414 train_time:315549ms step_avg:54.41ms +step:5900/20000 train_loss:3.2133 train_time:320937ms step_avg:54.40ms +step:6000/20000 train_loss:3.3359 train_time:326403ms step_avg:54.40ms +step:6000/20000 val_loss:3.2038 val_bpb:1.2403 train_time:326424ms step_avg:54.40ms +step:6100/20000 train_loss:3.2269 train_time:331811ms step_avg:54.40ms +step:6200/20000 train_loss:3.3712 train_time:337287ms step_avg:54.40ms +step:6200/20000 val_loss:3.2032 val_bpb:1.2400 train_time:337308ms step_avg:54.40ms +step:6300/20000 train_loss:3.2757 train_time:342751ms step_avg:54.40ms +step:6400/20000 train_loss:3.2800 train_time:348159ms step_avg:54.40ms +step:6400/20000 val_loss:3.2042 val_bpb:1.2404 train_time:348180ms step_avg:54.40ms +step:6500/20000 train_loss:3.2153 train_time:353635ms step_avg:54.41ms +step:6600/20000 train_loss:3.2502 train_time:359040ms step_avg:54.40ms +step:6600/20000 val_loss:3.1980 val_bpb:1.2380 train_time:359061ms step_avg:54.40ms +step:6700/20000 train_loss:2.9270 train_time:364505ms step_avg:54.40ms +step:6800/20000 train_loss:3.1250 train_time:369911ms step_avg:54.40ms +step:6800/20000 val_loss:3.1956 val_bpb:1.2371 train_time:369933ms step_avg:54.40ms +step:6900/20000 train_loss:3.2871 train_time:375387ms step_avg:54.40ms +step:7000/20000 train_loss:3.1712 train_time:380792ms step_avg:54.40ms +step:7000/20000 val_loss:3.1935 val_bpb:1.2363 train_time:380814ms step_avg:54.40ms +step:7100/20000 train_loss:3.1975 train_time:386257ms step_avg:54.40ms +step:7200/20000 train_loss:3.0835 train_time:391663ms step_avg:54.40ms +step:7200/20000 val_loss:3.1893 val_bpb:1.2347 train_time:391684ms step_avg:54.40ms +step:7300/20000 train_loss:2.9728 train_time:397139ms step_avg:54.40ms +step:7400/20000 train_loss:3.1677 train_time:402545ms step_avg:54.40ms +step:7400/20000 val_loss:3.1907 val_bpb:1.2352 train_time:402566ms step_avg:54.40ms +step:7500/20000 train_loss:3.2317 train_time:408010ms step_avg:54.40ms +step:7600/20000 train_loss:3.3071 train_time:413416ms step_avg:54.40ms +step:7600/20000 val_loss:3.1860 val_bpb:1.2334 train_time:413438ms step_avg:54.40ms +step:7700/20000 train_loss:3.1699 train_time:418890ms step_avg:54.40ms +step:7800/20000 train_loss:3.3811 train_time:424296ms step_avg:54.40ms +step:7800/20000 val_loss:3.1862 val_bpb:1.2335 train_time:424317ms step_avg:54.40ms +step:7900/20000 train_loss:3.2268 train_time:429759ms step_avg:54.40ms +step:8000/20000 train_loss:3.1732 train_time:435165ms step_avg:54.40ms +step:8000/20000 val_loss:3.1835 val_bpb:1.2324 train_time:435187ms step_avg:54.40ms +step:8100/20000 train_loss:3.2296 train_time:440638ms step_avg:54.40ms +step:8200/20000 train_loss:3.1422 train_time:446044ms step_avg:54.40ms +step:8200/20000 val_loss:3.1826 val_bpb:1.2321 train_time:446065ms step_avg:54.40ms +step:8300/20000 train_loss:3.2119 train_time:451509ms step_avg:54.40ms +step:8400/20000 train_loss:3.1516 train_time:456983ms step_avg:54.40ms +step:8400/20000 val_loss:3.1912 val_bpb:1.2354 train_time:457004ms step_avg:54.41ms +step:8500/20000 train_loss:3.2010 train_time:462388ms step_avg:54.40ms +step:8600/20000 train_loss:3.1915 train_time:467854ms step_avg:54.40ms +step:8600/20000 val_loss:3.1770 val_bpb:1.2299 train_time:467875ms step_avg:54.40ms +step:8700/20000 train_loss:3.2179 train_time:473261ms step_avg:54.40ms +step:8800/20000 train_loss:3.3301 train_time:478733ms step_avg:54.40ms +step:8800/20000 val_loss:3.1779 val_bpb:1.2303 train_time:478754ms step_avg:54.40ms +step:8900/20000 train_loss:3.2720 train_time:484138ms step_avg:54.40ms +step:9000/20000 train_loss:3.1970 train_time:489601ms step_avg:54.40ms +step:9000/20000 val_loss:3.1757 val_bpb:1.2294 train_time:489622ms step_avg:54.40ms +step:9100/20000 train_loss:3.1133 train_time:495005ms step_avg:54.40ms +step:9200/20000 train_loss:3.1765 train_time:500479ms step_avg:54.40ms +step:9200/20000 val_loss:3.1793 val_bpb:1.2308 train_time:500500ms step_avg:54.40ms +step:9300/20000 train_loss:3.2259 train_time:505886ms step_avg:54.40ms +step:9400/20000 train_loss:3.1610 train_time:511349ms step_avg:54.40ms +step:9400/20000 val_loss:3.1710 val_bpb:1.2276 train_time:511370ms step_avg:54.40ms +step:9500/20000 train_loss:3.0916 train_time:516755ms step_avg:54.40ms +step:9600/20000 train_loss:3.1341 train_time:522226ms step_avg:54.40ms +step:9600/20000 val_loss:3.1704 val_bpb:1.2274 train_time:522247ms step_avg:54.40ms +step:9700/20000 train_loss:3.3315 train_time:527630ms step_avg:54.39ms +step:9800/20000 train_loss:3.1174 train_time:533092ms step_avg:54.40ms +step:9800/20000 val_loss:3.1705 val_bpb:1.2274 train_time:533114ms step_avg:54.40ms +step:9900/20000 train_loss:3.1768 train_time:538498ms step_avg:54.39ms +step:10000/20000 train_loss:3.1855 train_time:543973ms step_avg:54.40ms +step:10000/20000 val_loss:3.1611 val_bpb:1.2238 train_time:543994ms step_avg:54.40ms +step:10100/20000 train_loss:3.1713 train_time:549445ms step_avg:54.40ms +step:10200/20000 train_loss:3.2169 train_time:554948ms step_avg:54.41ms +step:10200/20000 val_loss:3.1503 val_bpb:1.2196 train_time:554969ms step_avg:54.41ms +step:10300/20000 train_loss:3.3841 train_time:560420ms step_avg:54.41ms +step:10400/20000 train_loss:3.1940 train_time:565821ms step_avg:54.41ms +step:10400/20000 val_loss:3.1376 val_bpb:1.2147 train_time:565843ms step_avg:54.41ms +step:10500/20000 train_loss:3.3124 train_time:571286ms step_avg:54.41ms +step:10600/20000 train_loss:3.1063 train_time:576682ms step_avg:54.40ms +step:10600/20000 val_loss:3.1247 val_bpb:1.2097 train_time:576704ms step_avg:54.41ms +step:10700/20000 train_loss:3.1608 train_time:582151ms step_avg:54.41ms +step:10800/20000 train_loss:3.2238 train_time:587552ms step_avg:54.40ms +step:10800/20000 val_loss:3.1118 val_bpb:1.2047 train_time:587573ms step_avg:54.40ms +step:10900/20000 train_loss:3.1988 train_time:593008ms step_avg:54.40ms +step:11000/20000 train_loss:3.1866 train_time:598406ms step_avg:54.40ms +step:11000/20000 val_loss:3.1028 val_bpb:1.2012 train_time:598427ms step_avg:54.40ms +step:11030/20000 val_loss:3.1023 val_bpb:1.2010 train_time:600046ms step_avg:54.40ms +stopping_early: wallclock_cap train_time:600046ms step:11030/20000 +peak memory allocated: 10490 MiB reserved: 11622 MiB +Serialized model: 58152343 bytes +Code size: 73221 bytes +Total submission size: 58225564 bytes +Serialized model int8+zlib: 15340947 bytes (payload:16483504 raw_torch:16528409 payload_ratio:3.53x) +Total submission size int8+zlib: 15414168 bytes +final_int8_zlib_roundtrip val_loss:3.1218 val_bpb:1.2085 eval_time:1091ms +final_int8_zlib_roundtrip_exact val_loss:3.12179569 val_bpb:1.20854439 +ttt:start chunks=1238 chunk_tokens=32768 windows=39590 ttt_lr=0.005 ttt_epochs=3 freeze_blocks=0 +ttt:params unfrozen=16364616 frozen=0 + ttt_chunk [1/1238] bpb=1.210819 time=0.3s + ttt_chunk [11/1238] bpb=1.183445 time=1.4s + ttt_chunk [21/1238] bpb=1.233162 time=2.5s + ttt_chunk [31/1238] bpb=1.230954 time=3.6s + ttt_chunk [41/1238] bpb=1.224879 time=4.7s + ttt_chunk [51/1238] bpb=1.219433 time=5.9s + ttt_chunk [61/1238] bpb=1.212036 time=7.0s + ttt_chunk [71/1238] bpb=1.217920 time=8.2s + ttt_chunk [81/1238] bpb=1.212742 time=9.3s + ttt_chunk [91/1238] bpb=1.208864 time=10.4s + ttt_chunk [101/1238] bpb=1.208079 time=11.4s + ttt_chunk [111/1238] bpb=1.206968 time=12.5s + ttt_chunk [121/1238] bpb=1.209904 time=13.6s + ttt_chunk [131/1238] bpb=1.213614 time=14.7s + ttt_chunk [141/1238] bpb=1.213821 time=15.8s + ttt_chunk [151/1238] bpb=1.213630 time=16.9s + ttt_chunk [161/1238] bpb=1.214049 time=18.0s + ttt_chunk [171/1238] bpb=1.213960 time=19.1s + ttt_chunk [181/1238] bpb=1.212196 time=20.2s + ttt_chunk [191/1238] bpb=1.211686 time=21.3s + ttt_chunk [201/1238] bpb=1.208935 time=22.4s + ttt_chunk [211/1238] bpb=1.213224 time=23.5s + ttt_chunk [221/1238] bpb=1.212845 time=24.6s + ttt_chunk [231/1238] bpb=1.214621 time=25.8s + ttt_chunk [241/1238] bpb=1.214354 time=26.9s + ttt_chunk [251/1238] bpb=1.214464 time=28.0s + ttt_chunk [261/1238] bpb=1.214695 time=29.2s + ttt_chunk [271/1238] bpb=1.214841 time=30.3s + ttt_chunk [281/1238] bpb=1.213948 time=31.4s + ttt_chunk [291/1238] bpb=1.214915 time=32.5s + ttt_chunk [301/1238] bpb=1.214754 time=33.6s + ttt_chunk [311/1238] bpb=1.213150 time=34.8s + ttt_chunk [321/1238] bpb=1.213015 time=35.8s + ttt_chunk [331/1238] bpb=1.213417 time=36.9s + ttt_chunk [341/1238] bpb=1.212583 time=38.1s + ttt_chunk [351/1238] bpb=1.213327 time=39.2s + ttt_chunk [361/1238] bpb=1.212198 time=40.3s + ttt_chunk [371/1238] bpb=1.210693 time=41.4s + ttt_chunk [381/1238] bpb=1.210881 time=42.5s + ttt_chunk [391/1238] bpb=1.210383 time=43.7s + ttt_chunk [401/1238] bpb=1.209970 time=44.7s + ttt_chunk [411/1238] bpb=1.210587 time=45.8s + ttt_chunk [421/1238] bpb=1.209755 time=47.0s + ttt_chunk [431/1238] bpb=1.209838 time=48.1s + ttt_chunk [441/1238] bpb=1.210029 time=49.2s + ttt_chunk [451/1238] bpb=1.211321 time=50.4s + ttt_chunk [461/1238] bpb=1.209507 time=51.5s + ttt_chunk [471/1238] bpb=1.209426 time=52.6s + ttt_chunk [481/1238] bpb=1.209486 time=53.6s + ttt_chunk [491/1238] bpb=1.209829 time=54.8s + ttt_chunk [501/1238] bpb=1.209298 time=55.9s + ttt_chunk [511/1238] bpb=1.209098 time=57.0s + ttt_chunk [521/1238] bpb=1.208703 time=58.1s + ttt_chunk [531/1238] bpb=1.208627 time=59.2s + ttt_chunk [541/1238] bpb=1.208709 time=60.3s + ttt_chunk [551/1238] bpb=1.208345 time=61.5s + ttt_chunk [561/1238] bpb=1.207947 time=62.6s + ttt_chunk [571/1238] bpb=1.207253 time=63.7s + ttt_chunk [581/1238] bpb=1.207468 time=64.8s + ttt_chunk [591/1238] bpb=1.207642 time=65.9s + ttt_chunk [601/1238] bpb=1.207621 time=67.1s + ttt_chunk [611/1238] bpb=1.208172 time=68.2s + ttt_chunk [621/1238] bpb=1.209051 time=69.3s + ttt_chunk [631/1238] bpb=1.208977 time=70.4s + ttt_chunk [641/1238] bpb=1.209297 time=71.5s + ttt_chunk [651/1238] bpb=1.209591 time=72.6s + ttt_chunk [661/1238] bpb=1.208884 time=73.7s + ttt_chunk [671/1238] bpb=1.208413 time=74.8s + ttt_chunk [681/1238] bpb=1.209755 time=75.9s + ttt_chunk [691/1238] bpb=1.209821 time=77.1s + ttt_chunk [701/1238] bpb=1.209553 time=78.2s + ttt_chunk [711/1238] bpb=1.210331 time=79.3s + ttt_chunk [721/1238] bpb=1.210629 time=80.4s + ttt_chunk [731/1238] bpb=1.209954 time=81.5s + ttt_chunk [741/1238] bpb=1.209835 time=82.6s + ttt_chunk [751/1238] bpb=1.208895 time=83.7s + ttt_chunk [761/1238] bpb=1.207971 time=84.8s + ttt_chunk [771/1238] bpb=1.206830 time=85.9s + ttt_chunk [781/1238] bpb=1.206668 time=87.1s + ttt_chunk [791/1238] bpb=1.207100 time=88.2s + ttt_chunk [801/1238] bpb=1.207544 time=89.3s + ttt_chunk [811/1238] bpb=1.206938 time=90.4s + ttt_chunk [821/1238] bpb=1.205941 time=91.5s + ttt_chunk [831/1238] bpb=1.205660 time=92.6s + ttt_chunk [841/1238] bpb=1.205264 time=93.7s + ttt_chunk [851/1238] bpb=1.204758 time=94.8s + ttt_chunk [861/1238] bpb=1.204335 time=95.9s + ttt_chunk [871/1238] bpb=1.204032 time=97.0s + ttt_chunk [881/1238] bpb=1.203622 time=98.1s + ttt_chunk [891/1238] bpb=1.203002 time=99.2s + ttt_chunk [901/1238] bpb=1.203450 time=100.3s + ttt_chunk [911/1238] bpb=1.203271 time=101.4s + ttt_chunk [921/1238] bpb=1.203639 time=102.5s + ttt_chunk [931/1238] bpb=1.204414 time=103.7s + ttt_chunk [941/1238] bpb=1.204878 time=104.8s + ttt_chunk [951/1238] bpb=1.204918 time=105.9s + ttt_chunk [961/1238] bpb=1.205754 time=107.0s + ttt_chunk [971/1238] bpb=1.206309 time=108.1s + ttt_chunk [981/1238] bpb=1.206714 time=109.2s + ttt_chunk [991/1238] bpb=1.206541 time=110.3s + ttt_chunk [1001/1238] bpb=1.206827 time=111.3s + ttt_chunk [1011/1238] bpb=1.207203 time=112.5s + ttt_chunk [1021/1238] bpb=1.207889 time=113.6s + ttt_chunk [1031/1238] bpb=1.208448 time=114.7s + ttt_chunk [1041/1238] bpb=1.208842 time=115.8s + ttt_chunk [1051/1238] bpb=1.208656 time=116.9s + ttt_chunk [1061/1238] bpb=1.208743 time=118.0s + ttt_chunk [1071/1238] bpb=1.208752 time=119.1s + ttt_chunk [1081/1238] bpb=1.208575 time=120.1s + ttt_chunk [1091/1238] bpb=1.208638 time=121.3s + ttt_chunk [1101/1238] bpb=1.209126 time=122.4s + ttt_chunk [1111/1238] bpb=1.209403 time=123.5s + ttt_chunk [1121/1238] bpb=1.209539 time=124.6s + ttt_chunk [1131/1238] bpb=1.209098 time=125.7s + ttt_chunk [1141/1238] bpb=1.208701 time=126.8s + ttt_chunk [1151/1238] bpb=1.208727 time=127.9s + ttt_chunk [1161/1238] bpb=1.208870 time=129.0s + ttt_chunk [1171/1238] bpb=1.208562 time=130.2s + ttt_chunk [1181/1238] bpb=1.208056 time=131.3s + ttt_chunk [1191/1238] bpb=1.208166 time=132.4s + ttt_chunk [1201/1238] bpb=1.208296 time=133.5s + ttt_chunk [1211/1238] bpb=1.208066 time=134.6s + ttt_chunk [1221/1238] bpb=1.207489 time=135.7s + ttt_chunk [1231/1238] bpb=1.207166 time=136.8s + ttt_chunk [1238/1238] bpb=1.207117 time=137.4s +ttt:done val_loss=3.116912 val_bpb=1.206654 elapsed=137.5s +final_int8_ttt val_loss:3.1169 val_bpb:1.2067 eval_time:137492ms +final_int8_ttt_exact val_loss:3.11691247 val_bpb:1.20665395 diff --git a/records/track_10min_16mb/2026-04-24_SP8192_HeadwiseGate_LeakyReLU2_LegalTTT/train_seed2025.log b/records/track_10min_16mb/2026-04-24_SP8192_HeadwiseGate_LeakyReLU2_LegalTTT/train_seed2025.log new file mode 100644 index 0000000000..1cde783c28 --- /dev/null +++ b/records/track_10min_16mb/2026-04-24_SP8192_HeadwiseGate_LeakyReLU2_LegalTTT/train_seed2025.log @@ -0,0 +1,1929 @@ +""" +The `train_gpt.py` and `train_gpt_mlx.py` scripts are intended as good launching-off points for new participants, not SOTA configs. We'll accept PRs that tune, improve, or simplify these scripts without significantly increasing complexity, but competitive submissions should stay in the `/records` folder. + +Hard stop: To keep readable for newcomers, let's make sure `train_gpt.py` and `train_gpt_mlx.py` never are longer than 1500 lines. +""" + +from __future__ import annotations + +import copy +import glob +import io +import math +import os +import random +import subprocess +import sys +import time +import uuid +import zlib +from pathlib import Path + +import numpy as np +import sentencepiece as spm +import torch +import torch.distributed as dist +import torch.nn.functional as F +from torch import Tensor, nn +from torch.nn.parallel import DistributedDataParallel as DDP + +# ----------------------------- +# HYPERPARAMETERS +# ----------------------------- +# Default Simple Baseline run: +# - 9 transformer blocks at width 512 +# - 8 attention heads with 4 KV heads (GQA) and 2x MLP expansion +# - vocab size 1024, sequence length 1024, tied embeddings +# - 524,288 train tokens per step for 20,000 iterations with a ~10 minute cap + +class Hyperparameters: + # Data paths are shard globs produced by the existing preprocessing pipeline. + data_path = os.environ.get("DATA_PATH", "./data/datasets/fineweb10B_sp1024") + train_files = os.path.join(data_path, "fineweb_train_*.bin") + val_files = os.path.join(data_path, "fineweb_val_*.bin") + tokenizer_path = os.environ.get("TOKENIZER_PATH", "./data/tokenizers/fineweb_1024_bpe.model") + run_id = os.environ.get("RUN_ID", str(uuid.uuid4())) + seed = int(os.environ.get("SEED", 1337)) + + # Validation cadence and batch size. Validation always uses the full fineweb_val split. + val_batch_size = int(os.environ.get("VAL_BATCH_SIZE", 524_288)) + val_loss_every = int(os.environ.get("VAL_LOSS_EVERY", 1000)) + train_log_every = int(os.environ.get("TRAIN_LOG_EVERY", 200)) + + # Training length. + iterations = int(os.environ.get("ITERATIONS", 20000)) + warmdown_iters = int(os.environ.get("WARMDOWN_ITERS", 1200)) + warmup_steps = int(os.environ.get("WARMUP_STEPS", 20)) + train_batch_tokens = int(os.environ.get("TRAIN_BATCH_TOKENS", 524_288)) + train_seq_len = int(os.environ.get("TRAIN_SEQ_LEN", 1024)) + max_wallclock_seconds = float(os.environ.get("MAX_WALLCLOCK_SECONDS", 600.0)) + qk_gain_init = float(os.environ.get("QK_GAIN_INIT", 1.5)) + + # Model shape. + vocab_size = int(os.environ.get("VOCAB_SIZE", 1024)) + num_layers = int(os.environ.get("NUM_LAYERS", 9)) + num_kv_heads = int(os.environ.get("NUM_KV_HEADS", 4)) + model_dim = int(os.environ.get("MODEL_DIM", 512)) + num_heads = int(os.environ.get("NUM_HEADS", 8)) + mlp_mult = int(os.environ.get("MLP_MULT", 2)) + tie_embeddings = bool(int(os.environ.get("TIE_EMBEDDINGS", "1"))) + rope_base = float(os.environ.get("ROPE_BASE", 10000.0)) + logit_softcap = float(os.environ.get("LOGIT_SOFTCAP", 30.0)) + # Gated Attention (https://arxiv.org/abs/2505.06708, NeurIPS 2025 Best Paper) + # "none" = disabled (baseline), "headwise" = 1 gate per head, "elementwise" = 1 gate per dim + gated_attn = os.environ.get("GATED_ATTN", "none") + + # Optimizer hyperparameters. + embed_lr = float(os.environ.get("EMBED_LR", 0.6)) + head_lr = float(os.environ.get("HEAD_LR", 0.008)) + tied_embed_lr = float(os.environ.get("TIED_EMBED_LR", 0.05)) + tied_embed_init_std = float(os.environ.get("TIED_EMBED_INIT_STD", 0.005)) + matrix_lr = float(os.environ.get("MATRIX_LR", 0.04)) + scalar_lr = float(os.environ.get("SCALAR_LR", 0.04)) + muon_momentum = float(os.environ.get("MUON_MOMENTUM", 0.95)) + muon_backend_steps = int(os.environ.get("MUON_BACKEND_STEPS", 5)) + muon_momentum_warmup_start = float(os.environ.get("MUON_MOMENTUM_WARMUP_START", 0.85)) + muon_momentum_warmup_steps = int(os.environ.get("MUON_MOMENTUM_WARMUP_STEPS", 500)) + beta1 = float(os.environ.get("BETA1", 0.9)) + beta2 = float(os.environ.get("BETA2", 0.95)) + adam_eps = float(os.environ.get("ADAM_EPS", 1e-8)) + grad_clip_norm = float(os.environ.get("GRAD_CLIP_NORM", 0.0)) + + # Activation: "relu2" (default), "leaky_relu2" = LeakyReLU(0.5)² (PG ranks 10-11) + activation = os.environ.get("ACTIVATION", "relu2") + + # Test-Time Training (Score-First TTT, PG ranks 1-3) + # "none" = disabled (baseline), "score_first" = legal score-before-update TTT + ttt_mode = os.environ.get("TTT_MODE", "none") + ttt_lr = float(os.environ.get("TTT_LR", "0.005")) + ttt_momentum = float(os.environ.get("TTT_MOMENTUM", "0.9")) + ttt_epochs = int(os.environ.get("TTT_EPOCHS", "3")) + ttt_chunk_tokens = int(os.environ.get("TTT_CHUNK_TOKENS", "32768")) + ttt_freeze_blocks = int(os.environ.get("TTT_FREEZE_BLOCKS", "0")) + ttt_grad_clip = float(os.environ.get("TTT_GRAD_CLIP", "1.0")) + ttt_batch_seqs = int(os.environ.get("TTT_BATCH_SEQS", "32")) + +# ----------------------------- +# MUON OPTIMIZER +# ----------------------------- +# +# As borrowed from modded-nanogpt +# Background on Muon: https://kellerjordan.github.io/posts/muon/ + +def zeropower_via_newtonschulz5(G: Tensor, steps: int = 10, eps: float = 1e-7) -> Tensor: + # Orthogonalize a 2D update matrix with a fast Newton-Schulz iteration. + # Muon uses this to normalize matrix-shaped gradients before applying them. + a, b, c = (3.4445, -4.7750, 2.0315) + X = G.bfloat16() + X /= X.norm() + eps + transposed = G.size(0) > G.size(1) + if transposed: + X = X.T + for _ in range(steps): + A = X @ X.T + B = b * A + c * A @ A + X = a * X + B @ X + return X.T if transposed else X + + +class Muon(torch.optim.Optimizer): + def __init__(self, params, lr: float, momentum: float, backend_steps: int, nesterov: bool = True): + super().__init__( + params, + dict(lr=lr, momentum=momentum, backend_steps=backend_steps, nesterov=nesterov), + ) + + @torch.no_grad() + def step(self, closure=None): + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + + distributed = dist.is_available() and dist.is_initialized() + world_size = dist.get_world_size() if distributed else 1 + rank = dist.get_rank() if distributed else 0 + + for group in self.param_groups: + params = group["params"] + if not params: + continue + lr = group["lr"] + momentum = group["momentum"] + backend_steps = group["backend_steps"] + nesterov = group["nesterov"] + + total_params = sum(int(p.numel()) for p in params) + updates_flat = torch.zeros(total_params, device=params[0].device, dtype=torch.bfloat16) + + curr = 0 + for i, p in enumerate(params): + if i % world_size == rank and p.grad is not None: + g = p.grad + state = self.state[p] + if "momentum_buffer" not in state: + state["momentum_buffer"] = torch.zeros_like(g) + buf = state["momentum_buffer"] + buf.mul_(momentum).add_(g) + if nesterov: + g = g.add(buf, alpha=momentum) + g = zeropower_via_newtonschulz5(g, steps=backend_steps) + # Scale correction from Muon reference implementations. + g *= max(1, g.size(0) / g.size(1)) ** 0.5 + updates_flat[curr : curr + p.numel()] = g.reshape(-1) + curr += p.numel() + + if distributed: + dist.all_reduce(updates_flat, op=dist.ReduceOp.SUM) + + curr = 0 + for p in params: + g = updates_flat[curr : curr + p.numel()].view_as(p).to(dtype=p.dtype) + p.add_(g, alpha=-lr) + curr += p.numel() + + return loss + + +# ----------------------------- +# TOKENIZER-AGNOSTIC EVALUATION SETUP +# ----------------------------- +# +# It's common for small models have a large fraction of their parameters be embeddings, since the 2 * d_model * d_vocab vectors can be gigantic. +# Instead of locking the tokenizer, we let you bring your own and calculate our validation metrics on the average compression of the validation set. +# We calculate BPB (bits-per-byte) instead of validation loss, so we need methods to count the number of bits per token in the tokenizer. +# Note: Submissions that edit the tokenizer will be examined more carefully, since screwing this up might unjustly improve your score. + +def build_sentencepiece_luts( + sp: spm.SentencePieceProcessor, vocab_size: int, device: torch.device +) -> tuple[Tensor, Tensor, Tensor]: + sp_vocab_size = int(sp.vocab_size()) + table_size = max(sp_vocab_size, vocab_size) + base_bytes_np = np.zeros((table_size,), dtype=np.int16) + has_leading_space_np = np.zeros((table_size,), dtype=np.bool_) + is_boundary_token_np = np.ones((table_size,), dtype=np.bool_) + for token_id in range(sp_vocab_size): + if sp.is_control(token_id) or sp.is_unknown(token_id) or sp.is_unused(token_id): + continue + is_boundary_token_np[token_id] = False + if sp.is_byte(token_id): + base_bytes_np[token_id] = 1 + continue + piece = sp.id_to_piece(token_id) + if piece.startswith("▁"): + has_leading_space_np[token_id] = True + piece = piece[1:] + base_bytes_np[token_id] = len(piece.encode("utf-8")) + return ( + torch.tensor(base_bytes_np, dtype=torch.int16, device=device), + torch.tensor(has_leading_space_np, dtype=torch.bool, device=device), + torch.tensor(is_boundary_token_np, dtype=torch.bool, device=device), + ) + + +def load_validation_tokens(pattern: str, seq_len: int) -> Tensor: + files = [Path(p) for p in sorted(glob.glob(pattern))] + if not files: + raise FileNotFoundError(f"No files found for pattern: {pattern}") + # The export pipeline writes the fixed first-50k-doc validation set to fineweb_val_*. + tokens = torch.cat([load_data_shard(file) for file in files]).contiguous() + usable = ((tokens.numel() - 1) // seq_len) * seq_len + if usable <= 0: + raise ValueError(f"Validation split is too short for TRAIN_SEQ_LEN={seq_len}") + return tokens[: usable + 1] + + +def eval_val( + args: Hyperparameters, + model: nn.Module, + rank: int, + world_size: int, + device: torch.device, + grad_accum_steps: int, + val_tokens: Tensor, + base_bytes_lut: Tensor, + has_leading_space_lut: Tensor, + is_boundary_token_lut: Tensor, +) -> tuple[float, float]: + # Validation computes two metrics: + # - val_loss: token cross-entropy (natural log) + # - val_bpb: tokenizer-agnostic compression metric used by the challenge + local_batch_tokens = args.val_batch_size // (world_size * grad_accum_steps) + if local_batch_tokens < args.train_seq_len: + raise ValueError( + "VAL_BATCH_SIZE must provide at least one sequence per rank; " + f"got VAL_BATCH_SIZE={args.val_batch_size}, WORLD_SIZE={world_size}, " + f"GRAD_ACCUM_STEPS={grad_accum_steps}, TRAIN_SEQ_LEN={args.train_seq_len}" + ) + local_batch_seqs = local_batch_tokens // args.train_seq_len + total_seqs = (val_tokens.numel() - 1) // args.train_seq_len + seq_start = (total_seqs * rank) // world_size + seq_end = (total_seqs * (rank + 1)) // world_size + val_loss_sum = torch.zeros((), device=device, dtype=torch.float64) + val_token_count = torch.zeros((), device=device, dtype=torch.float64) + val_byte_count = torch.zeros((), device=device, dtype=torch.float64) + + model.eval() + with torch.inference_mode(): + for batch_seq_start in range(seq_start, seq_end, local_batch_seqs): + batch_seq_end = min(batch_seq_start + local_batch_seqs, seq_end) + raw_start = batch_seq_start * args.train_seq_len + raw_end = batch_seq_end * args.train_seq_len + 1 + local = val_tokens[raw_start:raw_end].to(device=device, dtype=torch.int64, non_blocking=True) + x = local[:-1].reshape(-1, args.train_seq_len) + y = local[1:].reshape(-1, args.train_seq_len) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + batch_loss = model(x, y).detach() + batch_token_count = float(y.numel()) + val_loss_sum += batch_loss.to(torch.float64) * batch_token_count + val_token_count += batch_token_count + prev_ids = x.reshape(-1) + tgt_ids = y.reshape(-1) + token_bytes = base_bytes_lut[tgt_ids].to(dtype=torch.int16) + token_bytes += (has_leading_space_lut[tgt_ids] & ~is_boundary_token_lut[prev_ids]).to(dtype=torch.int16) + val_byte_count += token_bytes.to(torch.float64).sum() + + if dist.is_available() and dist.is_initialized(): + dist.all_reduce(val_loss_sum, op=dist.ReduceOp.SUM) + dist.all_reduce(val_token_count, op=dist.ReduceOp.SUM) + dist.all_reduce(val_byte_count, op=dist.ReduceOp.SUM) + + val_loss = val_loss_sum / val_token_count + bits_per_token = val_loss.item() / math.log(2.0) + tokens_per_byte = val_token_count.item() / val_byte_count.item() + model.train() + return float(val_loss.item()), float(bits_per_token * tokens_per_byte) + + +# ----------------------------- +# SCORE-FIRST TTT (Legal Test-Time Training) +# ----------------------------- +# Protocol (PR #461 recipe, PG ranks 1-3): +# 1. Chunk validation data into ttt_chunk_tokens windows +# 2. For each chunk: SCORE all tokens under inference_mode(), then TRAIN on scored data +# 3. Last chunk: score only (never train on un-scored data) +# Guarantees: every token scored BEFORE any parameter update that could use it. + +def eval_val_score_first_ttt( + args: Hyperparameters, + base_model: nn.Module, + rank: int, + world_size: int, + device: torch.device, + val_tokens: Tensor, + base_bytes_lut: Tensor, + has_leading_space_lut: Tensor, + is_boundary_token_lut: Tensor, + log0=print, +) -> tuple[float, float]: + seq_len = args.train_seq_len + total_tokens = val_tokens.numel() - 1 + stride = seq_len # non-overlapping windows for simplicity + ttt_chunk = args.ttt_chunk_tokens + + # Pre-compute window starts + window_starts = [ws for ws in range(0, total_tokens, stride) + if min(ws + seq_len, total_tokens) - ws >= 1] + + # Assign windows to chunks + num_chunks = (total_tokens + ttt_chunk - 1) // ttt_chunk + chunk_windows: list[list[int]] = [[] for _ in range(num_chunks)] + for ws in window_starts: + end = min(ws + seq_len, total_tokens) + ci = min(ws // ttt_chunk, num_chunks - 1) + chunk_windows[ci].append(ws) + + log0(f"ttt:start chunks={num_chunks} chunk_tokens={ttt_chunk} " + f"windows={len(window_starts)} ttt_lr={args.ttt_lr} " + f"ttt_epochs={args.ttt_epochs} freeze_blocks={args.ttt_freeze_blocks}") + + loss_sum = torch.zeros((), device=device, dtype=torch.float64) + token_count = torch.zeros((), device=device, dtype=torch.float64) + byte_count = torch.zeros((), device=device, dtype=torch.float64) + + # Freeze first N blocks + frozen_block_ids = set(range(min(args.ttt_freeze_blocks, len(base_model.blocks)))) + ttt_params = [] + for name, p in base_model.named_parameters(): + freeze = any(f"blocks.{bi}." in name for bi in frozen_block_ids) + if freeze: + p.requires_grad_(False) + else: + p.requires_grad_(True) + ttt_params.append(p) + + log0(f"ttt:params unfrozen={sum(p.numel() for p in ttt_params)} " + f"frozen={sum(p.numel() for p in base_model.parameters() if not p.requires_grad)}") + + optimizer = torch.optim.SGD(ttt_params, lr=args.ttt_lr, momentum=args.ttt_momentum) + t0 = time.perf_counter() + + for ci in range(num_chunks): + windows = chunk_windows[ci] + if not windows: + continue + + # --- PHASE 1: SCORE (no gradients) --- + my_s = (len(windows) * rank) // world_size + my_e = (len(windows) * (rank + 1)) // world_size + my_windows = windows[my_s:my_e] + + base_model.eval() + with torch.inference_mode(): + for bi in range(0, len(my_windows), args.ttt_batch_seqs): + batch_ws = my_windows[bi:bi + args.ttt_batch_seqs] + bsz = len(batch_ws) + x_batch = torch.zeros(bsz, seq_len, dtype=torch.int64, device=device) + y_batch = torch.zeros(bsz, seq_len, dtype=torch.int64, device=device) + wlens: list[int] = [] + + for i, ws in enumerate(batch_ws): + end = min(ws + seq_len, total_tokens) + wlen = end - ws + wlens.append(wlen) + chunk_tok = val_tokens[ws:end + 1].to(dtype=torch.int64, device=device) + x_batch[i, :wlen] = chunk_tok[:-1] + y_batch[i, :wlen] = chunk_tok[1:] + + with torch.autocast(device_type="cuda", dtype=torch.bfloat16): + logits = base_model.forward_logits(x_batch) + + nll = F.cross_entropy( + logits.reshape(-1, logits.size(-1)).float(), + y_batch.reshape(-1), reduction="none", + ).reshape(bsz, seq_len) + + for i in range(bsz): + wlen = wlens[i] + scored_nll = nll[i, :wlen].to(torch.float64) + loss_sum += scored_nll.sum() + token_count += float(wlen) + tgt = y_batch[i, :wlen] + prev = x_batch[i, :wlen] + tb = base_bytes_lut[tgt].to(torch.float64) + tb += (has_leading_space_lut[tgt] & ~is_boundary_token_lut[prev]).to(torch.float64) + byte_count += tb.sum() + + # --- PHASE 2: TRAIN on scored chunk (skip last chunk) --- + is_last_chunk = (ci == num_chunks - 1) + if not is_last_chunk and args.ttt_epochs > 0: + base_model.train() + chunk_start = ci * ttt_chunk + chunk_end = min((ci + 1) * ttt_chunk, total_tokens) + chunk_seqs = (chunk_end - chunk_start) // seq_len + if chunk_seqs > 0: + # Cosine LR decay across chunks + cos_lr = args.ttt_lr * 0.5 * (1.0 + math.cos(math.pi * ci / max(num_chunks - 1, 1))) + for pg in optimizer.param_groups: + pg['lr'] = cos_lr + + my_seq_s = (chunk_seqs * rank) // world_size + my_seq_e = (chunk_seqs * (rank + 1)) // world_size + + for _ep in range(args.ttt_epochs): + for bs in range(my_seq_s, my_seq_e, args.ttt_batch_seqs): + be = min(bs + args.ttt_batch_seqs, my_seq_e) + start_tok = chunk_start + bs * seq_len + end_tok = chunk_start + be * seq_len + 1 + if end_tok > val_tokens.numel(): + continue + local = val_tokens[start_tok:end_tok].to(device=device, dtype=torch.int64) + x = local[:-1].reshape(-1, seq_len) + y = local[1:].reshape(-1, seq_len) + optimizer.zero_grad(set_to_none=True) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16): + loss = base_model(x, y) + loss.backward() + if world_size > 1: + for p in ttt_params: + if p.grad is not None: + dist.all_reduce(p.grad, op=dist.ReduceOp.AVG) + torch.nn.utils.clip_grad_norm_(ttt_params, args.ttt_grad_clip) + optimizer.step() + + if rank == 0 and (ci % 10 == 0 or ci == num_chunks - 1): + elapsed = time.perf_counter() - t0 + rl = loss_sum.item() / max(token_count.item(), 1) + rbpb = rl / math.log(2.0) * (token_count.item() / max(byte_count.item(), 1)) if token_count.item() > 0 else 0.0 + log0(f" ttt_chunk [{ci+1}/{num_chunks}] bpb={rbpb:.6f} time={elapsed:.1f}s") + + # AllReduce across ranks + if dist.is_available() and dist.is_initialized(): + dist.all_reduce(loss_sum, op=dist.ReduceOp.SUM) + dist.all_reduce(token_count, op=dist.ReduceOp.SUM) + dist.all_reduce(byte_count, op=dist.ReduceOp.SUM) + + val_loss = (loss_sum / token_count).item() + val_bpb = val_loss / math.log(2.0) * (token_count.item() / byte_count.item()) + + # Restore grad for all params + for p in base_model.parameters(): + p.requires_grad_(True) + base_model.eval() + + log0(f"ttt:done val_loss={val_loss:.6f} val_bpb={val_bpb:.6f} " + f"elapsed={time.perf_counter() - t0:.1f}s") + return val_loss, val_bpb + + +# ----------------------------- +# POST-TRAINING QUANTIZATION +# ----------------------------- +# +# It's silly to export our model, which is trained in bf16 and fp32, at that same precision. +# Instead, we get approximately the same model (with a small hit) by quantizing the model to int8 & zlib compressing. +# We can then decompress the model and run in higher precision for evaluation, after closing in under the size limit. + +CONTROL_TENSOR_NAME_PATTERNS = tuple( + pattern + for pattern in os.environ.get( + "CONTROL_TENSOR_NAME_PATTERNS", + "attn_scale,attn_scales,mlp_scale,mlp_scales,resid_mix,resid_mixes,q_gain,skip_weight,skip_weights", + ).split(",") + if pattern +) +INT8_KEEP_FLOAT_FP32_NAME_PATTERNS = tuple( + pattern + for pattern in os.environ.get( + "INT8_KEEP_FLOAT_FP32_NAME_PATTERNS", + ",".join(CONTROL_TENSOR_NAME_PATTERNS), + ).split(",") + if pattern +) +INT8_KEEP_FLOAT_MAX_NUMEL = 65_536 +INT8_KEEP_FLOAT_STORE_DTYPE = torch.float16 +INT8_PER_ROW_SCALE_DTYPE = torch.float16 +INT8_CLIP_PERCENTILE = 99.99984 +INT8_CLIP_Q = INT8_CLIP_PERCENTILE / 100.0 + +def tensor_nbytes(t: Tensor) -> int: + return int(t.numel()) * int(t.element_size()) + +def keep_float_tensor(name: str, t: Tensor, passthrough_orig_dtypes: dict[str, str]) -> Tensor: + if any(pattern in name for pattern in INT8_KEEP_FLOAT_FP32_NAME_PATTERNS): + return t.float().contiguous() + if t.dtype in {torch.float32, torch.bfloat16}: + passthrough_orig_dtypes[name] = str(t.dtype).removeprefix("torch.") + return t.to(dtype=INT8_KEEP_FLOAT_STORE_DTYPE).contiguous() + return t + +def quantize_float_tensor(t: Tensor) -> tuple[Tensor, Tensor]: + t32 = t.float() + if t32.ndim == 2: + # Matrices get one scale per row, which usually tracks output-channel + # ranges much better than a single tensor-wide scale. + clip_abs = ( + torch.quantile(t32.abs(), INT8_CLIP_Q, dim=1) + if t32.numel() + else torch.empty((t32.shape[0],), dtype=torch.float32) + ) + clipped = torch.maximum(torch.minimum(t32, clip_abs[:, None]), -clip_abs[:, None]) + scale = (clip_abs / 127.0).clamp_min(1.0 / 127.0) + q = torch.clamp(torch.round(clipped / scale[:, None]), -127, 127).to(torch.int8).contiguous() + return q, scale.to(dtype=INT8_PER_ROW_SCALE_DTYPE).contiguous() + + # Vectors / scalars use a simpler per-tensor scale. + clip_abs = float(torch.quantile(t32.abs().flatten(), INT8_CLIP_Q).item()) if t32.numel() else 0.0 + scale = torch.tensor(clip_abs / 127.0 if clip_abs > 0 else 1.0, dtype=torch.float32) + q = torch.clamp(torch.round(torch.clamp(t32, -clip_abs, clip_abs) / scale), -127, 127).to(torch.int8).contiguous() + return q, scale + +def quantize_state_dict_int8(state_dict: dict[str, Tensor]): + # Single supported clean-script export format: + # - per-row int8 for 2D float tensors + # - per-tensor int8 for other float tensors + # - exact passthrough for non-floats + # - passthrough for small float tensors, stored as fp16 to save bytes + quantized: dict[str, Tensor] = {} + scales: dict[str, Tensor] = {} + dtypes: dict[str, str] = {} + passthrough: dict[str, Tensor] = {} + passthrough_orig_dtypes: dict[str, str] = {} + qmeta: dict[str, dict[str, object]] = {} + stats = dict.fromkeys( + ("param_count", "num_tensors", "num_float_tensors", "num_nonfloat_tensors", "baseline_tensor_bytes", "int8_payload_bytes"), + 0, + ) + + for name, tensor in state_dict.items(): + t = tensor.detach().to("cpu").contiguous() + stats["param_count"] += int(t.numel()) + stats["num_tensors"] += 1 + stats["baseline_tensor_bytes"] += tensor_nbytes(t) + + if not t.is_floating_point(): + stats["num_nonfloat_tensors"] += 1 + passthrough[name] = t + stats["int8_payload_bytes"] += tensor_nbytes(t) + continue + + # Small float tensors are cheap enough to keep directly. We still downcast + # fp32/bf16 passthrough tensors to fp16 so metadata does not dominate size. + if t.numel() <= INT8_KEEP_FLOAT_MAX_NUMEL: + kept = keep_float_tensor(name, t, passthrough_orig_dtypes) + passthrough[name] = kept + stats["int8_payload_bytes"] += tensor_nbytes(kept) + continue + + stats["num_float_tensors"] += 1 + q, s = quantize_float_tensor(t) + if s.ndim > 0: + qmeta[name] = {"scheme": "per_row", "axis": 0} + quantized[name] = q + scales[name] = s + dtypes[name] = str(t.dtype).removeprefix("torch.") + stats["int8_payload_bytes"] += tensor_nbytes(q) + tensor_nbytes(s) + + obj: dict[str, object] = { + "__quant_format__": "int8_clean_per_row_v1", + "quantized": quantized, + "scales": scales, + "dtypes": dtypes, + "passthrough": passthrough, + } + if qmeta: + obj["qmeta"] = qmeta + if passthrough_orig_dtypes: + obj["passthrough_orig_dtypes"] = passthrough_orig_dtypes + return obj, stats + +def dequantize_state_dict_int8(obj: dict[str, object]) -> dict[str, Tensor]: + out: dict[str, Tensor] = {} + qmeta = obj.get("qmeta", {}) + passthrough_orig_dtypes = obj.get("passthrough_orig_dtypes", {}) + for name, q in obj["quantized"].items(): + dtype = getattr(torch, obj["dtypes"][name]) + s = obj["scales"][name] + if qmeta.get(name, {}).get("scheme") == "per_row" or s.ndim > 0: + s = s.to(dtype=torch.float32) + # Broadcast the saved row scale back across trailing dimensions. + out[name] = (q.float() * s.view(q.shape[0], *([1] * (q.ndim - 1)))).to(dtype=dtype).contiguous() + else: + scale = float(s.item()) + out[name] = (q.float() * scale).to(dtype=dtype).contiguous() + for name, t in obj["passthrough"].items(): + # Restore small tensors, undoing the temporary fp16 storage cast if needed. + out_t = t.detach().to("cpu").contiguous() + orig_dtype = passthrough_orig_dtypes.get(name) + if isinstance(orig_dtype, str): + out_t = out_t.to(dtype=getattr(torch, orig_dtype)).contiguous() + out[name] = out_t + return out + + +# ----------------------------- +# DATA LOADING +# ----------------------------- + +def load_data_shard(file: Path) -> Tensor: + header_bytes = 256 * np.dtype(" None: + self.file_idx = (self.file_idx + 1) % len(self.files) + self.tokens = load_data_shard(self.files[self.file_idx]) + self.pos = 0 + + def take(self, n: int) -> Tensor: + chunks: list[Tensor] = [] + remaining = n + while remaining > 0: + avail = self.tokens.numel() - self.pos + if avail <= 0: + self._advance_file() + continue + k = min(remaining, avail) + chunks.append(self.tokens[self.pos : self.pos + k]) + self.pos += k + remaining -= k + return chunks[0] if len(chunks) == 1 else torch.cat(chunks) + + +class DistributedTokenLoader: + # Each call consumes a contiguous chunk from the shared token stream, then slices out + # one disjoint span per rank. The extra "+1" token lets us build (x, y) by shifting. + def __init__(self, pattern: str, rank: int, world_size: int, device: torch.device): + self.rank = rank + self.world_size = world_size + self.device = device + self.stream = TokenStream(pattern) + + def next_batch(self, global_tokens: int, seq_len: int, grad_accum_steps: int) -> tuple[Tensor, Tensor]: + local_tokens = global_tokens // (self.world_size * grad_accum_steps) + per_rank_span = local_tokens + 1 + chunk = self.stream.take(per_rank_span * self.world_size) + start = self.rank * per_rank_span + local = chunk[start : start + per_rank_span].to(dtype=torch.int64) + x = local[:-1].reshape(-1, seq_len) + y = local[1:].reshape(-1, seq_len) + return x.to(self.device, non_blocking=True), y.to(self.device, non_blocking=True) + +# ----------------------------- +# TRANSFORMER MODULES +# ----------------------------- + +class RMSNorm(nn.Module): + def __init__(self, eps: float | None = None): + super().__init__() + self.eps = eps + + def forward(self, x: Tensor) -> Tensor: + return F.rms_norm(x, (x.size(-1),), eps=self.eps) + + +class CastedLinear(nn.Linear): + # Keep weights in fp32 for optimizer/state quality, cast at matmul time for bf16 compute. + def forward(self, x: Tensor) -> Tensor: + bias = self.bias.to(x.dtype) if self.bias is not None else None + return F.linear(x, self.weight.to(x.dtype), bias) + + +def restore_low_dim_params_to_fp32(module: nn.Module) -> None: + # Keep small/control parameters in fp32 even when the model body runs in bf16. + with torch.no_grad(): + for name, param in module.named_parameters(): + if (param.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS)) and param.dtype != torch.float32: + param.data = param.data.float() + + +class Rotary(nn.Module): + # Caches cos/sin tables per sequence length on the current device. + def __init__(self, dim: int, base: float = 10000.0): + super().__init__() + inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.float32) / dim)) + self.register_buffer("inv_freq", inv_freq, persistent=False) + self._seq_len_cached = 0 + self._cos_cached: Tensor | None = None + self._sin_cached: Tensor | None = None + + def forward(self, seq_len: int, device: torch.device, dtype: torch.dtype) -> tuple[Tensor, Tensor]: + if ( + self._cos_cached is None + or self._sin_cached is None + or self._seq_len_cached != seq_len + or self._cos_cached.device != device + ): + t = torch.arange(seq_len, device=device, dtype=self.inv_freq.dtype) + freqs = torch.outer(t, self.inv_freq.to(device)) + self._cos_cached = freqs.cos()[None, None, :, :] + self._sin_cached = freqs.sin()[None, None, :, :] + self._seq_len_cached = seq_len + return self._cos_cached.to(dtype=dtype), self._sin_cached.to(dtype=dtype) + + +def apply_rotary_emb(x: Tensor, cos: Tensor, sin: Tensor) -> Tensor: + half = x.size(-1) // 2 + x1, x2 = x[..., :half], x[..., half:] + return torch.cat((x1 * cos + x2 * sin, x1 * (-sin) + x2 * cos), dim=-1) + + +# ╔═══════════════════════════════════════════════════════════════════════════════════╗ +# ║ CAUSAL SELF-ATTENTION (with GQA) ║ +# ║ ║ +# ║ Plain English: Each token asks "which earlier tokens should I pay attention to?" ║ +# ║ It computes a compatibility score (Q·K) with every previous token, then takes ║ +# ║ a weighted average of their values (V). "Causal" = can only look backward. ║ +# ║ ║ +# ║ ┌──────────────── Input x: [batch, seq_len, 512] ────────────────┐ ║ +# ║ │ │ │ ║ +# ║ │ ┌──────────────────┼──────────────────┐ │ ║ +# ║ │ │ │ │ │ ║ +# ║ │ ▼ ▼ ▼ │ ║ +# ║ │ c_q (512→512) c_k (512→256) c_v (512→256) │ ║ +# ║ │ │ │ │ │ ║ +# ║ │ ▼ ▼ ▼ │ ║ +# ║ │ Q: 8 heads × 64d K: 4 heads × 64d V: 4 heads × 64d │ ║ +# ║ │ │ │ │ ║ +# ║ │ ▼ ▼ │ ║ +# ║ │ RMSNorm(Q) RMSNorm(K) ← stabilize magnitudes │ ║ +# ║ │ │ │ │ ║ +# ║ │ ▼ ▼ │ ║ +# ║ │ RoPE(Q) RoPE(K) ← encode positions │ ║ +# ║ │ │ │ ║ +# ║ │ ▼ │ ║ +# ║ │ Q × q_gain ← learnable sharpness │ ║ +# ║ │ │ │ │ │ ║ +# ║ │ └──────────────────┼──────────────────┘ │ ║ +# ║ │ ▼ │ ║ +# ║ │ scaled_dot_product_attention │ ║ +# ║ │ softmax(Q·Kᵀ / √64) · V │ ║ +# ║ │ (causal mask + FlashAttention + GQA) │ ║ +# ║ │ │ │ ║ +# ║ │ ▼ │ ║ +# ║ │ proj (512→512) │ ║ +# ║ │ │ │ ║ +# ║ │ ▼ │ ║ +# ║ │ Output: [batch, seq_len, 512] │ ║ +# ║ └─────────────────────────────────────────────────────────────────┘ ║ +# ║ ║ +# ║ GROUPED QUERY ATTENTION (GQA) — why K,V have 4 heads but Q has 8: ║ +# ║ ║ +# ║ Standard (MHA, 8×8): This model (GQA, 8×4): ║ +# ║ Q₀ → K₀ Q₄ → K₄ Q₀ ─┐ ║ +# ║ Q₁ → K₁ Q₅ → K₅ Q₁ ─┤→ K₀,V₀ ║ +# ║ Q₂ → K₂ Q₆ → K₆ Q₂ ─┐ ║ +# ║ Q₃ → K₃ Q₇ → K₇ Q₃ ─┤→ K₁,V₁ ║ +# ║ (8 KV pairs = full cost) Q₄ ─┐ ║ +# ║ Q₅ ─┤→ K₂,V₂ ║ +# ║ Q₆ ─┐ ║ +# ║ Q₇ ─┤→ K₃,V₃ ║ +# ║ (4 KV pairs = 50% memory) ║ +# ║ ║ +# ║ CAUSAL MASK (is_causal=True) — what makes this a language model: ║ +# ║ ║ +# ║ attends to→ t₀ t₁ t₂ t₃ t₄ ║ +# ║ t₀ ✓ ✗ ✗ ✗ ✗ "The" can only see itself ║ +# ║ t₁ ✓ ✓ ✗ ✗ ✗ "cat" sees "The cat" ║ +# ║ t₂ ✓ ✓ ✓ ✗ ✗ "sat" sees "The cat sat" ║ +# ║ t₃ ✓ ✓ ✓ ✓ ✗ "on" sees "The cat sat on" ║ +# ║ t₄ ✓ ✓ ✓ ✓ ✓ "the" sees everything before it ║ +# ║ ║ +# ║ Each token predicts the NEXT token without peeking at the answer. ║ +# ║ ║ +# ║ WHY EACH STEP EXISTS: ║ +# ║ • RMSNorm on Q,K → prevents attention scores from exploding to ±inf ║ +# ║ • RoPE on Q,K → encodes WHERE each token is (position 0, 1, 2...) ║ +# ║ • q_gain (4.0-5.25) → learnable per-head sharpness (higher = more focused) ║ +# ║ • √64 denominator → scales dot products so softmax doesn't saturate ║ +# ║ • proj → mixes info from all 8 heads back to 512 dimensions ║ +# ╚═══════════════════════════════════════════════════════════════════════════════════╝ +class CausalSelfAttention(nn.Module): + def __init__( + self, + dim: int, + num_heads: int, + num_kv_heads: int, + rope_base: float, + qk_gain_init: float, + gated_attn: str = "none", + ): + super().__init__() + if dim % num_heads != 0: + raise ValueError("model_dim must be divisible by num_heads") + if num_heads % num_kv_heads != 0: + raise ValueError("num_heads must be divisible by num_kv_heads") + self.num_heads = num_heads + self.num_kv_heads = num_kv_heads + self.head_dim = dim // num_heads # 512 / 8 = 64 dimensions per head + self.gated_attn = gated_attn + if self.head_dim % 2 != 0: + raise ValueError("head_dim must be even for RoPE") + kv_dim = self.num_kv_heads * self.head_dim # 4 × 64 = 256 (half of Q's 512, GQA savings) + # Gated Attention (arxiv.org/abs/2505.06708, NeurIPS 2025 Best Paper): + # Widen Q projection to also produce gate logits from the same input. + # "none" → c_q: 512→512, no gate + # "headwise" → c_q: 512→520, +8 dims (1 gate scalar per head, ~9K extra params/layer) + # "elementwise" → c_q: 512→1024, +512 dims (1 gate per dimension, doubles Q projection) + if gated_attn == "headwise": + self.gate_dim = num_heads # 8 extra outputs + elif gated_attn == "elementwise": + self.gate_dim = dim # 512 extra outputs + else: + self.gate_dim = 0 + self.c_q = CastedLinear(dim, dim + self.gate_dim, bias=False) # Query + gate logits + self.c_k = CastedLinear(dim, kv_dim, bias=False) # Key: 512→256 (only 4 KV heads) + self.c_v = CastedLinear(dim, kv_dim, bias=False) # Value: 512→256 (only 4 KV heads) + self.proj = CastedLinear(dim, dim, bias=False) # Output projection: 512→512 + self.proj._zero_init = True # Start at zero so skip connections dominate early + self.q_gain = nn.Parameter(torch.full((num_heads,), qk_gain_init, dtype=torch.float32)) + self.rotary = Rotary(self.head_dim, base=rope_base) + + def forward(self, x: Tensor) -> Tensor: + bsz, seqlen, dim = x.shape + # Step 1: Project input → Q (+ gate logits if gated attention is enabled) + q_out = self.c_q(x) + if self.gate_dim > 0: + q_raw, gate_logits = q_out.split([dim, self.gate_dim], dim=-1) + else: + q_raw = q_out + # Reshape into multi-head format: [batch, heads, seq, head_dim] + q = q_raw.reshape(bsz, seqlen, self.num_heads, self.head_dim).transpose(1, 2) + k = self.c_k(x).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim).transpose(1, 2) + v = self.c_v(x).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim).transpose(1, 2) + # Step 2: Normalize Q and K so dot products don't explode + q = F.rms_norm(q, (q.size(-1),)) + k = F.rms_norm(k, (k.size(-1),)) + # Step 3: Apply rotary position embeddings — encode WHERE each token sits + cos, sin = self.rotary(seqlen, x.device, q.dtype) + q = apply_rotary_emb(q, cos, sin) + k = apply_rotary_emb(k, cos, sin) + # Step 4: Scale Q by learnable per-head gain — controls attention sharpness + q = q * self.q_gain.to(dtype=q.dtype)[None, :, None, None] + # Step 5: The actual attention — softmax(Q·Kᵀ/√d) · V with causal mask + # Uses FlashAttention kernel under the hood (memory-efficient, no full NxN matrix) + # GQA: repeat K,V heads to match Q heads (e.g. 4 KV → 8 Q, repeat 2×) + if self.num_kv_heads != self.num_heads: + rep = self.num_heads // self.num_kv_heads + k = k.repeat_interleave(rep, dim=1) + v = v.repeat_interleave(rep, dim=1) + y = F.scaled_dot_product_attention( + q, + k, + v, + attn_mask=None, + is_causal=True, + ) + # Step 6 (Gated Attention): sigmoid gate AFTER attention, BEFORE output projection + # Gate logits come from the Q projection (query-dependent, input-dependent). + # sigmoid(gate) ∈ [0,1] lets the model suppress uninformative heads/dims per token. + # Headwise: [bsz, seq, 8] → [bsz, 8, seq, 1] (one scalar per head) + # Elementwise: [bsz, seq, 512] → [bsz, 8, seq, 64] (one per dimension) + if self.gated_attn == "headwise": + gate = torch.sigmoid(gate_logits) # [bsz, seqlen, num_heads] + gate = gate.transpose(1, 2).unsqueeze(-1) # [bsz, num_heads, seqlen, 1] + y = y * gate + elif self.gated_attn == "elementwise": + gate = torch.sigmoid(gate_logits) # [bsz, seqlen, dim] + gate = gate.reshape(bsz, seqlen, self.num_heads, self.head_dim).transpose(1, 2) + y = y * gate + # Step 7: Reshape from multi-head back to [batch, seq, 512] and project + y = y.transpose(1, 2).contiguous().reshape(bsz, seqlen, dim) + return self.proj(y) + + +class MLP(nn.Module): + # relu^2 MLP from the original modded-nanogpt setup + def __init__(self, dim: int, mlp_mult: int, activation: str = "relu2"): + super().__init__() + self.activation = activation + hidden = mlp_mult * dim + self.fc = CastedLinear(dim, hidden, bias=False) + self.proj = CastedLinear(hidden, dim, bias=False) + self.proj._zero_init = True + + def forward(self, x: Tensor) -> Tensor: + x = self.fc(x) + if self.activation == "leaky_relu2": + x = F.leaky_relu(x, negative_slope=0.5) + else: + x = torch.relu(x) + return self.proj(x.square()) + + +class Block(nn.Module): + def __init__( + self, + dim: int, + num_heads: int, + num_kv_heads: int, + mlp_mult: int, + rope_base: float, + qk_gain_init: float, + gated_attn: str = "none", + activation: str = "relu2", + ): + super().__init__() + self.attn_norm = RMSNorm() + self.mlp_norm = RMSNorm() + self.attn = CausalSelfAttention(dim, num_heads, num_kv_heads, rope_base, qk_gain_init, gated_attn) + self.mlp = MLP(dim, mlp_mult, activation) + self.attn_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + self.mlp_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + self.resid_mix = nn.Parameter(torch.stack((torch.ones(dim), torch.zeros(dim))).float()) + + def forward(self, x: Tensor, x0: Tensor) -> Tensor: + mix = self.resid_mix.to(dtype=x.dtype) + x = mix[0][None, None, :] * x + mix[1][None, None, :] * x0 + attn_out = self.attn(self.attn_norm(x)) + x = x + self.attn_scale.to(dtype=x.dtype)[None, None, :] * attn_out + x = x + self.mlp_scale.to(dtype=x.dtype)[None, None, :] * self.mlp(self.mlp_norm(x)) + return x + + +class GPT(nn.Module): + def __init__( + self, + vocab_size: int, + num_layers: int, + model_dim: int, + num_heads: int, + num_kv_heads: int, + mlp_mult: int, + tie_embeddings: bool, + tied_embed_init_std: float, + logit_softcap: float, + rope_base: float, + qk_gain_init: float, + gated_attn: str = "none", + activation: str = "relu2", + ): + super().__init__() + if logit_softcap <= 0.0: + raise ValueError(f"logit_softcap must be positive, got {logit_softcap}") + self.tie_embeddings = tie_embeddings + self.tied_embed_init_std = tied_embed_init_std + self.logit_softcap = logit_softcap + self.tok_emb = nn.Embedding(vocab_size, model_dim) + self.num_encoder_layers = num_layers // 2 + self.num_decoder_layers = num_layers - self.num_encoder_layers + self.num_skip_weights = min(self.num_encoder_layers, self.num_decoder_layers) + self.skip_weights = nn.Parameter(torch.ones(self.num_skip_weights, model_dim, dtype=torch.float32)) + self.blocks = nn.ModuleList( + [ + Block( + model_dim, + num_heads, + num_kv_heads, + mlp_mult, + rope_base, + qk_gain_init, + gated_attn, + activation, + ) + for i in range(num_layers) + ] + ) + self.final_norm = RMSNorm() + self.lm_head = None if tie_embeddings else CastedLinear(model_dim, vocab_size, bias=False) + if self.lm_head is not None: + self.lm_head._zero_init = True + self._init_weights() + + def _init_weights(self) -> None: + if self.tie_embeddings: + nn.init.normal_(self.tok_emb.weight, mean=0.0, std=self.tied_embed_init_std) + for module in self.modules(): + if isinstance(module, nn.Linear) and getattr(module, "_zero_init", False): + nn.init.zeros_(module.weight) + + def _run_backbone(self, input_ids: Tensor) -> Tensor: + x = self.tok_emb(input_ids) + x = F.rms_norm(x, (x.size(-1),)) + x0 = x + skips: list[Tensor] = [] + for i in range(self.num_encoder_layers): + x = self.blocks[i](x, x0) + skips.append(x) + for i in range(self.num_decoder_layers): + if skips: + x = x + self.skip_weights[i].to(dtype=x.dtype)[None, None, :] * skips.pop() + x = self.blocks[self.num_encoder_layers + i](x, x0) + x = self.final_norm(x) + if self.tie_embeddings: + logits_proj = F.linear(x, self.tok_emb.weight) + else: + if self.lm_head is None: + raise RuntimeError("lm_head is required when tie_embeddings=False") + logits_proj = self.lm_head(x) + return self.logit_softcap * torch.tanh(logits_proj / self.logit_softcap) + + def forward_logits(self, input_ids: Tensor) -> Tensor: + """Return logits [batch, seq, vocab] without computing loss.""" + return self._run_backbone(input_ids) + + def forward(self, input_ids: Tensor, target_ids: Tensor) -> Tensor: + logits = self._run_backbone(input_ids) + logits = logits.reshape(-1, logits.size(-1)) + targets = target_ids.reshape(-1) + return F.cross_entropy(logits.float(), targets, reduction="mean") + + +# ----------------------------- +# TRAINING +# ----------------------------- + +def main() -> None: + global zeropower_via_newtonschulz5 + + code = Path(__file__).read_text(encoding="utf-8") + args = Hyperparameters() + zeropower_via_newtonschulz5 = torch.compile(zeropower_via_newtonschulz5) + + # ----------------------------- + # DISTRIBUTED + CUDA SETUP + # ----------------------------- + + distributed = "RANK" in os.environ and "WORLD_SIZE" in os.environ + rank = int(os.environ.get("RANK", "0")) + world_size = int(os.environ.get("WORLD_SIZE", "1")) + local_rank = int(os.environ.get("LOCAL_RANK", "0")) + if world_size <= 0: + raise ValueError(f"WORLD_SIZE must be positive, got {world_size}") + if 8 % world_size != 0: + raise ValueError(f"WORLD_SIZE={world_size} must divide 8 so grad_accum_steps stays integral") + grad_accum_steps = 8 // world_size + grad_scale = 1.0 / grad_accum_steps + if not torch.cuda.is_available(): + raise RuntimeError("CUDA is required") + device = torch.device("cuda", local_rank) + torch.cuda.set_device(device) + if distributed: + dist.init_process_group(backend="nccl", device_id=device) + dist.barrier() + master_process = rank == 0 + + # Fast math knobs + torch.backends.cuda.matmul.allow_tf32 = True + torch.backends.cudnn.allow_tf32 = True + from torch.backends.cuda import enable_cudnn_sdp, enable_flash_sdp, enable_math_sdp, enable_mem_efficient_sdp + + enable_cudnn_sdp(False) + enable_flash_sdp(True) + enable_mem_efficient_sdp(False) + enable_math_sdp(False) + + logfile = None + if master_process: + os.makedirs("logs", exist_ok=True) + logfile = f"logs/{args.run_id}.txt" + print(logfile) + + def log0(msg: str, console: bool = True) -> None: + if not master_process: + return + if console: + print(msg) + if logfile is not None: + with open(logfile, "a", encoding="utf-8") as f: + print(msg, file=f) + + log0(code, console=False) + log0("=" * 100, console=False) + log0(f"Running Python {sys.version}", console=False) + log0(f"Running PyTorch {torch.__version__}", console=False) + log0( + subprocess.run(["nvidia-smi"], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True, check=False).stdout, + console=False, + ) + log0("=" * 100, console=False) + + # ----------------------------- + # TOKENIZER + VALIDATION METRIC SETUP + # ----------------------------- + + random.seed(args.seed) + np.random.seed(args.seed) + torch.manual_seed(args.seed) + torch.cuda.manual_seed_all(args.seed) + + if not args.tokenizer_path.endswith(".model"): + raise ValueError(f"Script only setup for SentencePiece .model file: {args.tokenizer_path}") + sp = spm.SentencePieceProcessor(model_file=args.tokenizer_path) + if int(sp.vocab_size()) != args.vocab_size: + raise ValueError( + f"VOCAB_SIZE={args.vocab_size} does not match tokenizer vocab_size={int(sp.vocab_size())}" + ) + dataset_dir = Path(args.data_path).resolve() + actual_train_files = len(list(dataset_dir.glob("fineweb_train_*.bin"))) + val_tokens = load_validation_tokens(args.val_files, args.train_seq_len) + base_bytes_lut, has_leading_space_lut, is_boundary_token_lut = build_sentencepiece_luts( + sp, args.vocab_size, device + ) + log0(f"val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path={args.tokenizer_path}") + log0(f"train_loader:dataset:{dataset_dir.name} train_shards:{actual_train_files}") + log0(f"val_loader:shards pattern={args.val_files} tokens:{val_tokens.numel() - 1}") + + # ----------------------------- + # MODEL + OPTIMIZER SETUP + # ----------------------------- + + base_model = GPT( + vocab_size=args.vocab_size, + num_layers=args.num_layers, + model_dim=args.model_dim, + num_heads=args.num_heads, + num_kv_heads=args.num_kv_heads, + mlp_mult=args.mlp_mult, + tie_embeddings=args.tie_embeddings, + tied_embed_init_std=args.tied_embed_init_std, + logit_softcap=args.logit_softcap, + rope_base=args.rope_base, + qk_gain_init=args.qk_gain_init, + gated_attn=args.gated_attn, + activation=args.activation, + ).to(device).bfloat16() + for module in base_model.modules(): + if isinstance(module, CastedLinear): + module.float() + restore_low_dim_params_to_fp32(base_model) + compiled_model = torch.compile(base_model, dynamic=False, fullgraph=True) + model: nn.Module = DDP(compiled_model, device_ids=[local_rank], broadcast_buffers=False) if distributed else compiled_model + + # ===================================================================== + # OPTIMIZER SPLIT — Different param types get different optimizers & LRs + # ===================================================================== + # + # The model's parameters are sorted into 3-4 groups, each with its own + # optimizer and learning rate. Think of it like a team where each role + # needs a different management style: + # + # ┌─────────────────────────────────────────────────────────────────┐ + # │ Parameter Group Optimizer LR % of Params │ + # │ ───────────────────── ───────── ────── ─────────── │ + # │ 1. Embedding table Adam 0.05 ~3% (SP1024) │ + # │ (tok_emb.weight) ~12% (SP8192) │ + # │ "the dictionary" │ + # │ │ + # │ 2. Matrix weights MUON 0.04 ~95% of blocks │ + # │ (Q, K, V, proj, The heavy lifters │ + # │ fc, mlp.proj) — 2D matrices │ + # │ "attention + MLP that do the real │ + # │ computation" learning │ + # │ │ + # │ 3. Scalar params Adam 0.04 <1% │ + # │ (attn_scale, Tiny knobs that │ + # │ mlp_scale, fine-tune how │ + # │ resid_mix, blocks blend │ + # │ q_gain, their outputs │ + # │ skip_weights) │ + # │ │ + # │ 4. Output head Adam 0.008 Only if untied │ + # │ (lm_head.weight) embeddings │ + # │ "vector → token (not used in │ + # │ probability" baseline) │ + # └─────────────────────────────────────────────────────────────────┘ + # + # WHY MUON FOR MATRICES? + # Muon orthogonalizes gradient updates via Newton-Schulz iterations. + # This only works on 2D matrices (needs rows & columns to orthogonalize). + # 1D vectors and scalars are too small — they use standard Adam instead. + # + # WHY DIFFERENT LEARNING RATES? + # - Embeddings are a lookup table — aggressive updates cause instability + # - Matrix weights are the core compute — Muon handles the LR scaling + # - Scalars are sensitive knobs — moderate LR with Adam's adaptive step + # - Output head (if untied) maps back to vocab — needs gentler updates + # + block_named_params = list(base_model.blocks.named_parameters()) + + # --- Group 2: Matrix params (2D tensors in transformer blocks) → Muon --- + # Collects: c_q.weight, c_k.weight, c_v.weight, proj.weight, fc.weight, mlp.proj.weight + # Excludes: control tensors like attn_scale, resid_mix (those are 2D but act as scalars) + matrix_params = [ + p + for name, p in block_named_params + if p.ndim == 2 and not any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) + ] + + # --- Group 3: Scalar params (1D vectors + control tensors) → Adam --- + # Collects: attn_scale, mlp_scale, resid_mix, q_gain, skip_weights + # These are the small "knobs" that tune how blocks combine their outputs + scalar_params = [ + p + for name, p in block_named_params + if p.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) + ] + if base_model.skip_weights.numel() > 0: + scalar_params.append(base_model.skip_weights) # U-Net skip connection weights + + # --- Group 1: Embedding table → Adam (gentle LR) --- + # The "dictionary" that maps token IDs to 512-dim vectors + # If tied: same weight used for input embedding AND output prediction (LR=0.05) + # If untied: separate input embedding (LR=0.6) and output head (LR=0.008) + token_lr = args.tied_embed_lr if args.tie_embeddings else args.embed_lr + optimizer_tok = torch.optim.Adam( + [{"params": [base_model.tok_emb.weight], "lr": token_lr, "base_lr": token_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + fused=True, # fused=True: single GPU kernel for speed + ) + + # --- Group 2 optimizer: Muon for the heavy 2D matrices --- + # Uses Newton-Schulz orthogonalization (5 steps) to normalize gradients + # before applying them. This is the key innovation from modded-nanogpt. + optimizer_muon = Muon( + matrix_params, + lr=args.matrix_lr, + momentum=args.muon_momentum, + backend_steps=args.muon_backend_steps, + ) + for group in optimizer_muon.param_groups: + group["base_lr"] = args.matrix_lr # Stash base LR for warmdown scheduling + + # --- Group 3 optimizer: Adam for small scalar/vector params --- + optimizer_scalar = torch.optim.Adam( + [{"params": scalar_params, "lr": args.scalar_lr, "base_lr": args.scalar_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + fused=True, + ) + + # Collect all optimizers — training loop will step() all of them each iteration + optimizers: list[torch.optim.Optimizer] = [optimizer_tok, optimizer_muon, optimizer_scalar] + + # --- Group 4 (optional): Output head → Adam (only when embeddings are untied) --- + # In baseline, tie_embeddings=True so lm_head is None and this is skipped. + # When untied, the output head gets its own gentle LR (0.008) + if base_model.lm_head is not None: + optimizer_head = torch.optim.Adam( + [{"params": [base_model.lm_head.weight], "lr": args.head_lr, "base_lr": args.head_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + fused=True, + ) + optimizers.insert(1, optimizer_head) + + n_params = sum(p.numel() for p in base_model.parameters()) + log0(f"model_params:{n_params}") + log0(f"world_size:{world_size} grad_accum_steps:{grad_accum_steps}") + log0("sdp_backends:cudnn=False flash=True mem_efficient=False math=False") + log0(f"attention_mode:gqa num_heads:{args.num_heads} num_kv_heads:{args.num_kv_heads}") + log0( + f"tie_embeddings:{args.tie_embeddings} embed_lr:{token_lr} " + f"head_lr:{args.head_lr if base_model.lm_head is not None else 0.0} " + f"matrix_lr:{args.matrix_lr} scalar_lr:{args.scalar_lr}" + ) + log0( + f"train_batch_tokens:{args.train_batch_tokens} train_seq_len:{args.train_seq_len} " + f"iterations:{args.iterations} warmup_steps:{args.warmup_steps} " + f"max_wallclock_seconds:{args.max_wallclock_seconds:.3f}" + ) + log0(f"seed:{args.seed}") + + # ----------------------------- + # DATA LOADER & MODEL WARMUP + # ----------------------------- + + train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) + + def zero_grad_all() -> None: + for opt in optimizers: + opt.zero_grad(set_to_none=True) + + max_wallclock_ms = 1000.0 * args.max_wallclock_seconds if args.max_wallclock_seconds > 0 else None + + def lr_mul(step: int, elapsed_ms: float) -> float: + if args.warmdown_iters <= 0: + return 1.0 + if max_wallclock_ms is None: + warmdown_start = max(args.iterations - args.warmdown_iters, 0) + return max((args.iterations - step) / max(args.warmdown_iters, 1), 0.0) if warmdown_start <= step < args.iterations else 1.0 + step_ms = elapsed_ms / max(step, 1) + warmdown_ms = args.warmdown_iters * step_ms + remaining_ms = max(max_wallclock_ms - elapsed_ms, 0.0) + return remaining_ms / max(warmdown_ms, 1e-9) if remaining_ms <= warmdown_ms else 1.0 + + # Warmup primes the compiled forward/backward/optimizer paths, then we restore the + # initial weights/optimizer state so measured training starts from the true init. + if args.warmup_steps > 0: + initial_model_state = {name: tensor.detach().cpu().clone() for name, tensor in base_model.state_dict().items()} + initial_optimizer_states = [copy.deepcopy(opt.state_dict()) for opt in optimizers] + model.train() + for warmup_step in range(args.warmup_steps): + zero_grad_all() + for micro_step in range(grad_accum_steps): + if distributed: + model.require_backward_grad_sync = micro_step == grad_accum_steps - 1 + x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + warmup_loss = model(x, y) + (warmup_loss * grad_scale).backward() + for opt in optimizers: + opt.step() + zero_grad_all() + if args.warmup_steps <= 20 or (warmup_step + 1) % 10 == 0 or warmup_step + 1 == args.warmup_steps: + log0(f"warmup_step:{warmup_step + 1}/{args.warmup_steps}") + base_model.load_state_dict(initial_model_state, strict=True) + for opt, state in zip(optimizers, initial_optimizer_states, strict=True): + opt.load_state_dict(state) + zero_grad_all() + if distributed: + model.require_backward_grad_sync = True + train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) + + # ----------------------------- + # MAIN TRAINING LOOP + # ----------------------------- + + training_time_ms = 0.0 + stop_after_step: int | None = None + torch.cuda.synchronize() + t0 = time.perf_counter() + + step = 0 + while True: + last_step = step == args.iterations or (stop_after_step is not None and step >= stop_after_step) + + should_validate = last_step or (args.val_loss_every > 0 and step % args.val_loss_every == 0) + if should_validate: + torch.cuda.synchronize() + training_time_ms += 1000.0 * (time.perf_counter() - t0) + val_loss, val_bpb = eval_val( + args, + model, + rank, + world_size, + device, + grad_accum_steps, + val_tokens, + base_bytes_lut, + has_leading_space_lut, + is_boundary_token_lut, + ) + log0( + f"step:{step}/{args.iterations} val_loss:{val_loss:.4f} val_bpb:{val_bpb:.4f} " + f"train_time:{training_time_ms:.0f}ms step_avg:{training_time_ms / max(step, 1):.2f}ms" + ) + torch.cuda.synchronize() + t0 = time.perf_counter() + + if last_step: + if stop_after_step is not None and step < args.iterations: + log0( + f"stopping_early: wallclock_cap train_time:{training_time_ms:.0f}ms " + f"step:{step}/{args.iterations}" + ) + break + + elapsed_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) + scale = lr_mul(step, elapsed_ms) + zero_grad_all() + train_loss = torch.zeros((), device=device) + for micro_step in range(grad_accum_steps): + if distributed: + model.require_backward_grad_sync = micro_step == grad_accum_steps - 1 + x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + loss = model(x, y) + train_loss += loss.detach() + (loss * grad_scale).backward() + train_loss /= grad_accum_steps + + frac = min(step / args.muon_momentum_warmup_steps, 1.0) if args.muon_momentum_warmup_steps > 0 else 1.0 + muon_momentum = (1 - frac) * args.muon_momentum_warmup_start + frac * args.muon_momentum + for group in optimizer_muon.param_groups: + group["momentum"] = muon_momentum + + for opt in optimizers: + for group in opt.param_groups: + group["lr"] = group["base_lr"] * scale + + if args.grad_clip_norm > 0: + torch.nn.utils.clip_grad_norm_(base_model.parameters(), args.grad_clip_norm) + for opt in optimizers: + opt.step() + zero_grad_all() + + step += 1 + approx_training_time_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) + should_log_train = ( + args.train_log_every > 0 + and (step <= 10 or step % args.train_log_every == 0 or stop_after_step is not None) + ) + if should_log_train: + log0( + f"step:{step}/{args.iterations} train_loss:{train_loss.item():.4f} " + f"train_time:{approx_training_time_ms:.0f}ms step_avg:{approx_training_time_ms / step:.2f}ms" + ) + + # Needed to sync whether we've reached the wallclock cap. + reached_cap = max_wallclock_ms is not None and approx_training_time_ms >= max_wallclock_ms + if distributed and max_wallclock_ms is not None: + reached_cap_tensor = torch.tensor(int(reached_cap), device=device) + dist.all_reduce(reached_cap_tensor, op=dist.ReduceOp.MAX) + reached_cap = bool(reached_cap_tensor.item()) + if stop_after_step is None and reached_cap: + stop_after_step = step + + log0( + f"peak memory allocated: {torch.cuda.max_memory_allocated() // 1024 // 1024} MiB " + f"reserved: {torch.cuda.max_memory_reserved() // 1024 // 1024} MiB" + ) + + # ----------------------------- + # SERIALIZATION + ROUNDTRIP VALIDATION + # ----------------------------- + # Save the raw state (useful for debugging/loading in PyTorch directly), then always produce + # the compressed int8+zlib artifact and validate the round-tripped weights. + + if master_process: + torch.save(base_model.state_dict(), "final_model.pt") + model_bytes = os.path.getsize("final_model.pt") + code_bytes = len(code.encode("utf-8")) + log0(f"Serialized model: {model_bytes} bytes") + log0(f"Code size: {code_bytes} bytes") + log0(f"Total submission size: {model_bytes + code_bytes} bytes") + + quant_obj, quant_stats = quantize_state_dict_int8(base_model.state_dict()) + quant_buf = io.BytesIO() + torch.save(quant_obj, quant_buf) + quant_raw = quant_buf.getvalue() + quant_blob = zlib.compress(quant_raw, level=9) + quant_raw_bytes = len(quant_raw) + if master_process: + with open("final_model.int8.ptz", "wb") as f: + f.write(quant_blob) + quant_file_bytes = os.path.getsize("final_model.int8.ptz") + code_bytes = len(code.encode("utf-8")) + ratio = quant_stats["baseline_tensor_bytes"] / max(quant_stats["int8_payload_bytes"], 1) + log0( + f"Serialized model int8+zlib: {quant_file_bytes} bytes " + f"(payload:{quant_stats['int8_payload_bytes']} raw_torch:{quant_raw_bytes} payload_ratio:{ratio:.2f}x)" + ) + log0(f"Total submission size int8+zlib: {quant_file_bytes + code_bytes} bytes") + + if distributed: + dist.barrier() + with open("final_model.int8.ptz", "rb") as f: + quant_blob_disk = f.read() + quant_state = torch.load(io.BytesIO(zlib.decompress(quant_blob_disk)), map_location="cpu") + base_model.load_state_dict(dequantize_state_dict_int8(quant_state), strict=True) + torch.cuda.synchronize() + t_qeval = time.perf_counter() + q_val_loss, q_val_bpb = eval_val( + args, + model, + rank, + world_size, + device, + grad_accum_steps, + val_tokens, + base_bytes_lut, + has_leading_space_lut, + is_boundary_token_lut, + ) + torch.cuda.synchronize() + log0( + f"final_int8_zlib_roundtrip val_loss:{q_val_loss:.4f} val_bpb:{q_val_bpb:.4f} " + f"eval_time:{1000.0 * (time.perf_counter() - t_qeval):.0f}ms" + ) + log0(f"final_int8_zlib_roundtrip_exact val_loss:{q_val_loss:.8f} val_bpb:{q_val_bpb:.8f}") + + # Score-First TTT on quantized model (if enabled) + if args.ttt_mode == "score_first": + torch.cuda.synchronize() + t_ttt = time.perf_counter() + # Save quantized weights, TTT will modify them during adaptation + saved_state = copy.deepcopy(base_model.state_dict()) + ttt_loss, ttt_bpb = eval_val_score_first_ttt( + args, base_model, rank, world_size, device, + val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, + log0=log0, + ) + # Restore quantized weights (TTT modifies them in-place) + base_model.load_state_dict(saved_state) + torch.cuda.synchronize() + log0(f"final_int8_ttt val_loss:{ttt_loss:.4f} val_bpb:{ttt_bpb:.4f} " + f"eval_time:{1000.0 * (time.perf_counter() - t_ttt):.0f}ms") + log0(f"final_int8_ttt_exact val_loss:{ttt_loss:.8f} val_bpb:{ttt_bpb:.8f}") + + if distributed: + dist.destroy_process_group() + + +if __name__ == "__main__": + main() + +==================================================================================================== +Running Python 3.11.10 (main, Sep 7 2024, 18:35:41) [GCC 11.4.0] +Running PyTorch 2.11.0+cu130 +Fri Apr 24 00:30:25 2026 ++-----------------------------------------------------------------------------------------+ +| NVIDIA-SMI 580.126.09 Driver Version: 580.126.09 CUDA Version: 13.0 | ++-----------------------------------------+------------------------+----------------------+ +| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC | +| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. | +| | | MIG M. | +|=========================================+========================+======================| +| 0 NVIDIA H100 80GB HBM3 On | 00000000:19:00.0 Off | 0 | +| N/A 39C P0 117W / 700W | 1505MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 1 NVIDIA H100 80GB HBM3 On | 00000000:3B:00.0 Off | 0 | +| N/A 34C P0 120W / 700W | 1505MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 2 NVIDIA H100 80GB HBM3 On | 00000000:4C:00.0 Off | 0 | +| N/A 32C P0 117W / 700W | 1505MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 3 NVIDIA H100 80GB HBM3 On | 00000000:5D:00.0 Off | 0 | +| N/A 39C P0 123W / 700W | 1505MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 4 NVIDIA H100 80GB HBM3 On | 00000000:9B:00.0 Off | 0 | +| N/A 40C P0 117W / 700W | 1505MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 5 NVIDIA H100 80GB HBM3 On | 00000000:BB:00.0 Off | 0 | +| N/A 35C P0 117W / 700W | 1505MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 6 NVIDIA H100 80GB HBM3 On | 00000000:CB:00.0 Off | 0 | +| N/A 39C P0 120W / 700W | 1505MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 7 NVIDIA H100 80GB HBM3 On | 00000000:DB:00.0 Off | 0 | +| N/A 33C P0 115W / 700W | 1505MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ + ++-----------------------------------------------------------------------------------------+ +| Processes: | +| GPU GI CI PID Type Process name GPU Memory | +| ID ID Usage | +|=========================================================================================| +| No running processes found | ++-----------------------------------------------------------------------------------------+ + +==================================================================================================== +val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path=./data/tokenizers/fineweb_8192_bpe.model +train_loader:dataset:fineweb10B_sp8192 train_shards:80 +val_loader:shards pattern=./data/datasets/fineweb10B_sp8192/fineweb_val_*.bin tokens:40540160 +model_params:16364616 +world_size:8 grad_accum_steps:1 +sdp_backends:cudnn=False flash=True mem_efficient=False math=False +attention_mode:gqa num_heads:8 num_kv_heads:4 +tie_embeddings:True embed_lr:0.05 head_lr:0.0 matrix_lr:0.04 scalar_lr:0.04 +train_batch_tokens:524288 train_seq_len:1024 iterations:20000 warmup_steps:20 max_wallclock_seconds:600.000 +seed:2025 +warmup_step:1/20 +warmup_step:2/20 +warmup_step:3/20 +warmup_step:4/20 +warmup_step:5/20 +warmup_step:6/20 +warmup_step:7/20 +warmup_step:8/20 +warmup_step:9/20 +warmup_step:10/20 +warmup_step:11/20 +warmup_step:12/20 +warmup_step:13/20 +warmup_step:14/20 +warmup_step:15/20 +warmup_step:16/20 +warmup_step:17/20 +warmup_step:18/20 +warmup_step:19/20 +warmup_step:20/20 +step:0/20000 val_loss:9.0076 val_bpb:3.4871 train_time:0ms step_avg:0.02ms +step:1/20000 train_loss:9.0081 train_time:38ms step_avg:38.05ms +step:2/20000 train_loss:15.4130 train_time:102ms step_avg:50.78ms +step:3/20000 train_loss:11.1693 train_time:149ms step_avg:49.60ms +step:4/20000 train_loss:8.9961 train_time:203ms step_avg:50.82ms +step:5/20000 train_loss:8.5988 train_time:257ms step_avg:51.43ms +step:6/20000 train_loss:8.4958 train_time:313ms step_avg:52.20ms +step:7/20000 train_loss:8.2557 train_time:366ms step_avg:52.30ms +step:8/20000 train_loss:8.0644 train_time:421ms step_avg:52.60ms +step:9/20000 train_loss:7.7044 train_time:476ms step_avg:52.84ms +step:10/20000 train_loss:7.5149 train_time:530ms step_avg:52.95ms +step:100/20000 train_loss:4.5776 train_time:5424ms step_avg:54.24ms +step:200/20000 train_loss:3.9471 train_time:10909ms step_avg:54.55ms +step:200/20000 val_loss:3.9730 val_bpb:1.5381 train_time:10929ms step_avg:54.64ms +step:300/20000 train_loss:3.8263 train_time:16304ms step_avg:54.35ms +step:400/20000 train_loss:3.6718 train_time:21765ms step_avg:54.41ms +step:400/20000 val_loss:3.6694 val_bpb:1.4205 train_time:21785ms step_avg:54.46ms +step:500/20000 train_loss:3.6175 train_time:27164ms step_avg:54.33ms +step:600/20000 train_loss:3.5503 train_time:32620ms step_avg:54.37ms +step:600/20000 val_loss:3.5567 val_bpb:1.3769 train_time:32641ms step_avg:54.40ms +step:700/20000 train_loss:3.3872 train_time:38017ms step_avg:54.31ms +step:800/20000 train_loss:3.4555 train_time:43482ms step_avg:54.35ms +step:800/20000 val_loss:3.4753 val_bpb:1.3454 train_time:43503ms step_avg:54.38ms +step:900/20000 train_loss:3.5995 train_time:48887ms step_avg:54.32ms +step:1000/20000 train_loss:3.1899 train_time:54352ms step_avg:54.35ms +step:1000/20000 val_loss:3.4285 val_bpb:1.3273 train_time:54373ms step_avg:54.37ms +step:1100/20000 train_loss:3.4276 train_time:59766ms step_avg:54.33ms +step:1200/20000 train_loss:3.4136 train_time:65242ms step_avg:54.37ms +step:1200/20000 val_loss:3.3930 val_bpb:1.3135 train_time:65263ms step_avg:54.39ms +step:1300/20000 train_loss:3.4069 train_time:70656ms step_avg:54.35ms +step:1400/20000 train_loss:3.4893 train_time:76124ms step_avg:54.37ms +step:1400/20000 val_loss:3.3705 val_bpb:1.3048 train_time:76145ms step_avg:54.39ms +step:1500/20000 train_loss:3.5440 train_time:81540ms step_avg:54.36ms +step:1600/20000 train_loss:3.5639 train_time:87009ms step_avg:54.38ms +step:1600/20000 val_loss:3.3497 val_bpb:1.2968 train_time:87030ms step_avg:54.39ms +step:1700/20000 train_loss:3.4667 train_time:92421ms step_avg:54.37ms +step:1800/20000 train_loss:3.2854 train_time:97890ms step_avg:54.38ms +step:1800/20000 val_loss:3.3334 val_bpb:1.2905 train_time:97911ms step_avg:54.39ms +step:1900/20000 train_loss:3.4558 train_time:103301ms step_avg:54.37ms +step:2000/20000 train_loss:3.3000 train_time:108771ms step_avg:54.39ms +step:2000/20000 val_loss:3.3192 val_bpb:1.2850 train_time:108792ms step_avg:54.40ms +step:2100/20000 train_loss:3.1869 train_time:114245ms step_avg:54.40ms +step:2200/20000 train_loss:3.3284 train_time:119654ms step_avg:54.39ms +step:2200/20000 val_loss:3.3098 val_bpb:1.2813 train_time:119675ms step_avg:54.40ms +step:2300/20000 train_loss:3.4382 train_time:125126ms step_avg:54.40ms +step:2400/20000 train_loss:3.3044 train_time:130537ms step_avg:54.39ms +step:2400/20000 val_loss:3.2953 val_bpb:1.2757 train_time:130554ms step_avg:54.40ms +step:2500/20000 train_loss:3.1765 train_time:136005ms step_avg:54.40ms +step:2600/20000 train_loss:3.2905 train_time:141411ms step_avg:54.39ms +step:2600/20000 val_loss:3.2846 val_bpb:1.2716 train_time:141432ms step_avg:54.40ms +step:2700/20000 train_loss:3.4104 train_time:146879ms step_avg:54.40ms +step:2800/20000 train_loss:3.4607 train_time:152278ms step_avg:54.38ms +step:2800/20000 val_loss:3.2787 val_bpb:1.2693 train_time:152299ms step_avg:54.39ms +step:2900/20000 train_loss:3.2319 train_time:157737ms step_avg:54.39ms +step:3000/20000 train_loss:3.3817 train_time:163138ms step_avg:54.38ms +step:3000/20000 val_loss:3.2723 val_bpb:1.2668 train_time:163159ms step_avg:54.39ms +step:3100/20000 train_loss:3.1766 train_time:168601ms step_avg:54.39ms +step:3200/20000 train_loss:3.4878 train_time:174002ms step_avg:54.38ms +step:3200/20000 val_loss:3.2661 val_bpb:1.2644 train_time:174022ms step_avg:54.38ms +step:3300/20000 train_loss:3.2958 train_time:179460ms step_avg:54.38ms +step:3400/20000 train_loss:3.3955 train_time:184861ms step_avg:54.37ms +step:3400/20000 val_loss:3.2592 val_bpb:1.2617 train_time:184882ms step_avg:54.38ms +step:3500/20000 train_loss:3.5916 train_time:190328ms step_avg:54.38ms +step:3600/20000 train_loss:3.2023 train_time:195723ms step_avg:54.37ms +step:3600/20000 val_loss:3.2526 val_bpb:1.2592 train_time:195744ms step_avg:54.37ms +step:3700/20000 train_loss:3.2239 train_time:201185ms step_avg:54.37ms +step:3800/20000 train_loss:3.1998 train_time:206582ms step_avg:54.36ms +step:3800/20000 val_loss:3.2488 val_bpb:1.2577 train_time:206603ms step_avg:54.37ms +step:3900/20000 train_loss:3.2075 train_time:212044ms step_avg:54.37ms +step:4000/20000 train_loss:3.3165 train_time:217441ms step_avg:54.36ms +step:4000/20000 val_loss:3.2424 val_bpb:1.2552 train_time:217462ms step_avg:54.37ms +step:4100/20000 train_loss:2.6903 train_time:222903ms step_avg:54.37ms +step:4200/20000 train_loss:3.1707 train_time:228369ms step_avg:54.37ms +step:4200/20000 val_loss:3.2362 val_bpb:1.2528 train_time:228390ms step_avg:54.38ms +step:4300/20000 train_loss:3.2680 train_time:233775ms step_avg:54.37ms +step:4400/20000 train_loss:3.1482 train_time:239232ms step_avg:54.37ms +step:4400/20000 val_loss:3.2348 val_bpb:1.2523 train_time:239253ms step_avg:54.38ms +step:4500/20000 train_loss:3.2017 train_time:244634ms step_avg:54.36ms +step:4600/20000 train_loss:3.1699 train_time:250101ms step_avg:54.37ms +step:4600/20000 val_loss:3.2287 val_bpb:1.2499 train_time:250122ms step_avg:54.37ms +step:4700/20000 train_loss:3.1495 train_time:255503ms step_avg:54.36ms +step:4800/20000 train_loss:3.2844 train_time:260963ms step_avg:54.37ms +step:4800/20000 val_loss:3.2241 val_bpb:1.2482 train_time:260983ms step_avg:54.37ms +step:4900/20000 train_loss:3.1245 train_time:266364ms step_avg:54.36ms +step:5000/20000 train_loss:3.3008 train_time:271828ms step_avg:54.37ms +step:5000/20000 val_loss:3.2213 val_bpb:1.2471 train_time:271848ms step_avg:54.37ms +step:5100/20000 train_loss:3.3342 train_time:277227ms step_avg:54.36ms +step:5200/20000 train_loss:3.2957 train_time:282687ms step_avg:54.36ms +step:5200/20000 val_loss:3.2189 val_bpb:1.2461 train_time:282708ms step_avg:54.37ms +step:5300/20000 train_loss:3.2338 train_time:288091ms step_avg:54.36ms +step:5400/20000 train_loss:3.1983 train_time:293556ms step_avg:54.36ms +step:5400/20000 val_loss:3.2145 val_bpb:1.2444 train_time:293577ms step_avg:54.37ms +step:5500/20000 train_loss:3.2069 train_time:298964ms step_avg:54.36ms +step:5600/20000 train_loss:3.3033 train_time:304423ms step_avg:54.36ms +step:5600/20000 val_loss:3.2131 val_bpb:1.2439 train_time:304443ms step_avg:54.36ms +step:5700/20000 train_loss:3.0399 train_time:309827ms step_avg:54.36ms +step:5800/20000 train_loss:3.2288 train_time:315296ms step_avg:54.36ms +step:5800/20000 val_loss:3.2098 val_bpb:1.2426 train_time:315317ms step_avg:54.36ms +step:5900/20000 train_loss:3.2138 train_time:320703ms step_avg:54.36ms +step:6000/20000 train_loss:3.3416 train_time:326166ms step_avg:54.36ms +step:6000/20000 val_loss:3.2080 val_bpb:1.2419 train_time:326186ms step_avg:54.36ms +step:6100/20000 train_loss:3.2299 train_time:331573ms step_avg:54.36ms +step:6200/20000 train_loss:3.3792 train_time:337040ms step_avg:54.36ms +step:6200/20000 val_loss:3.2075 val_bpb:1.2417 train_time:337061ms step_avg:54.36ms +step:6300/20000 train_loss:3.2788 train_time:342502ms step_avg:54.37ms +step:6400/20000 train_loss:3.2851 train_time:347907ms step_avg:54.36ms +step:6400/20000 val_loss:3.2076 val_bpb:1.2418 train_time:347927ms step_avg:54.36ms +step:6500/20000 train_loss:3.2235 train_time:353376ms step_avg:54.37ms +step:6600/20000 train_loss:3.2547 train_time:358779ms step_avg:54.36ms +step:6600/20000 val_loss:3.2014 val_bpb:1.2394 train_time:358800ms step_avg:54.36ms +step:6700/20000 train_loss:2.9329 train_time:364244ms step_avg:54.36ms +step:6800/20000 train_loss:3.1305 train_time:369652ms step_avg:54.36ms +step:6800/20000 val_loss:3.2004 val_bpb:1.2390 train_time:369672ms step_avg:54.36ms +step:6900/20000 train_loss:3.2880 train_time:375123ms step_avg:54.37ms +step:7000/20000 train_loss:3.1735 train_time:380529ms step_avg:54.36ms +step:7000/20000 val_loss:3.1973 val_bpb:1.2378 train_time:380550ms step_avg:54.36ms +step:7100/20000 train_loss:3.1964 train_time:385998ms step_avg:54.37ms +step:7200/20000 train_loss:3.0903 train_time:391402ms step_avg:54.36ms +step:7200/20000 val_loss:3.1920 val_bpb:1.2357 train_time:391423ms step_avg:54.36ms +step:7300/20000 train_loss:2.9739 train_time:396870ms step_avg:54.37ms +step:7400/20000 train_loss:3.1673 train_time:402277ms step_avg:54.36ms +step:7400/20000 val_loss:3.1938 val_bpb:1.2364 train_time:402298ms step_avg:54.36ms +step:7500/20000 train_loss:3.2342 train_time:407739ms step_avg:54.37ms +step:7600/20000 train_loss:3.3123 train_time:413143ms step_avg:54.36ms +step:7600/20000 val_loss:3.1885 val_bpb:1.2344 train_time:413163ms step_avg:54.36ms +step:7700/20000 train_loss:3.1705 train_time:418610ms step_avg:54.36ms +step:7800/20000 train_loss:3.3841 train_time:424014ms step_avg:54.36ms +step:7800/20000 val_loss:3.1882 val_bpb:1.2343 train_time:424035ms step_avg:54.36ms +step:7900/20000 train_loss:3.2268 train_time:429476ms step_avg:54.36ms +step:8000/20000 train_loss:3.1717 train_time:434881ms step_avg:54.36ms +step:8000/20000 val_loss:3.1859 val_bpb:1.2334 train_time:434901ms step_avg:54.36ms +step:8100/20000 train_loss:3.2339 train_time:440349ms step_avg:54.36ms +step:8200/20000 train_loss:3.1458 train_time:445754ms step_avg:54.36ms +step:8200/20000 val_loss:3.1850 val_bpb:1.2330 train_time:445774ms step_avg:54.36ms +step:8300/20000 train_loss:3.2185 train_time:451221ms step_avg:54.36ms +step:8400/20000 train_loss:3.1502 train_time:456684ms step_avg:54.37ms +step:8400/20000 val_loss:3.1914 val_bpb:1.2355 train_time:456705ms step_avg:54.37ms +step:8500/20000 train_loss:3.2065 train_time:462089ms step_avg:54.36ms +step:8600/20000 train_loss:3.1955 train_time:467550ms step_avg:54.37ms +step:8600/20000 val_loss:3.1797 val_bpb:1.2310 train_time:467571ms step_avg:54.37ms +step:8700/20000 train_loss:3.2285 train_time:472957ms step_avg:54.36ms +step:8800/20000 train_loss:3.3321 train_time:478423ms step_avg:54.37ms +step:8800/20000 val_loss:3.1801 val_bpb:1.2311 train_time:478444ms step_avg:54.37ms +step:8900/20000 train_loss:3.2777 train_time:483833ms step_avg:54.36ms +step:9000/20000 train_loss:3.1999 train_time:489295ms step_avg:54.37ms +step:9000/20000 val_loss:3.1775 val_bpb:1.2301 train_time:489315ms step_avg:54.37ms +step:9100/20000 train_loss:3.1205 train_time:494701ms step_avg:54.36ms +step:9200/20000 train_loss:3.1804 train_time:500279ms step_avg:54.38ms +step:9200/20000 val_loss:3.1832 val_bpb:1.2323 train_time:500300ms step_avg:54.38ms +step:9300/20000 train_loss:3.2254 train_time:505683ms step_avg:54.37ms +step:9400/20000 train_loss:3.1625 train_time:511146ms step_avg:54.38ms +step:9400/20000 val_loss:3.1738 val_bpb:1.2287 train_time:511165ms step_avg:54.38ms +step:9500/20000 train_loss:3.0975 train_time:516546ms step_avg:54.37ms +step:9600/20000 train_loss:3.1377 train_time:522014ms step_avg:54.38ms +step:9600/20000 val_loss:3.1718 val_bpb:1.2279 train_time:522034ms step_avg:54.38ms +step:9700/20000 train_loss:3.3282 train_time:527422ms step_avg:54.37ms +step:9800/20000 train_loss:3.1153 train_time:532884ms step_avg:54.38ms +step:9800/20000 val_loss:3.1735 val_bpb:1.2286 train_time:532905ms step_avg:54.38ms +step:9900/20000 train_loss:3.1701 train_time:538288ms step_avg:54.37ms +step:10000/20000 train_loss:3.1832 train_time:543755ms step_avg:54.38ms +step:10000/20000 val_loss:3.1640 val_bpb:1.2249 train_time:543776ms step_avg:54.38ms +step:10100/20000 train_loss:3.1720 train_time:549158ms step_avg:54.37ms +step:10200/20000 train_loss:3.2200 train_time:554620ms step_avg:54.37ms +step:10200/20000 val_loss:3.1523 val_bpb:1.2204 train_time:554641ms step_avg:54.38ms +step:10300/20000 train_loss:3.3841 train_time:560087ms step_avg:54.38ms +step:10400/20000 train_loss:3.1962 train_time:565486ms step_avg:54.37ms +step:10400/20000 val_loss:3.1406 val_bpb:1.2158 train_time:565507ms step_avg:54.38ms +step:10500/20000 train_loss:3.3115 train_time:570945ms step_avg:54.38ms +step:10600/20000 train_loss:3.1072 train_time:576346ms step_avg:54.37ms +step:10600/20000 val_loss:3.1274 val_bpb:1.2107 train_time:576366ms step_avg:54.37ms +step:10700/20000 train_loss:3.1606 train_time:581808ms step_avg:54.37ms +step:10800/20000 train_loss:3.2310 train_time:587206ms step_avg:54.37ms +step:10800/20000 val_loss:3.1142 val_bpb:1.2056 train_time:587227ms step_avg:54.37ms +step:10900/20000 train_loss:3.2008 train_time:592662ms step_avg:54.37ms +step:11000/20000 train_loss:3.1905 train_time:598057ms step_avg:54.37ms +step:11000/20000 val_loss:3.1049 val_bpb:1.2020 train_time:598078ms step_avg:54.37ms +step:11036/20000 val_loss:3.1043 val_bpb:1.2018 train_time:600020ms step_avg:54.37ms +stopping_early: wallclock_cap train_time:600020ms step:11036/20000 +peak memory allocated: 10490 MiB reserved: 11544 MiB +Serialized model: 58152343 bytes +Code size: 73221 bytes +Total submission size: 58225564 bytes +Serialized model int8+zlib: 15337072 bytes (payload:16483504 raw_torch:16528409 payload_ratio:3.53x) +Total submission size int8+zlib: 15410293 bytes +final_int8_zlib_roundtrip val_loss:3.1260 val_bpb:1.2102 eval_time:1090ms +final_int8_zlib_roundtrip_exact val_loss:3.12596029 val_bpb:1.21015664 +ttt:start chunks=1238 chunk_tokens=32768 windows=39590 ttt_lr=0.005 ttt_epochs=3 freeze_blocks=0 +ttt:params unfrozen=16364616 frozen=0 + ttt_chunk [1/1238] bpb=1.212951 time=0.3s + ttt_chunk [11/1238] bpb=1.187075 time=1.5s + ttt_chunk [21/1238] bpb=1.235597 time=2.6s + ttt_chunk [31/1238] bpb=1.232924 time=3.7s + ttt_chunk [41/1238] bpb=1.227025 time=4.9s + ttt_chunk [51/1238] bpb=1.221387 time=6.0s + ttt_chunk [61/1238] bpb=1.213670 time=7.1s + ttt_chunk [71/1238] bpb=1.219303 time=8.3s + ttt_chunk [81/1238] bpb=1.214049 time=9.4s + ttt_chunk [91/1238] bpb=1.210216 time=10.5s + ttt_chunk [101/1238] bpb=1.209118 time=11.6s + ttt_chunk [111/1238] bpb=1.208120 time=12.7s + ttt_chunk [121/1238] bpb=1.210948 time=13.8s + ttt_chunk [131/1238] bpb=1.214678 time=14.9s + ttt_chunk [141/1238] bpb=1.214882 time=16.0s + ttt_chunk [151/1238] bpb=1.214592 time=17.1s + ttt_chunk [161/1238] bpb=1.215128 time=18.2s + ttt_chunk [171/1238] bpb=1.214984 time=19.4s + ttt_chunk [181/1238] bpb=1.213126 time=20.5s + ttt_chunk [191/1238] bpb=1.212546 time=21.6s + ttt_chunk [201/1238] bpb=1.209728 time=22.7s + ttt_chunk [211/1238] bpb=1.214183 time=23.8s + ttt_chunk [221/1238] bpb=1.213949 time=24.9s + ttt_chunk [231/1238] bpb=1.215653 time=26.0s + ttt_chunk [241/1238] bpb=1.215334 time=27.1s + ttt_chunk [251/1238] bpb=1.215398 time=28.2s + ttt_chunk [261/1238] bpb=1.215660 time=29.3s + ttt_chunk [271/1238] bpb=1.215834 time=30.5s + ttt_chunk [281/1238] bpb=1.214855 time=31.6s + ttt_chunk [291/1238] bpb=1.215876 time=32.7s + ttt_chunk [301/1238] bpb=1.215702 time=33.9s + ttt_chunk [311/1238] bpb=1.214133 time=35.0s + ttt_chunk [321/1238] bpb=1.214030 time=36.1s + ttt_chunk [331/1238] bpb=1.214348 time=37.3s + ttt_chunk [341/1238] bpb=1.213510 time=38.4s + ttt_chunk [351/1238] bpb=1.214251 time=39.6s + ttt_chunk [361/1238] bpb=1.213138 time=40.7s + ttt_chunk [371/1238] bpb=1.211587 time=41.8s + ttt_chunk [381/1238] bpb=1.211790 time=43.0s + ttt_chunk [391/1238] bpb=1.211304 time=44.1s + ttt_chunk [401/1238] bpb=1.210925 time=45.2s + ttt_chunk [411/1238] bpb=1.211466 time=46.4s + ttt_chunk [421/1238] bpb=1.210757 time=47.5s + ttt_chunk [431/1238] bpb=1.210868 time=48.7s + ttt_chunk [441/1238] bpb=1.211092 time=49.8s + ttt_chunk [451/1238] bpb=1.212412 time=50.9s + ttt_chunk [461/1238] bpb=1.210609 time=52.0s + ttt_chunk [471/1238] bpb=1.210559 time=53.1s + ttt_chunk [481/1238] bpb=1.210645 time=54.3s + ttt_chunk [491/1238] bpb=1.210953 time=55.4s + ttt_chunk [501/1238] bpb=1.210407 time=56.6s + ttt_chunk [511/1238] bpb=1.210206 time=57.7s + ttt_chunk [521/1238] bpb=1.209849 time=58.8s + ttt_chunk [531/1238] bpb=1.209756 time=59.9s + ttt_chunk [541/1238] bpb=1.209854 time=61.1s + ttt_chunk [551/1238] bpb=1.209508 time=62.2s + ttt_chunk [561/1238] bpb=1.209086 time=63.3s + ttt_chunk [571/1238] bpb=1.208407 time=64.4s + ttt_chunk [581/1238] bpb=1.208641 time=65.6s + ttt_chunk [591/1238] bpb=1.208872 time=66.7s + ttt_chunk [601/1238] bpb=1.208806 time=67.8s + ttt_chunk [611/1238] bpb=1.209321 time=68.9s + ttt_chunk [621/1238] bpb=1.210181 time=70.0s + ttt_chunk [631/1238] bpb=1.210057 time=71.2s + ttt_chunk [641/1238] bpb=1.210362 time=72.3s + ttt_chunk [651/1238] bpb=1.210634 time=73.5s + ttt_chunk [661/1238] bpb=1.209919 time=74.6s + ttt_chunk [671/1238] bpb=1.209428 time=75.7s + ttt_chunk [681/1238] bpb=1.210755 time=76.9s + ttt_chunk [691/1238] bpb=1.210819 time=78.0s + ttt_chunk [701/1238] bpb=1.210541 time=79.1s + ttt_chunk [711/1238] bpb=1.211284 time=80.2s + ttt_chunk [721/1238] bpb=1.211625 time=81.4s + ttt_chunk [731/1238] bpb=1.210956 time=82.5s + ttt_chunk [741/1238] bpb=1.210797 time=83.6s + ttt_chunk [751/1238] bpb=1.209842 time=84.7s + ttt_chunk [761/1238] bpb=1.208942 time=85.8s + ttt_chunk [771/1238] bpb=1.207804 time=86.9s + ttt_chunk [781/1238] bpb=1.207670 time=88.1s + ttt_chunk [791/1238] bpb=1.208098 time=89.2s + ttt_chunk [801/1238] bpb=1.208571 time=90.3s + ttt_chunk [811/1238] bpb=1.207956 time=91.4s + ttt_chunk [821/1238] bpb=1.206956 time=92.6s + ttt_chunk [831/1238] bpb=1.206656 time=93.8s + ttt_chunk [841/1238] bpb=1.206287 time=94.9s + ttt_chunk [851/1238] bpb=1.205831 time=96.0s + ttt_chunk [861/1238] bpb=1.205390 time=97.2s + ttt_chunk [871/1238] bpb=1.205079 time=98.3s + ttt_chunk [881/1238] bpb=1.204664 time=99.4s + ttt_chunk [891/1238] bpb=1.204037 time=100.5s + ttt_chunk [901/1238] bpb=1.204502 time=101.6s + ttt_chunk [911/1238] bpb=1.204336 time=102.8s + ttt_chunk [921/1238] bpb=1.204698 time=103.8s + ttt_chunk [931/1238] bpb=1.205485 time=104.9s + ttt_chunk [941/1238] bpb=1.205953 time=106.1s + ttt_chunk [951/1238] bpb=1.205961 time=107.2s + ttt_chunk [961/1238] bpb=1.206816 time=108.3s + ttt_chunk [971/1238] bpb=1.207364 time=109.5s + ttt_chunk [981/1238] bpb=1.207769 time=110.6s + ttt_chunk [991/1238] bpb=1.207564 time=111.7s + ttt_chunk [1001/1238] bpb=1.207823 time=112.9s + ttt_chunk [1011/1238] bpb=1.208201 time=114.0s + ttt_chunk [1021/1238] bpb=1.208882 time=115.0s + ttt_chunk [1031/1238] bpb=1.209421 time=116.2s + ttt_chunk [1041/1238] bpb=1.209821 time=117.3s + ttt_chunk [1051/1238] bpb=1.209626 time=118.4s + ttt_chunk [1061/1238] bpb=1.209692 time=119.5s + ttt_chunk [1071/1238] bpb=1.209692 time=120.6s + ttt_chunk [1081/1238] bpb=1.209514 time=121.7s + ttt_chunk [1091/1238] bpb=1.209565 time=122.8s + ttt_chunk [1101/1238] bpb=1.210061 time=123.9s + ttt_chunk [1111/1238] bpb=1.210336 time=125.0s + ttt_chunk [1121/1238] bpb=1.210449 time=126.1s + ttt_chunk [1131/1238] bpb=1.209997 time=127.2s + ttt_chunk [1141/1238] bpb=1.209613 time=128.3s + ttt_chunk [1151/1238] bpb=1.209641 time=129.4s + ttt_chunk [1161/1238] bpb=1.209768 time=130.5s + ttt_chunk [1171/1238] bpb=1.209458 time=131.7s + ttt_chunk [1181/1238] bpb=1.208948 time=132.8s + ttt_chunk [1191/1238] bpb=1.209072 time=133.9s + ttt_chunk [1201/1238] bpb=1.209199 time=135.1s + ttt_chunk [1211/1238] bpb=1.208972 time=136.2s + ttt_chunk [1221/1238] bpb=1.208400 time=137.3s + ttt_chunk [1231/1238] bpb=1.208093 time=138.4s + ttt_chunk [1238/1238] bpb=1.208033 time=139.1s +ttt:done val_loss=3.118989 val_bpb=1.207458 elapsed=139.1s +final_int8_ttt val_loss:3.1190 val_bpb:1.2075 eval_time:139102ms +final_int8_ttt_exact val_loss:3.11898899 val_bpb:1.20745783 diff --git a/records/track_10min_16mb/2026-04-24_SP8192_HeadwiseGate_LeakyReLU2_LegalTTT/train_seed42.log b/records/track_10min_16mb/2026-04-24_SP8192_HeadwiseGate_LeakyReLU2_LegalTTT/train_seed42.log new file mode 100644 index 0000000000..03806c552a --- /dev/null +++ b/records/track_10min_16mb/2026-04-24_SP8192_HeadwiseGate_LeakyReLU2_LegalTTT/train_seed42.log @@ -0,0 +1,1929 @@ +""" +The `train_gpt.py` and `train_gpt_mlx.py` scripts are intended as good launching-off points for new participants, not SOTA configs. We'll accept PRs that tune, improve, or simplify these scripts without significantly increasing complexity, but competitive submissions should stay in the `/records` folder. + +Hard stop: To keep readable for newcomers, let's make sure `train_gpt.py` and `train_gpt_mlx.py` never are longer than 1500 lines. +""" + +from __future__ import annotations + +import copy +import glob +import io +import math +import os +import random +import subprocess +import sys +import time +import uuid +import zlib +from pathlib import Path + +import numpy as np +import sentencepiece as spm +import torch +import torch.distributed as dist +import torch.nn.functional as F +from torch import Tensor, nn +from torch.nn.parallel import DistributedDataParallel as DDP + +# ----------------------------- +# HYPERPARAMETERS +# ----------------------------- +# Default Simple Baseline run: +# - 9 transformer blocks at width 512 +# - 8 attention heads with 4 KV heads (GQA) and 2x MLP expansion +# - vocab size 1024, sequence length 1024, tied embeddings +# - 524,288 train tokens per step for 20,000 iterations with a ~10 minute cap + +class Hyperparameters: + # Data paths are shard globs produced by the existing preprocessing pipeline. + data_path = os.environ.get("DATA_PATH", "./data/datasets/fineweb10B_sp1024") + train_files = os.path.join(data_path, "fineweb_train_*.bin") + val_files = os.path.join(data_path, "fineweb_val_*.bin") + tokenizer_path = os.environ.get("TOKENIZER_PATH", "./data/tokenizers/fineweb_1024_bpe.model") + run_id = os.environ.get("RUN_ID", str(uuid.uuid4())) + seed = int(os.environ.get("SEED", 1337)) + + # Validation cadence and batch size. Validation always uses the full fineweb_val split. + val_batch_size = int(os.environ.get("VAL_BATCH_SIZE", 524_288)) + val_loss_every = int(os.environ.get("VAL_LOSS_EVERY", 1000)) + train_log_every = int(os.environ.get("TRAIN_LOG_EVERY", 200)) + + # Training length. + iterations = int(os.environ.get("ITERATIONS", 20000)) + warmdown_iters = int(os.environ.get("WARMDOWN_ITERS", 1200)) + warmup_steps = int(os.environ.get("WARMUP_STEPS", 20)) + train_batch_tokens = int(os.environ.get("TRAIN_BATCH_TOKENS", 524_288)) + train_seq_len = int(os.environ.get("TRAIN_SEQ_LEN", 1024)) + max_wallclock_seconds = float(os.environ.get("MAX_WALLCLOCK_SECONDS", 600.0)) + qk_gain_init = float(os.environ.get("QK_GAIN_INIT", 1.5)) + + # Model shape. + vocab_size = int(os.environ.get("VOCAB_SIZE", 1024)) + num_layers = int(os.environ.get("NUM_LAYERS", 9)) + num_kv_heads = int(os.environ.get("NUM_KV_HEADS", 4)) + model_dim = int(os.environ.get("MODEL_DIM", 512)) + num_heads = int(os.environ.get("NUM_HEADS", 8)) + mlp_mult = int(os.environ.get("MLP_MULT", 2)) + tie_embeddings = bool(int(os.environ.get("TIE_EMBEDDINGS", "1"))) + rope_base = float(os.environ.get("ROPE_BASE", 10000.0)) + logit_softcap = float(os.environ.get("LOGIT_SOFTCAP", 30.0)) + # Gated Attention (https://arxiv.org/abs/2505.06708, NeurIPS 2025 Best Paper) + # "none" = disabled (baseline), "headwise" = 1 gate per head, "elementwise" = 1 gate per dim + gated_attn = os.environ.get("GATED_ATTN", "none") + + # Optimizer hyperparameters. + embed_lr = float(os.environ.get("EMBED_LR", 0.6)) + head_lr = float(os.environ.get("HEAD_LR", 0.008)) + tied_embed_lr = float(os.environ.get("TIED_EMBED_LR", 0.05)) + tied_embed_init_std = float(os.environ.get("TIED_EMBED_INIT_STD", 0.005)) + matrix_lr = float(os.environ.get("MATRIX_LR", 0.04)) + scalar_lr = float(os.environ.get("SCALAR_LR", 0.04)) + muon_momentum = float(os.environ.get("MUON_MOMENTUM", 0.95)) + muon_backend_steps = int(os.environ.get("MUON_BACKEND_STEPS", 5)) + muon_momentum_warmup_start = float(os.environ.get("MUON_MOMENTUM_WARMUP_START", 0.85)) + muon_momentum_warmup_steps = int(os.environ.get("MUON_MOMENTUM_WARMUP_STEPS", 500)) + beta1 = float(os.environ.get("BETA1", 0.9)) + beta2 = float(os.environ.get("BETA2", 0.95)) + adam_eps = float(os.environ.get("ADAM_EPS", 1e-8)) + grad_clip_norm = float(os.environ.get("GRAD_CLIP_NORM", 0.0)) + + # Activation: "relu2" (default), "leaky_relu2" = LeakyReLU(0.5)² (PG ranks 10-11) + activation = os.environ.get("ACTIVATION", "relu2") + + # Test-Time Training (Score-First TTT, PG ranks 1-3) + # "none" = disabled (baseline), "score_first" = legal score-before-update TTT + ttt_mode = os.environ.get("TTT_MODE", "none") + ttt_lr = float(os.environ.get("TTT_LR", "0.005")) + ttt_momentum = float(os.environ.get("TTT_MOMENTUM", "0.9")) + ttt_epochs = int(os.environ.get("TTT_EPOCHS", "3")) + ttt_chunk_tokens = int(os.environ.get("TTT_CHUNK_TOKENS", "32768")) + ttt_freeze_blocks = int(os.environ.get("TTT_FREEZE_BLOCKS", "0")) + ttt_grad_clip = float(os.environ.get("TTT_GRAD_CLIP", "1.0")) + ttt_batch_seqs = int(os.environ.get("TTT_BATCH_SEQS", "32")) + +# ----------------------------- +# MUON OPTIMIZER +# ----------------------------- +# +# As borrowed from modded-nanogpt +# Background on Muon: https://kellerjordan.github.io/posts/muon/ + +def zeropower_via_newtonschulz5(G: Tensor, steps: int = 10, eps: float = 1e-7) -> Tensor: + # Orthogonalize a 2D update matrix with a fast Newton-Schulz iteration. + # Muon uses this to normalize matrix-shaped gradients before applying them. + a, b, c = (3.4445, -4.7750, 2.0315) + X = G.bfloat16() + X /= X.norm() + eps + transposed = G.size(0) > G.size(1) + if transposed: + X = X.T + for _ in range(steps): + A = X @ X.T + B = b * A + c * A @ A + X = a * X + B @ X + return X.T if transposed else X + + +class Muon(torch.optim.Optimizer): + def __init__(self, params, lr: float, momentum: float, backend_steps: int, nesterov: bool = True): + super().__init__( + params, + dict(lr=lr, momentum=momentum, backend_steps=backend_steps, nesterov=nesterov), + ) + + @torch.no_grad() + def step(self, closure=None): + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + + distributed = dist.is_available() and dist.is_initialized() + world_size = dist.get_world_size() if distributed else 1 + rank = dist.get_rank() if distributed else 0 + + for group in self.param_groups: + params = group["params"] + if not params: + continue + lr = group["lr"] + momentum = group["momentum"] + backend_steps = group["backend_steps"] + nesterov = group["nesterov"] + + total_params = sum(int(p.numel()) for p in params) + updates_flat = torch.zeros(total_params, device=params[0].device, dtype=torch.bfloat16) + + curr = 0 + for i, p in enumerate(params): + if i % world_size == rank and p.grad is not None: + g = p.grad + state = self.state[p] + if "momentum_buffer" not in state: + state["momentum_buffer"] = torch.zeros_like(g) + buf = state["momentum_buffer"] + buf.mul_(momentum).add_(g) + if nesterov: + g = g.add(buf, alpha=momentum) + g = zeropower_via_newtonschulz5(g, steps=backend_steps) + # Scale correction from Muon reference implementations. + g *= max(1, g.size(0) / g.size(1)) ** 0.5 + updates_flat[curr : curr + p.numel()] = g.reshape(-1) + curr += p.numel() + + if distributed: + dist.all_reduce(updates_flat, op=dist.ReduceOp.SUM) + + curr = 0 + for p in params: + g = updates_flat[curr : curr + p.numel()].view_as(p).to(dtype=p.dtype) + p.add_(g, alpha=-lr) + curr += p.numel() + + return loss + + +# ----------------------------- +# TOKENIZER-AGNOSTIC EVALUATION SETUP +# ----------------------------- +# +# It's common for small models have a large fraction of their parameters be embeddings, since the 2 * d_model * d_vocab vectors can be gigantic. +# Instead of locking the tokenizer, we let you bring your own and calculate our validation metrics on the average compression of the validation set. +# We calculate BPB (bits-per-byte) instead of validation loss, so we need methods to count the number of bits per token in the tokenizer. +# Note: Submissions that edit the tokenizer will be examined more carefully, since screwing this up might unjustly improve your score. + +def build_sentencepiece_luts( + sp: spm.SentencePieceProcessor, vocab_size: int, device: torch.device +) -> tuple[Tensor, Tensor, Tensor]: + sp_vocab_size = int(sp.vocab_size()) + table_size = max(sp_vocab_size, vocab_size) + base_bytes_np = np.zeros((table_size,), dtype=np.int16) + has_leading_space_np = np.zeros((table_size,), dtype=np.bool_) + is_boundary_token_np = np.ones((table_size,), dtype=np.bool_) + for token_id in range(sp_vocab_size): + if sp.is_control(token_id) or sp.is_unknown(token_id) or sp.is_unused(token_id): + continue + is_boundary_token_np[token_id] = False + if sp.is_byte(token_id): + base_bytes_np[token_id] = 1 + continue + piece = sp.id_to_piece(token_id) + if piece.startswith("▁"): + has_leading_space_np[token_id] = True + piece = piece[1:] + base_bytes_np[token_id] = len(piece.encode("utf-8")) + return ( + torch.tensor(base_bytes_np, dtype=torch.int16, device=device), + torch.tensor(has_leading_space_np, dtype=torch.bool, device=device), + torch.tensor(is_boundary_token_np, dtype=torch.bool, device=device), + ) + + +def load_validation_tokens(pattern: str, seq_len: int) -> Tensor: + files = [Path(p) for p in sorted(glob.glob(pattern))] + if not files: + raise FileNotFoundError(f"No files found for pattern: {pattern}") + # The export pipeline writes the fixed first-50k-doc validation set to fineweb_val_*. + tokens = torch.cat([load_data_shard(file) for file in files]).contiguous() + usable = ((tokens.numel() - 1) // seq_len) * seq_len + if usable <= 0: + raise ValueError(f"Validation split is too short for TRAIN_SEQ_LEN={seq_len}") + return tokens[: usable + 1] + + +def eval_val( + args: Hyperparameters, + model: nn.Module, + rank: int, + world_size: int, + device: torch.device, + grad_accum_steps: int, + val_tokens: Tensor, + base_bytes_lut: Tensor, + has_leading_space_lut: Tensor, + is_boundary_token_lut: Tensor, +) -> tuple[float, float]: + # Validation computes two metrics: + # - val_loss: token cross-entropy (natural log) + # - val_bpb: tokenizer-agnostic compression metric used by the challenge + local_batch_tokens = args.val_batch_size // (world_size * grad_accum_steps) + if local_batch_tokens < args.train_seq_len: + raise ValueError( + "VAL_BATCH_SIZE must provide at least one sequence per rank; " + f"got VAL_BATCH_SIZE={args.val_batch_size}, WORLD_SIZE={world_size}, " + f"GRAD_ACCUM_STEPS={grad_accum_steps}, TRAIN_SEQ_LEN={args.train_seq_len}" + ) + local_batch_seqs = local_batch_tokens // args.train_seq_len + total_seqs = (val_tokens.numel() - 1) // args.train_seq_len + seq_start = (total_seqs * rank) // world_size + seq_end = (total_seqs * (rank + 1)) // world_size + val_loss_sum = torch.zeros((), device=device, dtype=torch.float64) + val_token_count = torch.zeros((), device=device, dtype=torch.float64) + val_byte_count = torch.zeros((), device=device, dtype=torch.float64) + + model.eval() + with torch.inference_mode(): + for batch_seq_start in range(seq_start, seq_end, local_batch_seqs): + batch_seq_end = min(batch_seq_start + local_batch_seqs, seq_end) + raw_start = batch_seq_start * args.train_seq_len + raw_end = batch_seq_end * args.train_seq_len + 1 + local = val_tokens[raw_start:raw_end].to(device=device, dtype=torch.int64, non_blocking=True) + x = local[:-1].reshape(-1, args.train_seq_len) + y = local[1:].reshape(-1, args.train_seq_len) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + batch_loss = model(x, y).detach() + batch_token_count = float(y.numel()) + val_loss_sum += batch_loss.to(torch.float64) * batch_token_count + val_token_count += batch_token_count + prev_ids = x.reshape(-1) + tgt_ids = y.reshape(-1) + token_bytes = base_bytes_lut[tgt_ids].to(dtype=torch.int16) + token_bytes += (has_leading_space_lut[tgt_ids] & ~is_boundary_token_lut[prev_ids]).to(dtype=torch.int16) + val_byte_count += token_bytes.to(torch.float64).sum() + + if dist.is_available() and dist.is_initialized(): + dist.all_reduce(val_loss_sum, op=dist.ReduceOp.SUM) + dist.all_reduce(val_token_count, op=dist.ReduceOp.SUM) + dist.all_reduce(val_byte_count, op=dist.ReduceOp.SUM) + + val_loss = val_loss_sum / val_token_count + bits_per_token = val_loss.item() / math.log(2.0) + tokens_per_byte = val_token_count.item() / val_byte_count.item() + model.train() + return float(val_loss.item()), float(bits_per_token * tokens_per_byte) + + +# ----------------------------- +# SCORE-FIRST TTT (Legal Test-Time Training) +# ----------------------------- +# Protocol (PR #461 recipe, PG ranks 1-3): +# 1. Chunk validation data into ttt_chunk_tokens windows +# 2. For each chunk: SCORE all tokens under inference_mode(), then TRAIN on scored data +# 3. Last chunk: score only (never train on un-scored data) +# Guarantees: every token scored BEFORE any parameter update that could use it. + +def eval_val_score_first_ttt( + args: Hyperparameters, + base_model: nn.Module, + rank: int, + world_size: int, + device: torch.device, + val_tokens: Tensor, + base_bytes_lut: Tensor, + has_leading_space_lut: Tensor, + is_boundary_token_lut: Tensor, + log0=print, +) -> tuple[float, float]: + seq_len = args.train_seq_len + total_tokens = val_tokens.numel() - 1 + stride = seq_len # non-overlapping windows for simplicity + ttt_chunk = args.ttt_chunk_tokens + + # Pre-compute window starts + window_starts = [ws for ws in range(0, total_tokens, stride) + if min(ws + seq_len, total_tokens) - ws >= 1] + + # Assign windows to chunks + num_chunks = (total_tokens + ttt_chunk - 1) // ttt_chunk + chunk_windows: list[list[int]] = [[] for _ in range(num_chunks)] + for ws in window_starts: + end = min(ws + seq_len, total_tokens) + ci = min(ws // ttt_chunk, num_chunks - 1) + chunk_windows[ci].append(ws) + + log0(f"ttt:start chunks={num_chunks} chunk_tokens={ttt_chunk} " + f"windows={len(window_starts)} ttt_lr={args.ttt_lr} " + f"ttt_epochs={args.ttt_epochs} freeze_blocks={args.ttt_freeze_blocks}") + + loss_sum = torch.zeros((), device=device, dtype=torch.float64) + token_count = torch.zeros((), device=device, dtype=torch.float64) + byte_count = torch.zeros((), device=device, dtype=torch.float64) + + # Freeze first N blocks + frozen_block_ids = set(range(min(args.ttt_freeze_blocks, len(base_model.blocks)))) + ttt_params = [] + for name, p in base_model.named_parameters(): + freeze = any(f"blocks.{bi}." in name for bi in frozen_block_ids) + if freeze: + p.requires_grad_(False) + else: + p.requires_grad_(True) + ttt_params.append(p) + + log0(f"ttt:params unfrozen={sum(p.numel() for p in ttt_params)} " + f"frozen={sum(p.numel() for p in base_model.parameters() if not p.requires_grad)}") + + optimizer = torch.optim.SGD(ttt_params, lr=args.ttt_lr, momentum=args.ttt_momentum) + t0 = time.perf_counter() + + for ci in range(num_chunks): + windows = chunk_windows[ci] + if not windows: + continue + + # --- PHASE 1: SCORE (no gradients) --- + my_s = (len(windows) * rank) // world_size + my_e = (len(windows) * (rank + 1)) // world_size + my_windows = windows[my_s:my_e] + + base_model.eval() + with torch.inference_mode(): + for bi in range(0, len(my_windows), args.ttt_batch_seqs): + batch_ws = my_windows[bi:bi + args.ttt_batch_seqs] + bsz = len(batch_ws) + x_batch = torch.zeros(bsz, seq_len, dtype=torch.int64, device=device) + y_batch = torch.zeros(bsz, seq_len, dtype=torch.int64, device=device) + wlens: list[int] = [] + + for i, ws in enumerate(batch_ws): + end = min(ws + seq_len, total_tokens) + wlen = end - ws + wlens.append(wlen) + chunk_tok = val_tokens[ws:end + 1].to(dtype=torch.int64, device=device) + x_batch[i, :wlen] = chunk_tok[:-1] + y_batch[i, :wlen] = chunk_tok[1:] + + with torch.autocast(device_type="cuda", dtype=torch.bfloat16): + logits = base_model.forward_logits(x_batch) + + nll = F.cross_entropy( + logits.reshape(-1, logits.size(-1)).float(), + y_batch.reshape(-1), reduction="none", + ).reshape(bsz, seq_len) + + for i in range(bsz): + wlen = wlens[i] + scored_nll = nll[i, :wlen].to(torch.float64) + loss_sum += scored_nll.sum() + token_count += float(wlen) + tgt = y_batch[i, :wlen] + prev = x_batch[i, :wlen] + tb = base_bytes_lut[tgt].to(torch.float64) + tb += (has_leading_space_lut[tgt] & ~is_boundary_token_lut[prev]).to(torch.float64) + byte_count += tb.sum() + + # --- PHASE 2: TRAIN on scored chunk (skip last chunk) --- + is_last_chunk = (ci == num_chunks - 1) + if not is_last_chunk and args.ttt_epochs > 0: + base_model.train() + chunk_start = ci * ttt_chunk + chunk_end = min((ci + 1) * ttt_chunk, total_tokens) + chunk_seqs = (chunk_end - chunk_start) // seq_len + if chunk_seqs > 0: + # Cosine LR decay across chunks + cos_lr = args.ttt_lr * 0.5 * (1.0 + math.cos(math.pi * ci / max(num_chunks - 1, 1))) + for pg in optimizer.param_groups: + pg['lr'] = cos_lr + + my_seq_s = (chunk_seqs * rank) // world_size + my_seq_e = (chunk_seqs * (rank + 1)) // world_size + + for _ep in range(args.ttt_epochs): + for bs in range(my_seq_s, my_seq_e, args.ttt_batch_seqs): + be = min(bs + args.ttt_batch_seqs, my_seq_e) + start_tok = chunk_start + bs * seq_len + end_tok = chunk_start + be * seq_len + 1 + if end_tok > val_tokens.numel(): + continue + local = val_tokens[start_tok:end_tok].to(device=device, dtype=torch.int64) + x = local[:-1].reshape(-1, seq_len) + y = local[1:].reshape(-1, seq_len) + optimizer.zero_grad(set_to_none=True) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16): + loss = base_model(x, y) + loss.backward() + if world_size > 1: + for p in ttt_params: + if p.grad is not None: + dist.all_reduce(p.grad, op=dist.ReduceOp.AVG) + torch.nn.utils.clip_grad_norm_(ttt_params, args.ttt_grad_clip) + optimizer.step() + + if rank == 0 and (ci % 10 == 0 or ci == num_chunks - 1): + elapsed = time.perf_counter() - t0 + rl = loss_sum.item() / max(token_count.item(), 1) + rbpb = rl / math.log(2.0) * (token_count.item() / max(byte_count.item(), 1)) if token_count.item() > 0 else 0.0 + log0(f" ttt_chunk [{ci+1}/{num_chunks}] bpb={rbpb:.6f} time={elapsed:.1f}s") + + # AllReduce across ranks + if dist.is_available() and dist.is_initialized(): + dist.all_reduce(loss_sum, op=dist.ReduceOp.SUM) + dist.all_reduce(token_count, op=dist.ReduceOp.SUM) + dist.all_reduce(byte_count, op=dist.ReduceOp.SUM) + + val_loss = (loss_sum / token_count).item() + val_bpb = val_loss / math.log(2.0) * (token_count.item() / byte_count.item()) + + # Restore grad for all params + for p in base_model.parameters(): + p.requires_grad_(True) + base_model.eval() + + log0(f"ttt:done val_loss={val_loss:.6f} val_bpb={val_bpb:.6f} " + f"elapsed={time.perf_counter() - t0:.1f}s") + return val_loss, val_bpb + + +# ----------------------------- +# POST-TRAINING QUANTIZATION +# ----------------------------- +# +# It's silly to export our model, which is trained in bf16 and fp32, at that same precision. +# Instead, we get approximately the same model (with a small hit) by quantizing the model to int8 & zlib compressing. +# We can then decompress the model and run in higher precision for evaluation, after closing in under the size limit. + +CONTROL_TENSOR_NAME_PATTERNS = tuple( + pattern + for pattern in os.environ.get( + "CONTROL_TENSOR_NAME_PATTERNS", + "attn_scale,attn_scales,mlp_scale,mlp_scales,resid_mix,resid_mixes,q_gain,skip_weight,skip_weights", + ).split(",") + if pattern +) +INT8_KEEP_FLOAT_FP32_NAME_PATTERNS = tuple( + pattern + for pattern in os.environ.get( + "INT8_KEEP_FLOAT_FP32_NAME_PATTERNS", + ",".join(CONTROL_TENSOR_NAME_PATTERNS), + ).split(",") + if pattern +) +INT8_KEEP_FLOAT_MAX_NUMEL = 65_536 +INT8_KEEP_FLOAT_STORE_DTYPE = torch.float16 +INT8_PER_ROW_SCALE_DTYPE = torch.float16 +INT8_CLIP_PERCENTILE = 99.99984 +INT8_CLIP_Q = INT8_CLIP_PERCENTILE / 100.0 + +def tensor_nbytes(t: Tensor) -> int: + return int(t.numel()) * int(t.element_size()) + +def keep_float_tensor(name: str, t: Tensor, passthrough_orig_dtypes: dict[str, str]) -> Tensor: + if any(pattern in name for pattern in INT8_KEEP_FLOAT_FP32_NAME_PATTERNS): + return t.float().contiguous() + if t.dtype in {torch.float32, torch.bfloat16}: + passthrough_orig_dtypes[name] = str(t.dtype).removeprefix("torch.") + return t.to(dtype=INT8_KEEP_FLOAT_STORE_DTYPE).contiguous() + return t + +def quantize_float_tensor(t: Tensor) -> tuple[Tensor, Tensor]: + t32 = t.float() + if t32.ndim == 2: + # Matrices get one scale per row, which usually tracks output-channel + # ranges much better than a single tensor-wide scale. + clip_abs = ( + torch.quantile(t32.abs(), INT8_CLIP_Q, dim=1) + if t32.numel() + else torch.empty((t32.shape[0],), dtype=torch.float32) + ) + clipped = torch.maximum(torch.minimum(t32, clip_abs[:, None]), -clip_abs[:, None]) + scale = (clip_abs / 127.0).clamp_min(1.0 / 127.0) + q = torch.clamp(torch.round(clipped / scale[:, None]), -127, 127).to(torch.int8).contiguous() + return q, scale.to(dtype=INT8_PER_ROW_SCALE_DTYPE).contiguous() + + # Vectors / scalars use a simpler per-tensor scale. + clip_abs = float(torch.quantile(t32.abs().flatten(), INT8_CLIP_Q).item()) if t32.numel() else 0.0 + scale = torch.tensor(clip_abs / 127.0 if clip_abs > 0 else 1.0, dtype=torch.float32) + q = torch.clamp(torch.round(torch.clamp(t32, -clip_abs, clip_abs) / scale), -127, 127).to(torch.int8).contiguous() + return q, scale + +def quantize_state_dict_int8(state_dict: dict[str, Tensor]): + # Single supported clean-script export format: + # - per-row int8 for 2D float tensors + # - per-tensor int8 for other float tensors + # - exact passthrough for non-floats + # - passthrough for small float tensors, stored as fp16 to save bytes + quantized: dict[str, Tensor] = {} + scales: dict[str, Tensor] = {} + dtypes: dict[str, str] = {} + passthrough: dict[str, Tensor] = {} + passthrough_orig_dtypes: dict[str, str] = {} + qmeta: dict[str, dict[str, object]] = {} + stats = dict.fromkeys( + ("param_count", "num_tensors", "num_float_tensors", "num_nonfloat_tensors", "baseline_tensor_bytes", "int8_payload_bytes"), + 0, + ) + + for name, tensor in state_dict.items(): + t = tensor.detach().to("cpu").contiguous() + stats["param_count"] += int(t.numel()) + stats["num_tensors"] += 1 + stats["baseline_tensor_bytes"] += tensor_nbytes(t) + + if not t.is_floating_point(): + stats["num_nonfloat_tensors"] += 1 + passthrough[name] = t + stats["int8_payload_bytes"] += tensor_nbytes(t) + continue + + # Small float tensors are cheap enough to keep directly. We still downcast + # fp32/bf16 passthrough tensors to fp16 so metadata does not dominate size. + if t.numel() <= INT8_KEEP_FLOAT_MAX_NUMEL: + kept = keep_float_tensor(name, t, passthrough_orig_dtypes) + passthrough[name] = kept + stats["int8_payload_bytes"] += tensor_nbytes(kept) + continue + + stats["num_float_tensors"] += 1 + q, s = quantize_float_tensor(t) + if s.ndim > 0: + qmeta[name] = {"scheme": "per_row", "axis": 0} + quantized[name] = q + scales[name] = s + dtypes[name] = str(t.dtype).removeprefix("torch.") + stats["int8_payload_bytes"] += tensor_nbytes(q) + tensor_nbytes(s) + + obj: dict[str, object] = { + "__quant_format__": "int8_clean_per_row_v1", + "quantized": quantized, + "scales": scales, + "dtypes": dtypes, + "passthrough": passthrough, + } + if qmeta: + obj["qmeta"] = qmeta + if passthrough_orig_dtypes: + obj["passthrough_orig_dtypes"] = passthrough_orig_dtypes + return obj, stats + +def dequantize_state_dict_int8(obj: dict[str, object]) -> dict[str, Tensor]: + out: dict[str, Tensor] = {} + qmeta = obj.get("qmeta", {}) + passthrough_orig_dtypes = obj.get("passthrough_orig_dtypes", {}) + for name, q in obj["quantized"].items(): + dtype = getattr(torch, obj["dtypes"][name]) + s = obj["scales"][name] + if qmeta.get(name, {}).get("scheme") == "per_row" or s.ndim > 0: + s = s.to(dtype=torch.float32) + # Broadcast the saved row scale back across trailing dimensions. + out[name] = (q.float() * s.view(q.shape[0], *([1] * (q.ndim - 1)))).to(dtype=dtype).contiguous() + else: + scale = float(s.item()) + out[name] = (q.float() * scale).to(dtype=dtype).contiguous() + for name, t in obj["passthrough"].items(): + # Restore small tensors, undoing the temporary fp16 storage cast if needed. + out_t = t.detach().to("cpu").contiguous() + orig_dtype = passthrough_orig_dtypes.get(name) + if isinstance(orig_dtype, str): + out_t = out_t.to(dtype=getattr(torch, orig_dtype)).contiguous() + out[name] = out_t + return out + + +# ----------------------------- +# DATA LOADING +# ----------------------------- + +def load_data_shard(file: Path) -> Tensor: + header_bytes = 256 * np.dtype(" None: + self.file_idx = (self.file_idx + 1) % len(self.files) + self.tokens = load_data_shard(self.files[self.file_idx]) + self.pos = 0 + + def take(self, n: int) -> Tensor: + chunks: list[Tensor] = [] + remaining = n + while remaining > 0: + avail = self.tokens.numel() - self.pos + if avail <= 0: + self._advance_file() + continue + k = min(remaining, avail) + chunks.append(self.tokens[self.pos : self.pos + k]) + self.pos += k + remaining -= k + return chunks[0] if len(chunks) == 1 else torch.cat(chunks) + + +class DistributedTokenLoader: + # Each call consumes a contiguous chunk from the shared token stream, then slices out + # one disjoint span per rank. The extra "+1" token lets us build (x, y) by shifting. + def __init__(self, pattern: str, rank: int, world_size: int, device: torch.device): + self.rank = rank + self.world_size = world_size + self.device = device + self.stream = TokenStream(pattern) + + def next_batch(self, global_tokens: int, seq_len: int, grad_accum_steps: int) -> tuple[Tensor, Tensor]: + local_tokens = global_tokens // (self.world_size * grad_accum_steps) + per_rank_span = local_tokens + 1 + chunk = self.stream.take(per_rank_span * self.world_size) + start = self.rank * per_rank_span + local = chunk[start : start + per_rank_span].to(dtype=torch.int64) + x = local[:-1].reshape(-1, seq_len) + y = local[1:].reshape(-1, seq_len) + return x.to(self.device, non_blocking=True), y.to(self.device, non_blocking=True) + +# ----------------------------- +# TRANSFORMER MODULES +# ----------------------------- + +class RMSNorm(nn.Module): + def __init__(self, eps: float | None = None): + super().__init__() + self.eps = eps + + def forward(self, x: Tensor) -> Tensor: + return F.rms_norm(x, (x.size(-1),), eps=self.eps) + + +class CastedLinear(nn.Linear): + # Keep weights in fp32 for optimizer/state quality, cast at matmul time for bf16 compute. + def forward(self, x: Tensor) -> Tensor: + bias = self.bias.to(x.dtype) if self.bias is not None else None + return F.linear(x, self.weight.to(x.dtype), bias) + + +def restore_low_dim_params_to_fp32(module: nn.Module) -> None: + # Keep small/control parameters in fp32 even when the model body runs in bf16. + with torch.no_grad(): + for name, param in module.named_parameters(): + if (param.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS)) and param.dtype != torch.float32: + param.data = param.data.float() + + +class Rotary(nn.Module): + # Caches cos/sin tables per sequence length on the current device. + def __init__(self, dim: int, base: float = 10000.0): + super().__init__() + inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.float32) / dim)) + self.register_buffer("inv_freq", inv_freq, persistent=False) + self._seq_len_cached = 0 + self._cos_cached: Tensor | None = None + self._sin_cached: Tensor | None = None + + def forward(self, seq_len: int, device: torch.device, dtype: torch.dtype) -> tuple[Tensor, Tensor]: + if ( + self._cos_cached is None + or self._sin_cached is None + or self._seq_len_cached != seq_len + or self._cos_cached.device != device + ): + t = torch.arange(seq_len, device=device, dtype=self.inv_freq.dtype) + freqs = torch.outer(t, self.inv_freq.to(device)) + self._cos_cached = freqs.cos()[None, None, :, :] + self._sin_cached = freqs.sin()[None, None, :, :] + self._seq_len_cached = seq_len + return self._cos_cached.to(dtype=dtype), self._sin_cached.to(dtype=dtype) + + +def apply_rotary_emb(x: Tensor, cos: Tensor, sin: Tensor) -> Tensor: + half = x.size(-1) // 2 + x1, x2 = x[..., :half], x[..., half:] + return torch.cat((x1 * cos + x2 * sin, x1 * (-sin) + x2 * cos), dim=-1) + + +# ╔═══════════════════════════════════════════════════════════════════════════════════╗ +# ║ CAUSAL SELF-ATTENTION (with GQA) ║ +# ║ ║ +# ║ Plain English: Each token asks "which earlier tokens should I pay attention to?" ║ +# ║ It computes a compatibility score (Q·K) with every previous token, then takes ║ +# ║ a weighted average of their values (V). "Causal" = can only look backward. ║ +# ║ ║ +# ║ ┌──────────────── Input x: [batch, seq_len, 512] ────────────────┐ ║ +# ║ │ │ │ ║ +# ║ │ ┌──────────────────┼──────────────────┐ │ ║ +# ║ │ │ │ │ │ ║ +# ║ │ ▼ ▼ ▼ │ ║ +# ║ │ c_q (512→512) c_k (512→256) c_v (512→256) │ ║ +# ║ │ │ │ │ │ ║ +# ║ │ ▼ ▼ ▼ │ ║ +# ║ │ Q: 8 heads × 64d K: 4 heads × 64d V: 4 heads × 64d │ ║ +# ║ │ │ │ │ ║ +# ║ │ ▼ ▼ │ ║ +# ║ │ RMSNorm(Q) RMSNorm(K) ← stabilize magnitudes │ ║ +# ║ │ │ │ │ ║ +# ║ │ ▼ ▼ │ ║ +# ║ │ RoPE(Q) RoPE(K) ← encode positions │ ║ +# ║ │ │ │ ║ +# ║ │ ▼ │ ║ +# ║ │ Q × q_gain ← learnable sharpness │ ║ +# ║ │ │ │ │ │ ║ +# ║ │ └──────────────────┼──────────────────┘ │ ║ +# ║ │ ▼ │ ║ +# ║ │ scaled_dot_product_attention │ ║ +# ║ │ softmax(Q·Kᵀ / √64) · V │ ║ +# ║ │ (causal mask + FlashAttention + GQA) │ ║ +# ║ │ │ │ ║ +# ║ │ ▼ │ ║ +# ║ │ proj (512→512) │ ║ +# ║ │ │ │ ║ +# ║ │ ▼ │ ║ +# ║ │ Output: [batch, seq_len, 512] │ ║ +# ║ └─────────────────────────────────────────────────────────────────┘ ║ +# ║ ║ +# ║ GROUPED QUERY ATTENTION (GQA) — why K,V have 4 heads but Q has 8: ║ +# ║ ║ +# ║ Standard (MHA, 8×8): This model (GQA, 8×4): ║ +# ║ Q₀ → K₀ Q₄ → K₄ Q₀ ─┐ ║ +# ║ Q₁ → K₁ Q₅ → K₅ Q₁ ─┤→ K₀,V₀ ║ +# ║ Q₂ → K₂ Q₆ → K₆ Q₂ ─┐ ║ +# ║ Q₃ → K₃ Q₇ → K₇ Q₃ ─┤→ K₁,V₁ ║ +# ║ (8 KV pairs = full cost) Q₄ ─┐ ║ +# ║ Q₅ ─┤→ K₂,V₂ ║ +# ║ Q₆ ─┐ ║ +# ║ Q₇ ─┤→ K₃,V₃ ║ +# ║ (4 KV pairs = 50% memory) ║ +# ║ ║ +# ║ CAUSAL MASK (is_causal=True) — what makes this a language model: ║ +# ║ ║ +# ║ attends to→ t₀ t₁ t₂ t₃ t₄ ║ +# ║ t₀ ✓ ✗ ✗ ✗ ✗ "The" can only see itself ║ +# ║ t₁ ✓ ✓ ✗ ✗ ✗ "cat" sees "The cat" ║ +# ║ t₂ ✓ ✓ ✓ ✗ ✗ "sat" sees "The cat sat" ║ +# ║ t₃ ✓ ✓ ✓ ✓ ✗ "on" sees "The cat sat on" ║ +# ║ t₄ ✓ ✓ ✓ ✓ ✓ "the" sees everything before it ║ +# ║ ║ +# ║ Each token predicts the NEXT token without peeking at the answer. ║ +# ║ ║ +# ║ WHY EACH STEP EXISTS: ║ +# ║ • RMSNorm on Q,K → prevents attention scores from exploding to ±inf ║ +# ║ • RoPE on Q,K → encodes WHERE each token is (position 0, 1, 2...) ║ +# ║ • q_gain (4.0-5.25) → learnable per-head sharpness (higher = more focused) ║ +# ║ • √64 denominator → scales dot products so softmax doesn't saturate ║ +# ║ • proj → mixes info from all 8 heads back to 512 dimensions ║ +# ╚═══════════════════════════════════════════════════════════════════════════════════╝ +class CausalSelfAttention(nn.Module): + def __init__( + self, + dim: int, + num_heads: int, + num_kv_heads: int, + rope_base: float, + qk_gain_init: float, + gated_attn: str = "none", + ): + super().__init__() + if dim % num_heads != 0: + raise ValueError("model_dim must be divisible by num_heads") + if num_heads % num_kv_heads != 0: + raise ValueError("num_heads must be divisible by num_kv_heads") + self.num_heads = num_heads + self.num_kv_heads = num_kv_heads + self.head_dim = dim // num_heads # 512 / 8 = 64 dimensions per head + self.gated_attn = gated_attn + if self.head_dim % 2 != 0: + raise ValueError("head_dim must be even for RoPE") + kv_dim = self.num_kv_heads * self.head_dim # 4 × 64 = 256 (half of Q's 512, GQA savings) + # Gated Attention (arxiv.org/abs/2505.06708, NeurIPS 2025 Best Paper): + # Widen Q projection to also produce gate logits from the same input. + # "none" → c_q: 512→512, no gate + # "headwise" → c_q: 512→520, +8 dims (1 gate scalar per head, ~9K extra params/layer) + # "elementwise" → c_q: 512→1024, +512 dims (1 gate per dimension, doubles Q projection) + if gated_attn == "headwise": + self.gate_dim = num_heads # 8 extra outputs + elif gated_attn == "elementwise": + self.gate_dim = dim # 512 extra outputs + else: + self.gate_dim = 0 + self.c_q = CastedLinear(dim, dim + self.gate_dim, bias=False) # Query + gate logits + self.c_k = CastedLinear(dim, kv_dim, bias=False) # Key: 512→256 (only 4 KV heads) + self.c_v = CastedLinear(dim, kv_dim, bias=False) # Value: 512→256 (only 4 KV heads) + self.proj = CastedLinear(dim, dim, bias=False) # Output projection: 512→512 + self.proj._zero_init = True # Start at zero so skip connections dominate early + self.q_gain = nn.Parameter(torch.full((num_heads,), qk_gain_init, dtype=torch.float32)) + self.rotary = Rotary(self.head_dim, base=rope_base) + + def forward(self, x: Tensor) -> Tensor: + bsz, seqlen, dim = x.shape + # Step 1: Project input → Q (+ gate logits if gated attention is enabled) + q_out = self.c_q(x) + if self.gate_dim > 0: + q_raw, gate_logits = q_out.split([dim, self.gate_dim], dim=-1) + else: + q_raw = q_out + # Reshape into multi-head format: [batch, heads, seq, head_dim] + q = q_raw.reshape(bsz, seqlen, self.num_heads, self.head_dim).transpose(1, 2) + k = self.c_k(x).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim).transpose(1, 2) + v = self.c_v(x).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim).transpose(1, 2) + # Step 2: Normalize Q and K so dot products don't explode + q = F.rms_norm(q, (q.size(-1),)) + k = F.rms_norm(k, (k.size(-1),)) + # Step 3: Apply rotary position embeddings — encode WHERE each token sits + cos, sin = self.rotary(seqlen, x.device, q.dtype) + q = apply_rotary_emb(q, cos, sin) + k = apply_rotary_emb(k, cos, sin) + # Step 4: Scale Q by learnable per-head gain — controls attention sharpness + q = q * self.q_gain.to(dtype=q.dtype)[None, :, None, None] + # Step 5: The actual attention — softmax(Q·Kᵀ/√d) · V with causal mask + # Uses FlashAttention kernel under the hood (memory-efficient, no full NxN matrix) + # GQA: repeat K,V heads to match Q heads (e.g. 4 KV → 8 Q, repeat 2×) + if self.num_kv_heads != self.num_heads: + rep = self.num_heads // self.num_kv_heads + k = k.repeat_interleave(rep, dim=1) + v = v.repeat_interleave(rep, dim=1) + y = F.scaled_dot_product_attention( + q, + k, + v, + attn_mask=None, + is_causal=True, + ) + # Step 6 (Gated Attention): sigmoid gate AFTER attention, BEFORE output projection + # Gate logits come from the Q projection (query-dependent, input-dependent). + # sigmoid(gate) ∈ [0,1] lets the model suppress uninformative heads/dims per token. + # Headwise: [bsz, seq, 8] → [bsz, 8, seq, 1] (one scalar per head) + # Elementwise: [bsz, seq, 512] → [bsz, 8, seq, 64] (one per dimension) + if self.gated_attn == "headwise": + gate = torch.sigmoid(gate_logits) # [bsz, seqlen, num_heads] + gate = gate.transpose(1, 2).unsqueeze(-1) # [bsz, num_heads, seqlen, 1] + y = y * gate + elif self.gated_attn == "elementwise": + gate = torch.sigmoid(gate_logits) # [bsz, seqlen, dim] + gate = gate.reshape(bsz, seqlen, self.num_heads, self.head_dim).transpose(1, 2) + y = y * gate + # Step 7: Reshape from multi-head back to [batch, seq, 512] and project + y = y.transpose(1, 2).contiguous().reshape(bsz, seqlen, dim) + return self.proj(y) + + +class MLP(nn.Module): + # relu^2 MLP from the original modded-nanogpt setup + def __init__(self, dim: int, mlp_mult: int, activation: str = "relu2"): + super().__init__() + self.activation = activation + hidden = mlp_mult * dim + self.fc = CastedLinear(dim, hidden, bias=False) + self.proj = CastedLinear(hidden, dim, bias=False) + self.proj._zero_init = True + + def forward(self, x: Tensor) -> Tensor: + x = self.fc(x) + if self.activation == "leaky_relu2": + x = F.leaky_relu(x, negative_slope=0.5) + else: + x = torch.relu(x) + return self.proj(x.square()) + + +class Block(nn.Module): + def __init__( + self, + dim: int, + num_heads: int, + num_kv_heads: int, + mlp_mult: int, + rope_base: float, + qk_gain_init: float, + gated_attn: str = "none", + activation: str = "relu2", + ): + super().__init__() + self.attn_norm = RMSNorm() + self.mlp_norm = RMSNorm() + self.attn = CausalSelfAttention(dim, num_heads, num_kv_heads, rope_base, qk_gain_init, gated_attn) + self.mlp = MLP(dim, mlp_mult, activation) + self.attn_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + self.mlp_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + self.resid_mix = nn.Parameter(torch.stack((torch.ones(dim), torch.zeros(dim))).float()) + + def forward(self, x: Tensor, x0: Tensor) -> Tensor: + mix = self.resid_mix.to(dtype=x.dtype) + x = mix[0][None, None, :] * x + mix[1][None, None, :] * x0 + attn_out = self.attn(self.attn_norm(x)) + x = x + self.attn_scale.to(dtype=x.dtype)[None, None, :] * attn_out + x = x + self.mlp_scale.to(dtype=x.dtype)[None, None, :] * self.mlp(self.mlp_norm(x)) + return x + + +class GPT(nn.Module): + def __init__( + self, + vocab_size: int, + num_layers: int, + model_dim: int, + num_heads: int, + num_kv_heads: int, + mlp_mult: int, + tie_embeddings: bool, + tied_embed_init_std: float, + logit_softcap: float, + rope_base: float, + qk_gain_init: float, + gated_attn: str = "none", + activation: str = "relu2", + ): + super().__init__() + if logit_softcap <= 0.0: + raise ValueError(f"logit_softcap must be positive, got {logit_softcap}") + self.tie_embeddings = tie_embeddings + self.tied_embed_init_std = tied_embed_init_std + self.logit_softcap = logit_softcap + self.tok_emb = nn.Embedding(vocab_size, model_dim) + self.num_encoder_layers = num_layers // 2 + self.num_decoder_layers = num_layers - self.num_encoder_layers + self.num_skip_weights = min(self.num_encoder_layers, self.num_decoder_layers) + self.skip_weights = nn.Parameter(torch.ones(self.num_skip_weights, model_dim, dtype=torch.float32)) + self.blocks = nn.ModuleList( + [ + Block( + model_dim, + num_heads, + num_kv_heads, + mlp_mult, + rope_base, + qk_gain_init, + gated_attn, + activation, + ) + for i in range(num_layers) + ] + ) + self.final_norm = RMSNorm() + self.lm_head = None if tie_embeddings else CastedLinear(model_dim, vocab_size, bias=False) + if self.lm_head is not None: + self.lm_head._zero_init = True + self._init_weights() + + def _init_weights(self) -> None: + if self.tie_embeddings: + nn.init.normal_(self.tok_emb.weight, mean=0.0, std=self.tied_embed_init_std) + for module in self.modules(): + if isinstance(module, nn.Linear) and getattr(module, "_zero_init", False): + nn.init.zeros_(module.weight) + + def _run_backbone(self, input_ids: Tensor) -> Tensor: + x = self.tok_emb(input_ids) + x = F.rms_norm(x, (x.size(-1),)) + x0 = x + skips: list[Tensor] = [] + for i in range(self.num_encoder_layers): + x = self.blocks[i](x, x0) + skips.append(x) + for i in range(self.num_decoder_layers): + if skips: + x = x + self.skip_weights[i].to(dtype=x.dtype)[None, None, :] * skips.pop() + x = self.blocks[self.num_encoder_layers + i](x, x0) + x = self.final_norm(x) + if self.tie_embeddings: + logits_proj = F.linear(x, self.tok_emb.weight) + else: + if self.lm_head is None: + raise RuntimeError("lm_head is required when tie_embeddings=False") + logits_proj = self.lm_head(x) + return self.logit_softcap * torch.tanh(logits_proj / self.logit_softcap) + + def forward_logits(self, input_ids: Tensor) -> Tensor: + """Return logits [batch, seq, vocab] without computing loss.""" + return self._run_backbone(input_ids) + + def forward(self, input_ids: Tensor, target_ids: Tensor) -> Tensor: + logits = self._run_backbone(input_ids) + logits = logits.reshape(-1, logits.size(-1)) + targets = target_ids.reshape(-1) + return F.cross_entropy(logits.float(), targets, reduction="mean") + + +# ----------------------------- +# TRAINING +# ----------------------------- + +def main() -> None: + global zeropower_via_newtonschulz5 + + code = Path(__file__).read_text(encoding="utf-8") + args = Hyperparameters() + zeropower_via_newtonschulz5 = torch.compile(zeropower_via_newtonschulz5) + + # ----------------------------- + # DISTRIBUTED + CUDA SETUP + # ----------------------------- + + distributed = "RANK" in os.environ and "WORLD_SIZE" in os.environ + rank = int(os.environ.get("RANK", "0")) + world_size = int(os.environ.get("WORLD_SIZE", "1")) + local_rank = int(os.environ.get("LOCAL_RANK", "0")) + if world_size <= 0: + raise ValueError(f"WORLD_SIZE must be positive, got {world_size}") + if 8 % world_size != 0: + raise ValueError(f"WORLD_SIZE={world_size} must divide 8 so grad_accum_steps stays integral") + grad_accum_steps = 8 // world_size + grad_scale = 1.0 / grad_accum_steps + if not torch.cuda.is_available(): + raise RuntimeError("CUDA is required") + device = torch.device("cuda", local_rank) + torch.cuda.set_device(device) + if distributed: + dist.init_process_group(backend="nccl", device_id=device) + dist.barrier() + master_process = rank == 0 + + # Fast math knobs + torch.backends.cuda.matmul.allow_tf32 = True + torch.backends.cudnn.allow_tf32 = True + from torch.backends.cuda import enable_cudnn_sdp, enable_flash_sdp, enable_math_sdp, enable_mem_efficient_sdp + + enable_cudnn_sdp(False) + enable_flash_sdp(True) + enable_mem_efficient_sdp(False) + enable_math_sdp(False) + + logfile = None + if master_process: + os.makedirs("logs", exist_ok=True) + logfile = f"logs/{args.run_id}.txt" + print(logfile) + + def log0(msg: str, console: bool = True) -> None: + if not master_process: + return + if console: + print(msg) + if logfile is not None: + with open(logfile, "a", encoding="utf-8") as f: + print(msg, file=f) + + log0(code, console=False) + log0("=" * 100, console=False) + log0(f"Running Python {sys.version}", console=False) + log0(f"Running PyTorch {torch.__version__}", console=False) + log0( + subprocess.run(["nvidia-smi"], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True, check=False).stdout, + console=False, + ) + log0("=" * 100, console=False) + + # ----------------------------- + # TOKENIZER + VALIDATION METRIC SETUP + # ----------------------------- + + random.seed(args.seed) + np.random.seed(args.seed) + torch.manual_seed(args.seed) + torch.cuda.manual_seed_all(args.seed) + + if not args.tokenizer_path.endswith(".model"): + raise ValueError(f"Script only setup for SentencePiece .model file: {args.tokenizer_path}") + sp = spm.SentencePieceProcessor(model_file=args.tokenizer_path) + if int(sp.vocab_size()) != args.vocab_size: + raise ValueError( + f"VOCAB_SIZE={args.vocab_size} does not match tokenizer vocab_size={int(sp.vocab_size())}" + ) + dataset_dir = Path(args.data_path).resolve() + actual_train_files = len(list(dataset_dir.glob("fineweb_train_*.bin"))) + val_tokens = load_validation_tokens(args.val_files, args.train_seq_len) + base_bytes_lut, has_leading_space_lut, is_boundary_token_lut = build_sentencepiece_luts( + sp, args.vocab_size, device + ) + log0(f"val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path={args.tokenizer_path}") + log0(f"train_loader:dataset:{dataset_dir.name} train_shards:{actual_train_files}") + log0(f"val_loader:shards pattern={args.val_files} tokens:{val_tokens.numel() - 1}") + + # ----------------------------- + # MODEL + OPTIMIZER SETUP + # ----------------------------- + + base_model = GPT( + vocab_size=args.vocab_size, + num_layers=args.num_layers, + model_dim=args.model_dim, + num_heads=args.num_heads, + num_kv_heads=args.num_kv_heads, + mlp_mult=args.mlp_mult, + tie_embeddings=args.tie_embeddings, + tied_embed_init_std=args.tied_embed_init_std, + logit_softcap=args.logit_softcap, + rope_base=args.rope_base, + qk_gain_init=args.qk_gain_init, + gated_attn=args.gated_attn, + activation=args.activation, + ).to(device).bfloat16() + for module in base_model.modules(): + if isinstance(module, CastedLinear): + module.float() + restore_low_dim_params_to_fp32(base_model) + compiled_model = torch.compile(base_model, dynamic=False, fullgraph=True) + model: nn.Module = DDP(compiled_model, device_ids=[local_rank], broadcast_buffers=False) if distributed else compiled_model + + # ===================================================================== + # OPTIMIZER SPLIT — Different param types get different optimizers & LRs + # ===================================================================== + # + # The model's parameters are sorted into 3-4 groups, each with its own + # optimizer and learning rate. Think of it like a team where each role + # needs a different management style: + # + # ┌─────────────────────────────────────────────────────────────────┐ + # │ Parameter Group Optimizer LR % of Params │ + # │ ───────────────────── ───────── ────── ─────────── │ + # │ 1. Embedding table Adam 0.05 ~3% (SP1024) │ + # │ (tok_emb.weight) ~12% (SP8192) │ + # │ "the dictionary" │ + # │ │ + # │ 2. Matrix weights MUON 0.04 ~95% of blocks │ + # │ (Q, K, V, proj, The heavy lifters │ + # │ fc, mlp.proj) — 2D matrices │ + # │ "attention + MLP that do the real │ + # │ computation" learning │ + # │ │ + # │ 3. Scalar params Adam 0.04 <1% │ + # │ (attn_scale, Tiny knobs that │ + # │ mlp_scale, fine-tune how │ + # │ resid_mix, blocks blend │ + # │ q_gain, their outputs │ + # │ skip_weights) │ + # │ │ + # │ 4. Output head Adam 0.008 Only if untied │ + # │ (lm_head.weight) embeddings │ + # │ "vector → token (not used in │ + # │ probability" baseline) │ + # └─────────────────────────────────────────────────────────────────┘ + # + # WHY MUON FOR MATRICES? + # Muon orthogonalizes gradient updates via Newton-Schulz iterations. + # This only works on 2D matrices (needs rows & columns to orthogonalize). + # 1D vectors and scalars are too small — they use standard Adam instead. + # + # WHY DIFFERENT LEARNING RATES? + # - Embeddings are a lookup table — aggressive updates cause instability + # - Matrix weights are the core compute — Muon handles the LR scaling + # - Scalars are sensitive knobs — moderate LR with Adam's adaptive step + # - Output head (if untied) maps back to vocab — needs gentler updates + # + block_named_params = list(base_model.blocks.named_parameters()) + + # --- Group 2: Matrix params (2D tensors in transformer blocks) → Muon --- + # Collects: c_q.weight, c_k.weight, c_v.weight, proj.weight, fc.weight, mlp.proj.weight + # Excludes: control tensors like attn_scale, resid_mix (those are 2D but act as scalars) + matrix_params = [ + p + for name, p in block_named_params + if p.ndim == 2 and not any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) + ] + + # --- Group 3: Scalar params (1D vectors + control tensors) → Adam --- + # Collects: attn_scale, mlp_scale, resid_mix, q_gain, skip_weights + # These are the small "knobs" that tune how blocks combine their outputs + scalar_params = [ + p + for name, p in block_named_params + if p.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) + ] + if base_model.skip_weights.numel() > 0: + scalar_params.append(base_model.skip_weights) # U-Net skip connection weights + + # --- Group 1: Embedding table → Adam (gentle LR) --- + # The "dictionary" that maps token IDs to 512-dim vectors + # If tied: same weight used for input embedding AND output prediction (LR=0.05) + # If untied: separate input embedding (LR=0.6) and output head (LR=0.008) + token_lr = args.tied_embed_lr if args.tie_embeddings else args.embed_lr + optimizer_tok = torch.optim.Adam( + [{"params": [base_model.tok_emb.weight], "lr": token_lr, "base_lr": token_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + fused=True, # fused=True: single GPU kernel for speed + ) + + # --- Group 2 optimizer: Muon for the heavy 2D matrices --- + # Uses Newton-Schulz orthogonalization (5 steps) to normalize gradients + # before applying them. This is the key innovation from modded-nanogpt. + optimizer_muon = Muon( + matrix_params, + lr=args.matrix_lr, + momentum=args.muon_momentum, + backend_steps=args.muon_backend_steps, + ) + for group in optimizer_muon.param_groups: + group["base_lr"] = args.matrix_lr # Stash base LR for warmdown scheduling + + # --- Group 3 optimizer: Adam for small scalar/vector params --- + optimizer_scalar = torch.optim.Adam( + [{"params": scalar_params, "lr": args.scalar_lr, "base_lr": args.scalar_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + fused=True, + ) + + # Collect all optimizers — training loop will step() all of them each iteration + optimizers: list[torch.optim.Optimizer] = [optimizer_tok, optimizer_muon, optimizer_scalar] + + # --- Group 4 (optional): Output head → Adam (only when embeddings are untied) --- + # In baseline, tie_embeddings=True so lm_head is None and this is skipped. + # When untied, the output head gets its own gentle LR (0.008) + if base_model.lm_head is not None: + optimizer_head = torch.optim.Adam( + [{"params": [base_model.lm_head.weight], "lr": args.head_lr, "base_lr": args.head_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + fused=True, + ) + optimizers.insert(1, optimizer_head) + + n_params = sum(p.numel() for p in base_model.parameters()) + log0(f"model_params:{n_params}") + log0(f"world_size:{world_size} grad_accum_steps:{grad_accum_steps}") + log0("sdp_backends:cudnn=False flash=True mem_efficient=False math=False") + log0(f"attention_mode:gqa num_heads:{args.num_heads} num_kv_heads:{args.num_kv_heads}") + log0( + f"tie_embeddings:{args.tie_embeddings} embed_lr:{token_lr} " + f"head_lr:{args.head_lr if base_model.lm_head is not None else 0.0} " + f"matrix_lr:{args.matrix_lr} scalar_lr:{args.scalar_lr}" + ) + log0( + f"train_batch_tokens:{args.train_batch_tokens} train_seq_len:{args.train_seq_len} " + f"iterations:{args.iterations} warmup_steps:{args.warmup_steps} " + f"max_wallclock_seconds:{args.max_wallclock_seconds:.3f}" + ) + log0(f"seed:{args.seed}") + + # ----------------------------- + # DATA LOADER & MODEL WARMUP + # ----------------------------- + + train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) + + def zero_grad_all() -> None: + for opt in optimizers: + opt.zero_grad(set_to_none=True) + + max_wallclock_ms = 1000.0 * args.max_wallclock_seconds if args.max_wallclock_seconds > 0 else None + + def lr_mul(step: int, elapsed_ms: float) -> float: + if args.warmdown_iters <= 0: + return 1.0 + if max_wallclock_ms is None: + warmdown_start = max(args.iterations - args.warmdown_iters, 0) + return max((args.iterations - step) / max(args.warmdown_iters, 1), 0.0) if warmdown_start <= step < args.iterations else 1.0 + step_ms = elapsed_ms / max(step, 1) + warmdown_ms = args.warmdown_iters * step_ms + remaining_ms = max(max_wallclock_ms - elapsed_ms, 0.0) + return remaining_ms / max(warmdown_ms, 1e-9) if remaining_ms <= warmdown_ms else 1.0 + + # Warmup primes the compiled forward/backward/optimizer paths, then we restore the + # initial weights/optimizer state so measured training starts from the true init. + if args.warmup_steps > 0: + initial_model_state = {name: tensor.detach().cpu().clone() for name, tensor in base_model.state_dict().items()} + initial_optimizer_states = [copy.deepcopy(opt.state_dict()) for opt in optimizers] + model.train() + for warmup_step in range(args.warmup_steps): + zero_grad_all() + for micro_step in range(grad_accum_steps): + if distributed: + model.require_backward_grad_sync = micro_step == grad_accum_steps - 1 + x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + warmup_loss = model(x, y) + (warmup_loss * grad_scale).backward() + for opt in optimizers: + opt.step() + zero_grad_all() + if args.warmup_steps <= 20 or (warmup_step + 1) % 10 == 0 or warmup_step + 1 == args.warmup_steps: + log0(f"warmup_step:{warmup_step + 1}/{args.warmup_steps}") + base_model.load_state_dict(initial_model_state, strict=True) + for opt, state in zip(optimizers, initial_optimizer_states, strict=True): + opt.load_state_dict(state) + zero_grad_all() + if distributed: + model.require_backward_grad_sync = True + train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) + + # ----------------------------- + # MAIN TRAINING LOOP + # ----------------------------- + + training_time_ms = 0.0 + stop_after_step: int | None = None + torch.cuda.synchronize() + t0 = time.perf_counter() + + step = 0 + while True: + last_step = step == args.iterations or (stop_after_step is not None and step >= stop_after_step) + + should_validate = last_step or (args.val_loss_every > 0 and step % args.val_loss_every == 0) + if should_validate: + torch.cuda.synchronize() + training_time_ms += 1000.0 * (time.perf_counter() - t0) + val_loss, val_bpb = eval_val( + args, + model, + rank, + world_size, + device, + grad_accum_steps, + val_tokens, + base_bytes_lut, + has_leading_space_lut, + is_boundary_token_lut, + ) + log0( + f"step:{step}/{args.iterations} val_loss:{val_loss:.4f} val_bpb:{val_bpb:.4f} " + f"train_time:{training_time_ms:.0f}ms step_avg:{training_time_ms / max(step, 1):.2f}ms" + ) + torch.cuda.synchronize() + t0 = time.perf_counter() + + if last_step: + if stop_after_step is not None and step < args.iterations: + log0( + f"stopping_early: wallclock_cap train_time:{training_time_ms:.0f}ms " + f"step:{step}/{args.iterations}" + ) + break + + elapsed_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) + scale = lr_mul(step, elapsed_ms) + zero_grad_all() + train_loss = torch.zeros((), device=device) + for micro_step in range(grad_accum_steps): + if distributed: + model.require_backward_grad_sync = micro_step == grad_accum_steps - 1 + x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + loss = model(x, y) + train_loss += loss.detach() + (loss * grad_scale).backward() + train_loss /= grad_accum_steps + + frac = min(step / args.muon_momentum_warmup_steps, 1.0) if args.muon_momentum_warmup_steps > 0 else 1.0 + muon_momentum = (1 - frac) * args.muon_momentum_warmup_start + frac * args.muon_momentum + for group in optimizer_muon.param_groups: + group["momentum"] = muon_momentum + + for opt in optimizers: + for group in opt.param_groups: + group["lr"] = group["base_lr"] * scale + + if args.grad_clip_norm > 0: + torch.nn.utils.clip_grad_norm_(base_model.parameters(), args.grad_clip_norm) + for opt in optimizers: + opt.step() + zero_grad_all() + + step += 1 + approx_training_time_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) + should_log_train = ( + args.train_log_every > 0 + and (step <= 10 or step % args.train_log_every == 0 or stop_after_step is not None) + ) + if should_log_train: + log0( + f"step:{step}/{args.iterations} train_loss:{train_loss.item():.4f} " + f"train_time:{approx_training_time_ms:.0f}ms step_avg:{approx_training_time_ms / step:.2f}ms" + ) + + # Needed to sync whether we've reached the wallclock cap. + reached_cap = max_wallclock_ms is not None and approx_training_time_ms >= max_wallclock_ms + if distributed and max_wallclock_ms is not None: + reached_cap_tensor = torch.tensor(int(reached_cap), device=device) + dist.all_reduce(reached_cap_tensor, op=dist.ReduceOp.MAX) + reached_cap = bool(reached_cap_tensor.item()) + if stop_after_step is None and reached_cap: + stop_after_step = step + + log0( + f"peak memory allocated: {torch.cuda.max_memory_allocated() // 1024 // 1024} MiB " + f"reserved: {torch.cuda.max_memory_reserved() // 1024 // 1024} MiB" + ) + + # ----------------------------- + # SERIALIZATION + ROUNDTRIP VALIDATION + # ----------------------------- + # Save the raw state (useful for debugging/loading in PyTorch directly), then always produce + # the compressed int8+zlib artifact and validate the round-tripped weights. + + if master_process: + torch.save(base_model.state_dict(), "final_model.pt") + model_bytes = os.path.getsize("final_model.pt") + code_bytes = len(code.encode("utf-8")) + log0(f"Serialized model: {model_bytes} bytes") + log0(f"Code size: {code_bytes} bytes") + log0(f"Total submission size: {model_bytes + code_bytes} bytes") + + quant_obj, quant_stats = quantize_state_dict_int8(base_model.state_dict()) + quant_buf = io.BytesIO() + torch.save(quant_obj, quant_buf) + quant_raw = quant_buf.getvalue() + quant_blob = zlib.compress(quant_raw, level=9) + quant_raw_bytes = len(quant_raw) + if master_process: + with open("final_model.int8.ptz", "wb") as f: + f.write(quant_blob) + quant_file_bytes = os.path.getsize("final_model.int8.ptz") + code_bytes = len(code.encode("utf-8")) + ratio = quant_stats["baseline_tensor_bytes"] / max(quant_stats["int8_payload_bytes"], 1) + log0( + f"Serialized model int8+zlib: {quant_file_bytes} bytes " + f"(payload:{quant_stats['int8_payload_bytes']} raw_torch:{quant_raw_bytes} payload_ratio:{ratio:.2f}x)" + ) + log0(f"Total submission size int8+zlib: {quant_file_bytes + code_bytes} bytes") + + if distributed: + dist.barrier() + with open("final_model.int8.ptz", "rb") as f: + quant_blob_disk = f.read() + quant_state = torch.load(io.BytesIO(zlib.decompress(quant_blob_disk)), map_location="cpu") + base_model.load_state_dict(dequantize_state_dict_int8(quant_state), strict=True) + torch.cuda.synchronize() + t_qeval = time.perf_counter() + q_val_loss, q_val_bpb = eval_val( + args, + model, + rank, + world_size, + device, + grad_accum_steps, + val_tokens, + base_bytes_lut, + has_leading_space_lut, + is_boundary_token_lut, + ) + torch.cuda.synchronize() + log0( + f"final_int8_zlib_roundtrip val_loss:{q_val_loss:.4f} val_bpb:{q_val_bpb:.4f} " + f"eval_time:{1000.0 * (time.perf_counter() - t_qeval):.0f}ms" + ) + log0(f"final_int8_zlib_roundtrip_exact val_loss:{q_val_loss:.8f} val_bpb:{q_val_bpb:.8f}") + + # Score-First TTT on quantized model (if enabled) + if args.ttt_mode == "score_first": + torch.cuda.synchronize() + t_ttt = time.perf_counter() + # Save quantized weights, TTT will modify them during adaptation + saved_state = copy.deepcopy(base_model.state_dict()) + ttt_loss, ttt_bpb = eval_val_score_first_ttt( + args, base_model, rank, world_size, device, + val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, + log0=log0, + ) + # Restore quantized weights (TTT modifies them in-place) + base_model.load_state_dict(saved_state) + torch.cuda.synchronize() + log0(f"final_int8_ttt val_loss:{ttt_loss:.4f} val_bpb:{ttt_bpb:.4f} " + f"eval_time:{1000.0 * (time.perf_counter() - t_ttt):.0f}ms") + log0(f"final_int8_ttt_exact val_loss:{ttt_loss:.8f} val_bpb:{ttt_bpb:.8f}") + + if distributed: + dist.destroy_process_group() + + +if __name__ == "__main__": + main() + +==================================================================================================== +Running Python 3.11.10 (main, Sep 7 2024, 18:35:41) [GCC 11.4.0] +Running PyTorch 2.11.0+cu130 +Fri Apr 24 00:15:35 2026 ++-----------------------------------------------------------------------------------------+ +| NVIDIA-SMI 580.126.09 Driver Version: 580.126.09 CUDA Version: 13.0 | ++-----------------------------------------+------------------------+----------------------+ +| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC | +| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. | +| | | MIG M. | +|=========================================+========================+======================| +| 0 NVIDIA H100 80GB HBM3 On | 00000000:19:00.0 Off | 0 | +| N/A 35C P0 115W / 700W | 1505MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 1 NVIDIA H100 80GB HBM3 On | 00000000:3B:00.0 Off | 0 | +| N/A 33C P0 119W / 700W | 1505MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 2 NVIDIA H100 80GB HBM3 On | 00000000:4C:00.0 Off | 0 | +| N/A 31C P0 117W / 700W | 1505MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 3 NVIDIA H100 80GB HBM3 On | 00000000:5D:00.0 Off | 0 | +| N/A 35C P0 121W / 700W | 1505MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 4 NVIDIA H100 80GB HBM3 On | 00000000:9B:00.0 Off | 0 | +| N/A 36C P0 117W / 700W | 1505MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 5 NVIDIA H100 80GB HBM3 On | 00000000:BB:00.0 Off | 0 | +| N/A 33C P0 117W / 700W | 1505MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 6 NVIDIA H100 80GB HBM3 On | 00000000:CB:00.0 Off | 0 | +| N/A 36C P0 117W / 700W | 1505MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 7 NVIDIA H100 80GB HBM3 On | 00000000:DB:00.0 Off | 0 | +| N/A 32C P0 115W / 700W | 1505MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ + ++-----------------------------------------------------------------------------------------+ +| Processes: | +| GPU GI CI PID Type Process name GPU Memory | +| ID ID Usage | +|=========================================================================================| +| No running processes found | ++-----------------------------------------------------------------------------------------+ + +==================================================================================================== +val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path=./data/tokenizers/fineweb_8192_bpe.model +train_loader:dataset:fineweb10B_sp8192 train_shards:80 +val_loader:shards pattern=./data/datasets/fineweb10B_sp8192/fineweb_val_*.bin tokens:40540160 +model_params:16364616 +world_size:8 grad_accum_steps:1 +sdp_backends:cudnn=False flash=True mem_efficient=False math=False +attention_mode:gqa num_heads:8 num_kv_heads:4 +tie_embeddings:True embed_lr:0.05 head_lr:0.0 matrix_lr:0.04 scalar_lr:0.04 +train_batch_tokens:524288 train_seq_len:1024 iterations:20000 warmup_steps:20 max_wallclock_seconds:600.000 +seed:42 +warmup_step:1/20 +warmup_step:2/20 +warmup_step:3/20 +warmup_step:4/20 +warmup_step:5/20 +warmup_step:6/20 +warmup_step:7/20 +warmup_step:8/20 +warmup_step:9/20 +warmup_step:10/20 +warmup_step:11/20 +warmup_step:12/20 +warmup_step:13/20 +warmup_step:14/20 +warmup_step:15/20 +warmup_step:16/20 +warmup_step:17/20 +warmup_step:18/20 +warmup_step:19/20 +warmup_step:20/20 +step:0/20000 val_loss:9.0069 val_bpb:3.4868 train_time:0ms step_avg:0.02ms +step:1/20000 train_loss:9.0070 train_time:37ms step_avg:36.94ms +step:2/20000 train_loss:15.3495 train_time:97ms step_avg:48.66ms +step:3/20000 train_loss:11.1037 train_time:148ms step_avg:49.23ms +step:4/20000 train_loss:8.9281 train_time:202ms step_avg:50.38ms +step:5/20000 train_loss:8.5952 train_time:256ms step_avg:51.13ms +step:6/20000 train_loss:8.5455 train_time:310ms step_avg:51.66ms +step:7/20000 train_loss:8.2801 train_time:364ms step_avg:51.98ms +step:8/20000 train_loss:8.1036 train_time:419ms step_avg:52.33ms +step:9/20000 train_loss:7.6697 train_time:473ms step_avg:52.51ms +step:10/20000 train_loss:7.4806 train_time:527ms step_avg:52.67ms +step:100/20000 train_loss:4.5805 train_time:5410ms step_avg:54.10ms +step:200/20000 train_loss:3.9409 train_time:10892ms step_avg:54.46ms +step:200/20000 val_loss:3.9545 val_bpb:1.5309 train_time:10911ms step_avg:54.56ms +step:300/20000 train_loss:3.8303 train_time:16290ms step_avg:54.30ms +step:400/20000 train_loss:3.6828 train_time:21745ms step_avg:54.36ms +step:400/20000 val_loss:3.6772 val_bpb:1.4236 train_time:21766ms step_avg:54.41ms +step:500/20000 train_loss:3.6218 train_time:27142ms step_avg:54.28ms +step:600/20000 train_loss:3.5698 train_time:32595ms step_avg:54.33ms +step:600/20000 val_loss:3.5624 val_bpb:1.3791 train_time:32616ms step_avg:54.36ms +step:700/20000 train_loss:3.3834 train_time:37995ms step_avg:54.28ms +step:800/20000 train_loss:3.4537 train_time:43450ms step_avg:54.31ms +step:800/20000 val_loss:3.4811 val_bpb:1.3477 train_time:43471ms step_avg:54.34ms +step:900/20000 train_loss:3.5998 train_time:48855ms step_avg:54.28ms +step:1000/20000 train_loss:3.1952 train_time:54329ms step_avg:54.33ms +step:1000/20000 val_loss:3.4326 val_bpb:1.3289 train_time:54348ms step_avg:54.35ms +step:1100/20000 train_loss:3.4299 train_time:59742ms step_avg:54.31ms +step:1200/20000 train_loss:3.4245 train_time:65215ms step_avg:54.35ms +step:1200/20000 val_loss:3.3965 val_bpb:1.3149 train_time:65234ms step_avg:54.36ms +step:1300/20000 train_loss:3.4110 train_time:70626ms step_avg:54.33ms +step:1400/20000 train_loss:3.4871 train_time:76094ms step_avg:54.35ms +step:1400/20000 val_loss:3.3737 val_bpb:1.3061 train_time:76113ms step_avg:54.37ms +step:1500/20000 train_loss:3.5443 train_time:81502ms step_avg:54.33ms +step:1600/20000 train_loss:3.5690 train_time:86973ms step_avg:54.36ms +step:1600/20000 val_loss:3.3531 val_bpb:1.2981 train_time:86992ms step_avg:54.37ms +step:1700/20000 train_loss:3.4679 train_time:92380ms step_avg:54.34ms +step:1800/20000 train_loss:3.2814 train_time:97848ms step_avg:54.36ms +step:1800/20000 val_loss:3.3365 val_bpb:1.2917 train_time:97869ms step_avg:54.37ms +step:1900/20000 train_loss:3.4602 train_time:103253ms step_avg:54.34ms +step:2000/20000 train_loss:3.3016 train_time:108725ms step_avg:54.36ms +step:2000/20000 val_loss:3.3232 val_bpb:1.2865 train_time:108745ms step_avg:54.37ms +step:2100/20000 train_loss:3.1967 train_time:114192ms step_avg:54.38ms +step:2200/20000 train_loss:3.3338 train_time:119602ms step_avg:54.36ms +step:2200/20000 val_loss:3.3112 val_bpb:1.2819 train_time:119623ms step_avg:54.37ms +step:2300/20000 train_loss:3.4391 train_time:125074ms step_avg:54.38ms +step:2400/20000 train_loss:3.3066 train_time:130481ms step_avg:54.37ms +step:2400/20000 val_loss:3.2973 val_bpb:1.2765 train_time:130502ms step_avg:54.38ms +step:2500/20000 train_loss:3.1785 train_time:135950ms step_avg:54.38ms +step:2600/20000 train_loss:3.2917 train_time:141357ms step_avg:54.37ms +step:2600/20000 val_loss:3.2864 val_bpb:1.2723 train_time:141379ms step_avg:54.38ms +step:2700/20000 train_loss:3.4169 train_time:146832ms step_avg:54.38ms +step:2800/20000 train_loss:3.4657 train_time:152238ms step_avg:54.37ms +step:2800/20000 val_loss:3.2818 val_bpb:1.2705 train_time:152259ms step_avg:54.38ms +step:2900/20000 train_loss:3.2363 train_time:157706ms step_avg:54.38ms +step:3000/20000 train_loss:3.3869 train_time:163114ms step_avg:54.37ms +step:3000/20000 val_loss:3.2730 val_bpb:1.2671 train_time:163135ms step_avg:54.38ms +step:3100/20000 train_loss:3.1769 train_time:168589ms step_avg:54.38ms +step:3200/20000 train_loss:3.4884 train_time:173996ms step_avg:54.37ms +step:3200/20000 val_loss:3.2662 val_bpb:1.2645 train_time:174016ms step_avg:54.38ms +step:3300/20000 train_loss:3.3048 train_time:179464ms step_avg:54.38ms +step:3400/20000 train_loss:3.3978 train_time:184873ms step_avg:54.37ms +step:3400/20000 val_loss:3.2617 val_bpb:1.2627 train_time:184894ms step_avg:54.38ms +step:3500/20000 train_loss:3.5926 train_time:190350ms step_avg:54.39ms +step:3600/20000 train_loss:3.2022 train_time:195757ms step_avg:54.38ms +step:3600/20000 val_loss:3.2539 val_bpb:1.2597 train_time:195778ms step_avg:54.38ms +step:3700/20000 train_loss:3.2262 train_time:201227ms step_avg:54.39ms +step:3800/20000 train_loss:3.1936 train_time:206637ms step_avg:54.38ms +step:3800/20000 val_loss:3.2506 val_bpb:1.2584 train_time:206658ms step_avg:54.38ms +step:3900/20000 train_loss:3.2048 train_time:212114ms step_avg:54.39ms +step:4000/20000 train_loss:3.3121 train_time:217621ms step_avg:54.41ms +step:4000/20000 val_loss:3.2438 val_bpb:1.2558 train_time:217650ms step_avg:54.41ms +step:4100/20000 train_loss:2.6945 train_time:223096ms step_avg:54.41ms +step:4200/20000 train_loss:3.1780 train_time:228573ms step_avg:54.42ms +step:4200/20000 val_loss:3.2366 val_bpb:1.2530 train_time:228594ms step_avg:54.43ms +step:4300/20000 train_loss:3.2658 train_time:233977ms step_avg:54.41ms +step:4400/20000 train_loss:3.1496 train_time:239440ms step_avg:54.42ms +step:4400/20000 val_loss:3.2357 val_bpb:1.2527 train_time:239461ms step_avg:54.42ms +step:4500/20000 train_loss:3.1992 train_time:244847ms step_avg:54.41ms +step:4600/20000 train_loss:3.1715 train_time:250319ms step_avg:54.42ms +step:4600/20000 val_loss:3.2302 val_bpb:1.2505 train_time:250340ms step_avg:54.42ms +step:4700/20000 train_loss:3.1489 train_time:255722ms step_avg:54.41ms +step:4800/20000 train_loss:3.2850 train_time:261184ms step_avg:54.41ms +step:4800/20000 val_loss:3.2256 val_bpb:1.2487 train_time:261205ms step_avg:54.42ms +step:4900/20000 train_loss:3.1256 train_time:266590ms step_avg:54.41ms +step:5000/20000 train_loss:3.3067 train_time:272062ms step_avg:54.41ms +step:5000/20000 val_loss:3.2226 val_bpb:1.2476 train_time:272083ms step_avg:54.42ms +step:5100/20000 train_loss:3.3269 train_time:277465ms step_avg:54.40ms +step:5200/20000 train_loss:3.2949 train_time:282927ms step_avg:54.41ms +step:5200/20000 val_loss:3.2209 val_bpb:1.2469 train_time:282948ms step_avg:54.41ms +step:5300/20000 train_loss:3.2354 train_time:288332ms step_avg:54.40ms +step:5400/20000 train_loss:3.1966 train_time:293805ms step_avg:54.41ms +step:5400/20000 val_loss:3.2164 val_bpb:1.2452 train_time:293825ms step_avg:54.41ms +step:5500/20000 train_loss:3.2028 train_time:299207ms step_avg:54.40ms +step:5600/20000 train_loss:3.3100 train_time:304669ms step_avg:54.41ms +step:5600/20000 val_loss:3.2140 val_bpb:1.2442 train_time:304690ms step_avg:54.41ms +step:5700/20000 train_loss:3.0373 train_time:310073ms step_avg:54.40ms +step:5800/20000 train_loss:3.2252 train_time:315546ms step_avg:54.40ms +step:5800/20000 val_loss:3.2115 val_bpb:1.2433 train_time:315567ms step_avg:54.41ms +step:5900/20000 train_loss:3.2117 train_time:320949ms step_avg:54.40ms +step:6000/20000 train_loss:3.3454 train_time:326412ms step_avg:54.40ms +step:6000/20000 val_loss:3.2090 val_bpb:1.2423 train_time:326433ms step_avg:54.41ms +step:6100/20000 train_loss:3.2287 train_time:331816ms step_avg:54.40ms +step:6200/20000 train_loss:3.3731 train_time:337289ms step_avg:54.40ms +step:6200/20000 val_loss:3.2089 val_bpb:1.2423 train_time:337310ms step_avg:54.40ms +step:6300/20000 train_loss:3.2848 train_time:342750ms step_avg:54.40ms +step:6400/20000 train_loss:3.2794 train_time:348153ms step_avg:54.40ms +step:6400/20000 val_loss:3.2089 val_bpb:1.2423 train_time:348174ms step_avg:54.40ms +step:6500/20000 train_loss:3.2250 train_time:353630ms step_avg:54.40ms +step:6600/20000 train_loss:3.2528 train_time:359033ms step_avg:54.40ms +step:6600/20000 val_loss:3.2024 val_bpb:1.2398 train_time:359055ms step_avg:54.40ms +step:6700/20000 train_loss:2.9342 train_time:364496ms step_avg:54.40ms +step:6800/20000 train_loss:3.1303 train_time:369898ms step_avg:54.40ms +step:6800/20000 val_loss:3.2006 val_bpb:1.2391 train_time:369919ms step_avg:54.40ms +step:6900/20000 train_loss:3.2926 train_time:375631ms step_avg:54.44ms +step:7000/20000 train_loss:3.1737 train_time:381036ms step_avg:54.43ms +step:7000/20000 val_loss:3.1981 val_bpb:1.2381 train_time:381057ms step_avg:54.44ms +step:7100/20000 train_loss:3.2016 train_time:386499ms step_avg:54.44ms +step:7200/20000 train_loss:3.0915 train_time:391901ms step_avg:54.43ms +step:7200/20000 val_loss:3.1937 val_bpb:1.2364 train_time:391922ms step_avg:54.43ms +step:7300/20000 train_loss:2.9787 train_time:397372ms step_avg:54.43ms +step:7400/20000 train_loss:3.1655 train_time:402778ms step_avg:54.43ms +step:7400/20000 val_loss:3.1959 val_bpb:1.2372 train_time:402800ms step_avg:54.43ms +step:7500/20000 train_loss:3.2381 train_time:408250ms step_avg:54.43ms +step:7600/20000 train_loss:3.3154 train_time:413654ms step_avg:54.43ms +step:7600/20000 val_loss:3.1901 val_bpb:1.2350 train_time:413674ms step_avg:54.43ms +step:7700/20000 train_loss:3.1743 train_time:419125ms step_avg:54.43ms +step:7800/20000 train_loss:3.3929 train_time:424533ms step_avg:54.43ms +step:7800/20000 val_loss:3.1908 val_bpb:1.2353 train_time:424551ms step_avg:54.43ms +step:7900/20000 train_loss:3.2292 train_time:429993ms step_avg:54.43ms +step:8000/20000 train_loss:3.1787 train_time:435399ms step_avg:54.42ms +step:8000/20000 val_loss:3.1883 val_bpb:1.2343 train_time:435420ms step_avg:54.43ms +step:8100/20000 train_loss:3.2338 train_time:440873ms step_avg:54.43ms +step:8200/20000 train_loss:3.1517 train_time:446278ms step_avg:54.42ms +step:8200/20000 val_loss:3.1880 val_bpb:1.2342 train_time:446300ms step_avg:54.43ms +step:8300/20000 train_loss:3.2154 train_time:451740ms step_avg:54.43ms +step:8400/20000 train_loss:3.1496 train_time:457210ms step_avg:54.43ms +step:8400/20000 val_loss:3.1917 val_bpb:1.2356 train_time:457231ms step_avg:54.43ms +step:8500/20000 train_loss:3.2088 train_time:462614ms step_avg:54.43ms +step:8600/20000 train_loss:3.1975 train_time:468077ms step_avg:54.43ms +step:8600/20000 val_loss:3.1817 val_bpb:1.2317 train_time:468098ms step_avg:54.43ms +step:8700/20000 train_loss:3.2228 train_time:473478ms step_avg:54.42ms +step:8800/20000 train_loss:3.3391 train_time:478947ms step_avg:54.43ms +step:8800/20000 val_loss:3.1822 val_bpb:1.2319 train_time:478968ms step_avg:54.43ms +step:8900/20000 train_loss:3.2818 train_time:484354ms step_avg:54.42ms +step:9000/20000 train_loss:3.1970 train_time:489818ms step_avg:54.42ms +step:9000/20000 val_loss:3.1797 val_bpb:1.2310 train_time:489839ms step_avg:54.43ms +step:9100/20000 train_loss:3.1193 train_time:495222ms step_avg:54.42ms +step:9200/20000 train_loss:3.1824 train_time:500696ms step_avg:54.42ms +step:9200/20000 val_loss:3.1848 val_bpb:1.2330 train_time:500718ms step_avg:54.43ms +step:9300/20000 train_loss:3.2300 train_time:506101ms step_avg:54.42ms +step:9400/20000 train_loss:3.1601 train_time:511565ms step_avg:54.42ms +step:9400/20000 val_loss:3.1753 val_bpb:1.2293 train_time:511586ms step_avg:54.42ms +step:9500/20000 train_loss:3.0999 train_time:516969ms step_avg:54.42ms +step:9600/20000 train_loss:3.1371 train_time:522441ms step_avg:54.42ms +step:9600/20000 val_loss:3.1739 val_bpb:1.2287 train_time:522462ms step_avg:54.42ms +step:9700/20000 train_loss:3.3307 train_time:527845ms step_avg:54.42ms +step:9800/20000 train_loss:3.1221 train_time:533310ms step_avg:54.42ms +step:9800/20000 val_loss:3.1754 val_bpb:1.2293 train_time:533331ms step_avg:54.42ms +step:9900/20000 train_loss:3.1758 train_time:538712ms step_avg:54.42ms +step:10000/20000 train_loss:3.1863 train_time:544182ms step_avg:54.42ms +step:10000/20000 val_loss:3.1651 val_bpb:1.2253 train_time:544203ms step_avg:54.42ms +step:10100/20000 train_loss:3.1795 train_time:549586ms step_avg:54.41ms +step:10200/20000 train_loss:3.2239 train_time:555048ms step_avg:54.42ms +step:10200/20000 val_loss:3.1537 val_bpb:1.2209 train_time:555069ms step_avg:54.42ms +step:10300/20000 train_loss:3.3824 train_time:560517ms step_avg:54.42ms +step:10400/20000 train_loss:3.1915 train_time:565921ms step_avg:54.42ms +step:10400/20000 val_loss:3.1418 val_bpb:1.2163 train_time:565943ms step_avg:54.42ms +step:10500/20000 train_loss:3.3167 train_time:571382ms step_avg:54.42ms +step:10600/20000 train_loss:3.1134 train_time:576782ms step_avg:54.41ms +step:10600/20000 val_loss:3.1286 val_bpb:1.2112 train_time:576803ms step_avg:54.42ms +step:10700/20000 train_loss:3.1631 train_time:582249ms step_avg:54.42ms +step:10800/20000 train_loss:3.2310 train_time:587646ms step_avg:54.41ms +step:10800/20000 val_loss:3.1154 val_bpb:1.2061 train_time:587667ms step_avg:54.41ms +step:10900/20000 train_loss:3.1961 train_time:593102ms step_avg:54.41ms +step:11000/20000 train_loss:3.1909 train_time:598499ms step_avg:54.41ms +step:11000/20000 val_loss:3.1063 val_bpb:1.2025 train_time:598520ms step_avg:54.41ms +step:11028/20000 val_loss:3.1059 val_bpb:1.2024 train_time:600030ms step_avg:54.41ms +stopping_early: wallclock_cap train_time:600030ms step:11028/20000 +peak memory allocated: 10490 MiB reserved: 11544 MiB +Serialized model: 58152343 bytes +Code size: 73221 bytes +Total submission size: 58225564 bytes +Serialized model int8+zlib: 15340685 bytes (payload:16483504 raw_torch:16528409 payload_ratio:3.53x) +Total submission size int8+zlib: 15413906 bytes +final_int8_zlib_roundtrip val_loss:3.1251 val_bpb:1.2098 eval_time:1089ms +final_int8_zlib_roundtrip_exact val_loss:3.12506621 val_bpb:1.20981052 +ttt:start chunks=1238 chunk_tokens=32768 windows=39590 ttt_lr=0.005 ttt_epochs=3 freeze_blocks=0 +ttt:params unfrozen=16364616 frozen=0 + ttt_chunk [1/1238] bpb=1.213960 time=0.3s + ttt_chunk [11/1238] bpb=1.185170 time=1.4s + ttt_chunk [21/1238] bpb=1.233462 time=2.5s + ttt_chunk [31/1238] bpb=1.232353 time=3.6s + ttt_chunk [41/1238] bpb=1.226359 time=4.7s + ttt_chunk [51/1238] bpb=1.220914 time=5.8s + ttt_chunk [61/1238] bpb=1.213632 time=6.9s + ttt_chunk [71/1238] bpb=1.219326 time=8.0s + ttt_chunk [81/1238] bpb=1.214166 time=9.1s + ttt_chunk [91/1238] bpb=1.210535 time=10.1s + ttt_chunk [101/1238] bpb=1.209532 time=11.2s + ttt_chunk [111/1238] bpb=1.208397 time=12.3s + ttt_chunk [121/1238] bpb=1.211534 time=13.4s + ttt_chunk [131/1238] bpb=1.215234 time=14.5s + ttt_chunk [141/1238] bpb=1.215423 time=15.6s + ttt_chunk [151/1238] bpb=1.214929 time=16.7s + ttt_chunk [161/1238] bpb=1.215454 time=17.8s + ttt_chunk [171/1238] bpb=1.215333 time=18.8s + ttt_chunk [181/1238] bpb=1.213529 time=19.9s + ttt_chunk [191/1238] bpb=1.213063 time=21.0s + ttt_chunk [201/1238] bpb=1.210286 time=22.1s + ttt_chunk [211/1238] bpb=1.214710 time=23.2s + ttt_chunk [221/1238] bpb=1.214507 time=24.3s + ttt_chunk [231/1238] bpb=1.216251 time=25.4s + ttt_chunk [241/1238] bpb=1.215930 time=26.5s + ttt_chunk [251/1238] bpb=1.216113 time=27.6s + ttt_chunk [261/1238] bpb=1.216350 time=28.7s + ttt_chunk [271/1238] bpb=1.216465 time=29.8s + ttt_chunk [281/1238] bpb=1.215550 time=30.9s + ttt_chunk [291/1238] bpb=1.216590 time=32.1s + ttt_chunk [301/1238] bpb=1.216402 time=33.1s + ttt_chunk [311/1238] bpb=1.214764 time=34.3s + ttt_chunk [321/1238] bpb=1.214624 time=35.4s + ttt_chunk [331/1238] bpb=1.214961 time=36.4s + ttt_chunk [341/1238] bpb=1.214126 time=37.6s + ttt_chunk [351/1238] bpb=1.214878 time=38.6s + ttt_chunk [361/1238] bpb=1.213800 time=39.7s + ttt_chunk [371/1238] bpb=1.212219 time=40.8s + ttt_chunk [381/1238] bpb=1.212428 time=42.0s + ttt_chunk [391/1238] bpb=1.211991 time=43.0s + ttt_chunk [401/1238] bpb=1.211638 time=44.1s + ttt_chunk [411/1238] bpb=1.212205 time=45.2s + ttt_chunk [421/1238] bpb=1.211422 time=46.3s + ttt_chunk [431/1238] bpb=1.211544 time=47.3s + ttt_chunk [441/1238] bpb=1.211657 time=48.5s + ttt_chunk [451/1238] bpb=1.212990 time=49.6s + ttt_chunk [461/1238] bpb=1.211200 time=50.7s + ttt_chunk [471/1238] bpb=1.211167 time=51.7s + ttt_chunk [481/1238] bpb=1.211235 time=52.8s + ttt_chunk [491/1238] bpb=1.211563 time=53.9s + ttt_chunk [501/1238] bpb=1.211019 time=55.0s + ttt_chunk [511/1238] bpb=1.210805 time=56.1s + ttt_chunk [521/1238] bpb=1.210427 time=57.3s + ttt_chunk [531/1238] bpb=1.210362 time=58.4s + ttt_chunk [541/1238] bpb=1.210471 time=59.5s + ttt_chunk [551/1238] bpb=1.210132 time=60.6s + ttt_chunk [561/1238] bpb=1.209716 time=61.7s + ttt_chunk [571/1238] bpb=1.209006 time=62.8s + ttt_chunk [581/1238] bpb=1.209196 time=63.9s + ttt_chunk [591/1238] bpb=1.209397 time=65.1s + ttt_chunk [601/1238] bpb=1.209350 time=66.2s + ttt_chunk [611/1238] bpb=1.209862 time=67.3s + ttt_chunk [621/1238] bpb=1.210664 time=68.4s + ttt_chunk [631/1238] bpb=1.210564 time=69.5s + ttt_chunk [641/1238] bpb=1.210867 time=70.6s + ttt_chunk [651/1238] bpb=1.211152 time=71.6s + ttt_chunk [661/1238] bpb=1.210387 time=72.8s + ttt_chunk [671/1238] bpb=1.209923 time=73.9s + ttt_chunk [681/1238] bpb=1.211243 time=74.9s + ttt_chunk [691/1238] bpb=1.211332 time=76.0s + ttt_chunk [701/1238] bpb=1.211041 time=77.1s + ttt_chunk [711/1238] bpb=1.211820 time=78.2s + ttt_chunk [721/1238] bpb=1.212159 time=79.3s + ttt_chunk [731/1238] bpb=1.211493 time=80.4s + ttt_chunk [741/1238] bpb=1.211358 time=81.5s + ttt_chunk [751/1238] bpb=1.210440 time=82.6s + ttt_chunk [761/1238] bpb=1.209539 time=83.7s + ttt_chunk [771/1238] bpb=1.208415 time=84.8s + ttt_chunk [781/1238] bpb=1.208277 time=85.9s + ttt_chunk [791/1238] bpb=1.208715 time=87.1s + ttt_chunk [801/1238] bpb=1.209145 time=88.2s + ttt_chunk [811/1238] bpb=1.208540 time=89.3s + ttt_chunk [821/1238] bpb=1.207517 time=90.4s + ttt_chunk [831/1238] bpb=1.207216 time=91.5s + ttt_chunk [841/1238] bpb=1.206873 time=92.6s + ttt_chunk [851/1238] bpb=1.206386 time=93.6s + ttt_chunk [861/1238] bpb=1.205954 time=94.7s + ttt_chunk [871/1238] bpb=1.205667 time=95.8s + ttt_chunk [881/1238] bpb=1.205243 time=96.9s + ttt_chunk [891/1238] bpb=1.204608 time=98.0s + ttt_chunk [901/1238] bpb=1.205071 time=99.1s + ttt_chunk [911/1238] bpb=1.204903 time=100.2s + ttt_chunk [921/1238] bpb=1.205243 time=101.3s + ttt_chunk [931/1238] bpb=1.206012 time=102.4s + ttt_chunk [941/1238] bpb=1.206505 time=103.5s + ttt_chunk [951/1238] bpb=1.206512 time=104.7s + ttt_chunk [961/1238] bpb=1.207350 time=105.8s + ttt_chunk [971/1238] bpb=1.207901 time=106.9s + ttt_chunk [981/1238] bpb=1.208328 time=108.0s + ttt_chunk [991/1238] bpb=1.208143 time=109.1s + ttt_chunk [1001/1238] bpb=1.208430 time=110.3s + ttt_chunk [1011/1238] bpb=1.208825 time=111.4s + ttt_chunk [1021/1238] bpb=1.209521 time=112.6s + ttt_chunk [1031/1238] bpb=1.210074 time=113.6s + ttt_chunk [1041/1238] bpb=1.210463 time=114.7s + ttt_chunk [1051/1238] bpb=1.210280 time=115.8s + ttt_chunk [1061/1238] bpb=1.210359 time=116.9s + ttt_chunk [1071/1238] bpb=1.210372 time=118.1s + ttt_chunk [1081/1238] bpb=1.210161 time=119.1s + ttt_chunk [1091/1238] bpb=1.210231 time=120.2s + ttt_chunk [1101/1238] bpb=1.210725 time=121.3s + ttt_chunk [1111/1238] bpb=1.210976 time=122.4s + ttt_chunk [1121/1238] bpb=1.211104 time=123.5s + ttt_chunk [1131/1238] bpb=1.210638 time=124.6s + ttt_chunk [1141/1238] bpb=1.210245 time=125.7s + ttt_chunk [1151/1238] bpb=1.210262 time=126.8s + ttt_chunk [1161/1238] bpb=1.210405 time=127.8s + ttt_chunk [1171/1238] bpb=1.210097 time=128.9s + ttt_chunk [1181/1238] bpb=1.209575 time=130.0s + ttt_chunk [1191/1238] bpb=1.209705 time=131.1s + ttt_chunk [1201/1238] bpb=1.209780 time=132.2s + ttt_chunk [1211/1238] bpb=1.209560 time=133.3s + ttt_chunk [1221/1238] bpb=1.208977 time=134.4s + ttt_chunk [1231/1238] bpb=1.208662 time=135.6s + ttt_chunk [1238/1238] bpb=1.208603 time=136.2s +ttt:done val_loss=3.119953 val_bpb=1.207831 elapsed=136.3s +final_int8_ttt val_loss:3.1200 val_bpb:1.2078 eval_time:136274ms +final_int8_ttt_exact val_loss:3.11995300 val_bpb:1.20783103