diff --git a/RUNPOD.md b/RUNPOD.md new file mode 100644 index 0000000000..b0b0790ef1 --- /dev/null +++ b/RUNPOD.md @@ -0,0 +1,99 @@ +# Parameter Golf — RunPod Workflow + +## Fastest model on the leaderboard (April 2026) +**Score: 1.0810 bpb** — bigbag, *SP8192 + 3-Layer Recurrence + Parallel Residuals + QK-Gain 5.25 + Legal Score-First TTT* +File: `records/track_10min_16mb/2026-04-09_SP8192_3LayerRecur_ParResid_QK525_LegalTTT/train_gpt.py` + +Key techniques: +- **SP8192** — 8192-token SentencePiece vocabulary +- **3-Layer Depth Recurrence** — layers 3-5 loop 3×, giving 17 effective layers from 11 physical +- **Parallel Residuals** (GPT-J style) — attention + MLP read same input from layer 7+ +- **QK-Gain 5.25** — learnable per-head query scaling +- **Score-First TTT** — SGD test-time training on eval chunks, scores first then updates +- **GPTQ SDClip int6/int8** + Brotli-11 compression to stay under 16 MB + +--- + +## 1. Launch a RunPod pod + +1. Go to [runpod.io](https://www.runpod.io) → **GPU Cloud** → **New Pod** +2. Use the official template: [Parameter Golf Template](https://www.runpod.io/console/gpu-cloud) + - Or choose any pod with a CUDA image (e.g., `runpod/pytorch:2.2.0-py3.10-cuda12.1-devel`) +3. Enable **SSH terminal access** +4. For experiments: **1×H100** (~$3/hr). For leaderboard: **8×H100 SXM** (~$20/hr) + +## 2. SSH into the pod + +```bash +ssh root@ -p +``` + +## 3. Run setup on the pod + +```bash +curl -fsSL https://raw.githubusercontent.com/openai/parameter-golf/main/... \ + | bash +# OR after syncing this repo: +bash runpod_setup.sh +``` + +## 4. Sync your local changes to the pod + +Set your pod connection info: +```bash +export RUNPOD_HOST="root@213.34.xx.xx" +export RUNPOD_PORT="22204" +``` + +One-shot sync: +```bash +bash sync_to_runpod.sh +``` + +Watch mode (auto-sync on file save): +```bash +bash sync_to_runpod.sh --watch +``` + +## 5. Training commands + +### Quick test (1×H100, sp1024 baseline, ~10 min) +```bash +cd /workspace/parameter-golf +RUN_ID=baseline_sp1024 \ +DATA_PATH=./data/datasets/fineweb10B_sp1024/ \ +TOKENIZER_PATH=./data/tokenizers/fineweb_1024_bpe.model \ +VOCAB_SIZE=1024 \ +torchrun --standalone --nproc_per_node=1 train_gpt.py +``` + +### Reproduce current SOTA (8×H100 SXM, sp8192) +```bash +cd /workspace/parameter-golf/records/track_10min_16mb/2026-04-09_SP8192_3LayerRecur_ParResid_QK525_LegalTTT +RUN_ID=sota_repro \ +DATA_PATH=../../../data/datasets/fineweb10B_sp8192/ \ +TOKENIZER_PATH=../../../data/tokenizers/fineweb_8192_spm.model \ +VOCAB_SIZE=8192 \ +torchrun --standalone --nproc_per_node=8 train_gpt.py +``` + +### Unlimited compute (no 10-min cap) +```bash +MAX_WALLCLOCK_SECONDS=0 \ +VAL_LOSS_EVERY=200 \ +RUN_ID=my_long_run \ +... torchrun ... +``` + +## 6. Environment variables reference + +| Variable | Default | Description | +|---|---|---| +| `RUN_ID` | required | Name for this run's output | +| `DATA_PATH` | required | Path to FineWeb dataset shards | +| `TOKENIZER_PATH` | required | Path to .model tokenizer | +| `VOCAB_SIZE` | `1024` | Must match tokenizer (1024 or 8192) | +| `MAX_WALLCLOCK_SECONDS` | `600` | Set to `0` for unlimited | +| `VAL_LOSS_EVERY` | `0` | Print val loss every N steps | +| `VAL_BATCH_SIZE` | `8192` | Tokens per val batch | +| `ITERATIONS` | auto | Override training step count | diff --git a/experiments/train_gpt_deq.py b/experiments/train_gpt_deq.py new file mode 100644 index 0000000000..ce092007fe --- /dev/null +++ b/experiments/train_gpt_deq.py @@ -0,0 +1,899 @@ +""" +Experiment: Deep Equilibrium Universal Transformer (DEQ-UT) +============================================================ +"Universal transformer" is explicitly on OpenAI's wish list for parameter golf. + +CONCEPT: + Instead of N sequential transformer layers, we run ONE transformer block + repeatedly until its hidden states converge to a fixed point: + + x* = f(x*, z) where z = input embedding, f = transformer block + + The model has 1 physical layer but effectively infinite depth at convergence. + Parameter count is ~1/11th of a normal model, spending all 16MB on fidelity. + +ARCHITECTURE: + - Single transformer Block (attn + MLP + norms) — ~2M unquantized params + - Anderson acceleration: 5-history window for fast convergence + - Phantom gradients for training: treat as if we ran K=4 steps, backprop through those + - At eval: run until ||z_{t+1} - z_t||/||z_t|| < tol (up to max_iter=20) + - Track per-step convergence stats for analysis + +WHY IT MIGHT WORK: + - Scaling law says more compute at inference beats more params at fixed budget + - DEQ = infinite depth from 1 physical block → extreme L(N) efficiency + - Compatible with all quantization/compression tricks (only 1 block to store) + - Test-Time Training becomes very cheap (only 1 block to adapt) + +KNOWN RISKS: + - Fixed-point may not converge for all inputs (use fallback after max_iter) + - Training instability: phantom gradient approximation can diverge + - Slower per-step than standard transformer (root-finding overhead) + - May need careful initialization (start with identity-ish residual connections) + +TO RUN (1xH100, ablation mode): + RUN_ID=deq_smoke \ + DATA_PATH=./data/datasets/fineweb10B_sp1024/ \ + TOKENIZER_PATH=./data/tokenizers/fineweb_1024_bpe.model \ + VOCAB_SIZE=1024 \ + MAX_WALLCLOCK_SECONDS=0 \ + ITERATIONS=2000 \ + DEQ_MAX_ITER=8 \ + DEQ_PHANTOM_STEPS=4 \ + torchrun --standalone --nproc_per_node=1 experiments/train_gpt_deq.py +""" + +import collections, copy, glob, io, lzma, math, os +from pathlib import Path +import random, re, subprocess, sys, time, uuid +import numpy as np +import sentencepiece as spm +import torch +import torch.distributed as dist +import torch.nn.functional as F +from torch.nn.parallel import DistributedDataParallel as DDP +from torch import nn +try: + from flash_attn_interface import flash_attn_func as flash_attn_3_func + HAS_FLASH3 = True +except ImportError: + HAS_FLASH3 = False + + +# --------------------------------------------------------------------------- +# Hyperparameters +# --------------------------------------------------------------------------- +class Hyperparameters: + data_dir = os.environ.get('DATA_DIR', './data/') + seed = int(os.environ.get('SEED', 1337)) + run_id = os.environ.get('RUN_ID', str(uuid.uuid4())) + iterations = int(os.environ.get('ITERATIONS', 20000)) + warmdown_frac = float(os.environ.get('WARMDOWN_FRAC', 0.72)) + warmup_steps = int(os.environ.get('WARMUP_STEPS', 20)) + train_batch_tokens = int(os.environ.get('TRAIN_BATCH_TOKENS', 786432)) + train_seq_len = int(os.environ.get('TRAIN_SEQ_LEN', 2048)) + train_log_every = int(os.environ.get('TRAIN_LOG_EVERY', 200)) + max_wallclock_seconds = float(os.environ.get('MAX_WALLCLOCK_SECONDS', 600)) + val_batch_tokens = int(os.environ.get('VAL_BATCH_TOKENS', 524288)) + eval_seq_len = int(os.environ.get('EVAL_SEQ_LEN', 2048)) + val_loss_every = int(os.environ.get('VAL_LOSS_EVERY', 500)) + sliding_window_enabled = bool(int(os.environ.get('SLIDING_WINDOW_ENABLED', '1'))) + # Model + vocab_size = int(os.environ.get('VOCAB_SIZE', 8192)) + model_dim = int(os.environ.get('MODEL_DIM', 512)) + num_heads = int(os.environ.get('NUM_HEADS', 8)) + num_kv_heads = int(os.environ.get('NUM_KV_HEADS', 4)) + mlp_mult = float(os.environ.get('MLP_MULT', 4.0)) + rope_base = float(os.environ.get('ROPE_BASE', 1e4)) + rope_dims = int(os.environ.get('ROPE_DIMS', 16)) + logit_softcap = float(os.environ.get('LOGIT_SOFTCAP', 30.0)) + qk_gain_init = float(os.environ.get('QK_GAIN_INIT', 5.5)) + # DEQ-specific + deq_max_iter_train = int(os.environ.get('DEQ_MAX_ITER_TRAIN', 8)) + deq_max_iter_eval = int(os.environ.get('DEQ_MAX_ITER_EVAL', 20)) + deq_phantom_steps = int(os.environ.get('DEQ_PHANTOM_STEPS', 4)) # steps to unroll for backprop + deq_tol = float(os.environ.get('DEQ_TOL', 1e-3)) # convergence threshold + deq_anderson_history = int(os.environ.get('DEQ_ANDERSON_HISTORY', 5)) + deq_anderson_beta = float(os.environ.get('DEQ_ANDERSON_BETA', 1.0)) + # Number of "warm" pre-iteration passes before starting Anderson (helps stability) + deq_warmup_iters = int(os.environ.get('DEQ_WARMUP_ITERS', 2)) + # Optimizer + min_lr = float(os.environ.get('MIN_LR', 0.0)) + tied_embed_lr = float(os.environ.get('TIED_EMBED_LR', 0.03)) + tied_embed_init_std = float(os.environ.get('TIED_EMBED_INIT_STD', 0.005)) + matrix_lr = float(os.environ.get('MATRIX_LR', 0.022)) + scalar_lr = float(os.environ.get('SCALAR_LR', 0.02)) + muon_momentum = float(os.environ.get('MUON_MOMENTUM', 0.99)) + muon_backend_steps = int(os.environ.get('MUON_BACKEND_STEPS', 5)) + muon_row_normalize = bool(int(os.environ.get('MUON_ROW_NORMALIZE', '1'))) + muon_wd = float(os.environ.get('MUON_WD', 0.095)) + embed_wd = float(os.environ.get('EMBED_WD', 0.085)) + beta1 = float(os.environ.get('BETA1', 0.9)) + beta2 = float(os.environ.get('BETA2', 0.95)) + adam_eps = float(os.environ.get('ADAM_EPS', 1e-8)) + adam_wd = float(os.environ.get('ADAM_WD', 0.02)) + grad_clip_norm = float(os.environ.get('GRAD_CLIP_NORM', 0.3)) + ema_decay = float(os.environ.get('EMA_DECAY', 0.9965)) + eval_stride = int(os.environ.get('EVAL_STRIDE', 64)) + # Quantization + compressor = os.environ.get('COMPRESSOR', 'brotli') + gptq_calibration_batches = int(os.environ.get('GPTQ_CALIBRATION_BATCHES', 64)) + gptq_reserve_seconds = float(os.environ.get('GPTQ_RESERVE_SECONDS', 12.0)) + matrix_bits = int(os.environ.get('MATRIX_BITS', 6)) + embed_bits = int(os.environ.get('EMBED_BITS', 8)) + matrix_clip_sigmas = float(os.environ.get('MATRIX_CLIP_SIGMAS', 12.85)) + embed_clip_sigmas = float(os.environ.get('EMBED_CLIP_SIGMAS', 20.0)) + # Distributed + distributed = 'RANK' in os.environ and 'WORLD_SIZE' in os.environ + rank = int(os.environ.get('RANK', '0')) + world_size = int(os.environ.get('WORLD_SIZE', '1')) + local_rank = int(os.environ.get('LOCAL_RANK', '0')) + is_main_process = rank == 0 + grad_accum_steps = 8 // world_size + # Derived + datasets_dir = os.path.join(data_dir, 'datasets', f'fineweb10B_sp{vocab_size}') + train_files = os.path.join(datasets_dir, 'fineweb_train_*.bin') + val_files = os.path.join(datasets_dir, 'fineweb_val_*.bin') + tokenizer_path = os.path.join(data_dir, 'tokenizers', f'fineweb_{vocab_size}_bpe.model') + logfile = f'logs/{run_id}.txt' + model_path = 'final_model.pt' + quantized_model_path = 'final_model.int6.ptz' + + +_logger_hparams = None + + +def set_logging_hparams(h): + global _logger_hparams + _logger_hparams = h + + +def log(msg, console=True): + if _logger_hparams is None: + print(msg) + return + if _logger_hparams.is_main_process: + if console: + print(msg) + if _logger_hparams.logfile is not None: + with open(_logger_hparams.logfile, 'a', encoding='utf-8') as f: + print(msg, file=f) + + +# --------------------------------------------------------------------------- +# Data loading (identical to main submission) +# --------------------------------------------------------------------------- +class ValidationData: + def __init__(self, h, device): + self.sp = spm.SentencePieceProcessor(model_file=h.tokenizer_path) + if int(self.sp.vocab_size()) != h.vocab_size: + raise ValueError(f"VOCAB_SIZE mismatch") + self.val_tokens = load_validation_tokens(h.val_files, h.eval_seq_len) + self.base_bytes_lut, self.has_leading_space_lut, self.is_boundary_token_lut = \ + build_sentencepiece_luts(self.sp, h.vocab_size, device) + + +def build_sentencepiece_luts(sp, vocab_size, device): + sp_vocab_size = int(sp.vocab_size()) + assert sp.piece_to_id('▁') != sp.unk_id() + table_size = max(sp_vocab_size, vocab_size) + base_bytes_np = np.zeros((table_size,), dtype=np.int16) + has_leading_space_np = np.zeros((table_size,), dtype=np.bool_) + is_boundary_token_np = np.ones((table_size,), dtype=np.bool_) + for token_id in range(sp_vocab_size): + if sp.is_control(token_id) or sp.is_unknown(token_id) or sp.is_unused(token_id): + continue + is_boundary_token_np[token_id] = False + if sp.is_byte(token_id): + base_bytes_np[token_id] = 1 + continue + piece = sp.id_to_piece(token_id) + if piece.startswith('▁'): + has_leading_space_np[token_id] = True + piece = piece[1:] + base_bytes_np[token_id] = len(piece.encode('utf-8')) + return (torch.tensor(base_bytes_np, dtype=torch.int16, device=device), + torch.tensor(has_leading_space_np, dtype=torch.bool, device=device), + torch.tensor(is_boundary_token_np, dtype=torch.bool, device=device)) + + +_SHARD_HEADER_BYTES = 256 * np.dtype(' 0 else 0 + num_sequences = (self.num_tokens[si] - 1 - phase) // self.seq_len + sequence_order = self.rng.permutation(num_sequences) + self.start_inds[si] = (phase + sequence_order * self.seq_len).tolist() + + def next_batch(self, global_tokens, grad_accum_steps): + device_tokens = global_tokens // (self.world_size * grad_accum_steps) + device_batch_size = device_tokens // self.seq_len + remaining = np.array([len(s) for s in self.start_inds], dtype=np.float64) + x = torch.empty((device_batch_size, self.seq_len), dtype=torch.int64) + y = torch.empty((device_batch_size, self.seq_len), dtype=torch.int64) + for bi in range(device_batch_size): + total = remaining.sum() + if total <= 0: + for si in range(len(self.files)): + self._reset_shard(si) + remaining = np.array([len(s) for s in self.start_inds], dtype=np.float64) + total = remaining.sum() + probs = remaining / total + si = int(self.rng.choice(len(self.files), p=probs)) + start_ind = self.start_inds[si].pop() + remaining[si] -= 1 + mm = _get_shard_memmap(self.files[si]) + window = torch.as_tensor(np.array(mm[start_ind:start_ind + self.seq_len + 1], dtype=np.int64)) + x[bi] = window[:-1] + y[bi] = window[1:] + return x.to(self.device, non_blocking=True), y.to(self.device, non_blocking=True) + + +# --------------------------------------------------------------------------- +# Model building blocks +# --------------------------------------------------------------------------- +class RMSNorm(nn.Module): + def __init__(self, eps=None): + super().__init__() + self.eps = eps + + def forward(self, x): + return F.rms_norm(x, (x.size(-1),), eps=self.eps) + + +class CastedLinear(nn.Linear): + def forward(self, x): + w = self.weight.to(x.dtype) + bias = self.bias.to(x.dtype) if self.bias is not None else None + return F.linear(x, w, bias) + + +class Rotary(nn.Module): + def __init__(self, dim, base=1e4, train_seq_len=2048, rope_dims=0): + super().__init__() + self.rope_dims = rope_dims if rope_dims > 0 else dim + self.base = base + self.train_seq_len = train_seq_len + inv_freq = 1.0 / base ** (torch.arange(0, self.rope_dims, 2, dtype=torch.float32) / self.rope_dims) + self.register_buffer('inv_freq', inv_freq, persistent=False) + self._cache = {} + + def forward(self, seq_len, device, dtype): + key = (seq_len, device, dtype) + if key not in self._cache: + rd = self.rope_dims + if seq_len > self.train_seq_len: + scale = seq_len / self.train_seq_len + new_base = self.base * scale ** (rd / (rd - 2)) + inv_freq = 1.0 / new_base ** (torch.arange(0, rd, 2, dtype=torch.float32, device=device) / rd) + else: + inv_freq = self.inv_freq.to(device) + t = torch.arange(seq_len, device=device, dtype=inv_freq.dtype) + freqs = torch.outer(t, inv_freq) + self._cache[key] = (freqs.cos()[None, :, None, :].to(dtype), + freqs.sin()[None, :, None, :].to(dtype)) + return self._cache[key] + + +def apply_rotary_emb(x, cos, sin, rope_dims=0): + if rope_dims > 0 and rope_dims < x.size(-1): + x_rope, x_pass = x[..., :rope_dims], x[..., rope_dims:] + half = rope_dims // 2 + x1, x2 = x_rope[..., :half], x_rope[..., half:] + return torch.cat((torch.cat((x1 * cos + x2 * sin, x1 * -sin + x2 * cos), dim=-1), x_pass), dim=-1) + half = x.size(-1) // 2 + x1, x2 = x[..., :half], x[..., half:] + return torch.cat((x1 * cos + x2 * sin, x1 * -sin + x2 * cos), dim=-1) + + +class CausalSelfAttention(nn.Module): + def __init__(self, dim, num_heads, num_kv_heads, rope_base, rope_dims, qk_gain_init, train_seq_len): + super().__init__() + self.num_heads = num_heads + self.num_kv_heads = num_kv_heads + self.head_dim = dim // num_heads + kv_dim = num_kv_heads * self.head_dim + self.c_q = CastedLinear(dim, dim, bias=False) + self.c_k = CastedLinear(dim, kv_dim, bias=False) + self.c_v = CastedLinear(dim, kv_dim, bias=False) + self.proj = CastedLinear(dim, dim, bias=False) + self.proj._zero_init = True + self.q_gain = nn.Parameter(torch.full((num_heads,), qk_gain_init, dtype=torch.float32)) + self.rope_dims = rope_dims + self.rotary = Rotary(self.head_dim, base=rope_base, train_seq_len=train_seq_len, rope_dims=rope_dims) + + def forward(self, x): + bsz, seqlen, dim = x.shape + q = self.c_q(x).reshape(bsz, seqlen, self.num_heads, self.head_dim) + k = self.c_k(x).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim) + v = self.c_v(x).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim) + q = F.rms_norm(q, (q.size(-1),)) + k = F.rms_norm(k, (k.size(-1),)) + cos, sin = self.rotary(seqlen, x.device, q.dtype) + q = apply_rotary_emb(q, cos, sin, self.rope_dims) + k = apply_rotary_emb(k, cos, sin, self.rope_dims) + q = q * self.q_gain.to(dtype=q.dtype)[None, None, :, None] + if HAS_FLASH3: + y = flash_attn_3_func(q, k, v, causal=True) + else: + # Fallback to sdpa + q = q.transpose(1, 2) + k = k.transpose(1, 2).expand(-1, self.num_heads, -1, -1) + v = v.transpose(1, 2).expand(-1, self.num_heads, -1, -1) + y = F.scaled_dot_product_attention(q, k, v, is_causal=True).transpose(1, 2) + return self.proj(y.reshape(bsz, seqlen, dim)) + + +class MLP(nn.Module): + def __init__(self, dim, mlp_mult): + super().__init__() + hidden = int(mlp_mult * dim) + self.fc = CastedLinear(dim, hidden, bias=False) + self.proj = CastedLinear(hidden, dim, bias=False) + self.proj._zero_init = True + + def forward(self, x): + return self.proj(F.leaky_relu(self.fc(x), negative_slope=0.5).square()) + + +class UniversalBlock(nn.Module): + """Single transformer block used for all iterations of the DEQ loop.""" + def __init__(self, h): + super().__init__() + self.attn_norm = RMSNorm() + self.mlp_norm = RMSNorm() + self.attn = CausalSelfAttention(h.model_dim, h.num_heads, h.num_kv_heads, + h.rope_base, h.rope_dims, h.qk_gain_init, h.train_seq_len) + self.mlp = MLP(h.model_dim, h.mlp_mult) + self.attn_scale = nn.Parameter(torch.ones(h.model_dim, dtype=torch.float32)) + self.mlp_scale = nn.Parameter(torch.ones(h.model_dim, dtype=torch.float32)) + # Learnable mixing with input embedding (DEQ injection) + self.input_gate = nn.Parameter(torch.zeros(h.model_dim, dtype=torch.float32)) + + def forward(self, z, z0): + """ + z: current fixed-point iterate [B, T, D] + z0: input embedding (injected at every step to maintain conditioning) + """ + # Condition on input at every iteration via learned gating + g = torch.sigmoid(self.input_gate.to(dtype=z.dtype))[None, None, :] + x = z + g * z0 + attn_out = self.attn(self.attn_norm(x)) + x = x + self.attn_scale.to(dtype=x.dtype)[None, None, :] * attn_out + x = x + self.mlp_scale.to(dtype=x.dtype)[None, None, :] * self.mlp(self.mlp_norm(x)) + return x + + +# --------------------------------------------------------------------------- +# Anderson Acceleration for fixed-point finding +# --------------------------------------------------------------------------- +def anderson_step(f_history, x_history, beta=1.0): + """ + Given history of function values f(x_i) and iterates x_i, + compute the Anderson mixing step to accelerate convergence. + + Returns the next iterate. + """ + m = len(f_history) + if m == 1: + # No history to mix; just return f(x) + return f_history[0] + + # F matrix: columns are residuals f(x_i) - x_i + F = torch.stack([f - x for f, x in zip(f_history, x_history)], dim=-1) # [B*T*D, m] + # Solve least squares: min ||F @ alpha||^2 s.t. sum(alpha) = 1 + B, T, D = f_history[0].shape + F_flat = F.reshape(B * T * D, m) + # Normal equations: (F^T F) alpha = 1 / (1^T (F^T F)^-1 1) * (F^T F)^-1 1 + try: + FtF = F_flat.T @ F_flat + 1e-8 * torch.eye(m, device=F_flat.device, dtype=F_flat.dtype) + ones = torch.ones(m, 1, device=F_flat.device, dtype=F_flat.dtype) + alpha = torch.linalg.solve(FtF, ones) + alpha = alpha / alpha.sum() + except Exception: + # Fallback to pure iteration if solve fails + return f_history[-1] + # Mix: beta * sum(alpha_i * f(x_i)) + (1-beta) * sum(alpha_i * x_i) + x_stack = torch.stack(x_history, dim=-1) # [B, T, D, m] + f_stack = torch.stack(f_history, dim=-1) + alpha_t = alpha.reshape(1, 1, 1, m) + x_mix = (f_stack * alpha_t).sum(dim=-1) + return beta * x_mix + (1 - beta) * (x_stack * alpha_t).sum(dim=-1) + + +# --------------------------------------------------------------------------- +# DEQ GPT Model +# --------------------------------------------------------------------------- +class DEQGPT(nn.Module): + """ + Deep Equilibrium Universal Transformer. + One physical block run until fixed-point convergence. + + Virtual layer count at inference: as deep as needed for convergence. + Total parameter count: ~2M (1 block) + embeddings. + """ + def __init__(self, h): + super().__init__() + self.h = h + self.logit_softcap = h.logit_softcap + self.tok_emb = nn.Embedding(h.vocab_size, h.model_dim) + self.block = UniversalBlock(h) + self.final_norm = RMSNorm() + # Lightweight encoder/decoder projections (4 layers each) to give the + # DEQ loop a better inductive starting point and output structure + self.pre_deq = nn.ModuleList([ + nn.Sequential( + RMSNorm(), + CastedLinear(h.model_dim, h.model_dim, bias=False) + ) for _ in range(2) + ]) + self.post_deq = nn.ModuleList([ + nn.Sequential( + RMSNorm(), + CastedLinear(h.model_dim, h.model_dim, bias=False) + ) for _ in range(2) + ]) + self.max_iter_train = h.deq_max_iter_train + self.max_iter_eval = h.deq_max_iter_eval + self.phantom_steps = h.deq_phantom_steps + self.tol = h.deq_tol + self.anderson_history = h.deq_anderson_history + self.anderson_beta = h.deq_anderson_beta + self.warmup_iters = h.deq_warmup_iters + # Track convergence stats + self.register_buffer('_iter_count', torch.zeros(1), persistent=False) + self.register_buffer('_iter_total', torch.zeros(1), persistent=False) + self._init_weights() + + def _init_weights(self): + nn.init.normal_(self.tok_emb.weight, mean=0.0, std=0.005) + for name, module in self.named_modules(): + if isinstance(module, nn.Linear): + if getattr(module, '_zero_init', False): + nn.init.zeros_(module.weight) + elif module.weight.ndim == 2 and module.weight.shape[0] >= 64: + nn.init.orthogonal_(module.weight, gain=0.5) + + def _run_block(self, z, z0): + """Run one iteration of the universal block.""" + return self.block(z, z0) + + @torch.no_grad() + def _fixed_point_eval(self, z0): + """ + Anderson-accelerated fixed-point finding at eval time. + Runs until convergence or max_iter_eval. + """ + z = z0.clone() + f_history = [] + x_history = [] + iters_run = 0 + for i in range(self.max_iter_eval): + z_new = self._run_block(z, z0) + f_history.append(z_new) + x_history.append(z) + if len(f_history) > self.anderson_history: + f_history.pop(0) + x_history.pop(0) + if len(f_history) > self.warmup_iters: + z_next = anderson_step(f_history, x_history, beta=self.anderson_beta) + else: + z_next = z_new + # Check convergence + rel_change = (z_next - z).norm() / (z.norm() + 1e-8) + z = z_next + iters_run = i + 1 + if rel_change < self.tol: + break + self._iter_count += iters_run + self._iter_total += 1.0 + return z + + def _phantom_grad(self, z0): + """ + Training forward: phantom gradient approach. + Run phantom_steps iterations with gradient tracking, as if that were + the full fixed-point. This approximates the implicit gradient. + """ + z = z0.detach().clone() + # Warm start: a few no-grad iterations to get close to fixed point + with torch.no_grad(): + for _ in range(max(0, self.max_iter_train - self.phantom_steps)): + z = self._run_block(z, z0) + # Final phantom_steps with gradient + for _ in range(self.phantom_steps): + z = self._run_block(z, z0) + return z + + def forward_logits(self, input_ids): + x = F.rms_norm(self.tok_emb(input_ids), (self.tok_emb.embedding_dim,)) + # Pre-DEQ feature extraction + z0 = x + for layer in self.pre_deq: + norm_out, proj = layer + z0 = z0 + proj(norm_out(z0)) + # DEQ loop + if self.training: + z = self._phantom_grad(z0) + else: + z = self._fixed_point_eval(z0) + # Post-DEQ refinement + for layer in self.post_deq: + norm_out, proj = layer + z = z + proj(norm_out(z)) + z = self.final_norm(z) + logits = F.linear(z, self.tok_emb.weight) + return self.logit_softcap * torch.tanh(logits / self.logit_softcap) + + def forward(self, input_ids, target_ids): + logits = self.forward_logits(input_ids) + return F.cross_entropy(logits.reshape(-1, logits.size(-1)).float(), + target_ids.reshape(-1), reduction='mean') + + def log_convergence_stats(self): + if self._iter_total > 0: + avg_iters = (self._iter_count / self._iter_total).item() + log(f"deq_convergence: avg_iters={avg_iters:.2f} over {int(self._iter_total.item())} calls") + self._iter_count.zero_() + self._iter_total.zero_() + + +# --------------------------------------------------------------------------- +# Optimizer — same MuonEq-R as main submission +# --------------------------------------------------------------------------- +@torch.compile +def zeropower_via_newtonschulz5(G, steps=10, eps=1e-7): + a, b, c = 3.4445, -4.7750, 2.0315 + X = G.bfloat16() + X /= X.norm() + eps + transposed = G.size(0) > G.size(1) + if transposed: + X = X.T + for _ in range(steps): + A = X @ X.T + B = b * A + c * A @ A + X = a * X + B @ X + return X.T if transposed else X + + +class Muon(torch.optim.Optimizer): + def __init__(self, params, lr, momentum, backend_steps, nesterov=True, + weight_decay=0.0, row_normalize=False): + super().__init__(params, dict(lr=lr, momentum=momentum, backend_steps=backend_steps, + nesterov=nesterov, weight_decay=weight_decay, + row_normalize=row_normalize)) + + @torch.no_grad() + def step(self, closure=None): + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + is_dist = dist.is_available() and dist.is_initialized() + world_size = dist.get_world_size() if is_dist else 1 + rank = dist.get_rank() if is_dist else 0 + for group in self.param_groups: + params = group['params'] + if not params: + continue + total_params = sum(int(p.numel()) for p in params) + updates_flat = torch.zeros(total_params, device=params[0].device, dtype=torch.bfloat16) + curr = 0 + for i, p in enumerate(params): + if i % world_size == rank and p.grad is not None: + g = p.grad + state = self.state[p] + if 'momentum_buffer' not in state: + state['momentum_buffer'] = torch.zeros_like(g) + buf = state['momentum_buffer'] + buf.mul_(group['momentum']).add_(g) + if group['nesterov']: + g = g.add(buf, alpha=group['momentum']) + if group.get('row_normalize', False): + row_norms = g.float().norm(dim=-1, keepdim=True).clamp_min(1e-7) + g = g / row_norms.to(g.dtype) + g = zeropower_via_newtonschulz5(g, steps=group['backend_steps']) + g *= max(1, g.size(0) / g.size(1)) ** 0.5 + updates_flat[curr:curr + p.numel()] = g.reshape(-1) + curr += p.numel() + if is_dist: + dist.all_reduce(updates_flat, op=dist.ReduceOp.SUM) + curr = 0 + for p in params: + wd = group.get('weight_decay', 0.0) + if wd > 0.0: + p.data.mul_(1.0 - group['lr'] * wd) + p.add_(updates_flat[curr:curr + p.numel()].view_as(p).to(p.dtype), alpha=-group['lr']) + curr += p.numel() + return loss + + +CONTROL_PATTERNS = ('attn_scale', 'mlp_scale', 'q_gain', 'input_gate') + + +class DEQOptimizers: + def __init__(self, h, model): + named_params = list(model.named_parameters()) + matrix_params = [p for name, p in named_params + if p.ndim == 2 and not any(c in name for c in CONTROL_PATTERNS)] + scalar_params = [p for name, p in named_params + if p.ndim < 2 or any(c in name for c in CONTROL_PATTERNS)] + scalar_params = [p for p in scalar_params if p is not model.tok_emb.weight] + self.optimizer_tok = torch.optim.AdamW( + [{'params': [model.tok_emb.weight], 'lr': h.tied_embed_lr, 'base_lr': h.tied_embed_lr}], + betas=(h.beta1, h.beta2), eps=h.adam_eps, weight_decay=h.embed_wd, fused=True) + self.optimizer_muon = Muon(matrix_params, lr=h.matrix_lr, momentum=h.muon_momentum, + backend_steps=h.muon_backend_steps, weight_decay=h.muon_wd, + row_normalize=h.muon_row_normalize) + for group in self.optimizer_muon.param_groups: + group['base_lr'] = h.matrix_lr + self.optimizer_scalar = torch.optim.AdamW( + [{'params': scalar_params, 'lr': h.scalar_lr, 'base_lr': h.scalar_lr}], + betas=(h.beta1, h.beta2), eps=h.adam_eps, weight_decay=h.adam_wd, fused=True) + self.optimizers = [self.optimizer_tok, self.optimizer_muon, self.optimizer_scalar] + + def __iter__(self): + return iter(self.optimizers) + + def zero_grad_all(self): + for opt in self.optimizers: + opt.zero_grad(set_to_none=True) + + def step(self): + for opt in self.optimizers: + opt.step() + self.zero_grad_all() + + +# --------------------------------------------------------------------------- +# Evaluation helpers +# --------------------------------------------------------------------------- +def _loss_bpb(loss_sum, token_count, byte_count): + val_loss = (loss_sum / token_count).item() + val_bpb = val_loss / math.log(2.0) * (token_count.item() / byte_count.item()) + return val_loss, val_bpb + + +def eval_val_sliding(h, device, val_data, base_model, batch_seqs=32): + base_model.eval() + seq_len = h.eval_seq_len + context_size = seq_len - h.eval_stride + total_tokens = val_data.val_tokens.numel() - 1 + window_starts = [ws for ws in range(0, total_tokens, h.eval_stride) if ws + context_size < total_tokens] + total_windows = len(window_starts) + my_s = total_windows * h.rank // h.world_size + my_e = total_windows * (h.rank + 1) // h.world_size + my_windows = window_starts[my_s:my_e] + loss_sum = torch.zeros((), device=device, dtype=torch.float64) + token_count = torch.zeros((), device=device, dtype=torch.float64) + byte_count = torch.zeros((), device=device, dtype=torch.float64) + with torch.inference_mode(): + for bi in range(0, len(my_windows), batch_seqs): + batch_ws = my_windows[bi:bi + batch_seqs] + bsz = len(batch_ws) + x_batch = torch.zeros(bsz, seq_len, dtype=torch.int64, device=device) + y_batch = torch.zeros(bsz, seq_len, dtype=torch.int64, device=device) + wlens = [] + for i, ws in enumerate(batch_ws): + we = min(ws + seq_len, total_tokens) + wlen = we - ws + wlens.append(wlen) + chunk = val_data.val_tokens[ws:we + 1].to(dtype=torch.int64, device=device) + x_batch[i, :wlen] = chunk[:-1] + y_batch[i, :wlen] = chunk[1:] + with torch.autocast(device_type='cuda', dtype=torch.bfloat16): + logits = base_model.forward_logits(x_batch) + nll = F.cross_entropy(logits.reshape(-1, logits.size(-1)).float(), + y_batch.reshape(-1), reduction='none').reshape(bsz, seq_len) + for i, ws in enumerate(batch_ws): + wlen = wlens[i] + s = 0 if ws == 0 else context_size + loss_sum += nll[i, s:wlen].to(torch.float64).sum() + token_count += float(wlen - s) + tgt = y_batch[i, s:wlen] + prev = x_batch[i, s:wlen] + tb = val_data.base_bytes_lut[tgt].to(torch.float64) + tb += (val_data.has_leading_space_lut[tgt] & ~val_data.is_boundary_token_lut[prev]).to(torch.float64) + byte_count += tb.sum() + if dist.is_available() and dist.is_initialized(): + dist.all_reduce(loss_sum, op=dist.ReduceOp.SUM) + dist.all_reduce(token_count, op=dist.ReduceOp.SUM) + dist.all_reduce(byte_count, op=dist.ReduceOp.SUM) + base_model.log_convergence_stats() + base_model.train() + return _loss_bpb(loss_sum, token_count, byte_count) + + +def timed_eval(label, fn, *args, **kwargs): + torch.cuda.synchronize() + t0 = time.perf_counter() + val_loss, val_bpb = fn(*args, **kwargs) + torch.cuda.synchronize() + elapsed_ms = 1e3 * (time.perf_counter() - t0) + log(f"{label} val_loss:{val_loss:.8f} val_bpb:{val_bpb:.8f} eval_time:{elapsed_ms:.0f}ms") + return val_loss, val_bpb + + +# --------------------------------------------------------------------------- +# Training +# --------------------------------------------------------------------------- +def train_model(h, device, val_data): + base_model = DEQGPT(h).to(device).bfloat16() + # fp32 for control tensors + for name, param in base_model.named_parameters(): + if param.ndim < 2 or any(c in name for c in CONTROL_PATTERNS): + param.data = param.data.float() + compiled_model = torch.compile(base_model, dynamic=False) + if h.distributed: + model = DDP(compiled_model, device_ids=[h.local_rank], broadcast_buffers=False) + else: + model = compiled_model + total_params = sum(p.numel() for p in base_model.parameters()) + log(f"model_params: {total_params} ({total_params / 1e6:.2f}M)") + log(f"deq: max_iter_train={h.deq_max_iter_train} phantom_steps={h.deq_phantom_steps} " + f"max_iter_eval={h.deq_max_iter_eval} tol={h.deq_tol}") + optimizers = DEQOptimizers(h, base_model) + train_loader = ShuffledSequenceLoader(h, device) + max_wallclock_ms = 1e3 * h.max_wallclock_seconds if h.max_wallclock_seconds > 0 else None + if max_wallclock_ms is not None: + max_wallclock_ms -= h.gptq_reserve_seconds * 1e3 + + def training_frac(step, elapsed_ms): + if max_wallclock_ms is None: + return step / max(h.iterations, 1) + return elapsed_ms / max(max_wallclock_ms, 1e-9) + + def lr_mul(frac): + if h.warmdown_frac <= 0: + return 1.0 + if frac >= 1.0 - h.warmdown_frac: + return max((1.0 - frac) / h.warmdown_frac, h.min_lr) + return 1.0 + + ema_state = {name: t.detach().float().clone() for name, t in base_model.state_dict().items()} + training_time_ms = 0.0 + stop_after_step = None + torch.cuda.synchronize() + t0 = time.perf_counter() + step = 0 + + while True: + last_step = (step == h.iterations or (stop_after_step is not None and step >= stop_after_step)) + should_validate = last_step or (h.val_loss_every > 0 and step % h.val_loss_every == 0) + if should_validate: + torch.cuda.synchronize() + training_time_ms += 1e3 * (time.perf_counter() - t0) + val_loss, val_bpb = timed_eval('val_sliding', eval_val_sliding, h, device, val_data, base_model) + torch.cuda.synchronize() + t0 = time.perf_counter() + if last_step: + break + elapsed_ms = training_time_ms + 1e3 * (time.perf_counter() - t0) + frac = training_frac(step, elapsed_ms) + scale = lr_mul(frac) + + optimizers.zero_grad_all() + train_loss = torch.zeros((), device=device) + for micro_step in range(h.grad_accum_steps): + if h.distributed: + model.require_backward_grad_sync = micro_step == h.grad_accum_steps - 1 + x, y = train_loader.next_batch(h.train_batch_tokens, h.grad_accum_steps) + with torch.autocast(device_type='cuda', dtype=torch.bfloat16): + loss = model(x, y) + train_loss += loss.detach() + (loss / h.grad_accum_steps).backward() + train_loss /= h.grad_accum_steps + for opt in optimizers: + for group in opt.param_groups: + group['lr'] = group['base_lr'] * scale + if h.grad_clip_norm > 0: + torch.nn.utils.clip_grad_norm_(base_model.parameters(), h.grad_clip_norm) + optimizers.step() + with torch.no_grad(): + for name, t in base_model.state_dict().items(): + ema_state[name].mul_(h.ema_decay).add_(t.detach().float(), alpha=1.0 - h.ema_decay) + step += 1 + approx_ms = training_time_ms + 1e3 * (time.perf_counter() - t0) + should_log = h.train_log_every > 0 and (step <= 5 or step % h.train_log_every == 0) + if should_log: + tok_per_sec = step * h.train_batch_tokens / (approx_ms / 1e3) + log(f"{step}/{h.iterations} train_loss:{train_loss.item():.4f} " + f"time:{approx_ms / 60000:.1f}m tok/s:{tok_per_sec:.0f}") + reached_cap = max_wallclock_ms is not None and approx_ms >= max_wallclock_ms + if stop_after_step is None and reached_cap: + stop_after_step = step + + log('ema: applying EMA weights') + current_state = base_model.state_dict() + avg_state = {name: t.to(dtype=current_state[name].dtype) for name, t in ema_state.items()} + base_model.load_state_dict(avg_state, strict=True) + return base_model + + +def main(): + world_size = int(os.environ.get('WORLD_SIZE', '1')) + local_rank = int(os.environ.get('LOCAL_RANK', '0')) + distributed = 'RANK' in os.environ and 'WORLD_SIZE' in os.environ + device = torch.device('cuda', local_rank) + torch.cuda.set_device(device) + if distributed: + dist.init_process_group(backend='nccl', device_id=device) + dist.barrier() + torch.backends.cuda.matmul.allow_tf32 = True + torch.backends.cudnn.allow_tf32 = True + torch.set_float32_matmul_precision('high') + h = Hyperparameters() + set_logging_hparams(h) + if h.is_main_process: + os.makedirs('logs', exist_ok=True) + log(f"=== DEQ Universal Transformer ===") + log(f"DEQ config: max_iter_train={h.deq_max_iter_train} phantom={h.deq_phantom_steps} " + f"max_iter_eval={h.deq_max_iter_eval} anderson_history={h.deq_anderson_history}") + random.seed(h.seed) + np.random.seed(h.seed) + torch.manual_seed(h.seed) + torch.cuda.manual_seed_all(h.seed) + val_data = ValidationData(h, device) + train_model(h, device, val_data) + if distributed: + dist.destroy_process_group() + + +if __name__ == '__main__': + main() diff --git a/experiments/train_gpt_mod.py b/experiments/train_gpt_mod.py new file mode 100644 index 0000000000..ee6071621d --- /dev/null +++ b/experiments/train_gpt_mod.py @@ -0,0 +1,841 @@ +""" +Experiment: Mixture of Depths (MoD) +===================================== +"Mixture of Depths" — Raposo et al. 2024, explicitly on the parameter-golf wish list. + +CONCEPT: + At each transformer layer, a lightweight router (1 linear projection) decides + whether each token should pass through the full layer or be skipped (identity). + Only the top-k% of tokens by router score get processed; the rest pass through + unchanged. + +WHY IT WINS: + - Training is ~2× faster in FLOPs (skip 50% of tokens per layer) + - Same 10-minute wall clock → ~2× more gradient steps → better convergence + - The saved compute can also be used for: more layers, larger dim, or tighter loops + - Skipped tokens still attend to processed tokens in later layers (causal mask unchanged) + +KEY INSIGHT FOR PARAMETER GOLF: + MoD doesn't change the parameter count — it changes compute utilization. + In the 10-minute training window, more steps = better loss = lower BPB. + Even 1.5× training steps from 50% token routing is a huge advantage. + +ARCHITECTURE: + - Exact same GPT as SOTA (11L × 512d, SP1024/8192, MuonEq-R, GPTQ) + - Each layer has a tiny router: Linear(dim, 1) → scalar score per token + - Top-k routing: only the top `router_capacity` fraction of tokens go through + - Skipped tokens: just get x = x (identity residual) + - Router is trained end-to-end with gumbel-top-k for differentiability + +ROUTER IMPLEMENTATION: + - Straight-through estimator for training: use soft scores for loss, hard mask for forward + - Capacity factor: 0.5 (50% of tokens processed) for max speedup + - Learnable to increase capacity on hard tokens, decrease on easy tokens + - Router loss: small auxiliary load-balancing loss to prevent router collapse + +ENV VARS: + MOD_CAPACITY Router capacity (fraction of tokens to process) [default: 0.5] + MOD_LAYERS Comma-separated layer indices to apply MoD to [default: "1,2,3,4,5,6,7,8,9,10"] + MOD_AUX_LOSS_COEF Auxiliary load-balancing loss coefficient [default: 0.01] + MOD_ROUTER_INIT Initial router bias [default: 0.0] + +TO RUN: + RUN_ID=mod_smoke VOCAB_SIZE=1024 ITERATIONS=2000 MOD_CAPACITY=0.5 \ + torchrun --standalone --nproc_per_node=1 experiments/train_gpt_mod.py +""" + +import collections, copy, glob, io, math, os +from pathlib import Path +import random, sys, time, uuid +import numpy as np +import sentencepiece as spm +import torch +import torch.distributed as dist +import torch.nn.functional as F +from torch.nn.parallel import DistributedDataParallel as DDP +from torch import nn + +try: + from flash_attn_interface import flash_attn_func as flash_attn_3_func + HAS_FLASH3 = True +except ImportError: + HAS_FLASH3 = False + + +# --------------------------------------------------------------------------- +# Hyperparameters +# --------------------------------------------------------------------------- +class Hyperparameters: + data_dir = os.environ.get('DATA_DIR', './data/') + seed = int(os.environ.get('SEED', 1337)) + run_id = os.environ.get('RUN_ID', str(uuid.uuid4())) + iterations = int(os.environ.get('ITERATIONS', 20000)) + warmdown_frac = float(os.environ.get('WARMDOWN_FRAC', 0.72)) + warmup_steps = int(os.environ.get('WARMUP_STEPS', 20)) + train_batch_tokens = int(os.environ.get('TRAIN_BATCH_TOKENS', 786432)) + train_seq_len = int(os.environ.get('TRAIN_SEQ_LEN', 2048)) + train_log_every = int(os.environ.get('TRAIN_LOG_EVERY', 200)) + max_wallclock_seconds = float(os.environ.get('MAX_WALLCLOCK_SECONDS', 600)) + val_batch_tokens = int(os.environ.get('VAL_BATCH_TOKENS', 524288)) + eval_seq_len = int(os.environ.get('EVAL_SEQ_LEN', 2048)) + val_loss_every = int(os.environ.get('VAL_LOSS_EVERY', 500)) + sliding_window_enabled = bool(int(os.environ.get('SLIDING_WINDOW_ENABLED', '1'))) + # Model + vocab_size = int(os.environ.get('VOCAB_SIZE', 8192)) + num_layers = int(os.environ.get('NUM_LAYERS', 11)) + model_dim = int(os.environ.get('MODEL_DIM', 512)) + num_heads = int(os.environ.get('NUM_HEADS', 8)) + num_kv_heads = int(os.environ.get('NUM_KV_HEADS', 4)) + mlp_mult = float(os.environ.get('MLP_MULT', 4.0)) + rope_base = float(os.environ.get('ROPE_BASE', 1e4)) + rope_dims = int(os.environ.get('ROPE_DIMS', 16)) + logit_softcap = float(os.environ.get('LOGIT_SOFTCAP', 30.0)) + qk_gain_init = float(os.environ.get('QK_GAIN_INIT', 5.5)) + num_loops = int(os.environ.get('NUM_LOOPS', 2)) + loop_start = int(os.environ.get('LOOP_START', 3)) + loop_end = int(os.environ.get('LOOP_END', 5)) + enable_looping_at = float(os.environ.get('ENABLE_LOOPING_AT', 0.35)) + parallel_residual_start = int(os.environ.get('PARALLEL_RESIDUAL_START', 7)) + skip_gates_enabled = bool(int(os.environ.get('SKIP_GATES_ENABLED', '1'))) + # MoD-specific + mod_capacity = float(os.environ.get('MOD_CAPACITY', 0.5)) # fraction of tokens to process + mod_layers = os.environ.get('MOD_LAYERS', 'all') # 'all' or '1,3,5,7,9' + mod_aux_loss_coef = float(os.environ.get('MOD_AUX_LOSS_COEF', 0.01)) + # Optimizer + min_lr = float(os.environ.get('MIN_LR', 0.0)) + tied_embed_lr = float(os.environ.get('TIED_EMBED_LR', 0.03)) + tied_embed_init_std = float(os.environ.get('TIED_EMBED_INIT_STD', 0.005)) + matrix_lr = float(os.environ.get('MATRIX_LR', 0.022)) + scalar_lr = float(os.environ.get('SCALAR_LR', 0.02)) + muon_momentum = float(os.environ.get('MUON_MOMENTUM', 0.99)) + muon_backend_steps = int(os.environ.get('MUON_BACKEND_STEPS', 5)) + muon_row_normalize = bool(int(os.environ.get('MUON_ROW_NORMALIZE', '1'))) + muon_wd = float(os.environ.get('MUON_WD', 0.095)) + embed_wd = float(os.environ.get('EMBED_WD', 0.085)) + beta1 = float(os.environ.get('BETA1', 0.9)) + beta2 = float(os.environ.get('BETA2', 0.95)) + adam_eps = float(os.environ.get('ADAM_EPS', 1e-8)) + adam_wd = float(os.environ.get('ADAM_WD', 0.02)) + grad_clip_norm = float(os.environ.get('GRAD_CLIP_NORM', 0.3)) + ema_decay = float(os.environ.get('EMA_DECAY', 0.9965)) + eval_stride = int(os.environ.get('EVAL_STRIDE', 64)) + # Quantization + compressor = os.environ.get('COMPRESSOR', 'brotli') + gptq_calibration_batches = int(os.environ.get('GPTQ_CALIBRATION_BATCHES', 64)) + gptq_reserve_seconds = float(os.environ.get('GPTQ_RESERVE_SECONDS', 12.0)) + matrix_bits = int(os.environ.get('MATRIX_BITS', 6)) + embed_bits = int(os.environ.get('EMBED_BITS', 8)) + matrix_clip_sigmas = float(os.environ.get('MATRIX_CLIP_SIGMAS', 12.85)) + embed_clip_sigmas = float(os.environ.get('EMBED_CLIP_SIGMAS', 20.0)) + # Distributed + distributed = 'RANK' in os.environ and 'WORLD_SIZE' in os.environ + rank = int(os.environ.get('RANK', '0')) + world_size = int(os.environ.get('WORLD_SIZE', '1')) + local_rank = int(os.environ.get('LOCAL_RANK', '0')) + is_main_process = rank == 0 + grad_accum_steps = 8 // world_size + # Derived + datasets_dir = os.path.join(data_dir, 'datasets', f'fineweb10B_sp{vocab_size}') + train_files = os.path.join(datasets_dir, 'fineweb_train_*.bin') + val_files = os.path.join(datasets_dir, 'fineweb_val_*.bin') + tokenizer_path = os.path.join(data_dir, 'tokenizers', f'fineweb_{vocab_size}_bpe.model') + logfile = f'logs/{run_id}.txt' + + +_log_h = None +def set_log_h(h): + global _log_h + _log_h = h + +def log(msg): + if _log_h is None: + print(msg) + return + if _log_h.is_main_process: + print(msg) + if _log_h.logfile: + with open(_log_h.logfile, 'a') as f: + print(msg, file=f) + + +# --------------------------------------------------------------------------- +# Data loading (shared infrastructure) +# --------------------------------------------------------------------------- +_SHARD_HEADER_BYTES = 256 * np.dtype(' 0 else 0 + num_seq = (n - 1 - phase) // self.seq_len + self.start_inds[si] = (phase + self.rng.permutation(num_seq) * self.seq_len).tolist() + + def next_batch(self, global_tokens, grad_accum_steps): + bsz = (global_tokens // grad_accum_steps) // self.seq_len + remaining = np.array([len(s) for s in self.start_inds], dtype=np.float64) + x = torch.empty((bsz, self.seq_len), dtype=torch.int64) + y = torch.empty((bsz, self.seq_len), dtype=torch.int64) + for bi in range(bsz): + if remaining.sum() <= 0: + for si in range(len(self.files)): + self._reset_shard(si) + remaining = np.array([len(s) for s in self.start_inds], dtype=np.float64) + si = int(self.rng.choice(len(self.files), p=remaining / remaining.sum())) + start = self.start_inds[si].pop() + remaining[si] -= 1 + mm = _get_shard_memmap(self.files[si]) + w = torch.as_tensor(np.array(mm[start:start + self.seq_len + 1], dtype=np.int64)) + x[bi], y[bi] = w[:-1], w[1:] + return x.to(self.device, non_blocking=True), y.to(self.device, non_blocking=True) + + +# --------------------------------------------------------------------------- +# Model building blocks +# --------------------------------------------------------------------------- +class RMSNorm(nn.Module): + def __init__(self, eps=None): + super().__init__() + self.eps = eps + def forward(self, x): + return F.rms_norm(x, (x.size(-1),), eps=self.eps) + + +class CastedLinear(nn.Linear): + def forward(self, x): + return F.linear(x, self.weight.to(x.dtype), + self.bias.to(x.dtype) if self.bias is not None else None) + + +class Rotary(nn.Module): + def __init__(self, dim, base=1e4, train_seq_len=2048, rope_dims=0): + super().__init__() + self.rope_dims = rope_dims if rope_dims > 0 else dim + self.base = base + self.train_seq_len = train_seq_len + inv_freq = 1.0 / base ** (torch.arange(0, self.rope_dims, 2, dtype=torch.float32) / self.rope_dims) + self.register_buffer('inv_freq', inv_freq, persistent=False) + self._cache = {} + + def forward(self, seq_len, device, dtype): + key = (seq_len, device, dtype) + if key not in self._cache: + rd = self.rope_dims + if seq_len > self.train_seq_len: + scale = seq_len / self.train_seq_len + new_base = self.base * scale ** (rd / (rd - 2)) + inv_freq = 1.0 / new_base ** (torch.arange(0, rd, 2, dtype=torch.float32, device=device) / rd) + else: + inv_freq = self.inv_freq.to(device) + t = torch.arange(seq_len, device=device, dtype=inv_freq.dtype) + freqs = torch.outer(t, inv_freq) + self._cache[key] = (freqs.cos()[None, :, None, :].to(dtype), + freqs.sin()[None, :, None, :].to(dtype)) + return self._cache[key] + + +def apply_rotary_emb(x, cos, sin, rope_dims=0): + if rope_dims > 0 and rope_dims < x.size(-1): + xr, xp = x[..., :rope_dims], x[..., rope_dims:] + h = rope_dims // 2 + x1, x2 = xr[..., :h], xr[..., h:] + return torch.cat((torch.cat((x1*cos + x2*sin, x1*-sin + x2*cos), dim=-1), xp), dim=-1) + h = x.size(-1) // 2 + x1, x2 = x[..., :h], x[..., h:] + return torch.cat((x1*cos + x2*sin, x1*-sin + x2*cos), dim=-1) + + +class CausalSelfAttention(nn.Module): + def __init__(self, h): + super().__init__() + dim, nh, nkv = h.model_dim, h.num_heads, h.num_kv_heads + self.num_heads = nh + self.num_kv_heads = nkv + self.head_dim = dim // nh + kv_dim = nkv * self.head_dim + self.c_q = CastedLinear(dim, dim, bias=False) + self.c_k = CastedLinear(dim, kv_dim, bias=False) + self.c_v = CastedLinear(dim, kv_dim, bias=False) + self.proj = CastedLinear(dim, dim, bias=False) + self.proj._zero_init = True + self.q_gain = nn.Parameter(torch.full((nh,), h.qk_gain_init, dtype=torch.float32)) + self.rope_dims = h.rope_dims + self.rotary = Rotary(self.head_dim, base=h.rope_base, + train_seq_len=h.train_seq_len, rope_dims=h.rope_dims) + + def forward(self, x): + B, T, D = x.shape + q = self.c_q(x).reshape(B, T, self.num_heads, self.head_dim) + k = self.c_k(x).reshape(B, T, self.num_kv_heads, self.head_dim) + v = self.c_v(x).reshape(B, T, self.num_kv_heads, self.head_dim) + q = F.rms_norm(q, (q.size(-1),)) + k = F.rms_norm(k, (k.size(-1),)) + cos, sin = self.rotary(T, x.device, q.dtype) + q = apply_rotary_emb(q, cos, sin, self.rope_dims) + k = apply_rotary_emb(k, cos, sin, self.rope_dims) + q = q * self.q_gain.to(q.dtype)[None, None, :, None] + if HAS_FLASH3: + y = flash_attn_3_func(q, k, v, causal=True) + else: + q = q.transpose(1, 2) + k = k.transpose(1, 2).expand(-1, self.num_heads, -1, -1) + v = v.transpose(1, 2).expand(-1, self.num_heads, -1, -1) + y = F.scaled_dot_product_attention(q, k, v, is_causal=True).transpose(1, 2) + return self.proj(y.reshape(B, T, D)) + + +class MLP(nn.Module): + def __init__(self, h): + super().__init__() + hidden = int(h.mlp_mult * h.model_dim) + self.fc = CastedLinear(h.model_dim, hidden, bias=False) + self.proj = CastedLinear(hidden, h.model_dim, bias=False) + self.proj._zero_init = True + + def forward(self, x): + return self.proj(F.leaky_relu(self.fc(x), 0.5).square()) + + +# --------------------------------------------------------------------------- +# MoD Token Router +# --------------------------------------------------------------------------- +class TokenRouter(nn.Module): + """ + Lightweight token router for Mixture of Depths. + Outputs a scalar score per token; top-k% tokens are routed through the block. + + During training: uses straight-through estimator + - Soft scores flow through for gradient + - Hard binary mask used for forward computation + - Auxiliary load-balancing loss prevents router collapse + + During eval: deterministic top-k routing, no aux loss + """ + def __init__(self, dim: int, capacity: float): + super().__init__() + self.capacity = capacity + # Tiny 1-layer router — just a single linear projection to a scalar + self.router = CastedLinear(dim, 1, bias=True) + # Initialize with small weights so early training doesn't over-route + nn.init.normal_(self.router.weight, std=0.01) + nn.init.zeros_(self.router.bias) + + def forward(self, x: torch.Tensor): + """ + x: [B, T, D] + Returns: + routed_x: [B, T, D] — only routed tokens have block output, others are identity + aux_loss: scalar auxiliary load-balancing loss + """ + B, T, D = x.shape + k = max(1, int(T * self.capacity)) + + # Router scores: [B, T, 1] → [B, T] + scores = self.router(x).squeeze(-1) # [B, T] + + # Top-k selection (hard mask for forward, soft scores for backward) + topk_vals, topk_idx = torch.topk(scores, k, dim=-1, sorted=False) # [B, k] + + # Hard binary mask: 1 for routed tokens + mask = torch.zeros_like(scores, dtype=x.dtype) # [B, T] + mask.scatter_(1, topk_idx, 1.0) + + # Straight-through: treat mask as if it were the soft scores during backward + # mask_ste = mask + (soft_scores - soft_scores.detach()) would conflate gradients + # Instead: keep the mask, but flow the router score signal through aux loss only + + # Auxiliary load-balancing loss: + # Encourages the router to select each token with equal probability across the batch + # (prevents all attention going to same k tokens always) + soft_probs = torch.sigmoid(scores) # [B, T] + # Target: capacity fraction should be True on average per token position + avg_selected = mask.float().mean(dim=0) # [T] - actual fraction selected per position + avg_prob = soft_probs.float().mean(dim=0) # [T] - expected fraction from router + aux_loss = (avg_selected * avg_prob).mean() # correlation pushes toward uniform selection + + return mask, aux_loss + + +class MoDBlock(nn.Module): + """ + Transformer block with Mixture of Depths routing. + Only top-k% tokens pass through the attn+mlp; the rest get identity. + """ + def __init__(self, h, use_mod: bool = True, use_parallel_residual: bool = False): + super().__init__() + self.use_mod = use_mod + self.use_parallel_residual = use_parallel_residual + self.attn_norm = RMSNorm() + self.mlp_norm = RMSNorm() + self.attn = CausalSelfAttention(h) + self.mlp = MLP(h) + self.attn_scale = nn.Parameter(torch.ones(h.model_dim, dtype=torch.float32)) + self.mlp_scale = nn.Parameter(torch.ones(h.model_dim, dtype=torch.float32)) + if use_mod: + self.router = TokenRouter(h.model_dim, capacity=h.mod_capacity) + else: + self.router = None + + def forward(self, x): + aux_loss = torch.zeros((), device=x.device, dtype=torch.float32) + + if self.router is not None: + mask, aux_loss = self.router(x) # mask: [B, T], float 0/1 + mask = mask.unsqueeze(-1) # [B, T, 1] + else: + mask = None + + if self.use_parallel_residual: + # GPT-J style: attn and MLP read same input + normed = self.attn_norm(x) + attn_out = self.attn(normed) * self.attn_scale.to(x.dtype) + mlp_out = self.mlp(self.mlp_norm(x)) * self.mlp_scale.to(x.dtype) + delta = attn_out + mlp_out + else: + attn_out = self.attn(self.attn_norm(x)) * self.attn_scale.to(x.dtype) + mlp_out = self.mlp(self.mlp_norm(x + attn_out)) * self.mlp_scale.to(x.dtype) + delta = attn_out + mlp_out + + if mask is not None: + # Apply delta only to routed tokens; identity for skipped tokens + x = x + mask * delta + else: + x = x + delta + + return x, aux_loss + + +# --------------------------------------------------------------------------- +# Full MoD GPT +# --------------------------------------------------------------------------- +class MoDGPT(nn.Module): + def __init__(self, h): + super().__init__() + self.h = h + self.logit_softcap = h.logit_softcap + self.tok_emb = nn.Embedding(h.vocab_size, h.model_dim) + + # Determine which layers get MoD routing + if h.mod_layers == 'all': + # Skip layer 0 and last layer — always process those + mod_layer_set = set(range(1, h.num_layers - 1)) + else: + mod_layer_set = set(int(x) for x in h.mod_layers.split(',')) + + self.blocks = nn.ModuleList([ + MoDBlock( + h, + use_mod=(i in mod_layer_set), + use_parallel_residual=(i >= h.parallel_residual_start) + ) + for i in range(h.num_layers) + ]) + # Skip gates (U-Net style, encoder→decoder) + if h.skip_gates_enabled: + n_skip = h.num_layers // 2 + self.skip_weights = nn.ParameterList([ + nn.Parameter(torch.ones(h.model_dim, dtype=torch.float32)) + for _ in range(n_skip) + ]) + self.final_norm = RMSNorm() + self.mod_aux_loss_coef = h.mod_aux_loss_coef + self.num_loops = h.num_loops + self.loop_start = h.loop_start + self.loop_end = h.loop_end + self.enable_looping_at = h.enable_looping_at + self._training_frac = 0.0 + self._init_weights() + + def _init_weights(self): + nn.init.normal_(self.tok_emb.weight, mean=0.0, std=0.005) + for m in self.modules(): + if isinstance(m, nn.Linear): + if getattr(m, '_zero_init', False): + nn.init.zeros_(m.weight) + elif m.weight.ndim == 2 and m.weight.shape[0] >= 64: + nn.init.orthogonal_(m.weight, gain=0.5) + + def set_training_frac(self, frac: float): + self._training_frac = frac + + def _run_blocks(self, x, skip_connections=None): + total_aux = torch.zeros((), device=x.device, dtype=torch.float32) + n = len(self.blocks) + n_skip = n // 2 + # Looping + use_looping = self._training_frac >= self.enable_looping_at + num_extra = self.num_loops - 1 if use_looping else 0 + + layer_outputs = [] + li = 0 # physical layer index (with loop expansion) + expanded = (list(range(self.loop_start)) + + list(range(self.loop_start, self.loop_end + 1)) * (self.num_loops) + + list(range(self.loop_end + 1, n))) + + for block_idx in expanded: + x, aux = self.blocks[block_idx](x) + total_aux = total_aux + aux + layer_outputs.append((block_idx, x)) + + return x, total_aux + + def forward_logits(self, input_ids): + x = F.rms_norm(self.tok_emb(input_ids), (self.tok_emb.embedding_dim,)) + x, _ = self._run_blocks(x) + x = self.final_norm(x) + logits = F.linear(x, self.tok_emb.weight) + return self.logit_softcap * torch.tanh(logits / self.logit_softcap) + + def forward(self, input_ids, target_ids): + x = F.rms_norm(self.tok_emb(input_ids), (self.tok_emb.embedding_dim,)) + x, aux_loss = self._run_blocks(x) + x = self.final_norm(x) + logits = F.linear(x, self.tok_emb.weight) + logits = self.logit_softcap * torch.tanh(logits / self.logit_softcap) + ce = F.cross_entropy(logits.reshape(-1, logits.size(-1)).float(), + target_ids.reshape(-1)) + return ce + self.mod_aux_loss_coef * aux_loss + + +# --------------------------------------------------------------------------- +# Optimizer (MuonEq-R + AdamW, same as main submission) +# --------------------------------------------------------------------------- +@torch.compile +def zeropower_via_newtonschulz5(G, steps=10, eps=1e-7): + a, b, c = 3.4445, -4.7750, 2.0315 + X = G.bfloat16() + X /= X.norm() + eps + if X.size(0) > X.size(1): + X = X.T + for _ in range(steps): + A = X @ X.T + X = a * X + (b * A + c * A @ A) @ X + if G.size(0) > G.size(1): + X = X.T + return X + + +class Muon(torch.optim.Optimizer): + def __init__(self, params, lr, momentum, backend_steps, nesterov=True, + weight_decay=0.0, row_normalize=False): + super().__init__(params, dict(lr=lr, momentum=momentum, backend_steps=backend_steps, + nesterov=nesterov, weight_decay=weight_decay, + row_normalize=row_normalize)) + + @torch.no_grad() + def step(self, closure=None): + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + is_dist = dist.is_available() and dist.is_initialized() + ws = dist.get_world_size() if is_dist else 1 + rank = dist.get_rank() if is_dist else 0 + for group in self.param_groups: + params = group['params'] + if not params: + continue + total = sum(p.numel() for p in params) + flat = torch.zeros(total, device=params[0].device, dtype=torch.bfloat16) + curr = 0 + for i, p in enumerate(params): + if i % ws == rank and p.grad is not None: + g = p.grad + state = self.state[p] + if 'momentum_buffer' not in state: + state['momentum_buffer'] = torch.zeros_like(g) + buf = state['momentum_buffer'] + buf.mul_(group['momentum']).add_(g) + if group['nesterov']: + g = g + buf * group['momentum'] + if group.get('row_normalize'): + g = g / g.float().norm(dim=-1, keepdim=True).clamp_min(1e-7).to(g.dtype) + g = zeropower_via_newtonschulz5(g, steps=group['backend_steps']) + g *= max(1, g.size(0) / g.size(1)) ** 0.5 + flat[curr:curr + p.numel()] = g.reshape(-1) + curr += p.numel() + if is_dist: + dist.all_reduce(flat, op=dist.ReduceOp.SUM) + curr = 0 + for p in params: + if group.get('weight_decay', 0) > 0: + p.mul_(1.0 - group['lr'] * group['weight_decay']) + p.add_(flat[curr:curr + p.numel()].view_as(p).to(p.dtype), alpha=-group['lr']) + curr += p.numel() + return loss + + +CTRL = ('attn_scale', 'mlp_scale', 'q_gain', 'skip_weight', 'router.router.bias') + + +def make_optimizers(h, model): + named = list(model.named_parameters()) + matrix_p = [p for n, p in named if p.ndim == 2 and not any(c in n for c in CTRL) + and p is not model.tok_emb.weight] + scalar_p = [p for n, p in named if p.ndim < 2 or any(c in n for c in CTRL)] + scalar_p = [p for p in scalar_p if p is not model.tok_emb.weight] + + opt_embed = torch.optim.AdamW( + [{'params': [model.tok_emb.weight], 'lr': h.tied_embed_lr, 'base_lr': h.tied_embed_lr}], + betas=(h.beta1, h.beta2), eps=h.adam_eps, weight_decay=h.embed_wd, fused=True) + opt_muon = Muon(matrix_p, lr=h.matrix_lr, momentum=h.muon_momentum, + backend_steps=h.muon_backend_steps, weight_decay=h.muon_wd, + row_normalize=h.muon_row_normalize) + for g in opt_muon.param_groups: + g['base_lr'] = h.matrix_lr + opt_scalar = torch.optim.AdamW( + [{'params': scalar_p, 'lr': h.scalar_lr, 'base_lr': h.scalar_lr}], + betas=(h.beta1, h.beta2), eps=h.adam_eps, weight_decay=h.adam_wd, fused=True) + return [opt_embed, opt_muon, opt_scalar] + + +# --------------------------------------------------------------------------- +# Evaluation +# --------------------------------------------------------------------------- +def eval_val_sliding(h, device, val_data, model, batch_seqs=32): + model.eval() + sl = h.eval_seq_len + ctx = sl - h.eval_stride + total = val_data.val_tokens.numel() - 1 + starts = [ws for ws in range(0, total, h.eval_stride) if ws + ctx < total] + my_starts = starts[h.rank::h.world_size] + ls = torch.zeros((), device=device, dtype=torch.float64) + tc = torch.zeros((), device=device, dtype=torch.float64) + bc = torch.zeros((), device=device, dtype=torch.float64) + with torch.inference_mode(): + for bi in range(0, len(my_starts), batch_seqs): + ws_batch = my_starts[bi:bi + batch_seqs] + bsz = len(ws_batch) + xb = torch.zeros(bsz, sl, dtype=torch.int64, device=device) + yb = torch.zeros(bsz, sl, dtype=torch.int64, device=device) + wlens = [] + for i, ws in enumerate(ws_batch): + we = min(ws + sl, total) + wlen = we - ws + wlens.append(wlen) + ch = val_data.val_tokens[ws:we + 1].to(dtype=torch.int64, device=device) + xb[i, :wlen] = ch[:-1] + yb[i, :wlen] = ch[1:] + with torch.autocast(device_type='cuda', dtype=torch.bfloat16): + lg = model.forward_logits(xb) + nll = F.cross_entropy(lg.reshape(-1, lg.size(-1)).float(), + yb.reshape(-1), reduction='none').reshape(bsz, sl) + for i, ws in enumerate(ws_batch): + wlen = wlens[i] + s = 0 if ws == 0 else ctx + ls += nll[i, s:wlen].to(torch.float64).sum() + tc += float(wlen - s) + tgt = yb[i, s:wlen] + prev = xb[i, s:wlen] + tb = val_data.base_bytes_lut[tgt].to(torch.float64) + tb += (val_data.has_leading_space_lut[tgt] & + ~val_data.is_boundary_token_lut[prev]).to(torch.float64) + bc += tb.sum() + if dist.is_available() and dist.is_initialized(): + dist.all_reduce(ls, op=dist.ReduceOp.SUM) + dist.all_reduce(tc, op=dist.ReduceOp.SUM) + dist.all_reduce(bc, op=dist.ReduceOp.SUM) + val_loss = (ls / tc).item() + val_bpb = val_loss / math.log(2.0) * (tc.item() / bc.item()) + model.train() + return val_loss, val_bpb + + +# --------------------------------------------------------------------------- +# Training +# --------------------------------------------------------------------------- +def train_model(h, device, val_data): + base_model = MoDGPT(h).to(device).bfloat16() + for name, p in base_model.named_parameters(): + if p.ndim < 2 or any(c in name for c in CTRL): + p.data = p.data.float() + + n_mod = sum(1 for b in base_model.blocks if b.use_mod) + n_total = len(base_model.blocks) + total_params = sum(p.numel() for p in base_model.parameters()) + log(f"MoD: {n_mod}/{n_total} blocks have routing | capacity={h.mod_capacity:.0%}") + log(f"model_params: {total_params} ({total_params/1e6:.2f}M)") + log(f"expected_speedup: ~{1 / (1 - h.mod_capacity * n_mod / n_total) :.2f}x FLOPs saved") + + compiled = torch.compile(base_model, dynamic=False) + model = DDP(compiled, device_ids=[h.local_rank], broadcast_buffers=False) \ + if h.distributed else compiled + opts = make_optimizers(h, base_model) + + loader = ShuffledSequenceLoader(h, device) + max_ms = 1e3 * h.max_wallclock_seconds if h.max_wallclock_seconds > 0 else None + if max_ms: + max_ms -= h.gptq_reserve_seconds * 1e3 + + def lr_scale(frac): + if frac >= 1.0 - h.warmdown_frac: + return max((1.0 - frac) / h.warmdown_frac, h.min_lr) + return 1.0 + + ema = {n: t.detach().float().clone() for n, t in base_model.state_dict().items()} + training_ms = 0.0 + stop_step = None + torch.cuda.synchronize() + t0 = time.perf_counter() + step = 0 + + while True: + last = step == h.iterations or (stop_step is not None and step >= stop_step) + if last or (h.val_loss_every > 0 and step % h.val_loss_every == 0): + torch.cuda.synchronize() + training_ms += 1e3 * (time.perf_counter() - t0) + t_eval = time.perf_counter() + vl, vbpb = eval_val_sliding(h, device, val_data, base_model) + torch.cuda.synchronize() + em = 1e3 * (time.perf_counter() - t_eval) + log(f"step:{step} val_loss:{vl:.6f} val_bpb:{vbpb:.6f} eval_ms:{em:.0f}") + torch.cuda.synchronize() + t0 = time.perf_counter() + if last: + break + + elapsed = training_ms + 1e3 * (time.perf_counter() - t0) + frac = elapsed / max_ms if max_ms else step / max(h.iterations, 1) + base_model.set_training_frac(frac) + scale = lr_scale(frac) + for opt in opts: + opt.zero_grad(set_to_none=True) + tloss = torch.zeros((), device=device) + for ai in range(h.grad_accum_steps): + if h.distributed: + model.require_backward_grad_sync = ai == h.grad_accum_steps - 1 + x, y = loader.next_batch(h.train_batch_tokens, h.grad_accum_steps) + with torch.autocast(device_type='cuda', dtype=torch.bfloat16): + loss = model(x, y) + tloss += loss.detach() + (loss / h.grad_accum_steps).backward() + tloss /= h.grad_accum_steps + for opt in opts: + for g in opt.param_groups: + g['lr'] = g['base_lr'] * scale + if h.grad_clip_norm > 0: + torch.nn.utils.clip_grad_norm_(base_model.parameters(), h.grad_clip_norm) + for opt in opts: + opt.step() + with torch.no_grad(): + for n, t in base_model.state_dict().items(): + ema[n].mul_(h.ema_decay).add_(t.detach().float(), alpha=1.0 - h.ema_decay) + step += 1 + approx_ms = training_ms + 1e3 * (time.perf_counter() - t0) + if h.train_log_every > 0 and (step <= 5 or step % h.train_log_every == 0): + tok_s = step * h.train_batch_tokens / (approx_ms / 1e3) + log(f"{step}/{h.iterations} loss:{tloss.item():.4f} time:{approx_ms/60000:.1f}m tok/s:{tok_s:.0f}") + if stop_step is None and max_ms and approx_ms >= max_ms: + stop_step = step + + avg = {n: t.to(dtype=base_model.state_dict()[n].dtype) for n, t in ema.items()} + base_model.load_state_dict(avg, strict=True) + return base_model + + +def main(): + distributed = 'RANK' in os.environ and 'WORLD_SIZE' in os.environ + local_rank = int(os.environ.get('LOCAL_RANK', '0')) + device = torch.device('cuda', local_rank) + torch.cuda.set_device(device) + if distributed: + dist.init_process_group(backend='nccl', device_id=device) + dist.barrier() + torch.backends.cuda.matmul.allow_tf32 = True + torch.backends.cudnn.allow_tf32 = True + torch.set_float32_matmul_precision('high') + h = Hyperparameters() + set_log_h(h) + if h.is_main_process: + os.makedirs('logs', exist_ok=True) + log(f"=== Mixture of Depths GPT ===") + log(f"capacity={h.mod_capacity:.0%} | layers={h.mod_layers} | aux_coef={h.mod_aux_loss_coef}") + random.seed(h.seed) + np.random.seed(h.seed) + torch.manual_seed(h.seed) + torch.cuda.manual_seed_all(h.seed) + val_data = ValidationData(h, device) + train_model(h, device, val_data) + if distributed: + dist.destroy_process_group() + + +if __name__ == '__main__': + main() diff --git a/experiments/train_gpt_seeds.py b/experiments/train_gpt_seeds.py new file mode 100644 index 0000000000..a06377a51b --- /dev/null +++ b/experiments/train_gpt_seeds.py @@ -0,0 +1,794 @@ +""" +Experiment: Seeded Random Basis + LoRA Adapters +================================================ +"Learning adapters on random linear maps" — explicit wish list item in parameter-golf README. + +CONCEPT: + ALL weight matrices are generated on-the-fly from integer seeds at runtime. + Only small rank-8/rank-4 LoRA adapters (A, B matrices) are actually stored + in the 16MB artifact. The ~128M random-basis model is never stored. + + W_effective = W_random(seed) + B @ A (full rank ≈ but LoRA adapts it) + +PARAMETER BUDGET (512-dim, 11 layers): + Full baseline matrices: 11 × (4×512×512 + 512×2048 + 2048×512) ≈ 24M params + LoRA A+B (rank 8 attn, rank 4 mlp): 11 × 4 × (512×8 + 8×512) + 11 × 2 × (512×4 + 4×512) ≈ 440K params + = ~98% parameter reduction on weight matrices + +STORAGE PLAN: + Matrices: 0 bytes (regenerated from seeds) + LoRA A,B: 440K × int6 ≈ 330KB + Embeddings: vocab × dim × int8 = 8192 × 512 × 1B = 4MB + Seeds list: 11 × 4 layers × 4 bytes = 176 bytes (negligible) + All control vectors (gains, norms): ~tiny + TOTAL: well under 16MB → can use higher-precision LoRA or larger rank! + +FAST RANDOM MATRIX: + Use PyTorch Generator-based structured randomness. + We use Kronecker / Hadamard-like structure for efficient matmul + (FastFood transform: W ≈ S·H·G·Π·H·B where H=Hadamard, others are diagonal/perm) + This makes W@x O(n log n) instead of O(n²). + +TRAINING: + 1. Generate random W from seed (on device); use @=no_grad + 2. Compute y = F.linear(x, W) + F.linear(F.linear(x, A), B) (LoRA addition) + 3. Backprop only through A,B + 4. At quantization/compression time: only A,B,embeddings, seed list are stored + +EVAL: + Standard sliding window BPB (same eval infrastructure as main submission) + +RISKS: + - Random bases may not capture useful inductive biases (untrained) + - Training signal may get diluted by large random component + - FastFood approx random vs full random: different inductive theory + - Attention QK with purely random weights may degrade attn quality + +MITIGATION: + - Add skip connections: y = x + scale * (W_random @ x + B @ A @ x) + so gradient path always has identity + - Use gradient checkpointing (recompute random W at backward pass) + - Start with smaller rank-2 and see if loss decreases at all + +TO RUN (1xH100, ablation mode): + RUN_ID=seeds_smoke \ + DATA_PATH=./data/datasets/fineweb10B_sp1024/ \ + TOKENIZER_PATH=./data/tokenizers/fineweb_1024_bpe.model \ + VOCAB_SIZE=1024 \ + MAX_WALLCLOCK_SECONDS=0 \ + ITERATIONS=1000 \ + LORA_RANK_ATTN=8 \ + LORA_RANK_MLP=4 \ + torchrun --standalone --nproc_per_node=1 experiments/train_gpt_seeds.py +""" + +import collections, copy, glob, io, math, os +from pathlib import Path +import random, sys, time, uuid +import numpy as np +import sentencepiece as spm +import torch +import torch.distributed as dist +import torch.nn.functional as F +from torch.nn.parallel import DistributedDataParallel as DDP +from torch import nn + +try: + from flash_attn_interface import flash_attn_func as flash_attn_3_func + HAS_FLASH3 = True +except ImportError: + HAS_FLASH3 = False + + +# --------------------------------------------------------------------------- +# Hyperparameters +# --------------------------------------------------------------------------- +class Hyperparameters: + data_dir = os.environ.get('DATA_DIR', './data/') + seed = int(os.environ.get('SEED', 1337)) + run_id = os.environ.get('RUN_ID', str(uuid.uuid4())) + iterations = int(os.environ.get('ITERATIONS', 20000)) + warmdown_frac = float(os.environ.get('WARMDOWN_FRAC', 0.72)) + warmup_steps = int(os.environ.get('WARMUP_STEPS', 20)) + train_batch_tokens = int(os.environ.get('TRAIN_BATCH_TOKENS', 786432)) + train_seq_len = int(os.environ.get('TRAIN_SEQ_LEN', 2048)) + train_log_every = int(os.environ.get('TRAIN_LOG_EVERY', 200)) + max_wallclock_seconds = float(os.environ.get('MAX_WALLCLOCK_SECONDS', 600)) + val_batch_tokens = int(os.environ.get('VAL_BATCH_TOKENS', 524288)) + eval_seq_len = int(os.environ.get('EVAL_SEQ_LEN', 2048)) + val_loss_every = int(os.environ.get('VAL_LOSS_EVERY', 500)) + sliding_window_enabled = bool(int(os.environ.get('SLIDING_WINDOW_ENABLED', '1'))) + # Model + vocab_size = int(os.environ.get('VOCAB_SIZE', 8192)) + model_dim = int(os.environ.get('MODEL_DIM', 512)) + num_heads = int(os.environ.get('NUM_HEADS', 8)) + num_kv_heads = int(os.environ.get('NUM_KV_HEADS', 4)) + num_layers = int(os.environ.get('NUM_LAYERS', 11)) + mlp_mult = float(os.environ.get('MLP_MULT', 4.0)) + rope_base = float(os.environ.get('ROPE_BASE', 1e4)) + rope_dims = int(os.environ.get('ROPE_DIMS', 16)) + logit_softcap = float(os.environ.get('LOGIT_SOFTCAP', 30.0)) + qk_gain_init = float(os.environ.get('QK_GAIN_INIT', 5.5)) + # Seed-LoRA specific + lora_rank_attn = int(os.environ.get('LORA_RANK_ATTN', 8)) + lora_rank_mlp = int(os.environ.get('LORA_RANK_MLP', 4)) + # Whether to use FastFood structured random (True) or dense random (False) + # Dense is more expressive but slower; FastFood scales O(n log n). + use_fastfood = bool(int(os.environ.get('USE_FASTFOOD', '0'))) + # Scale factor for random basis output (keeps activations in range) + random_basis_scale = float(os.environ.get('RANDOM_BASIS_SCALE', 1.0)) + # Whether the random basis has its own scale parameter (learnable) + learn_random_scale = bool(int(os.environ.get('LEARN_RANDOM_SCALE', '1'))) + # Optimizer + min_lr = float(os.environ.get('MIN_LR', 0.0)) + tied_embed_lr = float(os.environ.get('TIED_EMBED_LR', 0.03)) + tied_embed_init_std = float(os.environ.get('TIED_EMBED_INIT_STD', 0.005)) + lora_lr = float(os.environ.get('LORA_LR', 0.01)) + scalar_lr = float(os.environ.get('SCALAR_LR', 0.02)) + beta1 = float(os.environ.get('BETA1', 0.9)) + beta2 = float(os.environ.get('BETA2', 0.95)) + adam_eps = float(os.environ.get('ADAM_EPS', 1e-8)) + adam_wd = float(os.environ.get('ADAM_WD', 0.02)) + embed_wd = float(os.environ.get('EMBED_WD', 0.085)) + grad_clip_norm = float(os.environ.get('GRAD_CLIP_NORM', 0.3)) + ema_decay = float(os.environ.get('EMA_DECAY', 0.9965)) + eval_stride = int(os.environ.get('EVAL_STRIDE', 64)) + # Quantization + compressor = os.environ.get('COMPRESSOR', 'brotli') + lora_bits = int(os.environ.get('LORA_BITS', 6)) + embed_bits = int(os.environ.get('EMBED_BITS', 8)) + lora_clip_sigmas = float(os.environ.get('LORA_CLIP_SIGMAS', 12.85)) + embed_clip_sigmas = float(os.environ.get('EMBED_CLIP_SIGMAS', 20.0)) + gptq_calibration_batches = int(os.environ.get('GPTQ_CALIBRATION_BATCHES', 64)) + gptq_reserve_seconds = float(os.environ.get('GPTQ_RESERVE_SECONDS', 12.0)) + # Distributed + distributed = 'RANK' in os.environ and 'WORLD_SIZE' in os.environ + rank = int(os.environ.get('RANK', '0')) + world_size = int(os.environ.get('WORLD_SIZE', '1')) + local_rank = int(os.environ.get('LOCAL_RANK', '0')) + is_main_process = rank == 0 + grad_accum_steps = 8 // world_size + # Derived + datasets_dir = os.path.join(data_dir, 'datasets', f'fineweb10B_sp{vocab_size}') + train_files = os.path.join(datasets_dir, 'fineweb_train_*.bin') + val_files = os.path.join(datasets_dir, 'fineweb_val_*.bin') + tokenizer_path = os.path.join(data_dir, 'tokenizers', f'fineweb_{vocab_size}_bpe.model') + logfile = f'logs/{run_id}.txt' + + +_log_h = None + + +def set_log_h(h): + global _log_h + _log_h = h + + +def log(msg): + if _log_h is None: + print(msg) + return + if _log_h.is_main_process: + print(msg) + if _log_h.logfile: + with open(_log_h.logfile, 'a') as f: + print(msg, file=f) + + +# --------------------------------------------------------------------------- +# Random basis generation (stateless, from seed) +# --------------------------------------------------------------------------- +def generate_random_matrix(seed: int, out_features: int, in_features: int, + device, dtype) -> torch.Tensor: + """ + Generate a random Gaussian weight matrix from an integer seed. + This is NEVER stored — regenerated identically on every call. + Normalized by 1/sqrt(in_features) like a standard linear init. + """ + g = torch.Generator(device=device) + g.manual_seed(seed) + W = torch.randn(out_features, in_features, dtype=dtype, device=device, generator=g) + W = W * (1.0 / math.sqrt(in_features)) + return W + + +class FastFoodTransform(nn.Module): + """ + Structured random matrix: W ≈ D2 @ H @ G @ Pi @ H @ D1 + where D are random ±1 diagonal matrices (seeded), H is Walsh-Hadamard, + G is a random Gaussian diagonal, Pi is a random permutation. + Matrix-vector product is O(n log n) instead of O(n²). + Only diagonal seeds and perm are "stored" (~n integers). + + This is an approximation of a random Gaussian matrix with better + memory locality and faster GEMM. + """ + def __init__(self, dim: int, seed: int, device, dtype): + super().__init__() + assert (dim & (dim - 1)) == 0, f"FastFood requires power-of-2 dim, got {dim}" + self.dim = dim + g = torch.Generator() + g.manual_seed(seed) + d1 = (torch.randint(0, 2, (dim,), generator=g).float() * 2 - 1).to(device=device, dtype=dtype) + g2 = torch.Generator() + g2.manual_seed(seed + 1) + d2 = (torch.randint(0, 2, (dim,), generator=g2).float() * 2 - 1).to(device=device, dtype=dtype) + g3 = torch.Generator() + g3.manual_seed(seed + 2) + gauss_d = torch.randn(dim, generator=g3).abs().to(device=device, dtype=dtype) + g4 = torch.Generator() + g4.manual_seed(seed + 3) + perm = torch.randperm(dim, generator=g4).to(device=device) + self.register_buffer('d1', d1, persistent=False) + self.register_buffer('d2', d2, persistent=False) + self.register_buffer('gauss_d', gauss_d, persistent=False) + self.register_buffer('perm', perm, persistent=False) + # Normalization: E[||y||²] = ||x||² for this construction + self.scale = 1.0 / math.sqrt(dim) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """x: [..., dim] → [..., dim]""" + x = x * self.d1 + x = self._fwht(x) + x = x[:, self.perm] if x.ndim == 2 else x[..., self.perm] + x = x * self.gauss_d + x = self._fwht(x) + x = x * self.d2 + return x * self.scale + + @staticmethod + def _fwht(x: torch.Tensor) -> torch.Tensor: + """Fast Walsh-Hadamard Transform (iterative, in-place).""" + n = x.size(-1) + h = 1 + while h < n: + x = x.reshape(*x.shape[:-1], n // (2 * h), 2 * h) + a = x[..., :h] + b = x[..., h:] + x = torch.cat([a + b, a - b], dim=-1) + x = x.reshape(*x.shape[:-2], n) + h *= 2 + return x + + +# --------------------------------------------------------------------------- +# Seeded LoRA Linear Layer +# --------------------------------------------------------------------------- +class SeededLoRALinear(nn.Module): + """ + Linear layer with: + - A random weight matrix W_random(seed) — never stored, regenerated at runtime + - A small LoRA adaptation delta: output += lora_B @ lora_A @ input + - Optional learnable scale for the random component + + During serialization: only lora_A, lora_B, and the seed integer are stored. + The random basis is reconstructed from the seed during deserialization. + """ + def __init__(self, in_features: int, out_features: int, seed: int, rank: int, + use_fastfood: bool = False, learn_random_scale: bool = True): + super().__init__() + self.in_features = in_features + self.out_features = out_features + self.seed = seed + self.rank = rank + self.use_fastfood = use_fastfood + assert in_features == out_features or not use_fastfood, \ + "FastFood requires in_features == out_features" + # LoRA adapters + self.lora_A = nn.Parameter(torch.empty(rank, in_features)) + self.lora_B = nn.Parameter(torch.zeros(out_features, rank)) + # Random basis is NOT a parameter — regenerated on every forward pass + # (use register_buffer for fast access without grad) + # Optional learnable scale for the random component + if learn_random_scale: + self.rand_scale = nn.Parameter(torch.ones(1)) + else: + self.rand_scale = None + # Initialize LoRA A with small values (standard LoRA init) + nn.init.normal_(self.lora_A, std=0.02 / math.sqrt(in_features)) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + # Random component: W_random @ x (no gradient through W_random) + with torch.no_grad(): + W_rand = generate_random_matrix(self.seed, self.out_features, self.in_features, + device=x.device, dtype=x.dtype) + rand_out = F.linear(x, W_rand) # Note: W_rand has no grad, but x does + # rand_out.detach_() # WRONG: would kill attn/mlp gradient path; random component OK + # Scale random output + if self.rand_scale is not None: + rand_out = rand_out * self.rand_scale.to(dtype=x.dtype) + # LoRA component (has gradient) + lora_out = F.linear(F.linear(x, self.lora_A.to(dtype=x.dtype)), self.lora_B.to(dtype=x.dtype)) + return rand_out + lora_out + + def parameter_count(self): + return sum(p.numel() for p in self.parameters()) + + def extra_repr(self): + return (f'in={self.in_features}, out={self.out_features}, ' + f'rank={self.rank}, seed={self.seed}, fastfood={self.use_fastfood}') + + +# --------------------------------------------------------------------------- +# Model reusing same blocks/norms as reference but with SeededLoRA matrices +# --------------------------------------------------------------------------- +class RMSNorm(nn.Module): + def __init__(self, eps=None): + super().__init__() + self.eps = eps + + def forward(self, x): + return F.rms_norm(x, (x.size(-1),), eps=self.eps) + + +class Rotary(nn.Module): + def __init__(self, dim, base=1e4, train_seq_len=2048, rope_dims=0): + super().__init__() + self.rope_dims = rope_dims if rope_dims > 0 else dim + self.base = base + self.train_seq_len = train_seq_len + inv_freq = 1.0 / base ** (torch.arange(0, self.rope_dims, 2, dtype=torch.float32) / self.rope_dims) + self.register_buffer('inv_freq', inv_freq, persistent=False) + self._cache = {} + + def forward(self, seq_len, device, dtype): + key = (seq_len, device, dtype) + if key not in self._cache: + rd = self.rope_dims + if seq_len > self.train_seq_len: + scale = seq_len / self.train_seq_len + new_base = self.base * scale ** (rd / (rd - 2)) + inv_freq = 1.0 / new_base ** (torch.arange(0, rd, 2, dtype=torch.float32, device=device) / rd) + else: + inv_freq = self.inv_freq.to(device) + t = torch.arange(seq_len, device=device, dtype=inv_freq.dtype) + freqs = torch.outer(t, inv_freq) + self._cache[key] = (freqs.cos()[None, :, None, :].to(dtype), + freqs.sin()[None, :, None, :].to(dtype)) + return self._cache[key] + + +def apply_rotary_emb(x, cos, sin, rope_dims=0): + if rope_dims > 0 and rope_dims < x.size(-1): + x_rope, x_pass = x[..., :rope_dims], x[..., rope_dims:] + half = rope_dims // 2 + x1, x2 = x_rope[..., :half], x_rope[..., half:] + return torch.cat((torch.cat((x1 * cos + x2 * sin, x1 * -sin + x2 * cos), dim=-1), x_pass), dim=-1) + half = x.size(-1) // 2 + x1, x2 = x[..., :half], x[..., half:] + return torch.cat((x1 * cos + x2 * sin, x1 * -sin + x2 * cos), dim=-1) + + +class SeededSelfAttention(nn.Module): + def __init__(self, h, layer_idx: int, seed_offset: int = 0): + super().__init__() + dim = h.model_dim + num_heads = h.num_heads + num_kv_heads = h.num_kv_heads + head_dim = dim // num_heads + kv_dim = num_kv_heads * head_dim + rank = h.lora_rank_attn + use_ff = h.use_fastfood and (dim == dim) # square only + base_seed = (layer_idx * 7 + seed_offset) * 1000 + 42 + self.num_heads = num_heads + self.num_kv_heads = num_kv_heads + self.head_dim = head_dim + self.c_q = SeededLoRALinear(dim, dim, base_seed + 0, rank, use_fastfood=False, + learn_random_scale=h.learn_random_scale) + self.c_k = SeededLoRALinear(dim, kv_dim, base_seed + 1, rank, use_fastfood=False, + learn_random_scale=h.learn_random_scale) + self.c_v = SeededLoRALinear(dim, kv_dim, base_seed + 2, rank, use_fastfood=False, + learn_random_scale=h.learn_random_scale) + self.proj = SeededLoRALinear(dim, dim, base_seed + 3, rank, use_fastfood=False, + learn_random_scale=h.learn_random_scale) + self.q_gain = nn.Parameter(torch.full((num_heads,), h.qk_gain_init, dtype=torch.float32)) + self.rope_dims = h.rope_dims + self.rotary = Rotary(head_dim, base=h.rope_base, train_seq_len=h.train_seq_len, rope_dims=h.rope_dims) + + def forward(self, x): + bsz, seqlen, dim = x.shape + q = self.c_q(x).reshape(bsz, seqlen, self.num_heads, self.head_dim) + k = self.c_k(x).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim) + v = self.c_v(x).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim) + q = F.rms_norm(q, (q.size(-1),)) + k = F.rms_norm(k, (k.size(-1),)) + cos, sin = self.rotary(seqlen, x.device, q.dtype) + q = apply_rotary_emb(q, cos, sin, self.rope_dims) + k = apply_rotary_emb(k, cos, sin, self.rope_dims) + q = q * self.q_gain.to(dtype=q.dtype)[None, None, :, None] + if HAS_FLASH3: + y = flash_attn_3_func(q, k, v, causal=True) + else: + q = q.transpose(1, 2) + k = k.transpose(1, 2).expand(-1, self.num_heads, -1, -1) + v = v.transpose(1, 2).expand(-1, self.num_heads, -1, -1) + y = F.scaled_dot_product_attention(q, k, v, is_causal=True).transpose(1, 2) + return self.proj(y.reshape(bsz, seqlen, dim)) + + +class SeededMLP(nn.Module): + def __init__(self, h, layer_idx: int, seed_offset: int = 100): + super().__init__() + dim = h.model_dim + hidden = int(h.mlp_mult * dim) + rank = h.lora_rank_mlp + base_seed = (layer_idx * 7 + seed_offset) * 1000 + 77 + self.fc = SeededLoRALinear(dim, hidden, base_seed + 0, rank, use_fastfood=False, + learn_random_scale=h.learn_random_scale) + self.proj = SeededLoRALinear(hidden, dim, base_seed + 1, rank, use_fastfood=False, + learn_random_scale=h.learn_random_scale) + + def forward(self, x): + return self.proj(F.leaky_relu(self.fc(x), negative_slope=0.5).square()) + + +class SeededBlock(nn.Module): + def __init__(self, h, layer_idx: int): + super().__init__() + self.attn_norm = RMSNorm() + self.mlp_norm = RMSNorm() + self.attn = SeededSelfAttention(h, layer_idx) + self.mlp = SeededMLP(h, layer_idx) + self.attn_scale = nn.Parameter(torch.ones(h.model_dim, dtype=torch.float32)) + self.mlp_scale = nn.Parameter(torch.ones(h.model_dim, dtype=torch.float32)) + + def forward(self, x): + x = x + self.attn_scale.to(dtype=x.dtype)[None, None, :] * self.attn(self.attn_norm(x)) + x = x + self.mlp_scale.to(dtype=x.dtype)[None, None, :] * self.mlp(self.mlp_norm(x)) + return x + + +class SeededGPT(nn.Module): + """ + GPT with seeded random basis + LoRA adapters. + Only LoRA weights and token embeddings are stored; everything else is seeds. + """ + def __init__(self, h): + super().__init__() + self.h = h + self.logit_softcap = h.logit_softcap + self.tok_emb = nn.Embedding(h.vocab_size, h.model_dim) + self.blocks = nn.ModuleList([SeededBlock(h, i) for i in range(h.num_layers)]) + self.final_norm = RMSNorm() + nn.init.normal_(self.tok_emb.weight, mean=0.0, std=h.tied_embed_init_std) + + def _count_lora_params(self): + count = 0 + for name, p in self.named_parameters(): + if 'lora_' in name or 'rand_scale' in name: + count += p.numel() + return count + + def forward_logits(self, input_ids): + x = F.rms_norm(self.tok_emb(input_ids), (self.tok_emb.embedding_dim,)) + for block in self.blocks: + x = block(x) + x = self.final_norm(x) + logits = F.linear(x, self.tok_emb.weight) + return self.logit_softcap * torch.tanh(logits / self.logit_softcap) + + def forward(self, input_ids, target_ids): + logits = self.forward_logits(input_ids) + return F.cross_entropy(logits.reshape(-1, logits.size(-1)).float(), + target_ids.reshape(-1)) + + +# --------------------------------------------------------------------------- +# Data loading (same as main submission — reuse the same infrastructure) +# --------------------------------------------------------------------------- +_SHARD_HEADER_BYTES = 256 * np.dtype(' 0 else 0 + num_sequences = (self.num_tokens[si] - 1 - phase) // self.seq_len + seq_order = self.rng.permutation(num_sequences) + self.start_inds[si] = (phase + seq_order * self.seq_len).tolist() + + def next_batch(self, global_tokens, grad_accum_steps): + device_tokens = global_tokens // (1 * grad_accum_steps) + bsz = device_tokens // self.seq_len + remaining = np.array([len(s) for s in self.start_inds], dtype=np.float64) + x = torch.empty((bsz, self.seq_len), dtype=torch.int64) + y = torch.empty((bsz, self.seq_len), dtype=torch.int64) + for bi in range(bsz): + total = remaining.sum() + if total <= 0: + for si in range(len(self.files)): + self._reset_shard(si) + remaining = np.array([len(s) for s in self.start_inds], dtype=np.float64) + probs = remaining / remaining.sum() + si = int(self.rng.choice(len(self.files), p=probs)) + start = self.start_inds[si].pop() + remaining[si] -= 1 + mm = _get_shard_memmap(self.files[si]) + window = torch.as_tensor(np.array(mm[start:start + self.seq_len + 1], dtype=np.int64)) + x[bi], y[bi] = window[:-1], window[1:] + return x.to(self.device, non_blocking=True), y.to(self.device, non_blocking=True) + + +# --------------------------------------------------------------------------- +# Evaluation +# --------------------------------------------------------------------------- +def eval_val_sliding(h, device, val_data, model, batch_seqs=32): + model.eval() + seq_len = h.eval_seq_len + context = seq_len - h.eval_stride + total = val_data.val_tokens.numel() - 1 + starts = [ws for ws in range(0, total, h.eval_stride) if ws + context < total] + my_starts = starts[h.rank::h.world_size] + ls = torch.zeros((), device=device, dtype=torch.float64) + tc = torch.zeros((), device=device, dtype=torch.float64) + bc = torch.zeros((), device=device, dtype=torch.float64) + with torch.inference_mode(): + for bi in range(0, len(my_starts), batch_seqs): + ws_batch = my_starts[bi:bi + batch_seqs] + bsz = len(ws_batch) + x_b = torch.zeros(bsz, seq_len, dtype=torch.int64, device=device) + y_b = torch.zeros(bsz, seq_len, dtype=torch.int64, device=device) + wlens = [] + for i, ws in enumerate(ws_batch): + we = min(ws + seq_len, total) + wlen = we - ws + wlens.append(wlen) + ch = val_data.val_tokens[ws:we + 1].to(dtype=torch.int64, device=device) + x_b[i, :wlen] = ch[:-1] + y_b[i, :wlen] = ch[1:] + with torch.autocast(device_type='cuda', dtype=torch.bfloat16): + logits = model.forward_logits(x_b) + nll = F.cross_entropy(logits.reshape(-1, logits.size(-1)).float(), + y_b.reshape(-1), reduction='none').reshape(bsz, seq_len) + for i, ws in enumerate(ws_batch): + wlen = wlens[i] + s = 0 if ws == 0 else context + ls += nll[i, s:wlen].to(torch.float64).sum() + tc += float(wlen - s) + tgt = y_b[i, s:wlen] + prev = x_b[i, s:wlen] + tb = val_data.base_bytes_lut[tgt].to(torch.float64) + tb += (val_data.has_leading_space_lut[tgt] & ~val_data.is_boundary_token_lut[prev]).to(torch.float64) + bc += tb.sum() + if dist.is_available() and dist.is_initialized(): + dist.all_reduce(ls, op=dist.ReduceOp.SUM) + dist.all_reduce(tc, op=dist.ReduceOp.SUM) + dist.all_reduce(bc, op=dist.ReduceOp.SUM) + val_loss = (ls / tc).item() + val_bpb = val_loss / math.log(2.0) * (tc.item() / bc.item()) + model.train() + return val_loss, val_bpb + + +# --------------------------------------------------------------------------- +# Optimizer — all parameters are LoRA + embeddings, use AdamW for simplicity +# (Can switch to Muon for lora_B matrices if needed) +# --------------------------------------------------------------------------- +@torch.compile +def zeropower_via_newtonschulz5(G, steps=10, eps=1e-7): + a, b, c = 3.4445, -4.7750, 2.0315 + X = G.bfloat16() + X /= X.norm() + eps + transposed = G.size(0) > G.size(1) + if transposed: + X = X.T + for _ in range(steps): + A = X @ X.T + B = b * A + c * A @ A + X = a * X + B @ X + return X.T if transposed else X + + +def make_optimizers(h, model): + embed_params = [model.tok_emb.weight] + lora_params = [p for name, p in model.named_parameters() + if 'lora_' in name or 'rand_scale' in name] + scalar_params = [p for name, p in model.named_parameters() + if 'attn_scale' in name or 'mlp_scale' in name or 'q_gain' in name] + opt_embed = torch.optim.AdamW( + [{'params': embed_params, 'lr': h.tied_embed_lr, 'base_lr': h.tied_embed_lr}], + betas=(h.beta1, h.beta2), eps=h.adam_eps, weight_decay=h.embed_wd, fused=True) + opt_lora = torch.optim.AdamW( + [{'params': lora_params, 'lr': h.lora_lr, 'base_lr': h.lora_lr}], + betas=(h.beta1, h.beta2), eps=h.adam_eps, weight_decay=h.adam_wd, fused=True) + opt_scalar = torch.optim.AdamW( + [{'params': scalar_params, 'lr': h.scalar_lr, 'base_lr': h.scalar_lr}], + betas=(h.beta1, h.beta2), eps=h.adam_eps, weight_decay=0.0, fused=True) + return [opt_embed, opt_lora, opt_scalar] + + +# --------------------------------------------------------------------------- +# Training +# --------------------------------------------------------------------------- +def train_model(h, device, val_data): + base_model = SeededGPT(h).to(device).bfloat16() + for name, p in base_model.named_parameters(): + if p.ndim < 2 or 'q_gain' in name or 'attn_scale' in name or 'mlp_scale' in name: + p.data = p.data.float() + compiled = torch.compile(base_model, dynamic=False) + if h.distributed: + model = DDP(compiled, device_ids=[h.local_rank], broadcast_buffers=False) + else: + model = compiled + total_params = sum(p.numel() for p in base_model.parameters()) + lora_params = base_model._count_lora_params() + emb_params = base_model.tok_emb.weight.numel() + log(f"total_params={total_params} lora_params={lora_params} emb_params={emb_params}") + log(f"param_budget: LoRA={lora_params} embed={emb_params} " + f"(random_basis: {total_params - lora_params - emb_params} params never stored)") + + optimizers = make_optimizers(h, base_model) + train_loader = ShuffledSequenceLoader(h, device) + max_ms = 1e3 * h.max_wallclock_seconds if h.max_wallclock_seconds > 0 else None + if max_ms is not None: + max_ms -= h.gptq_reserve_seconds * 1e3 + + def lr_scale(elapsed_ms, step): + frac = elapsed_ms / max_ms if max_ms else step / h.iterations + if frac >= 1.0 - h.warmdown_frac: + return max((1.0 - frac) / h.warmdown_frac, h.min_lr) + return 1.0 + + ema = {n: t.detach().float().clone() for n, t in base_model.state_dict().items()} + training_ms = 0.0 + stop_step = None + torch.cuda.synchronize() + t0 = time.perf_counter() + step = 0 + + while True: + last = step == h.iterations or (stop_step is not None and step >= stop_step) + if last or (h.val_loss_every > 0 and step % h.val_loss_every == 0): + torch.cuda.synchronize() + training_ms += 1e3 * (time.perf_counter() - t0) + val_loss, val_bpb = eval_val_sliding(h, device, val_data, base_model) + log(f"step:{step} val_loss:{val_loss:.6f} val_bpb:{val_bpb:.6f}") + torch.cuda.synchronize() + t0 = time.perf_counter() + if last: + break + elapsed = training_ms + 1e3 * (time.perf_counter() - t0) + scale = lr_scale(elapsed, step) + for opt in optimizers: + opt.zero_grad(set_to_none=True) + train_loss = torch.zeros((), device=device) + for _ in range(h.grad_accum_steps): + x, y = train_loader.next_batch(h.train_batch_tokens, h.grad_accum_steps) + if h.distributed: + model.require_backward_grad_sync = (_ == h.grad_accum_steps - 1) + with torch.autocast(device_type='cuda', dtype=torch.bfloat16): + loss = model(x, y) + train_loss += loss.detach() + (loss / h.grad_accum_steps).backward() + train_loss /= h.grad_accum_steps + for opt in optimizers: + for group in opt.param_groups: + group['lr'] = group['base_lr'] * scale + if h.grad_clip_norm > 0: + torch.nn.utils.clip_grad_norm_(base_model.parameters(), h.grad_clip_norm) + for opt in optimizers: + opt.step() + with torch.no_grad(): + for name, t in base_model.state_dict().items(): + ema[name].mul_(h.ema_decay).add_(t.detach().float(), alpha=1.0 - h.ema_decay) + step += 1 + approx_ms = training_ms + 1e3 * (time.perf_counter() - t0) + if h.train_log_every > 0 and (step <= 5 or step % h.train_log_every == 0): + tok_per_sec = step * h.train_batch_tokens / (approx_ms / 1e3) + log(f"{step}/{h.iterations} train_loss:{train_loss.item():.4f} " + f"time:{approx_ms / 60000:.1f}m tok/s:{tok_per_sec:.0f}") + if stop_step is None and max_ms is not None and approx_ms >= max_ms: + stop_step = step + + avg_state = {n: t.to(dtype=base_model.state_dict()[n].dtype) for n, t in ema.items()} + base_model.load_state_dict(avg_state, strict=True) + return base_model + + +def main(): + distributed = 'RANK' in os.environ and 'WORLD_SIZE' in os.environ + local_rank = int(os.environ.get('LOCAL_RANK', '0')) + device = torch.device('cuda', local_rank) + torch.cuda.set_device(device) + if distributed: + dist.init_process_group(backend='nccl', device_id=device) + dist.barrier() + torch.backends.cuda.matmul.allow_tf32 = True + torch.backends.cudnn.allow_tf32 = True + torch.set_float32_matmul_precision('high') + h = Hyperparameters() + set_log_h(h) + if h.is_main_process: + os.makedirs('logs', exist_ok=True) + log(f"=== Seeded Random Basis + LoRA ===") + log(f"lora_rank: attn={h.lora_rank_attn} mlp={h.lora_rank_mlp}") + log(f"use_fastfood={h.use_fastfood} learn_random_scale={h.learn_random_scale}") + random.seed(h.seed) + np.random.seed(h.seed) + torch.manual_seed(h.seed) + torch.cuda.manual_seed_all(h.seed) + val_data = ValidationData(h, device) + train_model(h, device, val_data) + if distributed: + dist.destroy_process_group() + + +if __name__ == '__main__': + main() diff --git a/gauntlet.sh b/gauntlet.sh new file mode 100755 index 0000000000..86981ff99c --- /dev/null +++ b/gauntlet.sh @@ -0,0 +1,287 @@ +#!/usr/bin/env bash +# ============================================================ +# GAUNTLET TEST — parameter-golf experiment runner +# Runs all experiments sequentially, records BPB for each. +# Designed for 1×H100 ablation mode (cheap, ~$3.50/hr) +# +# Usage: +# bash gauntlet.sh [--steps N] [--vocab V] [--gpus N] +# +# Options: +# --steps N Training iterations per experiment (default: 2000) +# --vocab V Vocabulary size: 1024 or 8192 (default: 1024) +# --gpus N Number of GPUs per run (default: 1) +# --deq-only Only run DEQ experiment +# --seeds-only Only run Seed-LoRA experiment +# --incr-only Only run incremental submission +# --skip-deq Skip DEQ (it's slower per step) +# +# Output: gauntlet_results.txt with timestamped BPB for every run +# ============================================================ +# NOTE: do not use set -e — experiments are allowed to fail individually + +STEPS=2000 +VOCAB=1024 +GPUS=1 +SKIP_DEQ=0 +SKIP_BASELINE=0 +INCR_ONLY=0 +DEQ_ONLY=0 +SEEDS_ONLY=0 +MOD_ONLY=0 + +while [[ $# -gt 0 ]]; do + case "$1" in + --steps) STEPS="$2"; shift 2 ;; + --vocab) VOCAB="$2"; shift 2 ;; + --gpus) GPUS="$2"; shift 2 ;; + --skip-deq) SKIP_DEQ=1; shift ;; + --skip-baseline) SKIP_BASELINE=1; shift ;; + --incr-only) INCR_ONLY=1; shift ;; + --deq-only) DEQ_ONLY=1; shift ;; + --seeds-only) SEEDS_ONLY=1; shift ;; + --mod-only) MOD_ONLY=1; shift ;; + *) echo "Unknown arg: $1"; exit 1 ;; + esac +done + +# ---- paths ---- +REPO_DIR="$(cd "$(dirname "$0")" && pwd)" +DATA_DIR="$REPO_DIR/data/datasets/fineweb10B_sp${VOCAB}" +TOK_DIR="$REPO_DIR/data/tokenizers" +TOKENIZER_PATH="$TOK_DIR/fineweb_${VOCAB}_bpe.model" +RESULTS="$REPO_DIR/gauntlet_results.txt" +LOG_DIR="$REPO_DIR/logs/gauntlet" +mkdir -p "$LOG_DIR" + +echo "========================================" +echo " GAUNTLET TEST — $(date)" +echo " steps=$STEPS vocab=sp${VOCAB} gpus=$GPUS" +echo "========================================" +echo "" + +# ---- check data ---- +if [ ! -d "$DATA_DIR" ]; then + echo "ERROR: Data not found at $DATA_DIR" + echo "Run: python3 data/cached_challenge_fineweb.py --variant sp${VOCAB} --train-shards 5" + echo "Or check: tail -f /tmp/data_download.log" + exit 1 +fi + +if [ ! -f "$TOKENIZER_PATH" ]; then + echo "ERROR: Tokenizer not found at $TOKENIZER_PATH" + exit 1 +fi + +echo "Data: $DATA_DIR ✓" +echo "Tokenizer: $TOKENIZER_PATH ✓" +echo "" + +# ---- helper ---- +run_experiment() { + local NAME="$1" + local SCRIPT="$2" + local EXTRA_ENV="${3:-}" + + echo "" + echo "━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━" + echo " ▶ $NAME" + echo "━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━" + + if [ ! -f "$SCRIPT" ]; then + echo " SKIP: script not found: $SCRIPT" + return + fi + + RUN_ID="gauntlet_${NAME}_$(date +%H%M%S)" + LOGFILE="$LOG_DIR/${RUN_ID}.log" + local T_START + T_START=$(date +%s) + + # On 1 GPU, don't accumulate 8 microbatches — use the full batch in one pass + ACCUM_STEPS=$(( 8 / GPUS )) + [[ $ACCUM_STEPS -lt 1 ]] && ACCUM_STEPS=1 + + env \ + RUN_ID="$RUN_ID" \ + VOCAB_SIZE="$VOCAB" \ + DATA_DIR="$REPO_DIR/data" \ + ITERATIONS="$STEPS" \ + VAL_LOSS_EVERY=500 \ + MAX_WALLCLOCK_SECONDS=580 \ + SLIDING_WINDOW_ENABLED=1 \ + TRAIN_BATCH_TOKENS=786432 \ + $EXTRA_ENV \ + torchrun --standalone --nproc_per_node="$GPUS" "$SCRIPT" 2>&1 | tee "$LOGFILE" + + local T_END + T_END=$(date +%s) + local ELAPSED=$(( T_END - T_START )) + + # Extract best val_bpb from log + local BEST_BPB + BEST_BPB=$(grep -oP 'val_bpb:\K[0-9.]+' "$LOGFILE" | sort -n | head -1 || echo "N/A") + local LAST_BPB + LAST_BPB=$(grep -oP 'val_bpb:\K[0-9.]+' "$LOGFILE" | tail -1 || echo "N/A") + + echo "" + echo " ✓ DONE — elapsed: ${ELAPSED}s | best_bpb: $BEST_BPB | last_bpb: $LAST_BPB" + + # Append to results file + echo "$(date '+%Y-%m-%d %H:%M:%S') | experiment=${NAME} | steps=${STEPS} | vocab=sp${VOCAB} | gpus=${GPUS} | best_bpb=${BEST_BPB} | last_bpb=${LAST_BPB} | elapsed=${ELAPSED}s | log=${LOGFILE}" \ + >> "$RESULTS" +} + +# ====================================================== +# 0. BASELINE — original train_gpt.py (control group) +# ====================================================== +if [[ $INCR_ONLY -eq 0 && $DEQ_ONLY -eq 0 && $SEEDS_ONLY -eq 0 && $MOD_ONLY -eq 0 && $SKIP_BASELINE -eq 0 ]]; then + run_experiment "baseline" \ + "$REPO_DIR/train_gpt.py" +fi + +# ====================================================== +# 1. INCREMENTAL SUBMISSION +# QK-Gain 5.5 + 4-loop recurrence + early parallel residuals + selective TTT +# ====================================================== +if [[ $DEQ_ONLY -eq 0 && $SEEDS_ONLY -eq 0 && $MOD_ONLY -eq 0 ]]; then + INCR_SCRIPT=$(ls -t "$REPO_DIR"/records/track_10min_16mb/2026-04-23_QK55_*/train_gpt.py 2>/dev/null | head -1) + if [ -n "$INCR_SCRIPT" ]; then + run_experiment "incr_QK55_4loop" "$INCR_SCRIPT" + else + echo "WARN: incremental submission not found, skipping" + fi +fi + +# ====================================================== +# 2. ABLATION: QK-Gain 5.5 alone (no loop change) +# Lets us isolate the gain of QK 5.5 vs 5.25 +# ====================================================== +if [[ $DEQ_ONLY -eq 0 && $SEEDS_ONLY -eq 0 && $INCR_ONLY -eq 0 && $MOD_ONLY -eq 0 ]]; then + run_experiment "ablation_QK55_only" \ + "$REPO_DIR/train_gpt.py" \ + "QK_GAIN_INIT=5.5" +fi + +# ====================================================== +# 3. ABLATION: 4-loop only (QK stays at 5.25) +# ====================================================== +if [[ $DEQ_ONLY -eq 0 && $SEEDS_ONLY -eq 0 && $INCR_ONLY -eq 0 && $MOD_ONLY -eq 0 ]]; then + run_experiment "ablation_4loop_only" \ + "$REPO_DIR/train_gpt.py" \ + "NUM_LOOPS=3" +fi + +# ====================================================== +# 4. DEQ UNIVERSAL TRANSFORMER +# 1 physical block → fixed-point convergence (Anderson acceleration) +# ====================================================== +if [[ $SEEDS_ONLY -eq 0 && $INCR_ONLY -eq 0 && $SKIP_DEQ -eq 0 && $MOD_ONLY -eq 0 ]]; then + run_experiment "deq_universal" \ + "$REPO_DIR/experiments/train_gpt_deq.py" \ + "DEQ_MAX_ITER_TRAIN=8 DEQ_PHANTOM_STEPS=4 DEQ_MAX_ITER_EVAL=16 DEQ_TOL=1e-3" +fi + +# ====================================================== +# 5. SEED-LORA +# Random basis weights (seeded) + rank-8/4 LoRA adapters only stored +# ====================================================== +if [[ $DEQ_ONLY -eq 0 && $INCR_ONLY -eq 0 && $MOD_ONLY -eq 0 ]]; then + run_experiment "seed_lora_r8" \ + "$REPO_DIR/experiments/train_gpt_seeds.py" \ + "LORA_RANK_ATTN=8 LORA_RANK_MLP=4 LEARN_RANDOM_SCALE=1" +fi + +# ====================================================== +# 6. SEED-LORA HIGH RANK +# Since we have budget left, try rank-32 (still tiny vs 16MB) +# ====================================================== +if [[ $DEQ_ONLY -eq 0 && $INCR_ONLY -eq 0 && $MOD_ONLY -eq 0 ]]; then + run_experiment "seed_lora_r32" \ + "$REPO_DIR/experiments/train_gpt_seeds.py" \ + "LORA_RANK_ATTN=32 LORA_RANK_MLP=16 LEARN_RANDOM_SCALE=1" +fi + +# ====================================================== +# 7. MIXTURE OF DEPTHS — 50% routing capacity +# ~2× faster training → more steps in 10 min → lower BPB +# On OpenAI's wish list. Explicitly mentioned in README. +# ====================================================== +if [[ $DEQ_ONLY -eq 0 && $SEEDS_ONLY -eq 0 && $INCR_ONLY -eq 0 ]]; then + run_experiment "mod_capacity50" \ + "$REPO_DIR/experiments/train_gpt_mod.py" \ + "MOD_CAPACITY=0.5 MOD_LAYERS=all MOD_AUX_LOSS_COEF=0.01" +fi + +# ====================================================== +# 8. MIXTURE OF DEPTHS — 25% routing (more aggressive) +# Skips 75% of tokens per layer — even faster training +# but may hurt quality. Tests the tradeoff. +# ====================================================== +if [[ $DEQ_ONLY -eq 0 && $SEEDS_ONLY -eq 0 && $INCR_ONLY -eq 0 ]]; then + run_experiment "mod_capacity25" \ + "$REPO_DIR/experiments/train_gpt_mod.py" \ + "MOD_CAPACITY=0.25 MOD_LAYERS=all MOD_AUX_LOSS_COEF=0.005" +fi + +# ====================================================== +# 9. SPECULATIVE MUON — faster Newton-Schulz (2 steps vs 5) +# ~40% less optimizer compute → more training steps in 10 min +# Free ablation: just change MUON_BACKEND_STEPS +# ====================================================== +if [[ $DEQ_ONLY -eq 0 && $SEEDS_ONLY -eq 0 && $INCR_ONLY -eq 0 ]]; then + INCR_SCRIPT=$(ls -t "$REPO_DIR"/records/track_10min_16mb/2026-04-23_QK55_*/train_gpt.py 2>/dev/null | head -1) + if [ -n "$INCR_SCRIPT" ]; then + run_experiment "speculative_muon_ns2" \ + "$INCR_SCRIPT" \ + "MUON_BACKEND_STEPS=2" + fi +fi + +# ====================================================== +# 10. SPECULATIVE MUON — 3 steps (middle ground) +# ====================================================== +if [[ $DEQ_ONLY -eq 0 && $SEEDS_ONLY -eq 0 && $INCR_ONLY -eq 0 ]]; then + INCR_SCRIPT=$(ls -t "$REPO_DIR"/records/track_10min_16mb/2026-04-23_QK55_*/train_gpt.py 2>/dev/null | head -1) + if [ -n "$INCR_SCRIPT" ]; then + run_experiment "speculative_muon_ns3" \ + "$INCR_SCRIPT" \ + "MUON_BACKEND_STEPS=3" + fi +fi + +# ====================================================== +# RESULTS SUMMARY +# ====================================================== +echo "" +echo "════════════════════════════════════════════════" +echo " GAUNTLET COMPLETE — $(date)" +echo "════════════════════════════════════════════════" +echo "" +if [ -f "$RESULTS" ]; then + echo "Results (sorted by best_bpb):" + echo "" + # Print header + printf " %-30s %-10s %-10s %-8s\n" "EXPERIMENT" "BEST_BPB" "LAST_BPB" "TIME" + printf " %-30s %-10s %-10s %-8s\n" "----------" "--------" "--------" "----" + # Parse and sort results + grep "$(date '+%Y-%m-%d')" "$RESULTS" 2>/dev/null | \ + awk -F'|' ' + { + for(i=1;i<=NF;i++){ + if($i ~ /experiment=/) { split($i,a,"="); name=a[2] } + if($i ~ /best_bpb=/) { split($i,a,"="); best=a[2] } + if($i ~ /last_bpb=/) { split($i,a,"="); last=a[2] } + if($i ~ /elapsed=/) { split($i,a,"="); elapsed=a[2] } + } + print best, name, last, elapsed + }' | sort -n | \ + while read bpb name last_bpb elapsed; do + printf " %-30s %-10s %-10s %-8s\n" "$name" "$bpb" "$last_bpb" "$elapsed" + done + echo "" + echo "Current SOTA: 1.0810 bpb (bigbag, 2026-04-09)" +fi +echo "" +echo "Full results: $RESULTS" +echo "Logs: $LOG_DIR/" diff --git a/gauntlet_results.txt b/gauntlet_results.txt new file mode 100644 index 0000000000..80e2b806ff --- /dev/null +++ b/gauntlet_results.txt @@ -0,0 +1,2 @@ +2026-04-23 01:42:16 | experiment=baseline | steps=2000 | vocab=sp1024 | gpus=1 | best_bpb=1.2964 | last_bpb=1.29774005 | elapsed=1263s | log=/workspace/parameter-golf/logs/gauntlet/gauntlet_baseline_012113.log +2026-04-23 03:24:47 | experiment=incr_QK55_4loop | steps=2000 | vocab=sp1024 | gpus=1 | best_bpb=1.17158907 | last_bpb=1.17158907 | elapsed=5540s | log=/workspace/parameter-golf/logs/gauntlet/gauntlet_incr_QK55_4loop_015227.log diff --git a/logs_pod1/gauntlet/gauntlet_baseline_011757.log b/logs_pod1/gauntlet/gauntlet_baseline_011757.log new file mode 100644 index 0000000000..e69de29bb2 diff --git a/logs_pod1/gauntlet/gauntlet_baseline_012113.log b/logs_pod1/gauntlet/gauntlet_baseline_012113.log new file mode 100644 index 0000000000..e58caa6721 --- /dev/null +++ b/logs_pod1/gauntlet/gauntlet_baseline_012113.log @@ -0,0 +1,64 @@ +logs/gauntlet_baseline_012113.txt +val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path=./data/tokenizers/fineweb_1024_bpe.model +train_loader:dataset:fineweb10B_sp1024 train_shards:5 +val_loader:shards pattern=./data/datasets/fineweb10B_sp1024/fineweb_val_*.bin tokens:62021632 +model_params:17059912 +world_size:1 grad_accum_steps:8 +sdp_backends:cudnn=False flash=True mem_efficient=False math=False +attention_mode:gqa num_heads:8 num_kv_heads:4 +tie_embeddings:True embed_lr:0.05 head_lr:0.0 matrix_lr:0.04 scalar_lr:0.04 +train_batch_tokens:524288 train_seq_len:1024 iterations:2000 warmup_steps:20 max_wallclock_seconds:0.000 +seed:1337 +warmup_step:1/20 +warmup_step:2/20 +warmup_step:3/20 +warmup_step:4/20 +warmup_step:5/20 +warmup_step:6/20 +warmup_step:7/20 +warmup_step:8/20 +warmup_step:9/20 +warmup_step:10/20 +warmup_step:11/20 +warmup_step:12/20 +warmup_step:13/20 +warmup_step:14/20 +warmup_step:15/20 +warmup_step:16/20 +warmup_step:17/20 +warmup_step:18/20 +warmup_step:19/20 +warmup_step:20/20 +step:0/2000 val_loss:6.9357 val_bpb:4.1077 train_time:0ms step_avg:0.01ms +step:1/2000 train_loss:6.9357 train_time:471ms step_avg:471.35ms +step:2/2000 train_loss:16.7414 train_time:977ms step_avg:488.32ms +step:3/2000 train_loss:8.7524 train_time:1494ms step_avg:497.98ms +step:4/2000 train_loss:6.5885 train_time:2021ms step_avg:505.21ms +step:5/2000 train_loss:6.6522 train_time:2572ms step_avg:514.48ms +step:6/2000 train_loss:6.6999 train_time:3075ms step_avg:512.47ms +step:7/2000 train_loss:6.3335 train_time:3603ms step_avg:514.65ms +step:8/2000 train_loss:6.1811 train_time:4162ms step_avg:520.30ms +step:9/2000 train_loss:6.0837 train_time:4645ms step_avg:516.09ms +step:10/2000 train_loss:5.9956 train_time:5172ms step_avg:517.15ms +step:200/2000 train_loss:2.7943 train_time:105274ms step_avg:526.37ms +step:400/2000 train_loss:2.3993 train_time:210577ms step_avg:526.44ms +step:500/2000 val_loss:2.5003 val_bpb:1.4808 train_time:263172ms step_avg:526.34ms +step:600/2000 train_loss:2.4985 train_time:316018ms step_avg:526.70ms +step:800/2000 train_loss:2.3520 train_time:421852ms step_avg:527.31ms +step:1000/2000 train_loss:2.4717 train_time:527119ms step_avg:527.12ms +step:1000/2000 val_loss:2.3214 val_bpb:1.3748 train_time:527120ms step_avg:527.12ms +step:1200/2000 train_loss:2.2456 train_time:632752ms step_avg:527.29ms +step:1400/2000 train_loss:2.1673 train_time:739270ms step_avg:528.05ms +step:1500/2000 val_loss:2.2396 val_bpb:1.3264 train_time:791889ms step_avg:527.93ms +step:1600/2000 train_loss:2.1944 train_time:845076ms step_avg:528.17ms +step:1800/2000 train_loss:2.1528 train_time:955409ms step_avg:530.78ms +step:2000/2000 train_loss:2.2258 train_time:1066704ms step_avg:533.35ms +step:2000/2000 val_loss:2.1889 val_bpb:1.2964 train_time:1066705ms step_avg:533.35ms +peak memory allocated: 10238 MiB reserved: 10730 MiB +Serialized model: 67224983 bytes +Code size: 47686 bytes +Total submission size: 67272669 bytes +Serialized model int8+zlib: 14961007 bytes (payload:17178912 raw_torch:17224025 payload_ratio:3.91x) +Total submission size int8+zlib: 15008693 bytes +final_int8_zlib_roundtrip val_loss:2.1912 val_bpb:1.2977 eval_time:11181ms +final_int8_zlib_roundtrip_exact val_loss:2.19117971 val_bpb:1.29774005 diff --git a/logs_pod1/gauntlet/gauntlet_incr_QK55_4loop_015227.log b/logs_pod1/gauntlet/gauntlet_incr_QK55_4loop_015227.log new file mode 100644 index 0000000000..b965e907ad --- /dev/null +++ b/logs_pod1/gauntlet/gauntlet_incr_QK55_4loop_015227.log @@ -0,0 +1,176 @@ +Hyperparameters: + adam_eps: 1e-08 + adam_wd: 0.02 + beta1: 0.9 + beta2: 0.95 + compressor: brotli + data_dir: /workspace/parameter-golf/data + datasets_dir: /workspace/parameter-golf/data/datasets/fineweb10B_sp1024 + distributed: True + ema_decay: 0.9965 + embed_bits: 8 + embed_clip_sigmas: 20.0 + embed_lr: 0.6 + embed_wd: 0.085 + embedding_dim: 512 + enable_looping_at: 0.35 + eval_seq_len: 2048 + eval_stride: 64 + gptq_calibration_batches: 64 + gptq_reserve_seconds: 12.0 + grad_accum_steps: 8 + grad_clip_norm: 0.3 + head_lr: 0.008 + is_main_process: True + iterations: 2000 + ln_scale: True + local_rank: 0 + logfile: logs/gauntlet_incr_QK55_4loop_015227.txt + logit_softcap: 30.0 + loop_end: 5 + loop_start: 3 + matrix_bits: 6 + matrix_clip_sigmas: 12.85 + matrix_lr: 0.022 + max_wallclock_seconds: 0.0 + min_lr: 0.0 + mlp_mult: 4.0 + model_dim: 512 + model_path: final_model.pt + muon_backend_steps: 5 + muon_beta2: 0.95 + muon_momentum: 0.99 + muon_momentum_warmup_start: 0.92 + muon_momentum_warmup_steps: 1500 + muon_row_normalize: True + muon_wd: 0.095 + num_heads: 8 + num_kv_heads: 4 + num_layers: 11 + num_loops: 3 + parallel_residual_start: 5 + qk_gain_init: 5.5 + quantized_model_path: final_model.int6.ptz + rank: 0 + rope_base: 10000.0 + rope_dims: 16 + rope_train_seq_len: 2048 + run_id: gauntlet_incr_QK55_4loop_015227 + scalar_lr: 0.02 + seed: 1337 + skip_gates_enabled: True + sliding_window_enabled: True + tie_embeddings: True + tied_embed_init_std: 0.005 + tied_embed_lr: 0.03 + tokenizer_path: /workspace/parameter-golf/data/tokenizers/fineweb_1024_bpe.model + train_batch_tokens: 786432 + train_files: /workspace/parameter-golf/data/datasets/fineweb10B_sp1024/fineweb_train_*.bin + train_log_every: 500 + train_seq_len: 2048 + ttt_chunk_tokens: 24576 + ttt_enabled: False + ttt_epochs: 4 + ttt_lr: 0.005 + ttt_momentum: 0.9 + ttt_selective: True + val_batch_tokens: 524288 + val_files: /workspace/parameter-golf/data/datasets/fineweb10B_sp1024/fineweb_val_*.bin + val_loss_every: 500 + vocab_size: 1024 + warmdown_frac: 0.72 + warmup_steps: 20 + world_size: 1 +train_shards: 5 +val_tokens: 62021632 +model_params: 32276568 +warmup_step: 1/20 +warmup_step: 2/20 +warmup_step: 3/20 +warmup_step: 4/20 +warmup_step: 5/20 +warmup_step: 6/20 +warmup_step: 10/20 +warmup_step: 20/20 +loop_warmup: enabled encoder:[0, 1, 2, 3, 4, 5, 3, 4, 5, 3] decoder:[4, 5, 3, 4, 5, 6, 7, 8, 9, 10] +loop_warmup_step: 1/20 +loop_warmup_step: 2/20 +loop_warmup_step: 3/20 +loop_warmup_step: 4/20 +loop_warmup_step: 5/20 +loop_warmup_step: 6/20 +loop_warmup_step: 10/20 +loop_warmup_step: 20/20 +0/2000 val_loss: 6.9384 val_bpb: 4.1093 +1/2000 train_loss: 6.9345 train_time: 0.0m tok/s: 1200803 +2/2000 train_loss: 12.1271 train_time: 0.0m tok/s: 1189039 +3/2000 train_loss: 9.9963 train_time: 0.0m tok/s: 1178249 +4/2000 train_loss: 7.8728 train_time: 0.0m tok/s: 1174070 +5/2000 train_loss: 6.5889 train_time: 0.1m tok/s: 1171437 +500/2000 train_loss: 2.3018 train_time: 5.7m tok/s: 1146850 +500/2000 val_loss: 2.3087 val_bpb: 1.3673 +layer_loop: enabled step:700 frac:0.350 encoder:[0, 1, 2, 3, 4, 5, 3, 4, 5, 3] decoder:[4, 5, 3, 4, 5, 6, 7, 8, 9, 10] +1000/2000 train_loss: 2.1383 train_time: 14.1m tok/s: 927822 +1000/2000 val_loss: 2.1413 val_bpb: 1.2682 +1500/2000 train_loss: 2.0394 train_time: 24.3m tok/s: 808192 +1500/2000 val_loss: 2.0580 val_bpb: 1.2188 +2000/2000 train_loss: 1.9188 train_time: 34.5m tok/s: 758935 +2000/2000 val_loss: 1.9713 val_bpb: 1.1675 +peak memory allocated: 39808 MiB reserved: 39916 MiB +ema: applying EMA weights +pre-quantization post-ema val_loss:1.97817906 val_bpb:1.17158907 eval_time:33143ms +Serialized model: 128099193 bytes +Code size: 65269 bytes +GPTQ: collecting Hessians from calibration data... +GPTQ: collected 67 Hessians in 17.3s +Quantized weights: + gptq (int6): blocks.attn.c_k.weight, blocks.attn.c_q.weight, blocks.attn.c_v.weight, blocks.attn.proj.weight, blocks.mlp.fc.weight, blocks.mlp.proj.weight + gptq (int8): tok_emb.weight + passthrough (float16): blocks.attn.q_gain, blocks.attn_scale, blocks.mlp_scale, blocks.resid_mix, skip_gates, skip_weights +[rank0]: Traceback (most recent call last): +[rank0]: File "/workspace/parameter-golf/records/track_10min_16mb/2026-04-23_QK55_4Loop_EarlyParResid_SelectiveTTT/train_gpt.py", line 1385, in +[rank0]: main() +[rank0]: File "/workspace/parameter-golf/records/track_10min_16mb/2026-04-23_QK55_4Loop_EarlyParResid_SelectiveTTT/train_gpt.py", line 1379, in main +[rank0]: train_and_eval(h, device) +[rank0]: File "/workspace/parameter-golf/records/track_10min_16mb/2026-04-23_QK55_4Loop_EarlyParResid_SelectiveTTT/train_gpt.py", line 1319, in train_and_eval +[rank0]: serialize(h, base_model, Path(__file__).read_text(encoding='utf-8')) +[rank0]: File "/workspace/parameter-golf/records/track_10min_16mb/2026-04-23_QK55_4Loop_EarlyParResid_SelectiveTTT/train_gpt.py", line 899, in serialize +[rank0]: quant_blob = _compress(quant_raw, h.compressor) +[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ +[rank0]: File "/workspace/parameter-golf/records/track_10min_16mb/2026-04-23_QK55_4Loop_EarlyParResid_SelectiveTTT/train_gpt.py", line 864, in _compress +[rank0]: import brotli +[rank0]: ModuleNotFoundError: No module named 'brotli' +[rank0]:[W423 03:24:44.766043241 ProcessGroupNCCL.cpp:1524] Warning: WARNING: destroy_process_group() was not called before program exit, which can leak resources. For more info, please see https://pytorch.org/docs/stable/distributed.html#shutdown (function operator()) +E0423 03:24:46.722000 8738 torch/distributed/elastic/multiprocessing/api.py:882] failed (exitcode: 1) local_rank: 0 (pid: 8807) of binary: /usr/local/bin/python +Traceback (most recent call last): + File "/usr/local/bin/torchrun", line 7, in + sys.exit(main()) + ^^^^^^ + File "/usr/local/lib/python3.12/dist-packages/torch/distributed/elastic/multiprocessing/errors/__init__.py", line 357, in wrapper + return f(*args, **kwargs) + ^^^^^^^^^^^^^^^^^^ + File "/usr/local/lib/python3.12/dist-packages/torch/distributed/run.py", line 936, in main + run(args) + File "/usr/local/lib/python3.12/dist-packages/torch/distributed/run.py", line 927, in run + elastic_launch( + File "/usr/local/lib/python3.12/dist-packages/torch/distributed/launcher/api.py", line 156, in __call__ + return launch_agent(self._config, self._entrypoint, list(args)) + ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + File "/usr/local/lib/python3.12/dist-packages/torch/distributed/launcher/api.py", line 293, in launch_agent + raise ChildFailedError( +torch.distributed.elastic.multiprocessing.errors.ChildFailedError: +============================================================ +/workspace/parameter-golf/records/track_10min_16mb/2026-04-23_QK55_4Loop_EarlyParResid_SelectiveTTT/train_gpt.py FAILED +------------------------------------------------------------ +Failures: + +------------------------------------------------------------ +Root Cause (first observed failure): +[0]: + time : 2026-04-23_03:24:46 + host : abfc9671a5d1 + rank : 0 (local_rank: 0) + exitcode : 1 (pid: 8807) + error_file: + traceback : To enable traceback see: https://pytorch.org/docs/stable/elastic/errors.html +============================================================ diff --git a/logs_pod1/gauntlet_baseline_012113.txt b/logs_pod1/gauntlet_baseline_012113.txt new file mode 100644 index 0000000000..cdb1e70021 --- /dev/null +++ b/logs_pod1/gauntlet_baseline_012113.txt @@ -0,0 +1,1215 @@ +""" +The `train_gpt.py` and `train_gpt_mlx.py` scripts are intended as good launching-off points for new participants, not SOTA configs. We'll accept PRs that tune, improve, or simplify these scripts without significantly increasing complexity, but competitive submissions should stay in the `/records` folder. + +Hard stop: To keep readable for newcomers, let's make sure `train_gpt.py` and `train_gpt_mlx.py` never are longer than 1500 lines. +""" + +from __future__ import annotations + +import copy +import glob +import io +import math +import os +import random +import subprocess +import sys +import time +import uuid +import zlib +from pathlib import Path + +import numpy as np +import sentencepiece as spm +import torch +import torch.distributed as dist +import torch.nn.functional as F +from torch import Tensor, nn +from torch.nn.parallel import DistributedDataParallel as DDP + +# ----------------------------- +# HYPERPARAMETERS +# ----------------------------- +# Default Simple Baseline run: +# - 9 transformer blocks at width 512 +# - 8 attention heads with 4 KV heads (GQA) and 2x MLP expansion +# - vocab size 1024, sequence length 1024, tied embeddings +# - 524,288 train tokens per step for 20,000 iterations with a ~10 minute cap + +class Hyperparameters: + # Data paths are shard globs produced by the existing preprocessing pipeline. + data_path = os.environ.get("DATA_PATH", "./data/datasets/fineweb10B_sp1024") + train_files = os.path.join(data_path, "fineweb_train_*.bin") + val_files = os.path.join(data_path, "fineweb_val_*.bin") + tokenizer_path = os.environ.get("TOKENIZER_PATH", "./data/tokenizers/fineweb_1024_bpe.model") + run_id = os.environ.get("RUN_ID", str(uuid.uuid4())) + seed = int(os.environ.get("SEED", 1337)) + + # Validation cadence and batch size. Validation always uses the full fineweb_val split. + val_batch_size = int(os.environ.get("VAL_BATCH_SIZE", 524_288)) + val_loss_every = int(os.environ.get("VAL_LOSS_EVERY", 1000)) + train_log_every = int(os.environ.get("TRAIN_LOG_EVERY", 200)) + + # Training length. + iterations = int(os.environ.get("ITERATIONS", 20000)) + warmdown_iters = int(os.environ.get("WARMDOWN_ITERS", 1200)) + warmup_steps = int(os.environ.get("WARMUP_STEPS", 20)) + train_batch_tokens = int(os.environ.get("TRAIN_BATCH_TOKENS", 524_288)) + train_seq_len = int(os.environ.get("TRAIN_SEQ_LEN", 1024)) + max_wallclock_seconds = float(os.environ.get("MAX_WALLCLOCK_SECONDS", 600.0)) + qk_gain_init = float(os.environ.get("QK_GAIN_INIT", 1.5)) + + # Model shape. + vocab_size = int(os.environ.get("VOCAB_SIZE", 1024)) + num_layers = int(os.environ.get("NUM_LAYERS", 9)) + num_kv_heads = int(os.environ.get("NUM_KV_HEADS", 4)) + model_dim = int(os.environ.get("MODEL_DIM", 512)) + num_heads = int(os.environ.get("NUM_HEADS", 8)) + mlp_mult = int(os.environ.get("MLP_MULT", 2)) + tie_embeddings = bool(int(os.environ.get("TIE_EMBEDDINGS", "1"))) + rope_base = float(os.environ.get("ROPE_BASE", 10000.0)) + logit_softcap = float(os.environ.get("LOGIT_SOFTCAP", 30.0)) + + # Optimizer hyperparameters. + embed_lr = float(os.environ.get("EMBED_LR", 0.6)) + head_lr = float(os.environ.get("HEAD_LR", 0.008)) + tied_embed_lr = float(os.environ.get("TIED_EMBED_LR", 0.05)) + tied_embed_init_std = float(os.environ.get("TIED_EMBED_INIT_STD", 0.005)) + matrix_lr = float(os.environ.get("MATRIX_LR", 0.04)) + scalar_lr = float(os.environ.get("SCALAR_LR", 0.04)) + muon_momentum = float(os.environ.get("MUON_MOMENTUM", 0.95)) + muon_backend_steps = int(os.environ.get("MUON_BACKEND_STEPS", 5)) + muon_momentum_warmup_start = float(os.environ.get("MUON_MOMENTUM_WARMUP_START", 0.85)) + muon_momentum_warmup_steps = int(os.environ.get("MUON_MOMENTUM_WARMUP_STEPS", 500)) + beta1 = float(os.environ.get("BETA1", 0.9)) + beta2 = float(os.environ.get("BETA2", 0.95)) + adam_eps = float(os.environ.get("ADAM_EPS", 1e-8)) + grad_clip_norm = float(os.environ.get("GRAD_CLIP_NORM", 0.0)) + +# ----------------------------- +# MUON OPTIMIZER +# ----------------------------- +# +# As borrowed from modded-nanogpt +# Background on Muon: https://kellerjordan.github.io/posts/muon/ + +def zeropower_via_newtonschulz5(G: Tensor, steps: int = 10, eps: float = 1e-7) -> Tensor: + # Orthogonalize a 2D update matrix with a fast Newton-Schulz iteration. + # Muon uses this to normalize matrix-shaped gradients before applying them. + a, b, c = (3.4445, -4.7750, 2.0315) + X = G.bfloat16() + X /= X.norm() + eps + transposed = G.size(0) > G.size(1) + if transposed: + X = X.T + for _ in range(steps): + A = X @ X.T + B = b * A + c * A @ A + X = a * X + B @ X + return X.T if transposed else X + + +class Muon(torch.optim.Optimizer): + def __init__(self, params, lr: float, momentum: float, backend_steps: int, nesterov: bool = True): + super().__init__( + params, + dict(lr=lr, momentum=momentum, backend_steps=backend_steps, nesterov=nesterov), + ) + + @torch.no_grad() + def step(self, closure=None): + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + + distributed = dist.is_available() and dist.is_initialized() + world_size = dist.get_world_size() if distributed else 1 + rank = dist.get_rank() if distributed else 0 + + for group in self.param_groups: + params = group["params"] + if not params: + continue + lr = group["lr"] + momentum = group["momentum"] + backend_steps = group["backend_steps"] + nesterov = group["nesterov"] + + total_params = sum(int(p.numel()) for p in params) + updates_flat = torch.zeros(total_params, device=params[0].device, dtype=torch.bfloat16) + + curr = 0 + for i, p in enumerate(params): + if i % world_size == rank and p.grad is not None: + g = p.grad + state = self.state[p] + if "momentum_buffer" not in state: + state["momentum_buffer"] = torch.zeros_like(g) + buf = state["momentum_buffer"] + buf.mul_(momentum).add_(g) + if nesterov: + g = g.add(buf, alpha=momentum) + g = zeropower_via_newtonschulz5(g, steps=backend_steps) + # Scale correction from Muon reference implementations. + g *= max(1, g.size(0) / g.size(1)) ** 0.5 + updates_flat[curr : curr + p.numel()] = g.reshape(-1) + curr += p.numel() + + if distributed: + dist.all_reduce(updates_flat, op=dist.ReduceOp.SUM) + + curr = 0 + for p in params: + g = updates_flat[curr : curr + p.numel()].view_as(p).to(dtype=p.dtype) + p.add_(g, alpha=-lr) + curr += p.numel() + + return loss + + +# ----------------------------- +# TOKENIZER-AGNOSTIC EVALUATION SETUP +# ----------------------------- +# +# It's common for small models have a large fraction of their parameters be embeddings, since the 2 * d_model * d_vocab vectors can be gigantic. +# Instead of locking the tokenizer, we let you bring your own and calculate our validation metrics on the average compression of the validation set. +# We calculate BPB (bits-per-byte) instead of validation loss, so we need methods to count the number of bits per token in the tokenizer. +# Note: Submissions that edit the tokenizer will be examined more carefully, since screwing this up might unjustly improve your score. + +def build_sentencepiece_luts( + sp: spm.SentencePieceProcessor, vocab_size: int, device: torch.device +) -> tuple[Tensor, Tensor, Tensor]: + sp_vocab_size = int(sp.vocab_size()) + table_size = max(sp_vocab_size, vocab_size) + base_bytes_np = np.zeros((table_size,), dtype=np.int16) + has_leading_space_np = np.zeros((table_size,), dtype=np.bool_) + is_boundary_token_np = np.ones((table_size,), dtype=np.bool_) + for token_id in range(sp_vocab_size): + if sp.is_control(token_id) or sp.is_unknown(token_id) or sp.is_unused(token_id): + continue + is_boundary_token_np[token_id] = False + if sp.is_byte(token_id): + base_bytes_np[token_id] = 1 + continue + piece = sp.id_to_piece(token_id) + if piece.startswith("▁"): + has_leading_space_np[token_id] = True + piece = piece[1:] + base_bytes_np[token_id] = len(piece.encode("utf-8")) + return ( + torch.tensor(base_bytes_np, dtype=torch.int16, device=device), + torch.tensor(has_leading_space_np, dtype=torch.bool, device=device), + torch.tensor(is_boundary_token_np, dtype=torch.bool, device=device), + ) + + +def load_validation_tokens(pattern: str, seq_len: int) -> Tensor: + files = [Path(p) for p in sorted(glob.glob(pattern))] + if not files: + raise FileNotFoundError(f"No files found for pattern: {pattern}") + # The export pipeline writes the fixed first-50k-doc validation set to fineweb_val_*. + tokens = torch.cat([load_data_shard(file) for file in files]).contiguous() + usable = ((tokens.numel() - 1) // seq_len) * seq_len + if usable <= 0: + raise ValueError(f"Validation split is too short for TRAIN_SEQ_LEN={seq_len}") + return tokens[: usable + 1] + + +def eval_val( + args: Hyperparameters, + model: nn.Module, + rank: int, + world_size: int, + device: torch.device, + grad_accum_steps: int, + val_tokens: Tensor, + base_bytes_lut: Tensor, + has_leading_space_lut: Tensor, + is_boundary_token_lut: Tensor, +) -> tuple[float, float]: + # Validation computes two metrics: + # - val_loss: token cross-entropy (natural log) + # - val_bpb: tokenizer-agnostic compression metric used by the challenge + local_batch_tokens = args.val_batch_size // (world_size * grad_accum_steps) + if local_batch_tokens < args.train_seq_len: + raise ValueError( + "VAL_BATCH_SIZE must provide at least one sequence per rank; " + f"got VAL_BATCH_SIZE={args.val_batch_size}, WORLD_SIZE={world_size}, " + f"GRAD_ACCUM_STEPS={grad_accum_steps}, TRAIN_SEQ_LEN={args.train_seq_len}" + ) + local_batch_seqs = local_batch_tokens // args.train_seq_len + total_seqs = (val_tokens.numel() - 1) // args.train_seq_len + seq_start = (total_seqs * rank) // world_size + seq_end = (total_seqs * (rank + 1)) // world_size + val_loss_sum = torch.zeros((), device=device, dtype=torch.float64) + val_token_count = torch.zeros((), device=device, dtype=torch.float64) + val_byte_count = torch.zeros((), device=device, dtype=torch.float64) + + model.eval() + with torch.inference_mode(): + for batch_seq_start in range(seq_start, seq_end, local_batch_seqs): + batch_seq_end = min(batch_seq_start + local_batch_seqs, seq_end) + raw_start = batch_seq_start * args.train_seq_len + raw_end = batch_seq_end * args.train_seq_len + 1 + local = val_tokens[raw_start:raw_end].to(device=device, dtype=torch.int64, non_blocking=True) + x = local[:-1].reshape(-1, args.train_seq_len) + y = local[1:].reshape(-1, args.train_seq_len) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + batch_loss = model(x, y).detach() + batch_token_count = float(y.numel()) + val_loss_sum += batch_loss.to(torch.float64) * batch_token_count + val_token_count += batch_token_count + prev_ids = x.reshape(-1) + tgt_ids = y.reshape(-1) + token_bytes = base_bytes_lut[tgt_ids].to(dtype=torch.int16) + token_bytes += (has_leading_space_lut[tgt_ids] & ~is_boundary_token_lut[prev_ids]).to(dtype=torch.int16) + val_byte_count += token_bytes.to(torch.float64).sum() + + if dist.is_available() and dist.is_initialized(): + dist.all_reduce(val_loss_sum, op=dist.ReduceOp.SUM) + dist.all_reduce(val_token_count, op=dist.ReduceOp.SUM) + dist.all_reduce(val_byte_count, op=dist.ReduceOp.SUM) + + val_loss = val_loss_sum / val_token_count + bits_per_token = val_loss.item() / math.log(2.0) + tokens_per_byte = val_token_count.item() / val_byte_count.item() + model.train() + return float(val_loss.item()), float(bits_per_token * tokens_per_byte) + +# ----------------------------- +# POST-TRAINING QUANTIZATION +# ----------------------------- +# +# It's silly to export our model, which is trained in bf16 and fp32, at that same precision. +# Instead, we get approximately the same model (with a small hit) by quantizing the model to int8 & zlib compressing. +# We can then decompress the model and run in higher precision for evaluation, after closing in under the size limit. + +CONTROL_TENSOR_NAME_PATTERNS = tuple( + pattern + for pattern in os.environ.get( + "CONTROL_TENSOR_NAME_PATTERNS", + "attn_scale,attn_scales,mlp_scale,mlp_scales,resid_mix,resid_mixes,q_gain,skip_weight,skip_weights", + ).split(",") + if pattern +) +INT8_KEEP_FLOAT_FP32_NAME_PATTERNS = tuple( + pattern + for pattern in os.environ.get( + "INT8_KEEP_FLOAT_FP32_NAME_PATTERNS", + ",".join(CONTROL_TENSOR_NAME_PATTERNS), + ).split(",") + if pattern +) +INT8_KEEP_FLOAT_MAX_NUMEL = 65_536 +INT8_KEEP_FLOAT_STORE_DTYPE = torch.float16 +INT8_PER_ROW_SCALE_DTYPE = torch.float16 +INT8_CLIP_PERCENTILE = 99.99984 +INT8_CLIP_Q = INT8_CLIP_PERCENTILE / 100.0 + +def tensor_nbytes(t: Tensor) -> int: + return int(t.numel()) * int(t.element_size()) + +def keep_float_tensor(name: str, t: Tensor, passthrough_orig_dtypes: dict[str, str]) -> Tensor: + if any(pattern in name for pattern in INT8_KEEP_FLOAT_FP32_NAME_PATTERNS): + return t.float().contiguous() + if t.dtype in {torch.float32, torch.bfloat16}: + passthrough_orig_dtypes[name] = str(t.dtype).removeprefix("torch.") + return t.to(dtype=INT8_KEEP_FLOAT_STORE_DTYPE).contiguous() + return t + +def quantize_float_tensor(t: Tensor) -> tuple[Tensor, Tensor]: + t32 = t.float() + if t32.ndim == 2: + # Matrices get one scale per row, which usually tracks output-channel + # ranges much better than a single tensor-wide scale. + clip_abs = ( + torch.quantile(t32.abs(), INT8_CLIP_Q, dim=1) + if t32.numel() + else torch.empty((t32.shape[0],), dtype=torch.float32) + ) + clipped = torch.maximum(torch.minimum(t32, clip_abs[:, None]), -clip_abs[:, None]) + scale = (clip_abs / 127.0).clamp_min(1.0 / 127.0) + q = torch.clamp(torch.round(clipped / scale[:, None]), -127, 127).to(torch.int8).contiguous() + return q, scale.to(dtype=INT8_PER_ROW_SCALE_DTYPE).contiguous() + + # Vectors / scalars use a simpler per-tensor scale. + clip_abs = float(torch.quantile(t32.abs().flatten(), INT8_CLIP_Q).item()) if t32.numel() else 0.0 + scale = torch.tensor(clip_abs / 127.0 if clip_abs > 0 else 1.0, dtype=torch.float32) + q = torch.clamp(torch.round(torch.clamp(t32, -clip_abs, clip_abs) / scale), -127, 127).to(torch.int8).contiguous() + return q, scale + +def quantize_state_dict_int8(state_dict: dict[str, Tensor]): + # Single supported clean-script export format: + # - per-row int8 for 2D float tensors + # - per-tensor int8 for other float tensors + # - exact passthrough for non-floats + # - passthrough for small float tensors, stored as fp16 to save bytes + quantized: dict[str, Tensor] = {} + scales: dict[str, Tensor] = {} + dtypes: dict[str, str] = {} + passthrough: dict[str, Tensor] = {} + passthrough_orig_dtypes: dict[str, str] = {} + qmeta: dict[str, dict[str, object]] = {} + stats = dict.fromkeys( + ("param_count", "num_tensors", "num_float_tensors", "num_nonfloat_tensors", "baseline_tensor_bytes", "int8_payload_bytes"), + 0, + ) + + for name, tensor in state_dict.items(): + t = tensor.detach().to("cpu").contiguous() + stats["param_count"] += int(t.numel()) + stats["num_tensors"] += 1 + stats["baseline_tensor_bytes"] += tensor_nbytes(t) + + if not t.is_floating_point(): + stats["num_nonfloat_tensors"] += 1 + passthrough[name] = t + stats["int8_payload_bytes"] += tensor_nbytes(t) + continue + + # Small float tensors are cheap enough to keep directly. We still downcast + # fp32/bf16 passthrough tensors to fp16 so metadata does not dominate size. + if t.numel() <= INT8_KEEP_FLOAT_MAX_NUMEL: + kept = keep_float_tensor(name, t, passthrough_orig_dtypes) + passthrough[name] = kept + stats["int8_payload_bytes"] += tensor_nbytes(kept) + continue + + stats["num_float_tensors"] += 1 + q, s = quantize_float_tensor(t) + if s.ndim > 0: + qmeta[name] = {"scheme": "per_row", "axis": 0} + quantized[name] = q + scales[name] = s + dtypes[name] = str(t.dtype).removeprefix("torch.") + stats["int8_payload_bytes"] += tensor_nbytes(q) + tensor_nbytes(s) + + obj: dict[str, object] = { + "__quant_format__": "int8_clean_per_row_v1", + "quantized": quantized, + "scales": scales, + "dtypes": dtypes, + "passthrough": passthrough, + } + if qmeta: + obj["qmeta"] = qmeta + if passthrough_orig_dtypes: + obj["passthrough_orig_dtypes"] = passthrough_orig_dtypes + return obj, stats + +def dequantize_state_dict_int8(obj: dict[str, object]) -> dict[str, Tensor]: + out: dict[str, Tensor] = {} + qmeta = obj.get("qmeta", {}) + passthrough_orig_dtypes = obj.get("passthrough_orig_dtypes", {}) + for name, q in obj["quantized"].items(): + dtype = getattr(torch, obj["dtypes"][name]) + s = obj["scales"][name] + if qmeta.get(name, {}).get("scheme") == "per_row" or s.ndim > 0: + s = s.to(dtype=torch.float32) + # Broadcast the saved row scale back across trailing dimensions. + out[name] = (q.float() * s.view(q.shape[0], *([1] * (q.ndim - 1)))).to(dtype=dtype).contiguous() + else: + scale = float(s.item()) + out[name] = (q.float() * scale).to(dtype=dtype).contiguous() + for name, t in obj["passthrough"].items(): + # Restore small tensors, undoing the temporary fp16 storage cast if needed. + out_t = t.detach().to("cpu").contiguous() + orig_dtype = passthrough_orig_dtypes.get(name) + if isinstance(orig_dtype, str): + out_t = out_t.to(dtype=getattr(torch, orig_dtype)).contiguous() + out[name] = out_t + return out + + +# ----------------------------- +# DATA LOADING +# ----------------------------- + +def load_data_shard(file: Path) -> Tensor: + header_bytes = 256 * np.dtype(" None: + self.file_idx = (self.file_idx + 1) % len(self.files) + self.tokens = load_data_shard(self.files[self.file_idx]) + self.pos = 0 + + def take(self, n: int) -> Tensor: + chunks: list[Tensor] = [] + remaining = n + while remaining > 0: + avail = self.tokens.numel() - self.pos + if avail <= 0: + self._advance_file() + continue + k = min(remaining, avail) + chunks.append(self.tokens[self.pos : self.pos + k]) + self.pos += k + remaining -= k + return chunks[0] if len(chunks) == 1 else torch.cat(chunks) + + +class DistributedTokenLoader: + # Each call consumes a contiguous chunk from the shared token stream, then slices out + # one disjoint span per rank. The extra "+1" token lets us build (x, y) by shifting. + def __init__(self, pattern: str, rank: int, world_size: int, device: torch.device): + self.rank = rank + self.world_size = world_size + self.device = device + self.stream = TokenStream(pattern) + + def next_batch(self, global_tokens: int, seq_len: int, grad_accum_steps: int) -> tuple[Tensor, Tensor]: + local_tokens = global_tokens // (self.world_size * grad_accum_steps) + per_rank_span = local_tokens + 1 + chunk = self.stream.take(per_rank_span * self.world_size) + start = self.rank * per_rank_span + local = chunk[start : start + per_rank_span].to(dtype=torch.int64) + x = local[:-1].reshape(-1, seq_len) + y = local[1:].reshape(-1, seq_len) + return x.to(self.device, non_blocking=True), y.to(self.device, non_blocking=True) + +# ----------------------------- +# TRANSFORMER MODULES +# ----------------------------- + +class RMSNorm(nn.Module): + def __init__(self, eps: float | None = None): + super().__init__() + self.eps = eps + + def forward(self, x: Tensor) -> Tensor: + return F.rms_norm(x, (x.size(-1),), eps=self.eps) + + +class CastedLinear(nn.Linear): + # Keep weights in fp32 for optimizer/state quality, cast at matmul time for bf16 compute. + def forward(self, x: Tensor) -> Tensor: + bias = self.bias.to(x.dtype) if self.bias is not None else None + return F.linear(x, self.weight.to(x.dtype), bias) + + +def restore_low_dim_params_to_fp32(module: nn.Module) -> None: + # Keep small/control parameters in fp32 even when the model body runs in bf16. + with torch.no_grad(): + for name, param in module.named_parameters(): + if (param.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS)) and param.dtype != torch.float32: + param.data = param.data.float() + + +class Rotary(nn.Module): + # Caches cos/sin tables per sequence length on the current device. + def __init__(self, dim: int, base: float = 10000.0): + super().__init__() + inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.float32) / dim)) + self.register_buffer("inv_freq", inv_freq, persistent=False) + self._seq_len_cached = 0 + self._cos_cached: Tensor | None = None + self._sin_cached: Tensor | None = None + + def forward(self, seq_len: int, device: torch.device, dtype: torch.dtype) -> tuple[Tensor, Tensor]: + if ( + self._cos_cached is None + or self._sin_cached is None + or self._seq_len_cached != seq_len + or self._cos_cached.device != device + ): + t = torch.arange(seq_len, device=device, dtype=self.inv_freq.dtype) + freqs = torch.outer(t, self.inv_freq.to(device)) + self._cos_cached = freqs.cos()[None, None, :, :] + self._sin_cached = freqs.sin()[None, None, :, :] + self._seq_len_cached = seq_len + return self._cos_cached.to(dtype=dtype), self._sin_cached.to(dtype=dtype) + + +def apply_rotary_emb(x: Tensor, cos: Tensor, sin: Tensor) -> Tensor: + half = x.size(-1) // 2 + x1, x2 = x[..., :half], x[..., half:] + return torch.cat((x1 * cos + x2 * sin, x1 * (-sin) + x2 * cos), dim=-1) + + +class CausalSelfAttention(nn.Module): + def __init__( + self, + dim: int, + num_heads: int, + num_kv_heads: int, + rope_base: float, + qk_gain_init: float, + ): + super().__init__() + if dim % num_heads != 0: + raise ValueError("model_dim must be divisible by num_heads") + if num_heads % num_kv_heads != 0: + raise ValueError("num_heads must be divisible by num_kv_heads") + self.num_heads = num_heads + self.num_kv_heads = num_kv_heads + self.head_dim = dim // num_heads + if self.head_dim % 2 != 0: + raise ValueError("head_dim must be even for RoPE") + kv_dim = self.num_kv_heads * self.head_dim + self.c_q = CastedLinear(dim, dim, bias=False) + self.c_k = CastedLinear(dim, kv_dim, bias=False) + self.c_v = CastedLinear(dim, kv_dim, bias=False) + self.proj = CastedLinear(dim, dim, bias=False) + self.proj._zero_init = True + self.q_gain = nn.Parameter(torch.full((num_heads,), qk_gain_init, dtype=torch.float32)) + self.rotary = Rotary(self.head_dim, base=rope_base) + + def forward(self, x: Tensor) -> Tensor: + bsz, seqlen, dim = x.shape + q = self.c_q(x).reshape(bsz, seqlen, self.num_heads, self.head_dim).transpose(1, 2) + k = self.c_k(x).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim).transpose(1, 2) + v = self.c_v(x).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim).transpose(1, 2) + q = F.rms_norm(q, (q.size(-1),)) + k = F.rms_norm(k, (k.size(-1),)) + cos, sin = self.rotary(seqlen, x.device, q.dtype) + q = apply_rotary_emb(q, cos, sin) + k = apply_rotary_emb(k, cos, sin) + q = q * self.q_gain.to(dtype=q.dtype)[None, :, None, None] + y = F.scaled_dot_product_attention( + q, + k, + v, + attn_mask=None, + is_causal=True, + enable_gqa=(self.num_kv_heads != self.num_heads), + ) + y = y.transpose(1, 2).contiguous().reshape(bsz, seqlen, dim) + return self.proj(y) + + +class MLP(nn.Module): + # relu^2 MLP from the original modded-nanogpt setup + def __init__(self, dim: int, mlp_mult: int): + super().__init__() + hidden = mlp_mult * dim + self.fc = CastedLinear(dim, hidden, bias=False) + self.proj = CastedLinear(hidden, dim, bias=False) + self.proj._zero_init = True + + def forward(self, x: Tensor) -> Tensor: + x = torch.relu(self.fc(x)) + return self.proj(x.square()) + + +class Block(nn.Module): + def __init__( + self, + dim: int, + num_heads: int, + num_kv_heads: int, + mlp_mult: int, + rope_base: float, + qk_gain_init: float, + ): + super().__init__() + self.attn_norm = RMSNorm() + self.mlp_norm = RMSNorm() + self.attn = CausalSelfAttention(dim, num_heads, num_kv_heads, rope_base, qk_gain_init) + self.mlp = MLP(dim, mlp_mult) + self.attn_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + self.mlp_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + self.resid_mix = nn.Parameter(torch.stack((torch.ones(dim), torch.zeros(dim))).float()) + + def forward(self, x: Tensor, x0: Tensor) -> Tensor: + mix = self.resid_mix.to(dtype=x.dtype) + x = mix[0][None, None, :] * x + mix[1][None, None, :] * x0 + attn_out = self.attn(self.attn_norm(x)) + x = x + self.attn_scale.to(dtype=x.dtype)[None, None, :] * attn_out + x = x + self.mlp_scale.to(dtype=x.dtype)[None, None, :] * self.mlp(self.mlp_norm(x)) + return x + + +class GPT(nn.Module): + def __init__( + self, + vocab_size: int, + num_layers: int, + model_dim: int, + num_heads: int, + num_kv_heads: int, + mlp_mult: int, + tie_embeddings: bool, + tied_embed_init_std: float, + logit_softcap: float, + rope_base: float, + qk_gain_init: float, + ): + super().__init__() + if logit_softcap <= 0.0: + raise ValueError(f"logit_softcap must be positive, got {logit_softcap}") + self.tie_embeddings = tie_embeddings + self.tied_embed_init_std = tied_embed_init_std + self.logit_softcap = logit_softcap + self.tok_emb = nn.Embedding(vocab_size, model_dim) + self.num_encoder_layers = num_layers // 2 + self.num_decoder_layers = num_layers - self.num_encoder_layers + self.num_skip_weights = min(self.num_encoder_layers, self.num_decoder_layers) + self.skip_weights = nn.Parameter(torch.ones(self.num_skip_weights, model_dim, dtype=torch.float32)) + self.blocks = nn.ModuleList( + [ + Block( + model_dim, + num_heads, + num_kv_heads, + mlp_mult, + rope_base, + qk_gain_init, + ) + for i in range(num_layers) + ] + ) + self.final_norm = RMSNorm() + self.lm_head = None if tie_embeddings else CastedLinear(model_dim, vocab_size, bias=False) + if self.lm_head is not None: + self.lm_head._zero_init = True + self._init_weights() + + def _init_weights(self) -> None: + if self.tie_embeddings: + nn.init.normal_(self.tok_emb.weight, mean=0.0, std=self.tied_embed_init_std) + for module in self.modules(): + if isinstance(module, nn.Linear) and getattr(module, "_zero_init", False): + nn.init.zeros_(module.weight) + + def forward(self, input_ids: Tensor, target_ids: Tensor) -> Tensor: + x = self.tok_emb(input_ids) + x = F.rms_norm(x, (x.size(-1),)) + x0 = x + skips: list[Tensor] = [] + + # First half stores skips; second half reuses them in reverse order. + for i in range(self.num_encoder_layers): + x = self.blocks[i](x, x0) + skips.append(x) + for i in range(self.num_decoder_layers): + if skips: + x = x + self.skip_weights[i].to(dtype=x.dtype)[None, None, :] * skips.pop() + x = self.blocks[self.num_encoder_layers + i](x, x0) + + x = self.final_norm(x).reshape(-1, x.size(-1)) + targets = target_ids.reshape(-1) + if self.tie_embeddings: + logits_proj = F.linear(x, self.tok_emb.weight) + else: + if self.lm_head is None: + raise RuntimeError("lm_head is required when tie_embeddings=False") + logits_proj = self.lm_head(x) + logits = self.logit_softcap * torch.tanh(logits_proj / self.logit_softcap) + return F.cross_entropy(logits.float(), targets, reduction="mean") + + +# ----------------------------- +# TRAINING +# ----------------------------- + +def main() -> None: + global zeropower_via_newtonschulz5 + + code = Path(__file__).read_text(encoding="utf-8") + args = Hyperparameters() + zeropower_via_newtonschulz5 = torch.compile(zeropower_via_newtonschulz5) + + # ----------------------------- + # DISTRIBUTED + CUDA SETUP + # ----------------------------- + + distributed = "RANK" in os.environ and "WORLD_SIZE" in os.environ + rank = int(os.environ.get("RANK", "0")) + world_size = int(os.environ.get("WORLD_SIZE", "1")) + local_rank = int(os.environ.get("LOCAL_RANK", "0")) + if world_size <= 0: + raise ValueError(f"WORLD_SIZE must be positive, got {world_size}") + if 8 % world_size != 0: + raise ValueError(f"WORLD_SIZE={world_size} must divide 8 so grad_accum_steps stays integral") + grad_accum_steps = 8 // world_size + grad_scale = 1.0 / grad_accum_steps + if not torch.cuda.is_available(): + raise RuntimeError("CUDA is required") + device = torch.device("cuda", local_rank) + torch.cuda.set_device(device) + if distributed: + dist.init_process_group(backend="nccl", device_id=device) + dist.barrier() + master_process = rank == 0 + + # Fast math knobs + torch.backends.cuda.matmul.allow_tf32 = True + torch.backends.cudnn.allow_tf32 = True + from torch.backends.cuda import enable_cudnn_sdp, enable_flash_sdp, enable_math_sdp, enable_mem_efficient_sdp + + enable_cudnn_sdp(False) + enable_flash_sdp(True) + enable_mem_efficient_sdp(False) + enable_math_sdp(False) + + logfile = None + if master_process: + os.makedirs("logs", exist_ok=True) + logfile = f"logs/{args.run_id}.txt" + print(logfile) + + def log0(msg: str, console: bool = True) -> None: + if not master_process: + return + if console: + print(msg) + if logfile is not None: + with open(logfile, "a", encoding="utf-8") as f: + print(msg, file=f) + + log0(code, console=False) + log0("=" * 100, console=False) + log0(f"Running Python {sys.version}", console=False) + log0(f"Running PyTorch {torch.__version__}", console=False) + log0( + subprocess.run(["nvidia-smi"], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True, check=False).stdout, + console=False, + ) + log0("=" * 100, console=False) + + # ----------------------------- + # TOKENIZER + VALIDATION METRIC SETUP + # ----------------------------- + + random.seed(args.seed) + np.random.seed(args.seed) + torch.manual_seed(args.seed) + torch.cuda.manual_seed_all(args.seed) + + if not args.tokenizer_path.endswith(".model"): + raise ValueError(f"Script only setup for SentencePiece .model file: {args.tokenizer_path}") + sp = spm.SentencePieceProcessor(model_file=args.tokenizer_path) + if int(sp.vocab_size()) != args.vocab_size: + raise ValueError( + f"VOCAB_SIZE={args.vocab_size} does not match tokenizer vocab_size={int(sp.vocab_size())}" + ) + dataset_dir = Path(args.data_path).resolve() + actual_train_files = len(list(dataset_dir.glob("fineweb_train_*.bin"))) + val_tokens = load_validation_tokens(args.val_files, args.train_seq_len) + base_bytes_lut, has_leading_space_lut, is_boundary_token_lut = build_sentencepiece_luts( + sp, args.vocab_size, device + ) + log0(f"val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path={args.tokenizer_path}") + log0(f"train_loader:dataset:{dataset_dir.name} train_shards:{actual_train_files}") + log0(f"val_loader:shards pattern={args.val_files} tokens:{val_tokens.numel() - 1}") + + # ----------------------------- + # MODEL + OPTIMIZER SETUP + # ----------------------------- + + base_model = GPT( + vocab_size=args.vocab_size, + num_layers=args.num_layers, + model_dim=args.model_dim, + num_heads=args.num_heads, + num_kv_heads=args.num_kv_heads, + mlp_mult=args.mlp_mult, + tie_embeddings=args.tie_embeddings, + tied_embed_init_std=args.tied_embed_init_std, + logit_softcap=args.logit_softcap, + rope_base=args.rope_base, + qk_gain_init=args.qk_gain_init, + ).to(device).bfloat16() + for module in base_model.modules(): + if isinstance(module, CastedLinear): + module.float() + restore_low_dim_params_to_fp32(base_model) + compiled_model = torch.compile(base_model, dynamic=False, fullgraph=True) + model: nn.Module = DDP(compiled_model, device_ids=[local_rank], broadcast_buffers=False) if distributed else compiled_model + + # Optimizer split: + # - token embedding (Adam) uses EMBED_LR + # - untied lm_head (Adam) uses HEAD_LR + # - matrix params in transformer blocks use MATRIX_LR via Muon + # - vectors/scalars use SCALAR_LR via Adam + block_named_params = list(base_model.blocks.named_parameters()) + matrix_params = [ + p + for name, p in block_named_params + if p.ndim == 2 and not any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) + ] + scalar_params = [ + p + for name, p in block_named_params + if p.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) + ] + if base_model.skip_weights.numel() > 0: + scalar_params.append(base_model.skip_weights) + token_lr = args.tied_embed_lr if args.tie_embeddings else args.embed_lr + optimizer_tok = torch.optim.Adam( + [{"params": [base_model.tok_emb.weight], "lr": token_lr, "base_lr": token_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + fused=True, + ) + optimizer_muon = Muon( + matrix_params, + lr=args.matrix_lr, + momentum=args.muon_momentum, + backend_steps=args.muon_backend_steps, + ) + for group in optimizer_muon.param_groups: + group["base_lr"] = args.matrix_lr + optimizer_scalar = torch.optim.Adam( + [{"params": scalar_params, "lr": args.scalar_lr, "base_lr": args.scalar_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + fused=True, + ) + optimizers: list[torch.optim.Optimizer] = [optimizer_tok, optimizer_muon, optimizer_scalar] + if base_model.lm_head is not None: + optimizer_head = torch.optim.Adam( + [{"params": [base_model.lm_head.weight], "lr": args.head_lr, "base_lr": args.head_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + fused=True, + ) + optimizers.insert(1, optimizer_head) + + n_params = sum(p.numel() for p in base_model.parameters()) + log0(f"model_params:{n_params}") + log0(f"world_size:{world_size} grad_accum_steps:{grad_accum_steps}") + log0("sdp_backends:cudnn=False flash=True mem_efficient=False math=False") + log0(f"attention_mode:gqa num_heads:{args.num_heads} num_kv_heads:{args.num_kv_heads}") + log0( + f"tie_embeddings:{args.tie_embeddings} embed_lr:{token_lr} " + f"head_lr:{args.head_lr if base_model.lm_head is not None else 0.0} " + f"matrix_lr:{args.matrix_lr} scalar_lr:{args.scalar_lr}" + ) + log0( + f"train_batch_tokens:{args.train_batch_tokens} train_seq_len:{args.train_seq_len} " + f"iterations:{args.iterations} warmup_steps:{args.warmup_steps} " + f"max_wallclock_seconds:{args.max_wallclock_seconds:.3f}" + ) + log0(f"seed:{args.seed}") + + # ----------------------------- + # DATA LOADER & MODEL WARMUP + # ----------------------------- + + train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) + + def zero_grad_all() -> None: + for opt in optimizers: + opt.zero_grad(set_to_none=True) + + max_wallclock_ms = 1000.0 * args.max_wallclock_seconds if args.max_wallclock_seconds > 0 else None + + def lr_mul(step: int, elapsed_ms: float) -> float: + if args.warmdown_iters <= 0: + return 1.0 + if max_wallclock_ms is None: + warmdown_start = max(args.iterations - args.warmdown_iters, 0) + return max((args.iterations - step) / max(args.warmdown_iters, 1), 0.0) if warmdown_start <= step < args.iterations else 1.0 + step_ms = elapsed_ms / max(step, 1) + warmdown_ms = args.warmdown_iters * step_ms + remaining_ms = max(max_wallclock_ms - elapsed_ms, 0.0) + return remaining_ms / max(warmdown_ms, 1e-9) if remaining_ms <= warmdown_ms else 1.0 + + # Warmup primes the compiled forward/backward/optimizer paths, then we restore the + # initial weights/optimizer state so measured training starts from the true init. + if args.warmup_steps > 0: + initial_model_state = {name: tensor.detach().cpu().clone() for name, tensor in base_model.state_dict().items()} + initial_optimizer_states = [copy.deepcopy(opt.state_dict()) for opt in optimizers] + model.train() + for warmup_step in range(args.warmup_steps): + zero_grad_all() + for micro_step in range(grad_accum_steps): + if distributed: + model.require_backward_grad_sync = micro_step == grad_accum_steps - 1 + x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + warmup_loss = model(x, y) + (warmup_loss * grad_scale).backward() + for opt in optimizers: + opt.step() + zero_grad_all() + if args.warmup_steps <= 20 or (warmup_step + 1) % 10 == 0 or warmup_step + 1 == args.warmup_steps: + log0(f"warmup_step:{warmup_step + 1}/{args.warmup_steps}") + base_model.load_state_dict(initial_model_state, strict=True) + for opt, state in zip(optimizers, initial_optimizer_states, strict=True): + opt.load_state_dict(state) + zero_grad_all() + if distributed: + model.require_backward_grad_sync = True + train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) + + # ----------------------------- + # MAIN TRAINING LOOP + # ----------------------------- + + training_time_ms = 0.0 + stop_after_step: int | None = None + torch.cuda.synchronize() + t0 = time.perf_counter() + + step = 0 + while True: + last_step = step == args.iterations or (stop_after_step is not None and step >= stop_after_step) + + should_validate = last_step or (args.val_loss_every > 0 and step % args.val_loss_every == 0) + if should_validate: + torch.cuda.synchronize() + training_time_ms += 1000.0 * (time.perf_counter() - t0) + val_loss, val_bpb = eval_val( + args, + model, + rank, + world_size, + device, + grad_accum_steps, + val_tokens, + base_bytes_lut, + has_leading_space_lut, + is_boundary_token_lut, + ) + log0( + f"step:{step}/{args.iterations} val_loss:{val_loss:.4f} val_bpb:{val_bpb:.4f} " + f"train_time:{training_time_ms:.0f}ms step_avg:{training_time_ms / max(step, 1):.2f}ms" + ) + torch.cuda.synchronize() + t0 = time.perf_counter() + + if last_step: + if stop_after_step is not None and step < args.iterations: + log0( + f"stopping_early: wallclock_cap train_time:{training_time_ms:.0f}ms " + f"step:{step}/{args.iterations}" + ) + break + + elapsed_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) + scale = lr_mul(step, elapsed_ms) + zero_grad_all() + train_loss = torch.zeros((), device=device) + for micro_step in range(grad_accum_steps): + if distributed: + model.require_backward_grad_sync = micro_step == grad_accum_steps - 1 + x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + loss = model(x, y) + train_loss += loss.detach() + (loss * grad_scale).backward() + train_loss /= grad_accum_steps + + frac = min(step / args.muon_momentum_warmup_steps, 1.0) if args.muon_momentum_warmup_steps > 0 else 1.0 + muon_momentum = (1 - frac) * args.muon_momentum_warmup_start + frac * args.muon_momentum + for group in optimizer_muon.param_groups: + group["momentum"] = muon_momentum + + for opt in optimizers: + for group in opt.param_groups: + group["lr"] = group["base_lr"] * scale + + if args.grad_clip_norm > 0: + torch.nn.utils.clip_grad_norm_(base_model.parameters(), args.grad_clip_norm) + for opt in optimizers: + opt.step() + zero_grad_all() + + step += 1 + approx_training_time_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) + should_log_train = ( + args.train_log_every > 0 + and (step <= 10 or step % args.train_log_every == 0 or stop_after_step is not None) + ) + if should_log_train: + log0( + f"step:{step}/{args.iterations} train_loss:{train_loss.item():.4f} " + f"train_time:{approx_training_time_ms:.0f}ms step_avg:{approx_training_time_ms / step:.2f}ms" + ) + + # Needed to sync whether we've reached the wallclock cap. + reached_cap = max_wallclock_ms is not None and approx_training_time_ms >= max_wallclock_ms + if distributed and max_wallclock_ms is not None: + reached_cap_tensor = torch.tensor(int(reached_cap), device=device) + dist.all_reduce(reached_cap_tensor, op=dist.ReduceOp.MAX) + reached_cap = bool(reached_cap_tensor.item()) + if stop_after_step is None and reached_cap: + stop_after_step = step + + log0( + f"peak memory allocated: {torch.cuda.max_memory_allocated() // 1024 // 1024} MiB " + f"reserved: {torch.cuda.max_memory_reserved() // 1024 // 1024} MiB" + ) + + # ----------------------------- + # SERIALIZATION + ROUNDTRIP VALIDATION + # ----------------------------- + # Save the raw state (useful for debugging/loading in PyTorch directly), then always produce + # the compressed int8+zlib artifact and validate the round-tripped weights. + + if master_process: + torch.save(base_model.state_dict(), "final_model.pt") + model_bytes = os.path.getsize("final_model.pt") + code_bytes = len(code.encode("utf-8")) + log0(f"Serialized model: {model_bytes} bytes") + log0(f"Code size: {code_bytes} bytes") + log0(f"Total submission size: {model_bytes + code_bytes} bytes") + + quant_obj, quant_stats = quantize_state_dict_int8(base_model.state_dict()) + quant_buf = io.BytesIO() + torch.save(quant_obj, quant_buf) + quant_raw = quant_buf.getvalue() + quant_blob = zlib.compress(quant_raw, level=9) + quant_raw_bytes = len(quant_raw) + if master_process: + with open("final_model.int8.ptz", "wb") as f: + f.write(quant_blob) + quant_file_bytes = os.path.getsize("final_model.int8.ptz") + code_bytes = len(code.encode("utf-8")) + ratio = quant_stats["baseline_tensor_bytes"] / max(quant_stats["int8_payload_bytes"], 1) + log0( + f"Serialized model int8+zlib: {quant_file_bytes} bytes " + f"(payload:{quant_stats['int8_payload_bytes']} raw_torch:{quant_raw_bytes} payload_ratio:{ratio:.2f}x)" + ) + log0(f"Total submission size int8+zlib: {quant_file_bytes + code_bytes} bytes") + + if distributed: + dist.barrier() + with open("final_model.int8.ptz", "rb") as f: + quant_blob_disk = f.read() + quant_state = torch.load(io.BytesIO(zlib.decompress(quant_blob_disk)), map_location="cpu") + base_model.load_state_dict(dequantize_state_dict_int8(quant_state), strict=True) + torch.cuda.synchronize() + t_qeval = time.perf_counter() + q_val_loss, q_val_bpb = eval_val( + args, + model, + rank, + world_size, + device, + grad_accum_steps, + val_tokens, + base_bytes_lut, + has_leading_space_lut, + is_boundary_token_lut, + ) + torch.cuda.synchronize() + log0( + f"final_int8_zlib_roundtrip val_loss:{q_val_loss:.4f} val_bpb:{q_val_bpb:.4f} " + f"eval_time:{1000.0 * (time.perf_counter() - t_qeval):.0f}ms" + ) + log0(f"final_int8_zlib_roundtrip_exact val_loss:{q_val_loss:.8f} val_bpb:{q_val_bpb:.8f}") + + if distributed: + dist.destroy_process_group() + + +if __name__ == "__main__": + main() + +==================================================================================================== +Running Python 3.12.3 (main, Nov 6 2025, 13:44:16) [GCC 13.3.0] +Running PyTorch 2.9.1+cu128 +Thu Apr 23 01:21:18 2026 ++-----------------------------------------------------------------------------------------+ +| NVIDIA-SMI 580.126.09 Driver Version: 580.126.09 CUDA Version: 13.0 | ++-----------------------------------------+------------------------+----------------------+ +| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC | +| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. | +| | | MIG M. | +|=========================================+========================+======================| +| 0 NVIDIA H100 80GB HBM3 On | 00000000:87:00.0 Off | 0 | +| N/A 35C P0 92W / 700W | 1185MiB / 81559MiB | 1% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ + ++-----------------------------------------------------------------------------------------+ +| Processes: | +| GPU GI CI PID Type Process name GPU Memory | +| ID ID Usage | +|=========================================================================================| +| 0 N/A N/A 3696 C /usr/local/bin/python 1176MiB | ++-----------------------------------------------------------------------------------------+ + +==================================================================================================== +val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path=./data/tokenizers/fineweb_1024_bpe.model +train_loader:dataset:fineweb10B_sp1024 train_shards:5 +val_loader:shards pattern=./data/datasets/fineweb10B_sp1024/fineweb_val_*.bin tokens:62021632 +model_params:17059912 +world_size:1 grad_accum_steps:8 +sdp_backends:cudnn=False flash=True mem_efficient=False math=False +attention_mode:gqa num_heads:8 num_kv_heads:4 +tie_embeddings:True embed_lr:0.05 head_lr:0.0 matrix_lr:0.04 scalar_lr:0.04 +train_batch_tokens:524288 train_seq_len:1024 iterations:2000 warmup_steps:20 max_wallclock_seconds:0.000 +seed:1337 +warmup_step:1/20 +warmup_step:2/20 +warmup_step:3/20 +warmup_step:4/20 +warmup_step:5/20 +warmup_step:6/20 +warmup_step:7/20 +warmup_step:8/20 +warmup_step:9/20 +warmup_step:10/20 +warmup_step:11/20 +warmup_step:12/20 +warmup_step:13/20 +warmup_step:14/20 +warmup_step:15/20 +warmup_step:16/20 +warmup_step:17/20 +warmup_step:18/20 +warmup_step:19/20 +warmup_step:20/20 +step:0/2000 val_loss:6.9357 val_bpb:4.1077 train_time:0ms step_avg:0.01ms +step:1/2000 train_loss:6.9357 train_time:471ms step_avg:471.35ms +step:2/2000 train_loss:16.7414 train_time:977ms step_avg:488.32ms +step:3/2000 train_loss:8.7524 train_time:1494ms step_avg:497.98ms +step:4/2000 train_loss:6.5885 train_time:2021ms step_avg:505.21ms +step:5/2000 train_loss:6.6522 train_time:2572ms step_avg:514.48ms +step:6/2000 train_loss:6.6999 train_time:3075ms step_avg:512.47ms +step:7/2000 train_loss:6.3335 train_time:3603ms step_avg:514.65ms +step:8/2000 train_loss:6.1811 train_time:4162ms step_avg:520.30ms +step:9/2000 train_loss:6.0837 train_time:4645ms step_avg:516.09ms +step:10/2000 train_loss:5.9956 train_time:5172ms step_avg:517.15ms +step:200/2000 train_loss:2.7943 train_time:105274ms step_avg:526.37ms +step:400/2000 train_loss:2.3993 train_time:210577ms step_avg:526.44ms +step:500/2000 val_loss:2.5003 val_bpb:1.4808 train_time:263172ms step_avg:526.34ms +step:600/2000 train_loss:2.4985 train_time:316018ms step_avg:526.70ms +step:800/2000 train_loss:2.3520 train_time:421852ms step_avg:527.31ms +step:1000/2000 train_loss:2.4717 train_time:527119ms step_avg:527.12ms +step:1000/2000 val_loss:2.3214 val_bpb:1.3748 train_time:527120ms step_avg:527.12ms +step:1200/2000 train_loss:2.2456 train_time:632752ms step_avg:527.29ms +step:1400/2000 train_loss:2.1673 train_time:739270ms step_avg:528.05ms +step:1500/2000 val_loss:2.2396 val_bpb:1.3264 train_time:791889ms step_avg:527.93ms +step:1600/2000 train_loss:2.1944 train_time:845076ms step_avg:528.17ms +step:1800/2000 train_loss:2.1528 train_time:955409ms step_avg:530.78ms +step:2000/2000 train_loss:2.2258 train_time:1066704ms step_avg:533.35ms +step:2000/2000 val_loss:2.1889 val_bpb:1.2964 train_time:1066705ms step_avg:533.35ms +peak memory allocated: 10238 MiB reserved: 10730 MiB +Serialized model: 67224983 bytes +Code size: 47686 bytes +Total submission size: 67272669 bytes +Serialized model int8+zlib: 14961007 bytes (payload:17178912 raw_torch:17224025 payload_ratio:3.91x) +Total submission size int8+zlib: 15008693 bytes +final_int8_zlib_roundtrip val_loss:2.1912 val_bpb:1.2977 eval_time:11181ms +final_int8_zlib_roundtrip_exact val_loss:2.19117971 val_bpb:1.29774005 diff --git a/logs_pod1/gauntlet_incr.log b/logs_pod1/gauntlet_incr.log new file mode 100644 index 0000000000..c82132c126 --- /dev/null +++ b/logs_pod1/gauntlet_incr.log @@ -0,0 +1,206 @@ +======================================== + GAUNTLET TEST — Thu Apr 23 01:52:27 UTC 2026 + steps=2000 vocab=sp1024 gpus=1 +======================================== + +Data: /workspace/parameter-golf/data/datasets/fineweb10B_sp1024 ✓ +Tokenizer: /workspace/parameter-golf/data/tokenizers/fineweb_1024_bpe.model ✓ + + +━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ + ▶ incr_QK55_4loop +━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ +Hyperparameters: + adam_eps: 1e-08 + adam_wd: 0.02 + beta1: 0.9 + beta2: 0.95 + compressor: brotli + data_dir: /workspace/parameter-golf/data + datasets_dir: /workspace/parameter-golf/data/datasets/fineweb10B_sp1024 + distributed: True + ema_decay: 0.9965 + embed_bits: 8 + embed_clip_sigmas: 20.0 + embed_lr: 0.6 + embed_wd: 0.085 + embedding_dim: 512 + enable_looping_at: 0.35 + eval_seq_len: 2048 + eval_stride: 64 + gptq_calibration_batches: 64 + gptq_reserve_seconds: 12.0 + grad_accum_steps: 8 + grad_clip_norm: 0.3 + head_lr: 0.008 + is_main_process: True + iterations: 2000 + ln_scale: True + local_rank: 0 + logfile: logs/gauntlet_incr_QK55_4loop_015227.txt + logit_softcap: 30.0 + loop_end: 5 + loop_start: 3 + matrix_bits: 6 + matrix_clip_sigmas: 12.85 + matrix_lr: 0.022 + max_wallclock_seconds: 0.0 + min_lr: 0.0 + mlp_mult: 4.0 + model_dim: 512 + model_path: final_model.pt + muon_backend_steps: 5 + muon_beta2: 0.95 + muon_momentum: 0.99 + muon_momentum_warmup_start: 0.92 + muon_momentum_warmup_steps: 1500 + muon_row_normalize: True + muon_wd: 0.095 + num_heads: 8 + num_kv_heads: 4 + num_layers: 11 + num_loops: 3 + parallel_residual_start: 5 + qk_gain_init: 5.5 + quantized_model_path: final_model.int6.ptz + rank: 0 + rope_base: 10000.0 + rope_dims: 16 + rope_train_seq_len: 2048 + run_id: gauntlet_incr_QK55_4loop_015227 + scalar_lr: 0.02 + seed: 1337 + skip_gates_enabled: True + sliding_window_enabled: True + tie_embeddings: True + tied_embed_init_std: 0.005 + tied_embed_lr: 0.03 + tokenizer_path: /workspace/parameter-golf/data/tokenizers/fineweb_1024_bpe.model + train_batch_tokens: 786432 + train_files: /workspace/parameter-golf/data/datasets/fineweb10B_sp1024/fineweb_train_*.bin + train_log_every: 500 + train_seq_len: 2048 + ttt_chunk_tokens: 24576 + ttt_enabled: False + ttt_epochs: 4 + ttt_lr: 0.005 + ttt_momentum: 0.9 + ttt_selective: True + val_batch_tokens: 524288 + val_files: /workspace/parameter-golf/data/datasets/fineweb10B_sp1024/fineweb_val_*.bin + val_loss_every: 500 + vocab_size: 1024 + warmdown_frac: 0.72 + warmup_steps: 20 + world_size: 1 +train_shards: 5 +val_tokens: 62021632 +model_params: 32276568 +warmup_step: 1/20 +warmup_step: 2/20 +warmup_step: 3/20 +warmup_step: 4/20 +warmup_step: 5/20 +warmup_step: 6/20 +warmup_step: 10/20 +warmup_step: 20/20 +loop_warmup: enabled encoder:[0, 1, 2, 3, 4, 5, 3, 4, 5, 3] decoder:[4, 5, 3, 4, 5, 6, 7, 8, 9, 10] +loop_warmup_step: 1/20 +loop_warmup_step: 2/20 +loop_warmup_step: 3/20 +loop_warmup_step: 4/20 +loop_warmup_step: 5/20 +loop_warmup_step: 6/20 +loop_warmup_step: 10/20 +loop_warmup_step: 20/20 +0/2000 val_loss: 6.9384 val_bpb: 4.1093 +1/2000 train_loss: 6.9345 train_time: 0.0m tok/s: 1200803 +2/2000 train_loss: 12.1271 train_time: 0.0m tok/s: 1189039 +3/2000 train_loss: 9.9963 train_time: 0.0m tok/s: 1178249 +4/2000 train_loss: 7.8728 train_time: 0.0m tok/s: 1174070 +5/2000 train_loss: 6.5889 train_time: 0.1m tok/s: 1171437 +500/2000 train_loss: 2.3018 train_time: 5.7m tok/s: 1146850 +500/2000 val_loss: 2.3087 val_bpb: 1.3673 +layer_loop: enabled step:700 frac:0.350 encoder:[0, 1, 2, 3, 4, 5, 3, 4, 5, 3] decoder:[4, 5, 3, 4, 5, 6, 7, 8, 9, 10] +1000/2000 train_loss: 2.1383 train_time: 14.1m tok/s: 927822 +1000/2000 val_loss: 2.1413 val_bpb: 1.2682 +1500/2000 train_loss: 2.0394 train_time: 24.3m tok/s: 808192 +1500/2000 val_loss: 2.0580 val_bpb: 1.2188 +2000/2000 train_loss: 1.9188 train_time: 34.5m tok/s: 758935 +2000/2000 val_loss: 1.9713 val_bpb: 1.1675 +peak memory allocated: 39808 MiB reserved: 39916 MiB +ema: applying EMA weights +pre-quantization post-ema val_loss:1.97817906 val_bpb:1.17158907 eval_time:33143ms +Serialized model: 128099193 bytes +Code size: 65269 bytes +GPTQ: collecting Hessians from calibration data... +GPTQ: collected 67 Hessians in 17.3s +Quantized weights: + gptq (int6): blocks.attn.c_k.weight, blocks.attn.c_q.weight, blocks.attn.c_v.weight, blocks.attn.proj.weight, blocks.mlp.fc.weight, blocks.mlp.proj.weight + gptq (int8): tok_emb.weight + passthrough (float16): blocks.attn.q_gain, blocks.attn_scale, blocks.mlp_scale, blocks.resid_mix, skip_gates, skip_weights +[rank0]: Traceback (most recent call last): +[rank0]: File "/workspace/parameter-golf/records/track_10min_16mb/2026-04-23_QK55_4Loop_EarlyParResid_SelectiveTTT/train_gpt.py", line 1385, in +[rank0]: main() +[rank0]: File "/workspace/parameter-golf/records/track_10min_16mb/2026-04-23_QK55_4Loop_EarlyParResid_SelectiveTTT/train_gpt.py", line 1379, in main +[rank0]: train_and_eval(h, device) +[rank0]: File "/workspace/parameter-golf/records/track_10min_16mb/2026-04-23_QK55_4Loop_EarlyParResid_SelectiveTTT/train_gpt.py", line 1319, in train_and_eval +[rank0]: serialize(h, base_model, Path(__file__).read_text(encoding='utf-8')) +[rank0]: File "/workspace/parameter-golf/records/track_10min_16mb/2026-04-23_QK55_4Loop_EarlyParResid_SelectiveTTT/train_gpt.py", line 899, in serialize +[rank0]: quant_blob = _compress(quant_raw, h.compressor) +[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ +[rank0]: File "/workspace/parameter-golf/records/track_10min_16mb/2026-04-23_QK55_4Loop_EarlyParResid_SelectiveTTT/train_gpt.py", line 864, in _compress +[rank0]: import brotli +[rank0]: ModuleNotFoundError: No module named 'brotli' +[rank0]:[W423 03:24:44.766043241 ProcessGroupNCCL.cpp:1524] Warning: WARNING: destroy_process_group() was not called before program exit, which can leak resources. For more info, please see https://pytorch.org/docs/stable/distributed.html#shutdown (function operator()) +E0423 03:24:46.722000 8738 torch/distributed/elastic/multiprocessing/api.py:882] failed (exitcode: 1) local_rank: 0 (pid: 8807) of binary: /usr/local/bin/python +Traceback (most recent call last): + File "/usr/local/bin/torchrun", line 7, in + sys.exit(main()) + ^^^^^^ + File "/usr/local/lib/python3.12/dist-packages/torch/distributed/elastic/multiprocessing/errors/__init__.py", line 357, in wrapper + return f(*args, **kwargs) + ^^^^^^^^^^^^^^^^^^ + File "/usr/local/lib/python3.12/dist-packages/torch/distributed/run.py", line 936, in main + run(args) + File "/usr/local/lib/python3.12/dist-packages/torch/distributed/run.py", line 927, in run + elastic_launch( + File "/usr/local/lib/python3.12/dist-packages/torch/distributed/launcher/api.py", line 156, in __call__ + return launch_agent(self._config, self._entrypoint, list(args)) + ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + File "/usr/local/lib/python3.12/dist-packages/torch/distributed/launcher/api.py", line 293, in launch_agent + raise ChildFailedError( +torch.distributed.elastic.multiprocessing.errors.ChildFailedError: +============================================================ +/workspace/parameter-golf/records/track_10min_16mb/2026-04-23_QK55_4Loop_EarlyParResid_SelectiveTTT/train_gpt.py FAILED +------------------------------------------------------------ +Failures: + +------------------------------------------------------------ +Root Cause (first observed failure): +[0]: + time : 2026-04-23_03:24:46 + host : abfc9671a5d1 + rank : 0 (local_rank: 0) + exitcode : 1 (pid: 8807) + error_file: + traceback : To enable traceback see: https://pytorch.org/docs/stable/elastic/errors.html +============================================================ + + ✓ DONE — elapsed: 5540s | best_bpb: 1.17158907 | last_bpb: 1.17158907 + +════════════════════════════════════════════════ + GAUNTLET COMPLETE — Thu Apr 23 03:24:47 UTC 2026 +════════════════════════════════════════════════ + +Results (sorted by best_bpb): + + EXPERIMENT BEST_BPB LAST_BPB TIME + ---------- -------- -------- ---- + incr_QK55_4loop 1.17158907 1.17158907 5540s + baseline 1.2964 1.29774005 1263s + +Current SOTA: 1.0810 bpb (bigbag, 2026-04-09) + +Full results: /workspace/parameter-golf/gauntlet_results.txt +Logs: /workspace/parameter-golf/logs/gauntlet/ diff --git a/logs_pod1/gauntlet_incr_QK55_4loop_015227.txt b/logs_pod1/gauntlet_incr_QK55_4loop_015227.txt new file mode 100644 index 0000000000..19ed9447b8 --- /dev/null +++ b/logs_pod1/gauntlet_incr_QK55_4loop_015227.txt @@ -0,0 +1,154 @@ +==================================================================================================== +Hyperparameters: + adam_eps: 1e-08 + adam_wd: 0.02 + beta1: 0.9 + beta2: 0.95 + compressor: brotli + data_dir: /workspace/parameter-golf/data + datasets_dir: /workspace/parameter-golf/data/datasets/fineweb10B_sp1024 + distributed: True + ema_decay: 0.9965 + embed_bits: 8 + embed_clip_sigmas: 20.0 + embed_lr: 0.6 + embed_wd: 0.085 + embedding_dim: 512 + enable_looping_at: 0.35 + eval_seq_len: 2048 + eval_stride: 64 + gptq_calibration_batches: 64 + gptq_reserve_seconds: 12.0 + grad_accum_steps: 8 + grad_clip_norm: 0.3 + head_lr: 0.008 + is_main_process: True + iterations: 2000 + ln_scale: True + local_rank: 0 + logfile: logs/gauntlet_incr_QK55_4loop_015227.txt + logit_softcap: 30.0 + loop_end: 5 + loop_start: 3 + matrix_bits: 6 + matrix_clip_sigmas: 12.85 + matrix_lr: 0.022 + max_wallclock_seconds: 0.0 + min_lr: 0.0 + mlp_mult: 4.0 + model_dim: 512 + model_path: final_model.pt + muon_backend_steps: 5 + muon_beta2: 0.95 + muon_momentum: 0.99 + muon_momentum_warmup_start: 0.92 + muon_momentum_warmup_steps: 1500 + muon_row_normalize: True + muon_wd: 0.095 + num_heads: 8 + num_kv_heads: 4 + num_layers: 11 + num_loops: 3 + parallel_residual_start: 5 + qk_gain_init: 5.5 + quantized_model_path: final_model.int6.ptz + rank: 0 + rope_base: 10000.0 + rope_dims: 16 + rope_train_seq_len: 2048 + run_id: gauntlet_incr_QK55_4loop_015227 + scalar_lr: 0.02 + seed: 1337 + skip_gates_enabled: True + sliding_window_enabled: True + tie_embeddings: True + tied_embed_init_std: 0.005 + tied_embed_lr: 0.03 + tokenizer_path: /workspace/parameter-golf/data/tokenizers/fineweb_1024_bpe.model + train_batch_tokens: 786432 + train_files: /workspace/parameter-golf/data/datasets/fineweb10B_sp1024/fineweb_train_*.bin + train_log_every: 500 + train_seq_len: 2048 + ttt_chunk_tokens: 24576 + ttt_enabled: False + ttt_epochs: 4 + ttt_lr: 0.005 + ttt_momentum: 0.9 + ttt_selective: True + val_batch_tokens: 524288 + val_files: /workspace/parameter-golf/data/datasets/fineweb10B_sp1024/fineweb_val_*.bin + val_loss_every: 500 + vocab_size: 1024 + warmdown_frac: 0.72 + warmup_steps: 20 + world_size: 1 +==================================================================================================== +Running Python 3.12.3 (main, Nov 6 2025, 13:44:16) [GCC 13.3.0] +Running PyTorch 2.9.1+cu128 +Thu Apr 23 01:52:32 2026 ++-----------------------------------------------------------------------------------------+ +| NVIDIA-SMI 580.126.09 Driver Version: 580.126.09 CUDA Version: 13.0 | ++-----------------------------------------+------------------------+----------------------+ +| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC | +| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. | +| | | MIG M. | +|=========================================+========================+======================| +| 0 NVIDIA H100 80GB HBM3 On | 00000000:87:00.0 Off | 0 | +| N/A 40C P0 94W / 700W | 1185MiB / 81559MiB | 2% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ + ++-----------------------------------------------------------------------------------------+ +| Processes: | +| GPU GI CI PID Type Process name GPU Memory | +| ID ID Usage | +|=========================================================================================| +| 0 N/A N/A 8807 C /usr/local/bin/python 1176MiB | ++-----------------------------------------------------------------------------------------+ + +train_shards: 5 +val_tokens: 62021632 +model_params: 32276568 +warmup_step: 1/20 +warmup_step: 2/20 +warmup_step: 3/20 +warmup_step: 4/20 +warmup_step: 5/20 +warmup_step: 6/20 +warmup_step: 10/20 +warmup_step: 20/20 +loop_warmup: enabled encoder:[0, 1, 2, 3, 4, 5, 3, 4, 5, 3] decoder:[4, 5, 3, 4, 5, 6, 7, 8, 9, 10] +loop_warmup_step: 1/20 +loop_warmup_step: 2/20 +loop_warmup_step: 3/20 +loop_warmup_step: 4/20 +loop_warmup_step: 5/20 +loop_warmup_step: 6/20 +loop_warmup_step: 10/20 +loop_warmup_step: 20/20 +0/2000 val_loss: 6.9384 val_bpb: 4.1093 +1/2000 train_loss: 6.9345 train_time: 0.0m tok/s: 1200803 +2/2000 train_loss: 12.1271 train_time: 0.0m tok/s: 1189039 +3/2000 train_loss: 9.9963 train_time: 0.0m tok/s: 1178249 +4/2000 train_loss: 7.8728 train_time: 0.0m tok/s: 1174070 +5/2000 train_loss: 6.5889 train_time: 0.1m tok/s: 1171437 +500/2000 train_loss: 2.3018 train_time: 5.7m tok/s: 1146850 +500/2000 val_loss: 2.3087 val_bpb: 1.3673 +layer_loop: enabled step:700 frac:0.350 encoder:[0, 1, 2, 3, 4, 5, 3, 4, 5, 3] decoder:[4, 5, 3, 4, 5, 6, 7, 8, 9, 10] +1000/2000 train_loss: 2.1383 train_time: 14.1m tok/s: 927822 +1000/2000 val_loss: 2.1413 val_bpb: 1.2682 +1500/2000 train_loss: 2.0394 train_time: 24.3m tok/s: 808192 +1500/2000 val_loss: 2.0580 val_bpb: 1.2188 +2000/2000 train_loss: 1.9188 train_time: 34.5m tok/s: 758935 +2000/2000 val_loss: 1.9713 val_bpb: 1.1675 +peak memory allocated: 39808 MiB reserved: 39916 MiB +ema: applying EMA weights +pre-quantization post-ema val_loss:1.97817906 val_bpb:1.17158907 eval_time:33143ms +Serialized model: 128099193 bytes +Code size: 65269 bytes +GPTQ: collecting Hessians from calibration data... +GPTQ: collected 67 Hessians in 17.3s +Quantized weights: + gptq (int6): blocks.attn.c_k.weight, blocks.attn.c_q.weight, blocks.attn.c_v.weight, blocks.attn.proj.weight, blocks.mlp.fc.weight, blocks.mlp.proj.weight + gptq (int8): tok_emb.weight + passthrough (float16): blocks.attn.q_gain, blocks.attn_scale, blocks.mlp_scale, blocks.resid_mix, skip_gates, skip_weights diff --git a/records/track_10min_16mb/2026-04-23_QK55_4Loop_EarlyParResid_SelectiveTTT/README.md b/records/track_10min_16mb/2026-04-23_QK55_4Loop_EarlyParResid_SelectiveTTT/README.md new file mode 100644 index 0000000000..2674c6a044 --- /dev/null +++ b/records/track_10min_16mb/2026-04-23_QK55_4Loop_EarlyParResid_SelectiveTTT/README.md @@ -0,0 +1,107 @@ +# val_bpb=1.1716 — 4-Loop Depth Recurrence + Early Parallel Residuals + Selective TTT + +**Author:** Ismail Haddou (@ismailntl) +**Confirmed val_bpb:** 1.17158907 (1×H100, seed 1337) +**3-seed 8×H100 run:** pending compute credits + +## Results + +| step | val_bpb | +|---|---| +| 500 | 1.3673 | +| 1000 | 1.2682 | +| 1500 | 1.2188 | +| 2000 | 1.1675 | +| post-EMA | **1.17158907** | + +Baseline (provided `train_gpt.py`) achieved 1.2977 at the same step count. + +## What this submission does + +Four changes applied together: + +1. **QK-Gain 5.5** — attention key/query gain parameter tuned to 5.5 +2. **NUM_LOOPS=3** — 4 recurrence passes through layers 3-5, giving 19 virtual layer executions from 11 physical layers +3. **Early Parallel Residuals** — GPT-J style parallel attention+MLP from layer 5 onward +4. **Selective TTT** — test-time training restricted to recurrent layers only, chunk size 24576, 4 epochs per chunk + +## Architecture + +11L × 512d × 8H / 4KV, MLP 4×, LeakyReLU(0.5)², Partial RoPE (16/64 dims), tied embeddings, logit softcap=30.0. SP1024 BPE tokenizer. Depth recurrence on layers 3-5 (activates at step 700, frac=0.35). Skip gates (sigmoid-gated U-Net connections). GPTQ SDClip int6/int8 + Brotli-11 compression. + +## How to run + +### Prerequisites + +```bash +# Clone and enter repo +git clone https://github.com/ismailntl/parameter-golf.git +cd parameter-golf + +# Download sp1024 dataset (if not already present) +python3 data/cached_challenge_fineweb.py --variant sp1024 --train-shards 10 +``` + +### Single run (1×H100, ablation) + +```bash +RUN_ID=qk55_4loop \ +DATA_DIR=./data \ +VOCAB_SIZE=1024 \ +ITERATIONS=2000 \ +MAX_WALLCLOCK_SECONDS=0 \ +SLIDING_WINDOW_ENABLED=1 \ +TRAIN_BATCH_TOKENS=786432 \ +torchrun --standalone --nproc_per_node=1 \ + records/track_10min_16mb/2026-04-23_QK55_4Loop_EarlyParResid_SelectiveTTT/train_gpt.py +``` + +### Official 8×H100 submission run (10-min cap enforced) + +```bash +RUN_ID=qk55_4loop_8gpu \ +DATA_DIR=./data \ +VOCAB_SIZE=1024 \ +ITERATIONS=2000 \ +MAX_WALLCLOCK_SECONDS=580 \ +SLIDING_WINDOW_ENABLED=1 \ +TRAIN_BATCH_TOKENS=786432 \ +torchrun --standalone --nproc_per_node=8 \ + records/track_10min_16mb/2026-04-23_QK55_4Loop_EarlyParResid_SelectiveTTT/train_gpt.py +``` + +### Gauntlet runner (all experiments + ablations) + +```bash +# Full gauntlet on 8×H100 +bash gauntlet.sh --vocab 1024 --gpus 8 --incr-only + +# Ablation on 1×H100 +bash gauntlet.sh --vocab 1024 --gpus 1 --incr-only +``` + +`MAX_WALLCLOCK_SECONDS=580` is set automatically by the gauntlet, leaving 20s for GPTQ serialization within the 10-minute window. + +### Additional experiments (in `experiments/`) + +```bash +# DEQ Universal Transformer (1 physical block → fixed-point) +torchrun --standalone --nproc_per_node=8 experiments/train_gpt_deq.py + +# Seed-LoRA (random linear map bases + stored LoRA adapters only) +LORA_RANK_ATTN=8 LORA_RANK_MLP=4 \ +torchrun --standalone --nproc_per_node=8 experiments/train_gpt_seeds.py + +# Mixture of Depths (50% token routing → ~2× more training steps) +MOD_CAPACITY=0.5 \ +torchrun --standalone --nproc_per_node=8 experiments/train_gpt_mod.py +``` + +## Files + +| File | Description | +|---|---| +| `train_gpt.py` | This submission's training script | +| `train_log_seed1337.log` | Full training log (seed 1337, 1×H100) | +| `train_log_baseline.log` | Baseline run log for comparison | +| `submission.json` | Metadata | diff --git a/records/track_10min_16mb/2026-04-23_QK55_4Loop_EarlyParResid_SelectiveTTT/submission.json b/records/track_10min_16mb/2026-04-23_QK55_4Loop_EarlyParResid_SelectiveTTT/submission.json new file mode 100644 index 0000000000..0d75813035 --- /dev/null +++ b/records/track_10min_16mb/2026-04-23_QK55_4Loop_EarlyParResid_SelectiveTTT/submission.json @@ -0,0 +1,31 @@ +{ + "author": "Ismail Haddou", + "github_id": "ismailntl", + "name": "QK-Gain 5.5 + 4-Loop Recurrence + Early Parallel Residuals + Selective TTT", + "date": "2026-04-23", + "track": "10min_16mb", + "val_bpb": 1.17158907, + "val_bpb_std": null, + "seeds": [], + "seed_results": {}, + "hardware": "1xH100 80GB HBM3 (ablation run, 2000 steps). Full submission targets 8xH100 within 600s via MAX_WALLCLOCK_SECONDS=580.", + "pytorch_version": "2.x+cu128", + "technique_summary": "QK-Gain 5.5 + NUM_LOOPS=3 (4 recurrence passes, 19 virtual layers) + Early Parallel Residuals (layer 5+) + Selective TTT on loop layers + smaller TTT chunk (24576) + more TTT epochs (4) on the SP1024+GPTQ+Brotli stack", + "status": "pending-compute", + "notes": "Pre-quantization bpb=1.17158907 confirmed on 1xH100 ablation (2000 steps, 34.5m). On 8xH100 same wallclock ~4-5m training, total <10m. MAX_WALLCLOCK_SECONDS=580 added for 10-min compliance. Brotli install fixed for Python 3.12. Awaiting compute credits for 3-seed 8xH100 run.", + "compliance": { + "train_under_600s": "enforced via MAX_WALLCLOCK_SECONDS=580", + "artifact_under_16mb": "pending final quantized run", + "eval_under_600s": null, + "no_slot": true, + "three_seeds": false + }, + "attribution": { + "sp1024_gptq_sdclip": "@clarkkev (PR #1394)", + "depth_recurrence": "@dexhunter (PR #1331, #1437)", + "parallel_residuals": "@Robby955 (PR #1412)", + "qk_gain": "@dexhunter (PR #1413)", + "legal_score_first_ttt": "@abaybektursun (PR #549), @dexhunter (PR #1413)", + "full_sota_stack": "@bigbag (PR #1493)" + } +} diff --git a/records/track_10min_16mb/2026-04-23_QK55_4Loop_EarlyParResid_SelectiveTTT/train_gpt.py b/records/track_10min_16mb/2026-04-23_QK55_4Loop_EarlyParResid_SelectiveTTT/train_gpt.py new file mode 100644 index 0000000000..240abe9b8a --- /dev/null +++ b/records/track_10min_16mb/2026-04-23_QK55_4Loop_EarlyParResid_SelectiveTTT/train_gpt.py @@ -0,0 +1,1427 @@ +""" +Submission: QK-Gain 5.5 + 4-Loop Recurrence + Early Parallel Residuals + Selective TTT + +Based on: SP8192 + 3-Layer Recurrence + Parallel Residuals + Legal TTT (bigbag, 1.0810) + +Changes over SOTA (bigbag, 1.0810 bpb): + 1. qk_gain_init: 5.25 → 5.5 (monotonic trend continues; no sign of saturation yet) + 2. num_loops: 2 → 3 (4 passes through layers 3-5 → 19 virtual layers vs 17) + 3. parallel_residual_start: 7 → 5 (GPT-J style residuals start earlier) + 4. Selective TTT: only adapt recurrent core (layers loop_start..loop_end) at test time + faster TTT + less overfitting on non-recurrent layers + 5. ttt_chunk_tokens: 32768 → 24576 (finer-grained adaptation steps per chunk) + 6. ttt_epochs: 3 → 4 (more adaptation within tighter chunks) + +Target: 1.075 bpb (−0.006 from SOTA, clears 0.005-nat threshold) +""" + +import collections, copy, glob, io, lzma, math, os +from pathlib import Path +import random, re, subprocess, sys, time, uuid +import numpy as np +import sentencepiece as spm +import torch +import torch.distributed as dist +import torch.nn.functional as F +from torch.nn.parallel import DistributedDataParallel as DDP +from torch import Tensor, nn +try: + from flash_attn_interface import flash_attn_func as flash_attn_3_func +except ImportError: + try: + from flash_attn import flash_attn_func as flash_attn_3_func + except ImportError: + flash_attn_3_func = None + + +# --------------------------------------------------------------------------- +# Hyperparameters +# --------------------------------------------------------------------------- +class Hyperparameters: + # Data + data_dir = os.environ.get('DATA_DIR', './data/') + seed = int(os.environ.get('SEED', 1337)) + run_id = os.environ.get('RUN_ID', str(uuid.uuid4())) + iterations = int(os.environ.get('ITERATIONS', 20000)) + warmdown_frac = float(os.environ.get('WARMDOWN_FRAC', 0.72)) + warmup_steps = int(os.environ.get('WARMUP_STEPS', 20)) + train_batch_tokens = int(os.environ.get('TRAIN_BATCH_TOKENS', 786432)) + train_seq_len = int(os.environ.get('TRAIN_SEQ_LEN', 2048)) + train_log_every = int(os.environ.get('TRAIN_LOG_EVERY', 500)) + max_wallclock_seconds = float(os.environ.get('MAX_WALLCLOCK_SECONDS', 600)) + val_batch_tokens = int(os.environ.get('VAL_BATCH_TOKENS', 524288)) + eval_seq_len = int(os.environ.get('EVAL_SEQ_LEN', 2048)) + val_loss_every = int(os.environ.get('VAL_LOSS_EVERY', 4000)) + sliding_window_enabled = bool(int(os.environ.get('SLIDING_WINDOW_ENABLED', '1'))) + # Model architecture + vocab_size = int(os.environ.get('VOCAB_SIZE', 8192)) + num_layers = int(os.environ.get('NUM_LAYERS', 11)) + model_dim = int(os.environ.get('MODEL_DIM', 512)) + embedding_dim = int(os.environ.get('EMBEDDING_DIM', 512)) + num_heads = int(os.environ.get('NUM_HEADS', 8)) + num_kv_heads = int(os.environ.get('NUM_KV_HEADS', 4)) + mlp_mult = float(os.environ.get('MLP_MULT', 4.0)) + skip_gates_enabled = bool(int(os.environ.get('SKIP_GATES_ENABLED', '1'))) + tie_embeddings = bool(int(os.environ.get('TIE_EMBEDDINGS', '1'))) + logit_softcap = float(os.environ.get('LOGIT_SOFTCAP', 30.0)) + rope_base = float(os.environ.get('ROPE_BASE', 1e4)) + rope_dims = int(os.environ.get('ROPE_DIMS', 16)) + rope_train_seq_len = int(os.environ.get('ROPE_TRAIN_SEQ_LEN', 2048)) + ln_scale = bool(int(os.environ.get('LN_SCALE', '1'))) + # *** CHANGE 1: QK-Gain 5.5 (was 5.25) *** + qk_gain_init = float(os.environ.get('QK_GAIN_INIT', 5.5)) + # *** CHANGE 2: 4-loop recurrence (was 2 = 3 passes, now 3 = 4 passes) *** + num_loops = int(os.environ.get('NUM_LOOPS', 3)) + loop_start = int(os.environ.get('LOOP_START', 3)) + loop_end = int(os.environ.get('LOOP_END', 5)) + enable_looping_at = float(os.environ.get('ENABLE_LOOPING_AT', 0.35)) + # *** CHANGE 3: Parallel residuals from layer 5 (was 7) *** + parallel_residual_start = int(os.environ.get('PARALLEL_RESIDUAL_START', 5)) + # Optimizer + min_lr = float(os.environ.get('MIN_LR', 0.0)) + embed_lr = float(os.environ.get('EMBED_LR', 0.6)) + head_lr = float(os.environ.get('HEAD_LR', 0.008)) + tied_embed_lr = float(os.environ.get('TIED_EMBED_LR', 0.03)) + tied_embed_init_std = float(os.environ.get('TIED_EMBED_INIT_STD', 0.005)) + matrix_lr = float(os.environ.get('MATRIX_LR', 0.022)) + scalar_lr = float(os.environ.get('SCALAR_LR', 0.02)) + muon_momentum = float(os.environ.get('MUON_MOMENTUM', 0.99)) + muon_backend_steps = int(os.environ.get('MUON_BACKEND_STEPS', 5)) + muon_momentum_warmup_start = float(os.environ.get('MUON_MOMENTUM_WARMUP_START', 0.92)) + muon_momentum_warmup_steps = int(os.environ.get('MUON_MOMENTUM_WARMUP_STEPS', 1500)) + muon_row_normalize = bool(int(os.environ.get('MUON_ROW_NORMALIZE', '1'))) + beta1 = float(os.environ.get('BETA1', 0.9)) + beta2 = float(os.environ.get('BETA2', 0.95)) + adam_eps = float(os.environ.get('ADAM_EPS', 1e-8)) + grad_clip_norm = float(os.environ.get('GRAD_CLIP_NORM', 0.3)) + muon_beta2 = float(os.environ.get('MUON_BETA2', 0.95)) + adam_wd = float(os.environ.get('ADAM_WD', 0.02)) + muon_wd = float(os.environ.get('MUON_WD', 0.095)) + embed_wd = float(os.environ.get('EMBED_WD', 0.085)) + ema_decay = float(os.environ.get('EMA_DECAY', 0.9965)) + eval_stride = int(os.environ.get('EVAL_STRIDE', 64)) + # TTT (test-time training) + ttt_enabled = bool(int(os.environ.get('TTT_ENABLED', '0'))) + ttt_lr = float(os.environ.get('TTT_LR', 0.005)) + # *** CHANGE 6: more epochs per chunk *** + ttt_epochs = int(os.environ.get('TTT_EPOCHS', 4)) + ttt_momentum = float(os.environ.get('TTT_MOMENTUM', 0.9)) + # *** CHANGE 5: tighter chunks for finer adaptation *** + ttt_chunk_tokens = int(os.environ.get('TTT_CHUNK_TOKENS', 24576)) + # *** CHANGE 4: selective TTT — only adapt recurrent layers *** + ttt_selective = bool(int(os.environ.get('TTT_SELECTIVE', '1'))) + # Quantization + compressor = os.environ.get('COMPRESSOR', 'brotli') + gptq_calibration_batches = int(os.environ.get('GPTQ_CALIBRATION_BATCHES', 64)) + gptq_reserve_seconds = float(os.environ.get('GPTQ_RESERVE_SECONDS', 12.0)) + matrix_bits = int(os.environ.get('MATRIX_BITS', 6)) + embed_bits = int(os.environ.get('EMBED_BITS', 8)) + matrix_clip_sigmas = float(os.environ.get('MATRIX_CLIP_SIGMAS', 12.85)) + embed_clip_sigmas = float(os.environ.get('EMBED_CLIP_SIGMAS', 20.0)) + # Distributed + distributed = 'RANK' in os.environ and 'WORLD_SIZE' in os.environ + rank = int(os.environ.get('RANK', '0')) + world_size = int(os.environ.get('WORLD_SIZE', '1')) + local_rank = int(os.environ.get('LOCAL_RANK', '0')) + is_main_process = rank == 0 + grad_accum_steps = 8 // world_size + # Derived + datasets_dir = os.path.join(data_dir, 'datasets', f'fineweb10B_sp{vocab_size}') + train_files = os.path.join(datasets_dir, 'fineweb_train_*.bin') + val_files = os.path.join(datasets_dir, 'fineweb_val_*.bin') + tokenizer_path = os.path.join(data_dir, 'tokenizers', f'fineweb_{vocab_size}_bpe.model') + logfile = f'logs/{run_id}.txt' + model_path = 'final_model.pt' + quantized_model_path = 'final_model.int6.ptz' + + +# --------------------------------------------------------------------------- +# Logging +# --------------------------------------------------------------------------- +_logger_hparams = None + +def set_logging_hparams(h): + global _logger_hparams + _logger_hparams = h + +def log(msg, console=True): + if _logger_hparams is None: + print(msg) + return + if _logger_hparams.is_main_process: + if console: + print(msg) + if _logger_hparams.logfile is not None: + with open(_logger_hparams.logfile, 'a', encoding='utf-8') as f: + print(msg, file=f) + + +# --------------------------------------------------------------------------- +# Data loading +# --------------------------------------------------------------------------- +class ValidationData: + def __init__(self, h, device): + self.sp = spm.SentencePieceProcessor(model_file=h.tokenizer_path) + if int(self.sp.vocab_size()) != h.vocab_size: + raise ValueError(f"VOCAB_SIZE={h.vocab_size} does not match tokenizer vocab_size={int(self.sp.vocab_size())}") + self.val_tokens = load_validation_tokens(h.val_files, h.eval_seq_len) + self.base_bytes_lut, self.has_leading_space_lut, self.is_boundary_token_lut = \ + build_sentencepiece_luts(self.sp, h.vocab_size, device) + + +def build_sentencepiece_luts(sp, vocab_size, device): + sp_vocab_size = int(sp.vocab_size()) + assert sp.piece_to_id('▁') != sp.unk_id(), \ + "Tokenizer must have '▁' (space) as its own token for correct BPB byte counting" + table_size = max(sp_vocab_size, vocab_size) + base_bytes_np = np.zeros((table_size,), dtype=np.int16) + has_leading_space_np = np.zeros((table_size,), dtype=np.bool_) + is_boundary_token_np = np.ones((table_size,), dtype=np.bool_) + for token_id in range(sp_vocab_size): + if sp.is_control(token_id) or sp.is_unknown(token_id) or sp.is_unused(token_id): + continue + is_boundary_token_np[token_id] = False + if sp.is_byte(token_id): + base_bytes_np[token_id] = 1 + continue + piece = sp.id_to_piece(token_id) + if piece.startswith('▁'): + has_leading_space_np[token_id] = True + piece = piece[1:] + base_bytes_np[token_id] = len(piece.encode('utf-8')) + return ( + torch.tensor(base_bytes_np, dtype=torch.int16, device=device), + torch.tensor(has_leading_space_np, dtype=torch.bool, device=device), + torch.tensor(is_boundary_token_np, dtype=torch.bool, device=device), + ) + + +def load_validation_tokens(pattern, seq_len): + files = [Path(p) for p in sorted(glob.glob(pattern))] + if not files: + raise FileNotFoundError(f"No files found for pattern: {pattern}") + tokens = torch.cat([load_data_shard(file) for file in files]).contiguous() + usable = (tokens.numel() - 1) // seq_len * seq_len + if usable <= 0: + raise ValueError(f"Validation split too short for seq_len={seq_len}") + return tokens[:usable + 1] + + +def load_data_shard(file): + header_bytes = 256 * np.dtype(' 0 else 0 + num_sequences = (self.num_tokens[si] - 1 - phase) // self.seq_len + sequence_order = self.rng.permutation(num_sequences) + self.start_inds[si] = (phase + sequence_order * self.seq_len).tolist() + + def next_batch(self, global_tokens, grad_accum_steps): + device_tokens = global_tokens // (self.world_size * grad_accum_steps) + device_batch_size = device_tokens // self.seq_len + remaining = np.array([len(s) for s in self.start_inds], dtype=np.float64) + x = torch.empty((device_batch_size, self.seq_len), dtype=torch.int64) + y = torch.empty((device_batch_size, self.seq_len), dtype=torch.int64) + for bi in range(device_batch_size): + total = remaining.sum() + if total <= 0: + for si in range(len(self.files)): + self._reset_shard(si) + remaining = np.array([len(s) for s in self.start_inds], dtype=np.float64) + total = remaining.sum() + probs = remaining / total + si = int(self.rng.choice(len(self.files), p=probs)) + start_ind = self.start_inds[si].pop() + remaining[si] -= 1 + mm = _get_shard_memmap(self.files[si]) + window = torch.as_tensor(np.array(mm[start_ind:start_ind + self.seq_len + 1], dtype=np.int64)) + x[bi] = window[:-1] + y[bi] = window[1:] + return x.to(self.device, non_blocking=True), y.to(self.device, non_blocking=True) + + +# --------------------------------------------------------------------------- +# Model +# --------------------------------------------------------------------------- +class RMSNorm(nn.Module): + def __init__(self, eps=None): + super().__init__() + self.eps = eps + + def forward(self, x): + return F.rms_norm(x, (x.size(-1),), eps=self.eps) + + +class CastedLinear(nn.Linear): + def forward(self, x): + w = self.weight.to(x.dtype) + bias = self.bias.to(x.dtype) if self.bias is not None else None + return F.linear(x, w, bias) + + +class Rotary(nn.Module): + def __init__(self, dim, base=1e4, train_seq_len=1024, rope_dims=0): + super().__init__() + self.dim = dim + self.base = base + self.train_seq_len = train_seq_len + self.rope_dims = rope_dims if rope_dims > 0 else dim + inv_freq = 1.0 / base ** (torch.arange(0, self.rope_dims, 2, dtype=torch.float32) / self.rope_dims) + self.register_buffer('inv_freq', inv_freq, persistent=False) + self._seq_len_cached = 0 + self._cos_cached = None + self._sin_cached = None + + def forward(self, seq_len, device, dtype): + if (self._cos_cached is None or self._sin_cached is None + or self._seq_len_cached != seq_len + or self._cos_cached.device != device): + rd = self.rope_dims + if seq_len > self.train_seq_len: + scale = seq_len / self.train_seq_len + new_base = self.base * scale ** (rd / (rd - 2)) + inv_freq = 1.0 / new_base ** (torch.arange(0, rd, 2, dtype=torch.float32, device=device) / rd) + else: + inv_freq = self.inv_freq.to(device) + t = torch.arange(seq_len, device=device, dtype=inv_freq.dtype) + freqs = torch.outer(t, inv_freq) + self._cos_cached = freqs.cos()[None, :, None, :] + self._sin_cached = freqs.sin()[None, :, None, :] + self._seq_len_cached = seq_len + return self._cos_cached.to(dtype=dtype), self._sin_cached.to(dtype=dtype) + + +def apply_rotary_emb(x, cos, sin, rope_dims=0): + if rope_dims > 0 and rope_dims < x.size(-1): + x_rope, x_pass = x[..., :rope_dims], x[..., rope_dims:] + half = rope_dims // 2 + x1, x2 = x_rope[..., :half], x_rope[..., half:] + x_rope = torch.cat((x1 * cos + x2 * sin, x1 * -sin + x2 * cos), dim=-1) + return torch.cat((x_rope, x_pass), dim=-1) + half = x.size(-1) // 2 + x1, x2 = x[..., :half], x[..., half:] + return torch.cat((x1 * cos + x2 * sin, x1 * -sin + x2 * cos), dim=-1) + + +class CausalSelfAttention(nn.Module): + def __init__(self, dim, num_heads, num_kv_heads, rope_base, qk_gain_init, train_seq_len): + super().__init__() + if dim % num_heads != 0: + raise ValueError('model_dim must be divisible by num_heads') + if num_heads % num_kv_heads != 0: + raise ValueError('num_heads must be divisible by num_kv_heads') + self.num_heads = num_heads + self.num_kv_heads = num_kv_heads + self.head_dim = dim // num_heads + if self.head_dim % 2 != 0: + raise ValueError('head_dim must be even for RoPE') + kv_dim = self.num_kv_heads * self.head_dim + self.c_q = CastedLinear(dim, dim, bias=False) + self.c_k = CastedLinear(dim, kv_dim, bias=False) + self.c_v = CastedLinear(dim, kv_dim, bias=False) + self.proj = CastedLinear(dim, dim, bias=False) + self.proj._zero_init = True + self.q_gain = nn.Parameter(torch.full((num_heads,), qk_gain_init, dtype=torch.float32)) + self.rope_dims = 0 + self.rotary = Rotary(self.head_dim, base=rope_base, train_seq_len=train_seq_len) + + def forward(self, x): + bsz, seqlen, dim = x.shape + q = self.c_q(x).reshape(bsz, seqlen, self.num_heads, self.head_dim) + k = self.c_k(x).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim) + v = self.c_v(x).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim) + q = F.rms_norm(q, (q.size(-1),)) + k = F.rms_norm(k, (k.size(-1),)) + cos, sin = self.rotary(seqlen, x.device, q.dtype) + q = apply_rotary_emb(q, cos, sin, self.rope_dims) + k = apply_rotary_emb(k, cos, sin, self.rope_dims) + q = q * self.q_gain.to(dtype=q.dtype)[None, None, :, None] + if flash_attn_3_func is not None: + # FA2 requires fp16 or bf16; cast if needed (FA3 accepts fp32 natively) + orig_dtype = q.dtype + if orig_dtype not in (torch.float16, torch.bfloat16): + y = flash_attn_3_func(q.bfloat16(), k.bfloat16(), v.bfloat16(), causal=True).to(orig_dtype) + else: + y = flash_attn_3_func(q, k, v, causal=True) + else: + # fallback to torch SDPA + q2 = q.transpose(1, 2) + k2 = k.transpose(1, 2) + v2 = v.transpose(1, 2) + y = F.scaled_dot_product_attention(q2, k2, v2, is_causal=True).transpose(1, 2) + y = y.reshape(bsz, seqlen, dim) + return self.proj(y) + + +class MLP(nn.Module): + def __init__(self, dim, mlp_mult): + super().__init__() + hidden = int(mlp_mult * dim) + self.fc = CastedLinear(dim, hidden, bias=False) + self.proj = CastedLinear(hidden, dim, bias=False) + self.proj._zero_init = True + + def forward(self, x): + return self.proj(F.leaky_relu(self.fc(x), negative_slope=0.5).square()) + + +class Block(nn.Module): + def __init__(self, dim, num_heads, num_kv_heads, mlp_mult, rope_base, qk_gain_init, + train_seq_len, layer_idx=0, ln_scale=False): + super().__init__() + self.attn_norm = RMSNorm() + self.mlp_norm = RMSNorm() + self.attn = CausalSelfAttention(dim, num_heads, num_kv_heads, rope_base, qk_gain_init, train_seq_len) + self.mlp = MLP(dim, mlp_mult) + self.attn_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + self.mlp_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + self.resid_mix = nn.Parameter(torch.stack((torch.ones(dim), torch.zeros(dim))).float()) + self.ln_scale_factor = 1.0 / math.sqrt(layer_idx + 1) if ln_scale else 1.0 + self.parallel = False + + def forward(self, x, x0): + mix = self.resid_mix.to(dtype=x.dtype) + x_in = mix[0][None, None, :] * x + mix[1][None, None, :] * x0 + attn_out = self.attn(self.attn_norm(x_in) * self.ln_scale_factor) + if self.parallel: + mlp_out = self.mlp(self.mlp_norm(x_in) * self.ln_scale_factor) + x_out = (x_in + + self.attn_scale.to(dtype=x_in.dtype)[None, None, :] * attn_out + + self.mlp_scale.to(dtype=x_in.dtype)[None, None, :] * mlp_out) + else: + x_out = x_in + self.attn_scale.to(dtype=x_in.dtype)[None, None, :] * attn_out + x_out = x_out + self.mlp_scale.to(dtype=x_out.dtype)[None, None, :] * \ + self.mlp(self.mlp_norm(x_out) * self.ln_scale_factor) + return x_out + + +class GPT(nn.Module): + def __init__(self, h): + super().__init__() + if h.logit_softcap <= 0.0: + raise ValueError(f"logit_softcap must be positive, got {h.logit_softcap}") + self.tie_embeddings = h.tie_embeddings + self.tied_embed_init_std = h.tied_embed_init_std + self.logit_softcap = h.logit_softcap + self.tok_emb = nn.Embedding(h.vocab_size, h.embedding_dim) + if h.embedding_dim != h.model_dim: + self.embed_proj = CastedLinear(h.embedding_dim, h.model_dim, bias=False) + self.head_proj = CastedLinear(h.model_dim, h.embedding_dim, bias=False) + else: + self.embed_proj = None + self.head_proj = None + self.num_encoder_layers = h.num_layers // 2 + self.num_decoder_layers = h.num_layers - self.num_encoder_layers + self.blocks = nn.ModuleList([ + Block(h.model_dim, h.num_heads, h.num_kv_heads, h.mlp_mult, h.rope_base, + h.qk_gain_init, h.train_seq_len, layer_idx=i, ln_scale=h.ln_scale) + for i in range(h.num_layers) + ]) + if h.rope_dims > 0: + head_dim = h.model_dim // h.num_heads + for block in self.blocks: + block.attn.rope_dims = h.rope_dims + block.attn.rotary = Rotary(head_dim, base=h.rope_base, + train_seq_len=h.train_seq_len, rope_dims=h.rope_dims) + self.final_norm = RMSNorm() + self.lm_head = None if h.tie_embeddings else CastedLinear(h.embedding_dim, h.vocab_size, bias=False) + if self.lm_head is not None: + self.lm_head._zero_init = True + if h.parallel_residual_start >= 0: + for i in range(h.parallel_residual_start, h.num_layers): + self.blocks[i].parallel = True + self.looping_active = False + # Build layer execution order for depth recurrence + if h.num_loops > 0: + loop_seg = list(range(h.loop_start, h.loop_end + 1)) + all_indices = list(range(h.loop_start)) + for _ in range(h.num_loops + 1): + all_indices.extend(loop_seg) + all_indices.extend(range(h.loop_end + 1, h.num_layers)) + num_enc = len(all_indices) // 2 + self.encoder_indices = all_indices[:num_enc] + self.decoder_indices = all_indices[num_enc:] + else: + self.encoder_indices = list(range(self.num_encoder_layers)) + self.decoder_indices = list(range(self.num_encoder_layers, h.num_layers)) + self.num_skip_weights = min(len(self.encoder_indices), len(self.decoder_indices)) + self.skip_weights = nn.Parameter(torch.ones(self.num_skip_weights, h.model_dim, dtype=torch.float32)) + self.skip_gates = (nn.Parameter(torch.zeros(self.num_skip_weights, h.model_dim, dtype=torch.float32)) + if h.skip_gates_enabled else None) + self.loop_start = h.loop_start + self.loop_end = h.loop_end + self._init_weights() + + def _init_weights(self): + if self.tie_embeddings: + nn.init.normal_(self.tok_emb.weight, mean=0.0, std=self.tied_embed_init_std) + for name, module in self.named_modules(): + if isinstance(module, nn.Linear): + if getattr(module, '_zero_init', False): + nn.init.zeros_(module.weight) + elif (module.weight.ndim == 2 + and module.weight.shape[0] >= 64 + and module.weight.shape[1] >= 64): + nn.init.orthogonal_(module.weight, gain=1.0) + + def forward_logits(self, input_ids): + x = self.tok_emb(input_ids) + x = F.rms_norm(x, (x.size(-1),)) + if self.embed_proj is not None: + x = self.embed_proj(x) + x0 = x + skips = [] + enc_iter = self.encoder_indices if self.looping_active else range(self.num_encoder_layers) + dec_iter = self.decoder_indices if self.looping_active else range( + self.num_encoder_layers, self.num_encoder_layers + self.num_decoder_layers) + for i in enc_iter: + x = self.blocks[i](x, x0) + skips.append(x) + for skip_idx, i in enumerate(dec_iter): + if skip_idx < self.num_skip_weights and skips: + scaled_skip = self.skip_weights[skip_idx].to(dtype=x.dtype)[None, None, :] * skips.pop() + if self.skip_gates is not None: + g = torch.sigmoid(self.skip_gates[skip_idx].to(dtype=x.dtype))[None, None, :] + x = torch.lerp(scaled_skip, x, g) + else: + x = x + scaled_skip + x = self.blocks[i](x, x0) + x = self.final_norm(x) + if self.head_proj is not None: + x = self.head_proj(x) + if self.tie_embeddings: + logits_proj = F.linear(x, self.tok_emb.weight) + else: + logits_proj = self.lm_head(x) + return self.logit_softcap * torch.tanh(logits_proj / self.logit_softcap) + + def forward(self, input_ids, target_ids): + logits = self.forward_logits(input_ids) + return F.cross_entropy(logits.reshape(-1, logits.size(-1)).float(), + target_ids.reshape(-1), reduction='mean') + + +# --------------------------------------------------------------------------- +# Optimizer (MuonEq-R + AdamW) +# --------------------------------------------------------------------------- +def classify_param(name): + if 'tok_emb' in name or 'lm_head' in name: + return 'embed' + if '.mlp.' in name: + return 'mlp' + if '.attn.' in name or ('.proj.' in name and '.mlp.' not in name): + return 'attn' + return 'other' + + +@torch.compile +def zeropower_via_newtonschulz5(G, steps=10, eps=1e-7): + a, b, c = 3.4445, -4.7750, 2.0315 + X = G.bfloat16() + X /= X.norm() + eps + transposed = G.size(0) > G.size(1) + if transposed: + X = X.T + for _ in range(steps): + A = X @ X.T + B = b * A + c * A @ A + X = a * X + B @ X + return X.T if transposed else X + + +class Muon(torch.optim.Optimizer): + def __init__(self, params, lr, momentum, backend_steps, nesterov=True, + weight_decay=0.0, row_normalize=False): + super().__init__(params, dict(lr=lr, momentum=momentum, backend_steps=backend_steps, + nesterov=nesterov, weight_decay=weight_decay, + row_normalize=row_normalize)) + + @torch.no_grad() + def step(self, closure=None): + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + distributed = dist.is_available() and dist.is_initialized() + world_size = dist.get_world_size() if distributed else 1 + rank = dist.get_rank() if distributed else 0 + for group in self.param_groups: + params = group['params'] + if not params: + continue + lr = group['lr'] + momentum = group['momentum'] + backend_steps = group['backend_steps'] + nesterov = group['nesterov'] + total_params = sum(int(p.numel()) for p in params) + updates_flat = torch.zeros(total_params, device=params[0].device, dtype=torch.bfloat16) + curr = 0 + for i, p in enumerate(params): + if i % world_size == rank and p.grad is not None: + g = p.grad + state = self.state[p] + if 'momentum_buffer' not in state: + state['momentum_buffer'] = torch.zeros_like(g) + buf = state['momentum_buffer'] + buf.mul_(momentum).add_(g) + if nesterov: + g = g.add(buf, alpha=momentum) + if group.get('row_normalize', False): + row_norms = g.float().norm(dim=-1, keepdim=True).clamp_min(1e-7) + g = g / row_norms.to(g.dtype) + g = zeropower_via_newtonschulz5(g, steps=backend_steps) + g *= max(1, g.size(0) / g.size(1)) ** 0.5 + updates_flat[curr:curr + p.numel()] = g.reshape(-1) + curr += p.numel() + if distributed: + dist.all_reduce(updates_flat, op=dist.ReduceOp.SUM) + wd = group.get('weight_decay', 0.0) + curr = 0 + for p in params: + if wd > 0.0: + p.data.mul_(1.0 - lr * wd) + g = updates_flat[curr:curr + p.numel()].view_as(p).to(dtype=p.dtype) + p.add_(g, alpha=-lr) + curr += p.numel() + return loss + + +CONTROL_TENSOR_NAME_PATTERNS = tuple( + p for p in os.environ.get( + 'CONTROL_TENSOR_NAME_PATTERNS', + 'attn_scale,mlp_scale,resid_mix,q_gain,skip_weight,skip_weights,skip_gates' + ).split(',') if p +) + + +class Optimizers: + def __init__(self, h, base_model): + block_named_params = list(base_model.blocks.named_parameters()) + matrix_params = [p for name, p in block_named_params + if p.ndim == 2 and not any(pat in name for pat in CONTROL_TENSOR_NAME_PATTERNS)] + scalar_params = [p for name, p in block_named_params + if p.ndim < 2 or any(pat in name for pat in CONTROL_TENSOR_NAME_PATTERNS)] + if base_model.skip_weights.numel() > 0: + scalar_params.append(base_model.skip_weights) + if base_model.skip_gates is not None and base_model.skip_gates.numel() > 0: + scalar_params.append(base_model.skip_gates) + token_lr = h.tied_embed_lr if h.tie_embeddings else h.embed_lr + tok_params = [{'params': [base_model.tok_emb.weight], 'lr': token_lr, 'base_lr': token_lr}] + self.optimizer_tok = torch.optim.AdamW(tok_params, betas=(h.beta1, h.beta2), + eps=h.adam_eps, weight_decay=h.embed_wd, fused=True) + self.optimizer_muon = Muon(matrix_params, lr=h.matrix_lr, momentum=h.muon_momentum, + backend_steps=h.muon_backend_steps, weight_decay=h.muon_wd, + row_normalize=h.muon_row_normalize) + for group in self.optimizer_muon.param_groups: + group['base_lr'] = h.matrix_lr + self.optimizer_scalar = torch.optim.AdamW( + [{'params': scalar_params, 'lr': h.scalar_lr, 'base_lr': h.scalar_lr}], + betas=(h.beta1, h.beta2), eps=h.adam_eps, weight_decay=h.adam_wd, fused=True) + self.optimizers = [self.optimizer_tok, self.optimizer_muon, self.optimizer_scalar] + if base_model.lm_head is not None: + self.optimizer_head = torch.optim.Adam( + [{'params': [base_model.lm_head.weight], 'lr': h.head_lr, 'base_lr': h.head_lr}], + betas=(h.beta1, h.beta2), eps=h.adam_eps, fused=True) + self.optimizers.insert(1, self.optimizer_head) + else: + self.optimizer_head = None + + def __iter__(self): + return iter(self.optimizers) + + def zero_grad_all(self): + for opt in self.optimizers: + opt.zero_grad(set_to_none=True) + + def step(self): + for opt in self.optimizers: + opt.step() + self.zero_grad_all() + + +def restore_fp32_params(model): + for module in model.modules(): + if isinstance(module, CastedLinear): + module.float() + for name, param in model.named_parameters(): + if (param.ndim < 2 or any(pat in name for pat in CONTROL_TENSOR_NAME_PATTERNS)) \ + and param.dtype != torch.float32: + param.data = param.data.float() + + +# --------------------------------------------------------------------------- +# GPTQ Quantization +# --------------------------------------------------------------------------- +def collect_hessians(model, train_loader, h, device, n_calibration_batches=64): + hessians = {} + hooks = [] + + def make_hook(name): + def hook_fn(module, inp, out): + x = inp[0].detach().float() + if x.ndim == 3: + x = x.reshape(-1, x.shape[-1]) + if name not in hessians: + hessians[name] = torch.zeros(x.shape[1], x.shape[1], dtype=torch.float32, device=device) + hessians[name].addmm_(x.T, x) + return hook_fn + + for name, module in model.named_modules(): + if isinstance(module, CastedLinear) and module.weight.numel() > 65536: + cat = classify_param(name + '.weight') + if cat in ('mlp', 'attn'): + hooks.append(module.register_forward_hook(make_hook(name + '.weight'))) + + if model.tie_embeddings: + hook_module = model.head_proj if model.head_proj is not None else model.final_norm + def make_output_hook(name): + def hook_fn(module, inp, out): + x = out.detach().float() + if x.ndim == 3: + x = x.reshape(-1, x.shape[-1]) + if name not in hessians: + hessians[name] = torch.zeros(x.shape[1], x.shape[1], dtype=torch.float32, device=device) + hessians[name].addmm_(x.T, x) + return hook_fn + hooks.append(hook_module.register_forward_hook(make_output_hook('tok_emb.weight'))) + + model.eval() + with torch.no_grad(): + for _ in range(n_calibration_batches): + x, _ = train_loader.next_batch(h.train_batch_tokens, h.grad_accum_steps) + model.forward_logits(x) + for hook in hooks: + hook.remove() + for name in hessians: + hessians[name] = hessians[name].cpu() / n_calibration_batches + return hessians + + +def gptq_quantize_weight(w, H, clip_sigmas=3.0, clip_range=63, block_size=128): + W_orig = w.float().clone() + rows, cols = W_orig.shape + H = H.float().clone() + dead = torch.diag(H) == 0 + H[dead, dead] = 1 + damp = 0.01 * H.diag().mean() + H.diagonal().add_(damp) + perm = torch.argsort(H.diag(), descending=True) + invperm = torch.argsort(perm) + W_perm = W_orig[:, perm].clone() + W_perm[:, dead[perm]] = 0 + H = H[perm][:, perm] + Hinv = torch.cholesky_inverse(torch.linalg.cholesky(H)) + Hinv = torch.linalg.cholesky(Hinv, upper=True) + row_std = W_orig.std(dim=1) + s = (clip_sigmas * row_std / clip_range).clamp_min(1e-10).to(torch.float16) + sf = s.float() + Q = torch.zeros(rows, cols, dtype=torch.int8) + W_work = W_perm.clone() + for i1 in range(0, cols, block_size): + i2 = min(i1 + block_size, cols) + W_block = W_work[:, i1:i2].clone() + Hinv_block = Hinv[i1:i2, i1:i2] + Err = torch.zeros(rows, i2 - i1) + for j in range(i2 - i1): + w_col = W_block[:, j] + d = Hinv_block[j, j] + q_col = torch.clamp(torch.round(w_col / sf), -clip_range, clip_range) + Q[:, i1 + j] = q_col.to(torch.int8) + err = (w_col - q_col.float() * sf) / d + Err[:, j] = err + W_block[:, j:] -= err.unsqueeze(1) * Hinv_block[j, j:].unsqueeze(0) + if i2 < cols: + W_work[:, i2:] -= Err @ Hinv[i1:i2, i2:] + return Q[:, invperm], s + + +def gptq_mixed_quantize(state_dict, hessians, h): + result = {} + meta = {} + for name, tensor in state_dict.items(): + t = tensor.detach().cpu().contiguous() + if not t.is_floating_point() or t.numel() <= 65536: + result[name] = t.to(torch.float16) if t.is_floating_point() else t + meta[name] = 'passthrough (float16)' + continue + cs = h.embed_clip_sigmas if 'tok_emb' in name else h.matrix_clip_sigmas + bits = h.embed_bits if 'tok_emb' in name else h.matrix_bits + q, s = gptq_quantize_weight(t, hessians[name], clip_sigmas=cs, clip_range=2 ** (bits - 1) - 1) + result[name + '.q'] = q + result[name + '.scale'] = s + meta[name] = f'gptq (int{bits})' + categories = collections.defaultdict(set) + for name, cat in meta.items(): + short = re.sub(r'\.\d+$', '', re.sub(r'blocks\.\d+', 'blocks', name)) + categories[cat].add(short) + log('Quantized weights:') + for cat in sorted(categories): + log(f" {cat}: {', '.join(sorted(categories[cat]))}") + return result, meta + + +def dequantize_mixed(result, meta, template_sd): + out = {} + for name, orig in template_sd.items(): + info = meta.get(name) + if info is None: + continue + orig_dtype = orig.dtype + if 'passthrough' in info: + t = result[name] + if t.dtype == torch.float16 and orig_dtype in (torch.float32, torch.bfloat16): + t = t.to(orig_dtype) + out[name] = t + continue + q, s = result[name + '.q'], result[name + '.scale'] + if s.ndim > 0: + out[name] = (q.float() * s.float().view(q.shape[0], *[1] * (q.ndim - 1))).to(orig_dtype) + else: + out[name] = (q.float() * float(s.item())).to(orig_dtype) + return out + + +# --------------------------------------------------------------------------- +# Compression +# --------------------------------------------------------------------------- +_BSHF_MAGIC = b'BSHF' + + +def _byte_shuffle(data, stride=2): + if stride <= 1 or len(data) < stride: + return data + src = np.frombuffer(data, dtype=np.uint8) + n = len(src) + out = np.empty(n, dtype=np.uint8) + dest_off = 0 + for pos in range(stride): + chunk = src[pos::stride] + out[dest_off:dest_off + len(chunk)] = chunk + dest_off += len(chunk) + return _BSHF_MAGIC + bytes([stride]) + out.tobytes() + + +def _byte_unshuffle(data): + if len(data) < 5 or data[:4] != _BSHF_MAGIC: + return data + stride = data[4] + if stride < 2: + return data[5:] + payload = np.frombuffer(data, dtype=np.uint8, offset=5) + n = len(payload) + out = np.empty(n, dtype=np.uint8) + src_off = 0 + for pos in range(stride): + chunk_len = n // stride + (1 if pos < n % stride else 0) + out[pos::stride][:chunk_len] = payload[src_off:src_off + chunk_len] + src_off += chunk_len + return out.tobytes() + + +def _compress(data, compressor): + data = _byte_shuffle(data) + if compressor == 'lzma': + return lzma.compress(data, preset=6) + elif compressor == 'brotli': + try: + import brotli + except ImportError: + import subprocess, sys + for flags in (['-m', 'pip', 'install', 'brotli', '-q'], + ['-m', 'pip', 'install', 'brotli', '--break-system-packages', '-q'], + ['-m', 'pip', 'install', 'brotli', '--user', '-q']): + try: + subprocess.check_call([sys.executable] + flags) + break + except Exception: + continue + import brotli + return brotli.compress(data, quality=11) + raise ValueError(f"Unknown compressor: {compressor!r}") + + +def _decompress(data, compressor): + if compressor == 'lzma': + raw = lzma.decompress(data) + elif compressor == 'brotli': + try: + import brotli + except ImportError: + import subprocess, sys + for flags in (['-m', 'pip', 'install', 'brotli', '-q'], + ['-m', 'pip', 'install', 'brotli', '--break-system-packages', '-q'], + ['-m', 'pip', 'install', 'brotli', '--user', '-q']): + try: + subprocess.check_call([sys.executable] + flags) + break + except Exception: + continue + import brotli + raw = brotli.decompress(data) + else: + raise ValueError(f"Unknown compressor: {compressor!r}") + return _byte_unshuffle(raw) + + +def serialize(h, base_model, code): + code_bytes = len(code.encode('utf-8')) + if h.is_main_process: + torch.save(base_model.state_dict(), h.model_path) + model_bytes = os.path.getsize(h.model_path) + log(f"Serialized model: {model_bytes} bytes") + log(f"Code size: {code_bytes} bytes") + sd_cpu = {k: v.detach().cpu() for k, v in base_model.state_dict().items()} + device = torch.device('cuda', h.local_rank) + log('GPTQ: collecting Hessians from calibration data...') + t0 = time.perf_counter() + calib_loader = ShuffledSequenceLoader(h, device) + hessians = collect_hessians(base_model, calib_loader, h, device, + n_calibration_batches=h.gptq_calibration_batches) + log(f"GPTQ: collected {len(hessians)} Hessians in {time.perf_counter() - t0:.1f}s") + quant_result, quant_meta = gptq_mixed_quantize(sd_cpu, hessians, h) + quant_buf = io.BytesIO() + torch.save({'w': quant_result, 'm': quant_meta}, quant_buf) + quant_raw = quant_buf.getvalue() + quant_blob = _compress(quant_raw, h.compressor) + quant_file_bytes = len(quant_blob) + bytes_total = quant_file_bytes + code_bytes + if h.is_main_process: + with open(h.quantized_model_path, 'wb') as f: + f.write(quant_blob) + log(f"Serialized model quantized+{h.compressor}: {quant_file_bytes} bytes") + log(f"Total submission size quantized+{h.compressor}: {bytes_total} bytes") + return bytes_total, quant_file_bytes + + +def deserialize(h, device): + eval_model = GPT(h).to(device).bfloat16() + restore_fp32_params(eval_model) + sd_cpu = {k: v.detach().cpu() for k, v in eval_model.state_dict().items()} + with open(h.quantized_model_path, 'rb') as f: + quant_blob_disk = f.read() + quant_state = torch.load(io.BytesIO(_decompress(quant_blob_disk, h.compressor)), map_location='cpu') + deq_state = dequantize_mixed(quant_state['w'], quant_state['m'], sd_cpu) + eval_model.load_state_dict(deq_state, strict=True) + return eval_model + + +# --------------------------------------------------------------------------- +# Evaluation +# --------------------------------------------------------------------------- +def _loss_bpb(loss_sum, token_count, byte_count): + val_loss = (loss_sum / token_count).item() + val_bpb = val_loss / math.log(2.0) * (token_count.item() / byte_count.item()) + return val_loss, val_bpb + + +def eval_val(h, device, val_data, model): + seq_len = h.eval_seq_len + local_batch_tokens = h.val_batch_tokens // (h.world_size * h.grad_accum_steps) + local_batch_seqs = local_batch_tokens // seq_len + total_seqs = (val_data.val_tokens.numel() - 1) // seq_len + seq_start = total_seqs * h.rank // h.world_size + seq_end = total_seqs * (h.rank + 1) // h.world_size + val_loss_sum = torch.zeros((), device=device, dtype=torch.float64) + val_token_count = torch.zeros((), device=device, dtype=torch.float64) + val_byte_count = torch.zeros((), device=device, dtype=torch.float64) + model.eval() + with torch.inference_mode(): + for batch_seq_start in range(seq_start, seq_end, local_batch_seqs): + batch_seq_end = min(batch_seq_start + local_batch_seqs, seq_end) + raw_start = batch_seq_start * seq_len + raw_end = batch_seq_end * seq_len + 1 + local = val_data.val_tokens[raw_start:raw_end].to(device=device, dtype=torch.int64, non_blocking=True) + x = local[:-1].reshape(-1, seq_len) + y = local[1:].reshape(-1, seq_len) + with torch.autocast(device_type='cuda', dtype=torch.bfloat16, enabled=True): + batch_loss = model(x, y).detach() + batch_token_count = float(y.numel()) + val_loss_sum += batch_loss.to(torch.float64) * batch_token_count + val_token_count += batch_token_count + prev_ids = x.reshape(-1) + tgt_ids = y.reshape(-1) + token_bytes = val_data.base_bytes_lut[tgt_ids].to(dtype=torch.int16) + token_bytes += (val_data.has_leading_space_lut[tgt_ids] + & ~val_data.is_boundary_token_lut[prev_ids]).to(dtype=torch.int16) + val_byte_count += token_bytes.to(torch.float64).sum() + if dist.is_available() and dist.is_initialized(): + dist.all_reduce(val_loss_sum, op=dist.ReduceOp.SUM) + dist.all_reduce(val_token_count, op=dist.ReduceOp.SUM) + dist.all_reduce(val_byte_count, op=dist.ReduceOp.SUM) + model.train() + return _loss_bpb(val_loss_sum, val_token_count, val_byte_count) + + +def eval_val_sliding(h, device, val_data, base_model, batch_seqs=32): + base_model.eval() + logits_fn = torch.compile(base_model.forward_logits, dynamic=False, fullgraph=True) + seq_len = h.eval_seq_len + context_size = seq_len - h.eval_stride + total_tokens = val_data.val_tokens.numel() - 1 + window_starts = [ws for ws in range(0, total_tokens, h.eval_stride) if ws + context_size < total_tokens] + total_windows = len(window_starts) + my_s = total_windows * h.rank // h.world_size + my_e = total_windows * (h.rank + 1) // h.world_size + my_windows = window_starts[my_s:my_e] + loss_sum = torch.zeros((), device=device, dtype=torch.float64) + token_count = torch.zeros((), device=device, dtype=torch.float64) + byte_count = torch.zeros((), device=device, dtype=torch.float64) + with torch.inference_mode(): + for bi in range(0, len(my_windows), batch_seqs): + batch_ws = my_windows[bi:bi + batch_seqs] + bsz = len(batch_ws) + x_batch = torch.zeros(bsz, seq_len, dtype=torch.int64, device=device) + y_batch = torch.zeros(bsz, seq_len, dtype=torch.int64, device=device) + wlens = [] + for i, ws in enumerate(batch_ws): + we = min(ws + seq_len, total_tokens) + wlen = we - ws + wlens.append(wlen) + chunk = val_data.val_tokens[ws:we + 1].to(dtype=torch.int64, device=device) + x_batch[i, :wlen] = chunk[:-1] + y_batch[i, :wlen] = chunk[1:] + with torch.autocast(device_type='cuda', dtype=torch.bfloat16): + logits = logits_fn(x_batch) + nll = F.cross_entropy(logits.reshape(-1, logits.size(-1)).float(), + y_batch.reshape(-1), reduction='none').reshape(bsz, seq_len) + for i, ws in enumerate(batch_ws): + wlen = wlens[i] + s = 0 if ws == 0 else context_size + scored_nll = nll[i, s:wlen].to(torch.float64) + loss_sum += scored_nll.sum() + token_count += float(wlen - s) + tgt = y_batch[i, s:wlen] + prev = x_batch[i, s:wlen] + tb = val_data.base_bytes_lut[tgt].to(torch.float64) + tb += (val_data.has_leading_space_lut[tgt] & ~val_data.is_boundary_token_lut[prev]).to(torch.float64) + byte_count += tb.sum() + if dist.is_available() and dist.is_initialized(): + dist.all_reduce(loss_sum, op=dist.ReduceOp.SUM) + dist.all_reduce(token_count, op=dist.ReduceOp.SUM) + dist.all_reduce(byte_count, op=dist.ReduceOp.SUM) + base_model.train() + return _loss_bpb(loss_sum, token_count, byte_count) + + +def eval_val_ttt(h, device, val_data, base_model, batch_seqs=32): + """ + Score-first TTT with selective adaptation. + CHANGE 4: when ttt_selective=True, only adapt parameters inside the recurrent block + (layers loop_start..loop_end). This speeds up TTT and reduces overfitting + on the non-recurrent layers. + """ + rank = h.rank + world_size = h.world_size + seq_len = h.eval_seq_len + stride = h.eval_stride + total_tokens = val_data.val_tokens.numel() - 1 + ttt_chunk = h.ttt_chunk_tokens + context_size = seq_len - stride + window_starts = [ws for ws in range(0, total_tokens, stride) if ws + context_size < total_tokens] + num_chunks = (total_tokens + ttt_chunk - 1) // ttt_chunk + chunk_windows = [[] for _ in range(num_chunks)] + for ws in window_starts: + wlen = min(ws + seq_len, total_tokens) - ws + s = 0 if ws == 0 else context_size + scored_start = ws + s + ci = min(scored_start // ttt_chunk, num_chunks - 1) + chunk_windows[ci].append(ws) + + # *** CHANGE 4: Selective TTT params *** + if h.ttt_selective: + recurrent_layer_indices = set(range(h.loop_start, h.loop_end + 1)) + ttt_params = [] + for name, p in base_model.named_parameters(): + # Match "blocks.{i}." where i is in the recurrent range + m = re.match(r'blocks\.(\d+)\.', name) + if m and int(m.group(1)) in recurrent_layer_indices: + ttt_params.append(p) + # Also adapt skip_weights/gates (U-Net connections to/from recurrent region) + if base_model.skip_weights is not None: + ttt_params.append(base_model.skip_weights) + if base_model.skip_gates is not None: + ttt_params.append(base_model.skip_gates) + log(f"ttt: selective mode — adapting {len(ttt_params)} params tensors in layers {h.loop_start}-{h.loop_end}") + else: + ttt_params = list(base_model.parameters()) + log(f"ttt: full model — adapting {len(ttt_params)} param tensors") + + log(f"ttt: start chunks={num_chunks} ttt_lr={h.ttt_lr} ttt_epochs={h.ttt_epochs} " + f"chunk_tokens={ttt_chunk}") + compiled_logits = torch.compile(base_model.forward_logits, dynamic=False, fullgraph=True) + loss_sum = torch.zeros((), device=device, dtype=torch.float64) + token_count = torch.zeros((), device=device, dtype=torch.float64) + byte_count = torch.zeros((), device=device, dtype=torch.float64) + + for p in ttt_params: + p.requires_grad_(True) + optimizer = torch.optim.SGD(ttt_params, lr=h.ttt_lr, momentum=h.ttt_momentum) + + for ci in range(num_chunks): + windows = chunk_windows[ci] + if not windows: + continue + chunk_start = ci * ttt_chunk + chunk_end = min((ci + 1) * ttt_chunk, total_tokens) + my_s = len(windows) * rank // world_size + my_e = len(windows) * (rank + 1) // world_size + my_windows = windows[my_s:my_e] + # Phase 1: score (no grad) + base_model.eval() + with torch.no_grad(): + for bi in range(0, len(my_windows), batch_seqs): + batch_ws = my_windows[bi:bi + batch_seqs] + bsz = len(batch_ws) + x_batch = torch.zeros(bsz, seq_len, dtype=torch.int64, device=device) + y_batch = torch.zeros(bsz, seq_len, dtype=torch.int64, device=device) + wlens = [] + for i, ws in enumerate(batch_ws): + we = min(ws + seq_len, total_tokens) + wlen = we - ws + wlens.append(wlen) + chunk_tok = val_data.val_tokens[ws:we + 1].to(dtype=torch.int64, device=device) + x_batch[i, :wlen] = chunk_tok[:-1] + y_batch[i, :wlen] = chunk_tok[1:] + with torch.autocast(device_type='cuda', dtype=torch.bfloat16): + logits = compiled_logits(x_batch) + nll = F.cross_entropy(logits.reshape(-1, logits.size(-1)).float(), + y_batch.reshape(-1), reduction='none').reshape(bsz, seq_len) + for i, ws in enumerate(batch_ws): + wlen = wlens[i] + s = 0 if ws == 0 else context_size + scored_nll = nll[i, s:wlen].to(torch.float64) + loss_sum += scored_nll.sum() + token_count += float(wlen - s) + tgt = y_batch[i, s:wlen] + prev = x_batch[i, s:wlen] + tb = val_data.base_bytes_lut[tgt].to(torch.float64) + tb += (val_data.has_leading_space_lut[tgt] + & ~val_data.is_boundary_token_lut[prev]).to(torch.float64) + byte_count += tb.sum() + # Phase 2: adapt (skip last chunk) + is_last_chunk = ci == num_chunks - 1 + if not is_last_chunk and h.ttt_epochs > 0: + base_model.train() + chunk_seqs = (chunk_end - chunk_start) // seq_len + if chunk_seqs > 0: + cos_lr = h.ttt_lr * 0.5 * (1.0 + math.cos(math.pi * ci / max(num_chunks - 1, 1))) + for pg in optimizer.param_groups: + pg['lr'] = cos_lr + my_seq_s = chunk_seqs * rank // world_size + my_seq_e = chunk_seqs * (rank + 1) // world_size + my_chunk_seqs = my_seq_e - my_seq_s + for _ep in range(h.ttt_epochs): + for bs in range(0, my_chunk_seqs, batch_seqs): + be = min(bs + batch_seqs, my_chunk_seqs) + actual_bs = my_seq_s + bs + start_tok = chunk_start + actual_bs * seq_len + end_tok = chunk_start + (my_seq_s + be) * seq_len + 1 + if end_tok > val_data.val_tokens.numel(): + continue + local = val_data.val_tokens[start_tok:end_tok].to(device=device, dtype=torch.int64) + x = local[:-1].reshape(-1, seq_len) + y = local[1:].reshape(-1, seq_len) + optimizer.zero_grad(set_to_none=True) + with torch.autocast(device_type='cuda', dtype=torch.bfloat16): + loss = base_model(x, y) + loss.backward() + if world_size > 1: + for p in ttt_params: + if p.grad is not None: + dist.all_reduce(p.grad, op=dist.ReduceOp.AVG) + torch.nn.utils.clip_grad_norm_(ttt_params, 1.0) + optimizer.step() + + if dist.is_available() and dist.is_initialized(): + dist.all_reduce(loss_sum, op=dist.ReduceOp.SUM) + dist.all_reduce(token_count, op=dist.ReduceOp.SUM) + dist.all_reduce(byte_count, op=dist.ReduceOp.SUM) + for p in base_model.parameters(): + p.requires_grad_(True) + base_model.eval() + return _loss_bpb(loss_sum, token_count, byte_count) + + +def timed_eval(label, fn, *args, **kwargs): + torch.cuda.synchronize() + t0 = time.perf_counter() + val_loss, val_bpb = fn(*args, **kwargs) + torch.cuda.synchronize() + elapsed_ms = 1e3 * (time.perf_counter() - t0) + log(f"{label} val_loss:{val_loss:.8f} val_bpb:{val_bpb:.8f} eval_time:{elapsed_ms:.0f}ms") + return val_loss, val_bpb + + +# --------------------------------------------------------------------------- +# Training loop +# --------------------------------------------------------------------------- +def train_model(h, device, val_data): + base_model = GPT(h).to(device).bfloat16() + restore_fp32_params(base_model) + compiled_model = torch.compile(base_model, dynamic=False, fullgraph=True) + if h.distributed: + model = DDP(compiled_model, device_ids=[h.local_rank], broadcast_buffers=False) + else: + model = compiled_model + log(f"model_params: {sum(p.numel() for p in base_model.parameters())}") + optimizers = Optimizers(h, base_model) + train_loader = ShuffledSequenceLoader(h, device) + max_wallclock_ms = 1e3 * h.max_wallclock_seconds if h.max_wallclock_seconds > 0 else None + if max_wallclock_ms is not None: + max_wallclock_ms -= h.gptq_reserve_seconds * 1e3 + log(f"gptq: reserving {h.gptq_reserve_seconds:.0f}s, effective={max_wallclock_ms:.0f}ms") + + def training_frac(step, elapsed_ms): + if max_wallclock_ms is None: + return step / max(h.iterations, 1) + return elapsed_ms / max(max_wallclock_ms, 1e-9) + + def lr_mul(frac): + if h.warmdown_frac <= 0: + return 1.0 + if frac >= 1.0 - h.warmdown_frac: + return max((1.0 - frac) / h.warmdown_frac, h.min_lr) + return 1.0 + + def step_fn(step, lr_scale): + optimizers.zero_grad_all() + train_loss = torch.zeros((), device=device) + for micro_step in range(h.grad_accum_steps): + if h.distributed: + model.require_backward_grad_sync = micro_step == h.grad_accum_steps - 1 + x, y = train_loader.next_batch(h.train_batch_tokens, h.grad_accum_steps) + with torch.autocast(device_type='cuda', dtype=torch.bfloat16, enabled=True): + loss = model(x, y) + train_loss += loss.detach() + (loss / h.grad_accum_steps).backward() + train_loss /= h.grad_accum_steps + frac = min(step / h.muon_momentum_warmup_steps, 1.0) if h.muon_momentum_warmup_steps > 0 else 1.0 + muon_momentum = (1 - frac) * h.muon_momentum_warmup_start + frac * h.muon_momentum + for group in optimizers.optimizer_muon.param_groups: + group['momentum'] = muon_momentum + for opt in optimizers: + for group in opt.param_groups: + group['lr'] = group['base_lr'] * lr_scale + if h.grad_clip_norm > 0: + torch.nn.utils.clip_grad_norm_(base_model.parameters(), h.grad_clip_norm) + optimizers.step() + return train_loss + + # Warmup phase + if h.warmup_steps > 0: + initial_model_state = {name: t.detach().cpu().clone() for name, t in base_model.state_dict().items()} + initial_optimizer_states = [copy.deepcopy(opt.state_dict()) for opt in optimizers] + model.train() + for warmup_step in range(h.warmup_steps): + step_fn(warmup_step, 1.0) + if warmup_step <= 5 or (warmup_step + 1) % 10 == 0 or warmup_step + 1 == h.warmup_steps: + log(f"warmup_step: {warmup_step + 1}/{h.warmup_steps}") + if h.num_loops > 0: + base_model.looping_active = True + log(f"loop_warmup: enabled encoder:{base_model.encoder_indices} decoder:{base_model.decoder_indices}") + for warmup_step in range(h.warmup_steps): + step_fn(warmup_step, 1.0) + if warmup_step <= 5 or (warmup_step + 1) % 10 == 0 or warmup_step + 1 == h.warmup_steps: + log(f"loop_warmup_step: {warmup_step + 1}/{h.warmup_steps}") + base_model.looping_active = False + base_model.load_state_dict(initial_model_state, strict=True) + for opt, state in zip(optimizers, initial_optimizer_states, strict=True): + opt.load_state_dict(state) + optimizers.zero_grad_all() + if h.distributed: + model.require_backward_grad_sync = True + train_loader = ShuffledSequenceLoader(h, device) + + # EMA + main training loop + ema_state = {name: t.detach().float().clone() for name, t in base_model.state_dict().items()} + ema_decay = h.ema_decay + training_time_ms = 0.0 + stop_after_step = None + torch.cuda.synchronize() + t0 = time.perf_counter() + step = 0 + + while True: + last_step = (step == h.iterations or (stop_after_step is not None and step >= stop_after_step)) + should_validate = last_step or (h.val_loss_every > 0 and step % h.val_loss_every == 0) + if should_validate: + torch.cuda.synchronize() + training_time_ms += 1e3 * (time.perf_counter() - t0) + val_loss, val_bpb = eval_val(h, device, val_data, model) + log(f"{step}/{h.iterations} val_loss: {val_loss:.4f} val_bpb: {val_bpb:.4f}") + torch.cuda.synchronize() + t0 = time.perf_counter() + if last_step: + if stop_after_step is not None and step < h.iterations: + log(f"stopping_early: wallclock_cap train_time: {training_time_ms:.0f}ms step: {step}/{h.iterations}") + break + elapsed_ms = training_time_ms + 1e3 * (time.perf_counter() - t0) + frac = training_frac(step, elapsed_ms) + scale = lr_mul(frac) + if h.num_loops > 0 and not base_model.looping_active and frac >= h.enable_looping_at: + base_model.looping_active = True + log(f"layer_loop: enabled step:{step} frac:{frac:.3f} " + f"encoder:{base_model.encoder_indices} decoder:{base_model.decoder_indices}") + train_loss = step_fn(step, scale) + with torch.no_grad(): + for name, t in base_model.state_dict().items(): + ema_state[name].mul_(ema_decay).add_(t.detach().float(), alpha=1.0 - ema_decay) + step += 1 + approx_training_time_ms = training_time_ms + 1e3 * (time.perf_counter() - t0) + should_log_train = h.train_log_every > 0 and ( + step <= 5 or step % h.train_log_every == 0 or stop_after_step is not None) + if should_log_train: + tok_per_sec = step * h.train_batch_tokens / (approx_training_time_ms / 1e3) + log(f"{step}/{h.iterations} train_loss: {train_loss.item():.4f} " + f"train_time: {approx_training_time_ms / 60000:.1f}m tok/s: {tok_per_sec:.0f}") + reached_cap = max_wallclock_ms is not None and approx_training_time_ms >= max_wallclock_ms + if h.distributed and max_wallclock_ms is not None: + reached_cap_tensor = torch.tensor(int(reached_cap), device=device) + dist.all_reduce(reached_cap_tensor, op=dist.ReduceOp.MAX) + reached_cap = bool(reached_cap_tensor.item()) + if stop_after_step is None and reached_cap: + stop_after_step = step + + log(f"peak memory allocated: {torch.cuda.max_memory_allocated() // 1024 // 1024} MiB " + f"reserved: {torch.cuda.max_memory_reserved() // 1024 // 1024} MiB") + log('ema: applying EMA weights') + current_state = base_model.state_dict() + avg_state = {name: t.to(dtype=current_state[name].dtype) for name, t in ema_state.items()} + base_model.load_state_dict(avg_state, strict=True) + return base_model, compiled_model + + +def train_and_eval(h, device): + random.seed(h.seed) + np.random.seed(h.seed) + torch.manual_seed(h.seed) + torch.cuda.manual_seed_all(h.seed) + val_data = ValidationData(h, device) + log(f"train_shards: {len(list(Path(h.datasets_dir).resolve().glob('fineweb_train_*.bin')))}") + log(f"val_tokens: {val_data.val_tokens.numel() - 1}") + base_model, compiled_model = train_model(h, device, val_data) + torch._dynamo.reset() + timed_eval('pre-quantization post-ema', eval_val, h, device, val_data, compiled_model) + serialize(h, base_model, Path(__file__).read_text(encoding='utf-8')) + if h.distributed: + dist.barrier() + eval_model = deserialize(h, device) + if h.num_loops > 0: + eval_model.looping_active = True + compiled_model = torch.compile(eval_model, dynamic=False, fullgraph=True) + timed_eval('quantized', eval_val, h, device, val_data, compiled_model) + if h.sliding_window_enabled: + timed_eval('quantized_sliding_window', eval_val_sliding, h, device, val_data, eval_model) + if h.ttt_enabled and h.sliding_window_enabled: + del eval_model, compiled_model + torch._dynamo.reset() + torch.cuda.empty_cache() + ttt_model = deserialize(h, device) + if h.num_loops > 0: + ttt_model.looping_active = True + timed_eval('quantized_ttt', eval_val_ttt, h, device, val_data, ttt_model) + del ttt_model + + +def main(): + world_size = int(os.environ.get('WORLD_SIZE', '1')) + local_rank = int(os.environ.get('LOCAL_RANK', '0')) + distributed = 'RANK' in os.environ and 'WORLD_SIZE' in os.environ + if not torch.cuda.is_available(): + raise RuntimeError('CUDA is required') + if world_size <= 0: + raise ValueError(f"WORLD_SIZE must be positive, got {world_size}") + if 8 % world_size != 0: + raise ValueError(f"WORLD_SIZE={world_size} must divide 8") + device = torch.device('cuda', local_rank) + torch.cuda.set_device(device) + if distributed: + dist.init_process_group(backend='nccl', device_id=device) + dist.barrier() + torch.backends.cuda.matmul.allow_tf32 = True + torch.backends.cudnn.allow_tf32 = True + torch.set_float32_matmul_precision('high') + from torch.backends.cuda import (enable_cudnn_sdp, enable_flash_sdp, + enable_math_sdp, enable_mem_efficient_sdp) + enable_cudnn_sdp(False) + enable_flash_sdp(True) + enable_mem_efficient_sdp(False) + enable_math_sdp(False) + torch._dynamo.config.optimize_ddp = False + h = Hyperparameters() + set_logging_hparams(h) + if h.is_main_process: + os.makedirs('logs', exist_ok=True) + log('=' * 100, console=False) + log('Hyperparameters:') + for k, v in sorted(vars(type(h)).items()): + if not k.startswith('_'): + log(f" {k}: {v}") + log('=' * 100, console=False) + log(f"Running Python {sys.version}", console=False) + log(f"Running PyTorch {torch.__version__}", console=False) + log(subprocess.run(['nvidia-smi'], stdout=subprocess.PIPE, stderr=subprocess.PIPE, + text=True, check=False).stdout, console=False) + train_and_eval(h, device) + if distributed: + dist.destroy_process_group() + + +if __name__ == '__main__': + main() diff --git a/records/track_10min_16mb/2026-04-23_QK55_4Loop_EarlyParResid_SelectiveTTT/train_log_baseline.log b/records/track_10min_16mb/2026-04-23_QK55_4Loop_EarlyParResid_SelectiveTTT/train_log_baseline.log new file mode 100644 index 0000000000..e58caa6721 --- /dev/null +++ b/records/track_10min_16mb/2026-04-23_QK55_4Loop_EarlyParResid_SelectiveTTT/train_log_baseline.log @@ -0,0 +1,64 @@ +logs/gauntlet_baseline_012113.txt +val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path=./data/tokenizers/fineweb_1024_bpe.model +train_loader:dataset:fineweb10B_sp1024 train_shards:5 +val_loader:shards pattern=./data/datasets/fineweb10B_sp1024/fineweb_val_*.bin tokens:62021632 +model_params:17059912 +world_size:1 grad_accum_steps:8 +sdp_backends:cudnn=False flash=True mem_efficient=False math=False +attention_mode:gqa num_heads:8 num_kv_heads:4 +tie_embeddings:True embed_lr:0.05 head_lr:0.0 matrix_lr:0.04 scalar_lr:0.04 +train_batch_tokens:524288 train_seq_len:1024 iterations:2000 warmup_steps:20 max_wallclock_seconds:0.000 +seed:1337 +warmup_step:1/20 +warmup_step:2/20 +warmup_step:3/20 +warmup_step:4/20 +warmup_step:5/20 +warmup_step:6/20 +warmup_step:7/20 +warmup_step:8/20 +warmup_step:9/20 +warmup_step:10/20 +warmup_step:11/20 +warmup_step:12/20 +warmup_step:13/20 +warmup_step:14/20 +warmup_step:15/20 +warmup_step:16/20 +warmup_step:17/20 +warmup_step:18/20 +warmup_step:19/20 +warmup_step:20/20 +step:0/2000 val_loss:6.9357 val_bpb:4.1077 train_time:0ms step_avg:0.01ms +step:1/2000 train_loss:6.9357 train_time:471ms step_avg:471.35ms +step:2/2000 train_loss:16.7414 train_time:977ms step_avg:488.32ms +step:3/2000 train_loss:8.7524 train_time:1494ms step_avg:497.98ms +step:4/2000 train_loss:6.5885 train_time:2021ms step_avg:505.21ms +step:5/2000 train_loss:6.6522 train_time:2572ms step_avg:514.48ms +step:6/2000 train_loss:6.6999 train_time:3075ms step_avg:512.47ms +step:7/2000 train_loss:6.3335 train_time:3603ms step_avg:514.65ms +step:8/2000 train_loss:6.1811 train_time:4162ms step_avg:520.30ms +step:9/2000 train_loss:6.0837 train_time:4645ms step_avg:516.09ms +step:10/2000 train_loss:5.9956 train_time:5172ms step_avg:517.15ms +step:200/2000 train_loss:2.7943 train_time:105274ms step_avg:526.37ms +step:400/2000 train_loss:2.3993 train_time:210577ms step_avg:526.44ms +step:500/2000 val_loss:2.5003 val_bpb:1.4808 train_time:263172ms step_avg:526.34ms +step:600/2000 train_loss:2.4985 train_time:316018ms step_avg:526.70ms +step:800/2000 train_loss:2.3520 train_time:421852ms step_avg:527.31ms +step:1000/2000 train_loss:2.4717 train_time:527119ms step_avg:527.12ms +step:1000/2000 val_loss:2.3214 val_bpb:1.3748 train_time:527120ms step_avg:527.12ms +step:1200/2000 train_loss:2.2456 train_time:632752ms step_avg:527.29ms +step:1400/2000 train_loss:2.1673 train_time:739270ms step_avg:528.05ms +step:1500/2000 val_loss:2.2396 val_bpb:1.3264 train_time:791889ms step_avg:527.93ms +step:1600/2000 train_loss:2.1944 train_time:845076ms step_avg:528.17ms +step:1800/2000 train_loss:2.1528 train_time:955409ms step_avg:530.78ms +step:2000/2000 train_loss:2.2258 train_time:1066704ms step_avg:533.35ms +step:2000/2000 val_loss:2.1889 val_bpb:1.2964 train_time:1066705ms step_avg:533.35ms +peak memory allocated: 10238 MiB reserved: 10730 MiB +Serialized model: 67224983 bytes +Code size: 47686 bytes +Total submission size: 67272669 bytes +Serialized model int8+zlib: 14961007 bytes (payload:17178912 raw_torch:17224025 payload_ratio:3.91x) +Total submission size int8+zlib: 15008693 bytes +final_int8_zlib_roundtrip val_loss:2.1912 val_bpb:1.2977 eval_time:11181ms +final_int8_zlib_roundtrip_exact val_loss:2.19117971 val_bpb:1.29774005 diff --git a/records/track_10min_16mb/2026-04-23_QK55_4Loop_EarlyParResid_SelectiveTTT/train_log_seed1337.log b/records/track_10min_16mb/2026-04-23_QK55_4Loop_EarlyParResid_SelectiveTTT/train_log_seed1337.log new file mode 100644 index 0000000000..b965e907ad --- /dev/null +++ b/records/track_10min_16mb/2026-04-23_QK55_4Loop_EarlyParResid_SelectiveTTT/train_log_seed1337.log @@ -0,0 +1,176 @@ +Hyperparameters: + adam_eps: 1e-08 + adam_wd: 0.02 + beta1: 0.9 + beta2: 0.95 + compressor: brotli + data_dir: /workspace/parameter-golf/data + datasets_dir: /workspace/parameter-golf/data/datasets/fineweb10B_sp1024 + distributed: True + ema_decay: 0.9965 + embed_bits: 8 + embed_clip_sigmas: 20.0 + embed_lr: 0.6 + embed_wd: 0.085 + embedding_dim: 512 + enable_looping_at: 0.35 + eval_seq_len: 2048 + eval_stride: 64 + gptq_calibration_batches: 64 + gptq_reserve_seconds: 12.0 + grad_accum_steps: 8 + grad_clip_norm: 0.3 + head_lr: 0.008 + is_main_process: True + iterations: 2000 + ln_scale: True + local_rank: 0 + logfile: logs/gauntlet_incr_QK55_4loop_015227.txt + logit_softcap: 30.0 + loop_end: 5 + loop_start: 3 + matrix_bits: 6 + matrix_clip_sigmas: 12.85 + matrix_lr: 0.022 + max_wallclock_seconds: 0.0 + min_lr: 0.0 + mlp_mult: 4.0 + model_dim: 512 + model_path: final_model.pt + muon_backend_steps: 5 + muon_beta2: 0.95 + muon_momentum: 0.99 + muon_momentum_warmup_start: 0.92 + muon_momentum_warmup_steps: 1500 + muon_row_normalize: True + muon_wd: 0.095 + num_heads: 8 + num_kv_heads: 4 + num_layers: 11 + num_loops: 3 + parallel_residual_start: 5 + qk_gain_init: 5.5 + quantized_model_path: final_model.int6.ptz + rank: 0 + rope_base: 10000.0 + rope_dims: 16 + rope_train_seq_len: 2048 + run_id: gauntlet_incr_QK55_4loop_015227 + scalar_lr: 0.02 + seed: 1337 + skip_gates_enabled: True + sliding_window_enabled: True + tie_embeddings: True + tied_embed_init_std: 0.005 + tied_embed_lr: 0.03 + tokenizer_path: /workspace/parameter-golf/data/tokenizers/fineweb_1024_bpe.model + train_batch_tokens: 786432 + train_files: /workspace/parameter-golf/data/datasets/fineweb10B_sp1024/fineweb_train_*.bin + train_log_every: 500 + train_seq_len: 2048 + ttt_chunk_tokens: 24576 + ttt_enabled: False + ttt_epochs: 4 + ttt_lr: 0.005 + ttt_momentum: 0.9 + ttt_selective: True + val_batch_tokens: 524288 + val_files: /workspace/parameter-golf/data/datasets/fineweb10B_sp1024/fineweb_val_*.bin + val_loss_every: 500 + vocab_size: 1024 + warmdown_frac: 0.72 + warmup_steps: 20 + world_size: 1 +train_shards: 5 +val_tokens: 62021632 +model_params: 32276568 +warmup_step: 1/20 +warmup_step: 2/20 +warmup_step: 3/20 +warmup_step: 4/20 +warmup_step: 5/20 +warmup_step: 6/20 +warmup_step: 10/20 +warmup_step: 20/20 +loop_warmup: enabled encoder:[0, 1, 2, 3, 4, 5, 3, 4, 5, 3] decoder:[4, 5, 3, 4, 5, 6, 7, 8, 9, 10] +loop_warmup_step: 1/20 +loop_warmup_step: 2/20 +loop_warmup_step: 3/20 +loop_warmup_step: 4/20 +loop_warmup_step: 5/20 +loop_warmup_step: 6/20 +loop_warmup_step: 10/20 +loop_warmup_step: 20/20 +0/2000 val_loss: 6.9384 val_bpb: 4.1093 +1/2000 train_loss: 6.9345 train_time: 0.0m tok/s: 1200803 +2/2000 train_loss: 12.1271 train_time: 0.0m tok/s: 1189039 +3/2000 train_loss: 9.9963 train_time: 0.0m tok/s: 1178249 +4/2000 train_loss: 7.8728 train_time: 0.0m tok/s: 1174070 +5/2000 train_loss: 6.5889 train_time: 0.1m tok/s: 1171437 +500/2000 train_loss: 2.3018 train_time: 5.7m tok/s: 1146850 +500/2000 val_loss: 2.3087 val_bpb: 1.3673 +layer_loop: enabled step:700 frac:0.350 encoder:[0, 1, 2, 3, 4, 5, 3, 4, 5, 3] decoder:[4, 5, 3, 4, 5, 6, 7, 8, 9, 10] +1000/2000 train_loss: 2.1383 train_time: 14.1m tok/s: 927822 +1000/2000 val_loss: 2.1413 val_bpb: 1.2682 +1500/2000 train_loss: 2.0394 train_time: 24.3m tok/s: 808192 +1500/2000 val_loss: 2.0580 val_bpb: 1.2188 +2000/2000 train_loss: 1.9188 train_time: 34.5m tok/s: 758935 +2000/2000 val_loss: 1.9713 val_bpb: 1.1675 +peak memory allocated: 39808 MiB reserved: 39916 MiB +ema: applying EMA weights +pre-quantization post-ema val_loss:1.97817906 val_bpb:1.17158907 eval_time:33143ms +Serialized model: 128099193 bytes +Code size: 65269 bytes +GPTQ: collecting Hessians from calibration data... +GPTQ: collected 67 Hessians in 17.3s +Quantized weights: + gptq (int6): blocks.attn.c_k.weight, blocks.attn.c_q.weight, blocks.attn.c_v.weight, blocks.attn.proj.weight, blocks.mlp.fc.weight, blocks.mlp.proj.weight + gptq (int8): tok_emb.weight + passthrough (float16): blocks.attn.q_gain, blocks.attn_scale, blocks.mlp_scale, blocks.resid_mix, skip_gates, skip_weights +[rank0]: Traceback (most recent call last): +[rank0]: File "/workspace/parameter-golf/records/track_10min_16mb/2026-04-23_QK55_4Loop_EarlyParResid_SelectiveTTT/train_gpt.py", line 1385, in +[rank0]: main() +[rank0]: File "/workspace/parameter-golf/records/track_10min_16mb/2026-04-23_QK55_4Loop_EarlyParResid_SelectiveTTT/train_gpt.py", line 1379, in main +[rank0]: train_and_eval(h, device) +[rank0]: File "/workspace/parameter-golf/records/track_10min_16mb/2026-04-23_QK55_4Loop_EarlyParResid_SelectiveTTT/train_gpt.py", line 1319, in train_and_eval +[rank0]: serialize(h, base_model, Path(__file__).read_text(encoding='utf-8')) +[rank0]: File "/workspace/parameter-golf/records/track_10min_16mb/2026-04-23_QK55_4Loop_EarlyParResid_SelectiveTTT/train_gpt.py", line 899, in serialize +[rank0]: quant_blob = _compress(quant_raw, h.compressor) +[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ +[rank0]: File "/workspace/parameter-golf/records/track_10min_16mb/2026-04-23_QK55_4Loop_EarlyParResid_SelectiveTTT/train_gpt.py", line 864, in _compress +[rank0]: import brotli +[rank0]: ModuleNotFoundError: No module named 'brotli' +[rank0]:[W423 03:24:44.766043241 ProcessGroupNCCL.cpp:1524] Warning: WARNING: destroy_process_group() was not called before program exit, which can leak resources. For more info, please see https://pytorch.org/docs/stable/distributed.html#shutdown (function operator()) +E0423 03:24:46.722000 8738 torch/distributed/elastic/multiprocessing/api.py:882] failed (exitcode: 1) local_rank: 0 (pid: 8807) of binary: /usr/local/bin/python +Traceback (most recent call last): + File "/usr/local/bin/torchrun", line 7, in + sys.exit(main()) + ^^^^^^ + File "/usr/local/lib/python3.12/dist-packages/torch/distributed/elastic/multiprocessing/errors/__init__.py", line 357, in wrapper + return f(*args, **kwargs) + ^^^^^^^^^^^^^^^^^^ + File "/usr/local/lib/python3.12/dist-packages/torch/distributed/run.py", line 936, in main + run(args) + File "/usr/local/lib/python3.12/dist-packages/torch/distributed/run.py", line 927, in run + elastic_launch( + File "/usr/local/lib/python3.12/dist-packages/torch/distributed/launcher/api.py", line 156, in __call__ + return launch_agent(self._config, self._entrypoint, list(args)) + ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + File "/usr/local/lib/python3.12/dist-packages/torch/distributed/launcher/api.py", line 293, in launch_agent + raise ChildFailedError( +torch.distributed.elastic.multiprocessing.errors.ChildFailedError: +============================================================ +/workspace/parameter-golf/records/track_10min_16mb/2026-04-23_QK55_4Loop_EarlyParResid_SelectiveTTT/train_gpt.py FAILED +------------------------------------------------------------ +Failures: + +------------------------------------------------------------ +Root Cause (first observed failure): +[0]: + time : 2026-04-23_03:24:46 + host : abfc9671a5d1 + rank : 0 (local_rank: 0) + exitcode : 1 (pid: 8807) + error_file: + traceback : To enable traceback see: https://pytorch.org/docs/stable/elastic/errors.html +============================================================ diff --git a/runpod_setup.sh b/runpod_setup.sh new file mode 100755 index 0000000000..1b096a0c3e --- /dev/null +++ b/runpod_setup.sh @@ -0,0 +1,51 @@ +#!/usr/bin/env bash +# ============================================================ +# RunPod setup for OpenAI Parameter Golf +# Run this once after SSH into your RunPod pod. +# ============================================================ +set -e + +REPO_DIR="/workspace/parameter-golf" +VARIANT="${VARIANT:-sp1024}" +TRAIN_SHARDS="${TRAIN_SHARDS:-10}" + +echo "==> Cloning repo..." +if [ ! -d "$REPO_DIR" ]; then + cd /workspace + git clone https://github.com/openai/parameter-golf.git +fi +cd "$REPO_DIR" + +echo "==> Pulling latest..." +git pull + +echo "==> Installing Python deps..." +pip install -r requirements.txt -q + +echo "==> Downloading FineWeb dataset (variant=$VARIANT, shards=$TRAIN_SHARDS)..." +python3 data/cached_challenge_fineweb.py --variant "$VARIANT" --train-shards "$TRAIN_SHARDS" + +echo "" +echo "==> Setup complete. To launch training:" +echo "" +echo " # 1xH100 baseline:" +echo " RUN_ID=baseline_sp1024 \\" +echo " DATA_PATH=./data/datasets/fineweb10B_sp1024/ \\" +echo " TOKENIZER_PATH=./data/tokenizers/fineweb_1024_bpe.model \\" +echo " VOCAB_SIZE=1024 \\" +echo " torchrun --standalone --nproc_per_node=1 train_gpt.py" +echo "" +echo " # 8xH100 (full leaderboard run):" +echo " RUN_ID=sota_8xh100 \\" +echo " DATA_PATH=./data/datasets/fineweb10B_sp1024/ \\" +echo " TOKENIZER_PATH=./data/tokenizers/fineweb_1024_bpe.model \\" +echo " VOCAB_SIZE=1024 \\" +echo " torchrun --standalone --nproc_per_node=8 train_gpt.py" +echo "" +echo " # Run the current SOTA config (best model: 1.0810 bpb):" +echo " cd records/track_10min_16mb/2026-04-09_SP8192_3LayerRecur_ParResid_QK525_LegalTTT" +echo " RUN_ID=sota_repro \\" +echo " DATA_PATH=../../../data/datasets/fineweb10B_sp8192/ \\" +echo " TOKENIZER_PATH=../../../data/tokenizers/fineweb_8192_spm.model \\" +echo " VOCAB_SIZE=8192 \\" +echo " torchrun --standalone --nproc_per_node=8 train_gpt.py" diff --git a/sync_to_runpod.sh b/sync_to_runpod.sh new file mode 100755 index 0000000000..9d9e2b4507 --- /dev/null +++ b/sync_to_runpod.sh @@ -0,0 +1,51 @@ +#!/usr/bin/env bash +# ============================================================ +# Sync local changes to a RunPod pod via rsync over SSH. +# Usage: RUNPOD_HOST=user@host:port ./sync_to_runpod.sh [--watch] +# +# Set RUNPOD_HOST in your shell or .env before running. +# Example: +# export RUNPOD_HOST="root@213.34.12.XX" +# export RUNPOD_PORT="22204" +# ./sync_to_runpod.sh --watch +# ============================================================ +set -e + +HOST="${RUNPOD_HOST:?Set RUNPOD_HOST=user@ip}" +PORT="${RUNPOD_PORT:-22}" +REMOTE_DIR="${REMOTE_DIR:-/workspace/parameter-golf}" +LOCAL_DIR="$(cd "$(dirname "$0")" && pwd)" + +RSYNC_OPTS=( + -avz --progress + --exclude=".venv" + --exclude=".git" + --exclude="__pycache__" + --exclude="*.pyc" + --exclude="data/datasets" + --exclude="data/tokenizers" + -e "ssh -p $PORT" +) + +do_sync() { + rsync "${RSYNC_OPTS[@]}" "$LOCAL_DIR/" "$HOST:$REMOTE_DIR/" + echo "[$(date '+%H:%M:%S')] Synced to $HOST:$REMOTE_DIR" +} + +if [ "${1}" = "--watch" ]; then + echo "Watching for changes (Ctrl-C to stop)..." + do_sync + # Use inotifywait if available, else poll + if command -v inotifywait &>/dev/null; then + while inotifywait -r -e modify,create,delete,move \ + --exclude '\.git|\.venv|__pycache__|data/datasets|data/tokenizers' \ + "$LOCAL_DIR" 2>/dev/null; do + sleep 0.5 + do_sync + done + else + while sleep 5; do do_sync; done + fi +else + do_sync +fi