diff --git a/records/track_10min_16mb/2026-04-05_ProgressiveDepth_HedgeMixer/README.md b/records/track_10min_16mb/2026-04-05_ProgressiveDepth_HedgeMixer/README.md new file mode 100644 index 0000000000..eb37f73624 --- /dev/null +++ b/records/track_10min_16mb/2026-04-05_ProgressiveDepth_HedgeMixer/README.md @@ -0,0 +1,110 @@ +# Progressive Depth + Hedge Mixer (Depth Recurrence) + +**val_bpb: 1.1441** (3-seed mean, std 0.0051) | **~15.88 MB** | 8×H100 SXM + +## Results (8×H100 80GB SXM, PyTorch 2.5.1) + +| Seed | Steps | Step avg | Roundtrip bpb | Sliding bpb | **Hedge bpb** | Eval time | +|------|-------|----------|---------------|-------------|---------------|-----------| +| 1337 | 5,668 | 105.8ms | 1.2302 | 1.1965 | **1.1441** | 580s | +| 42 | 5,170 | 116.1ms | 1.2298 | 1.1962 | **1.1491** | 580s | +| 7 | 5,405 | 111.0ms | 1.2286 | 1.1952 | **1.1390** | 587s | +| **Mean** | **5,414** | **111.0ms** | **1.2295** | **1.1960** | **1.1441 (std 0.0051)** | **~582s** | + +## Architecture: Depth Recurrence + +Instead of 9-11 unique transformer blocks, we use **3 shared blocks repeated 4 times** (12 effective layers). This trades unique parameters for effective depth, fitting more computation into the 16MB budget. + +``` +3 blocks × 4 repeats = 12 effective layers, 17.14M params +``` + +### Key components + +- **Cross-Repeat Skip**: Each block receives a weighted residual from its own output in the previous repeat, turning stateless recurrence into stateful. Per-repeat learned scales. +- **Loop Embedding**: Learned per-layer vector added before each block — depth-wise positional encoding for shared weights. +- **Value Embeddings**: 2 extra embedding tables mixed into the residual stream at each effective layer with learned scales. +- **XSA (Exclusive Self-Attention)**: On last 4 effective layers — prevents attention collapse in deep recurrent models. +- **LeakyReLU(0.5)²**: Better gradient flow than ReLU² for deep/recurrent models. + +### Model config + +| Parameter | Value | +|-----------|-------| +| Layers × Repeats | 3 × 4 (12 effective) | +| Model dim | 832 | +| Heads / KV heads | 8 / 4 | +| MLP multiplier | 2× | +| Vocab size | 1024 (SP BPE) | +| Logit softcap | 30.0 | + +## Key Innovation: Progressive Depth Training + +Unique to shared-weight architectures — train with increasing recurrence depth over time: + +| Phase | Time fraction | Repeats | Step speed | +|-------|--------------|---------|------------| +| Phase 1 | 0–40% | 2 | ~80ms | +| Phase 2 | 40–65% | 3 | ~90ms | +| Phase 3 | 65–100% | 4 | ~105ms | + +This gives **+30% more training steps** compared to training at full depth the entire time (5,414 vs ~4,300 steps). Early phases are cheaper because fewer repeats = faster forward/backward pass. The model learns basic representations quickly at shallow depth, then refines with full recurrence. + +`torch._dynamo.reset()` + recompile on phase transitions (~10s × 2 = 20s overhead). + +Controlled by env var: `PROG_DEPTH="0.4:2,0.65:3,1.0:4"` + +## Eval: Hedge Mixer (5-Expert Online Ensemble) + +Eval-time improvement via online mixture of 5 experts using the Hedge (multiplicative weights) algorithm: + +| Expert | Description | +|--------|-------------| +| Neural | Model's own logits (log-softmax) | +| Unigram | Global token frequency with Laplace smoothing | +| Bigram | Conditional P(token | prev_token) | +| Trigram | Hashed trigram context (65K buckets) | +| Entropy | Model's own entropy as calibration signal | + +The mixer processes validation windows sequentially, updating n-gram statistics and expert weights after scoring each window. Initial bias toward the neural expert (log_weight = 2.0). Learning rate η = 0.1. + +**Hedge provides −0.052 bpb improvement** over sliding window eval (1.1960 → 1.1441 mean). + +### Timing budget + +| Phase | Time | +|-------|------| +| Training (10 min cap) | 600s | +| Roundtrip eval | ~14s | +| Sliding window eval | ~67s | +| Hedge Mixer eval | ~582s | + +## Training details + +- **Optimizer**: Muon (matrix params) + Adam (scalars, embeddings) +- **LR**: matrix 0.012, scalar 0.012, tied_embed 0.015 +- **Muon WD**: 0.04 +- **Warmdown**: 3000 steps (wallclock-proportional) +- **SWA**: During warmdown, every 50 steps, 13-16 checkpoints averaged +- **Grad clip**: 0.3 +- **Quantization**: int8 + zstd-22 (~15.88 MB artifact) + +## Evolution & Prior PRs + +This submission is the result of iterative development across several PRs in this repo: + +| PR | Date | Score | What changed | +|----|------|-------|-------------| +| [#148](https://github.com/openai/parameter-golf/pull/148) | Mar 20 | 1.2196 | Depth recurrence (3×4), cross-repeat skip, value embeddings, sliding window eval | +| [#784](https://github.com/openai/parameter-golf/pull/784) | Mar 25 | 1.2065 | + XSA(4), LeakyReLU², GPTQ-lite, zstd-22 | +| [#835](https://github.com/openai/parameter-golf/pull/835) | Mar 26 | 1.1980 | + Progressive depth training (+30% steps) | +| [#856](https://github.com/openai/parameter-golf/pull/856) | Mar 26 | 1.1454 | + Hedge Mixer (5-expert eval-time ensemble) | +| **This PR** | Apr 5 | **1.1441** | Clean submission with 3-seed validation | + +This PR supersedes the above with a clean diff and proper 3-seed statistical validation. + +## Lineage + +- Depth recurrence architecture is original to this submission line +- XSA from PR #198 (unnir), LeakyReLU² from PR #493 (parinzee) +- SWA and Muon WD from modded-nanogpt community diff --git a/records/track_10min_16mb/2026-04-05_ProgressiveDepth_HedgeMixer/submission.json b/records/track_10min_16mb/2026-04-05_ProgressiveDepth_HedgeMixer/submission.json new file mode 100644 index 0000000000..0052ae5d01 --- /dev/null +++ b/records/track_10min_16mb/2026-04-05_ProgressiveDepth_HedgeMixer/submission.json @@ -0,0 +1,19 @@ +{ + "author": "Ivan Verbovoy", + "github_id": "iverbovoy", + "name": "Progressive Depth + Hedge Mixer (3x4 depth recurrence, 5-expert online ensemble)", + "blurb": "3 shared blocks x 4 repeats (12 effective layers) with progressive depth scheduling (2→3→4 repeats), XSA, LeakyReLU², Cross-Repeat Skip, SWA, int8+zstd22. Eval: 5-expert Hedge Mixer (neural + unigram + bigram + trigram + entropy) with online multiplicative weight updates. Mean over 3 seeds.", + "date": "2026-04-05T00:00:00Z", + "val_loss": 1.93171750, + "val_bpb": 1.14407142, + "roundtrip_val_loss": 2.07601138, + "roundtrip_val_bpb": 1.22953088, + "sliding_val_loss": 2.01933888, + "sliding_val_bpb": 1.19596573, + "seeds": [1337, 42, 7], + "mean_steps": 5414, + "wallclock_seconds": 600, + "eval_seconds": 580, + "bytes_model_int8_zstd22": 15818418, + "bytes_code": 65854 +} diff --git a/records/track_10min_16mb/2026-04-05_ProgressiveDepth_HedgeMixer/train_gpt.py b/records/track_10min_16mb/2026-04-05_ProgressiveDepth_HedgeMixer/train_gpt.py new file mode 100644 index 0000000000..1738288f39 --- /dev/null +++ b/records/track_10min_16mb/2026-04-05_ProgressiveDepth_HedgeMixer/train_gpt.py @@ -0,0 +1,1498 @@ +""" +The `train_gpt.py` and `train_gpt_mlx.py` scripts are intended as good launching-off points for new participants, not SOTA configs. We'll accept PRs that tune, improve, or simplify these scripts without significantly increasing complexity, but competitive submissions should stay in the `/records` folder. + +Hard stop: `train_gpt.py` and `train_gpt_mlx.py` must never be longer than 1500 lines. +""" + +from __future__ import annotations + +import copy +import glob +import io +import math +import os +import random +import subprocess +import sys +import time +import uuid +import zstandard as zstd +from pathlib import Path + +import numpy as np +import sentencepiece as spm +import torch +import torch.distributed as dist +import torch.nn.functional as F +from torch import Tensor, nn +from torch.nn.parallel import DistributedDataParallel as DDP + + +class HedgeMixer: + """Online mixture of 5 experts via Hedge algorithm for eval-time improvement. + Experts: Neural, Unigram, Bigram, Trigram (hashed), Entropy.""" + def __init__(self, vocab_size: int = 1024, device: str = "cuda", eta: float = 0.1): + self.V = vocab_size + self.device = device + self.eta = eta + self.log_weights = torch.zeros(5, device=device) + self.log_weights[0] = 2.0 # bias toward neural + self.uni_counts = torch.zeros(vocab_size, device=device) + self.bi_counts = torch.zeros(vocab_size, vocab_size, device=device) + self.total_tokens = 0 + self.TRI_HASH = 65536 + self.tri_counts = torch.zeros(self.TRI_HASH, vocab_size, device=device) + self.tri_row_totals = torch.zeros(self.TRI_HASH, device=device) + + def update(self, tokens: Tensor) -> None: + t = tokens.to(self.device).long() + n = t.numel() + if n == 0: + return + self.total_tokens += n + ones = torch.ones(n, device=self.device) + self.uni_counts.scatter_add_(0, t, ones) + if n >= 2: + bi_idx = t[:-1] * self.V + t[1:] + self.bi_counts.reshape(-1).scatter_add_(0, bi_idx, torch.ones(n - 1, device=self.device)) + if n >= 3: + tri_ctx = ((t[:-2] * 36313) ^ (t[1:-1] * 27191)) % self.TRI_HASH + tri_idx = tri_ctx * self.V + t[2:] + ones_tri = torch.ones(n - 2, device=self.device) + self.tri_counts.reshape(-1).scatter_add_(0, tri_idx, ones_tri) + self.tri_row_totals.scatter_add_(0, tri_ctx, ones_tri) + + def mix_and_score(self, neural_logits: Tensor, x_batch: Tensor, y_batch: Tensor, wlens: list[int]) -> Tensor: + bsz, slen, V = neural_logits.shape + uniform_nll = math.log(self.V) + has_data = self.total_tokens > 0 + neural_lp = F.log_softmax(neural_logits, dim=-1) + neural_nll = -neural_lp.gather(2, y_batch.unsqueeze(2)).squeeze(2) + if not has_data or self.total_tokens < 10000: + return neural_nll + uni_probs = (self.uni_counts + 0.1) / (self.total_tokens + 0.1 * self.V) + uni_nll = -uni_probs.log()[y_batch] + bi_total = self.bi_counts.sum(dim=1, keepdim=True) + bi_probs = (self.bi_counts + 0.1) / (bi_total + 0.1 * self.V) + bi_nll = -bi_probs.log()[x_batch.reshape(-1), y_batch.reshape(-1)].reshape(bsz, slen) + if slen >= 2: + prev2 = torch.zeros_like(x_batch) + prev2[:, 1:] = x_batch[:, :-1] + ctx_hash = ((prev2 * 36313) ^ (x_batch * 27191)) % self.TRI_HASH + tri_count = self.tri_counts[ctx_hash.reshape(-1).long(), y_batch.reshape(-1).long()] + tri_total = self.tri_row_totals[ctx_hash.reshape(-1).long()].clamp(min=1) + tri_nll = -(((tri_count + 0.01) / (tri_total + 0.01 * self.V)).log()).reshape(bsz, slen) + else: + tri_nll = torch.full((bsz, slen), uniform_nll, device=self.device) + entropy_nll = -(neural_lp.exp() * neural_lp).sum(-1) + expert_nll = torch.stack([neural_nll, uni_nll, bi_nll, tri_nll, entropy_nll], dim=-1) + log_w = self.log_weights - self.log_weights.logsumexp(0) + mixed_nll = -(-expert_nll + log_w.unsqueeze(0).unsqueeze(0)).logsumexp(dim=-1) + # Update weights + wlens_t = torch.tensor(wlens, device=self.device, dtype=torch.long) + mask = torch.arange(slen, device=self.device).unsqueeze(0) < wlens_t.unsqueeze(1) + masked_nll = expert_nll * mask.unsqueeze(-1).float() + expert_mean_loss = masked_nll.sum(dim=(0, 1)) / mask.sum().clamp(min=1) + self.log_weights -= self.eta * expert_mean_loss + return mixed_nll + + +# HYPERPARAMETERS + +class Hyperparameters: + # Data paths are shard globs produced by the existing preprocessing pipeline. + data_path = os.environ.get("DATA_PATH", "./data/datasets/fineweb10B_sp1024") + train_files = os.path.join(data_path, "fineweb_train_*.bin") + val_files = os.path.join(data_path, "fineweb_val_*.bin") + tokenizer_path = os.environ.get("TOKENIZER_PATH", "./data/tokenizers/fineweb_1024_bpe.model") + run_id = os.environ.get("RUN_ID", str(uuid.uuid4())) + seed = int(os.environ.get("SEED", 1337)) + + # Validation cadence and batch size. Validation always uses the full fineweb_val split. + val_batch_size = int(os.environ.get("VAL_BATCH_SIZE", 524_288)) + val_loss_every = int(os.environ.get("VAL_LOSS_EVERY", 1000)) + train_log_every = int(os.environ.get("TRAIN_LOG_EVERY", 200)) + + # Training length. + iterations = int(os.environ.get("ITERATIONS", 20000)) + warmdown_iters = int(os.environ.get("WARMDOWN_ITERS", 2000)) + warmup_steps = int(os.environ.get("WARMUP_STEPS", 20)) + train_batch_tokens = int(os.environ.get("TRAIN_BATCH_TOKENS", 524_288)) + train_seq_len = int(os.environ.get("TRAIN_SEQ_LEN", 1024)) + max_wallclock_seconds = float(os.environ.get("MAX_WALLCLOCK_SECONDS", 600.0)) + qk_gain_init = float(os.environ.get("QK_GAIN_INIT", 1.5)) + grad_clip_norm = float(os.environ.get("GRAD_CLIP_NORM", 0.3)) + + # Progressive Depth: train with fewer repeats early (faster), more repeats later (deeper). + # Schedule format: "frac1:rep1,frac2:rep2,..." e.g. "0.4:2,0.65:3,1.0:4" + prog_depth_schedule = os.environ.get("PROG_DEPTH", "0.4:2,0.65:3,1.0:4") + + # XSA (Exclusive Self-Attention) on last N effective layers. + xsa_last_n = int(os.environ.get("XSA_LAST_N", 4)) + + # SWA (Stochastic Weight Averaging) during warmdown. + swa_enabled = bool(int(os.environ.get("SWA_ENABLED", "1"))) + swa_start_frac = float(os.environ.get("SWA_START_FRAC", 0.4)) + swa_every = int(os.environ.get("SWA_EVERY", 50)) + + # Sliding window eval. + eval_seq_len = int(os.environ.get("EVAL_SEQ_LEN", 1024)) + eval_stride = int(os.environ.get("EVAL_STRIDE", 256)) + eval_batch_seqs = int(os.environ.get("EVAL_BATCH_SEQS", 32)) + + # Hedge Mixer (eval-time n-gram ensemble). + use_hedge = bool(int(os.environ.get("USE_HEDGE", "1"))) + hedge_eta = float(os.environ.get("HEDGE_ETA", 0.1)) + + # Model shape. + vocab_size = int(os.environ.get("VOCAB_SIZE", 1024)) + num_layers = int(os.environ.get("NUM_LAYERS", 3)) + num_repeats = int(os.environ.get("NUM_REPEATS", 4)) + num_kv_heads = int(os.environ.get("NUM_KV_HEADS", 4)) + model_dim = int(os.environ.get("MODEL_DIM", 832)) + num_heads = int(os.environ.get("NUM_HEADS", 8)) + mlp_mult = int(os.environ.get("MLP_MULT", 2)) + num_value_embeds = int(os.environ.get("NUM_VALUE_EMBEDS", 2)) + tie_embeddings = bool(int(os.environ.get("TIE_EMBEDDINGS", "1"))) + rope_base = float(os.environ.get("ROPE_BASE", 10000.0)) + logit_softcap = float(os.environ.get("LOGIT_SOFTCAP", 30.0)) + + # Optimizer hyperparameters. + embed_lr = float(os.environ.get("EMBED_LR", 0.6)) + head_lr = float(os.environ.get("HEAD_LR", 0.008)) + tied_embed_lr = float(os.environ.get("TIED_EMBED_LR", 0.021)) + tied_embed_init_std = float(os.environ.get("TIED_EMBED_INIT_STD", 0.005)) + matrix_lr = float(os.environ.get("MATRIX_LR", 0.018)) + scalar_lr = float(os.environ.get("SCALAR_LR", 0.018)) + muon_momentum = float(os.environ.get("MUON_MOMENTUM", 0.95)) + muon_backend_steps = int(os.environ.get("MUON_BACKEND_STEPS", 5)) + muon_momentum_warmup_start = float(os.environ.get("MUON_MOMENTUM_WARMUP_START", 0.85)) + muon_momentum_warmup_steps = int(os.environ.get("MUON_MOMENTUM_WARMUP_STEPS", 500)) + muon_wd = float(os.environ.get("MUON_WD", 0.04)) + beta1 = float(os.environ.get("BETA1", 0.9)) + beta2 = float(os.environ.get("BETA2", 0.95)) + adam_eps = float(os.environ.get("ADAM_EPS", 1e-8)) + +# MUON OPTIMIZER +# +# As borrowed from modded-nanogpt +# Background on Muon: https://kellerjordan.github.io/posts/muon/ + +def zeropower_via_newtonschulz5(G: Tensor, steps: int = 10, eps: float = 1e-7) -> Tensor: + # Orthogonalize a 2D update matrix with a fast Newton-Schulz iteration. + # Muon uses this to normalize matrix-shaped gradients before applying them. + a, b, c = (3.4445, -4.7750, 2.0315) + X = G.bfloat16() + X /= X.norm() + eps + transposed = G.size(0) > G.size(1) + if transposed: + X = X.T + for _ in range(steps): + A = X @ X.T + B = b * A + c * A @ A + X = a * X + B @ X + return X.T if transposed else X + + +class Muon(torch.optim.Optimizer): + def __init__(self, params, lr: float, momentum: float, backend_steps: int, nesterov: bool = True, weight_decay: float = 0.0): + super().__init__( + params, + dict(lr=lr, momentum=momentum, backend_steps=backend_steps, nesterov=nesterov, weight_decay=weight_decay), + ) + + @torch.no_grad() + def step(self, closure=None): + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + + distributed = dist.is_available() and dist.is_initialized() + world_size = dist.get_world_size() if distributed else 1 + rank = dist.get_rank() if distributed else 0 + + for group in self.param_groups: + params = group["params"] + if not params: + continue + lr = group["lr"] + momentum = group["momentum"] + backend_steps = group["backend_steps"] + nesterov = group["nesterov"] + + total_params = sum(int(p.numel()) for p in params) + updates_flat = torch.zeros(total_params, device=params[0].device, dtype=torch.bfloat16) + + curr = 0 + for i, p in enumerate(params): + if i % world_size == rank and p.grad is not None: + g = p.grad + state = self.state[p] + if "momentum_buffer" not in state: + state["momentum_buffer"] = torch.zeros_like(g) + buf = state["momentum_buffer"] + buf.mul_(momentum).add_(g) + if nesterov: + g = g.add(buf, alpha=momentum) + g = zeropower_via_newtonschulz5(g, steps=backend_steps) + # Scale correction from Muon reference implementations. + g *= max(1, g.size(0) / g.size(1)) ** 0.5 + updates_flat[curr : curr + p.numel()] = g.reshape(-1) + curr += p.numel() + + if distributed: + dist.all_reduce(updates_flat, op=dist.ReduceOp.SUM) + + wd = group.get("weight_decay", 0.0) + curr = 0 + for p in params: + g = updates_flat[curr : curr + p.numel()].view_as(p).to(dtype=p.dtype) + if wd > 0: + p.data.mul_(1.0 - lr * wd) + p.add_(g, alpha=-lr) + curr += p.numel() + + return loss + + +# TOKENIZER-AGNOSTIC EVALUATION SETUP +# +# It's common for small models have a large fraction of their parameters be embeddings, since the 2 * d_model * d_vocab vectors can be gigantic. +# Instead of locking the tokenizer, we let you bring your own and calculate our validation metrics on the average compression of the validation set. +# We calculate BPB (bits-per-byte) instead of validation loss, so we need methods to count the number of bits per token in the tokenizer. +# Note: Submissions that edit the tokenizer will be examined more carefully, since screwing this up might unjustly improve your score. + +def build_sentencepiece_luts( + sp: spm.SentencePieceProcessor, vocab_size: int, device: torch.device +) -> tuple[Tensor, Tensor, Tensor]: + sp_vocab_size = int(sp.vocab_size()) + table_size = max(sp_vocab_size, vocab_size) + base_bytes_np = np.zeros((table_size,), dtype=np.int16) + has_leading_space_np = np.zeros((table_size,), dtype=np.bool_) + is_boundary_token_np = np.ones((table_size,), dtype=np.bool_) + for token_id in range(sp_vocab_size): + if sp.is_control(token_id) or sp.is_unknown(token_id) or sp.is_unused(token_id): + continue + is_boundary_token_np[token_id] = False + if sp.is_byte(token_id): + base_bytes_np[token_id] = 1 + continue + piece = sp.id_to_piece(token_id) + if piece.startswith("▁"): + has_leading_space_np[token_id] = True + piece = piece[1:] + base_bytes_np[token_id] = len(piece.encode("utf-8")) + return ( + torch.tensor(base_bytes_np, dtype=torch.int16, device=device), + torch.tensor(has_leading_space_np, dtype=torch.bool, device=device), + torch.tensor(is_boundary_token_np, dtype=torch.bool, device=device), + ) + + +def load_validation_tokens(pattern: str, seq_len: int) -> Tensor: + files = [Path(p) for p in sorted(glob.glob(pattern))] + if not files: + raise FileNotFoundError(f"No files found for pattern: {pattern}") + # The export pipeline writes the fixed first-50k-doc validation set to fineweb_val_*. + tokens = torch.cat([load_data_shard(file) for file in files]).contiguous() + usable = ((tokens.numel() - 1) // seq_len) * seq_len + if usable <= 0: + raise ValueError(f"Validation split is too short for TRAIN_SEQ_LEN={seq_len}") + return tokens[: usable + 1] + + +def eval_val( + args: Hyperparameters, + model: nn.Module, + rank: int, + world_size: int, + device: torch.device, + grad_accum_steps: int, + val_tokens: Tensor, + base_bytes_lut: Tensor, + has_leading_space_lut: Tensor, + is_boundary_token_lut: Tensor, +) -> tuple[float, float]: + # Validation computes two metrics: + # - val_loss: token cross-entropy (natural log) + # - val_bpb: tokenizer-agnostic compression metric used by the challenge + local_batch_tokens = args.val_batch_size // (world_size * grad_accum_steps) + if local_batch_tokens < args.train_seq_len: + raise ValueError( + "VAL_BATCH_SIZE must provide at least one sequence per rank; " + f"got VAL_BATCH_SIZE={args.val_batch_size}, WORLD_SIZE={world_size}, " + f"GRAD_ACCUM_STEPS={grad_accum_steps}, TRAIN_SEQ_LEN={args.train_seq_len}" + ) + local_batch_seqs = local_batch_tokens // args.train_seq_len + total_seqs = (val_tokens.numel() - 1) // args.train_seq_len + seq_start = (total_seqs * rank) // world_size + seq_end = (total_seqs * (rank + 1)) // world_size + val_loss_sum = torch.zeros((), device=device, dtype=torch.float64) + val_token_count = torch.zeros((), device=device, dtype=torch.float64) + val_byte_count = torch.zeros((), device=device, dtype=torch.float64) + + model.eval() + with torch.inference_mode(): + for batch_seq_start in range(seq_start, seq_end, local_batch_seqs): + batch_seq_end = min(batch_seq_start + local_batch_seqs, seq_end) + raw_start = batch_seq_start * args.train_seq_len + raw_end = batch_seq_end * args.train_seq_len + 1 + local = val_tokens[raw_start:raw_end].to(device=device, dtype=torch.int64, non_blocking=True) + x = local[:-1].reshape(-1, args.train_seq_len) + y = local[1:].reshape(-1, args.train_seq_len) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + batch_loss = model(x, y).detach() + batch_token_count = float(y.numel()) + val_loss_sum += batch_loss.to(torch.float64) * batch_token_count + val_token_count += batch_token_count + prev_ids = x.reshape(-1) + tgt_ids = y.reshape(-1) + token_bytes = base_bytes_lut[tgt_ids].to(dtype=torch.int16) + token_bytes += (has_leading_space_lut[tgt_ids] & ~is_boundary_token_lut[prev_ids]).to(dtype=torch.int16) + val_byte_count += token_bytes.to(torch.float64).sum() + + if dist.is_available() and dist.is_initialized(): + dist.all_reduce(val_loss_sum, op=dist.ReduceOp.SUM) + dist.all_reduce(val_token_count, op=dist.ReduceOp.SUM) + dist.all_reduce(val_byte_count, op=dist.ReduceOp.SUM) + + val_loss = val_loss_sum / val_token_count + bits_per_token = val_loss.item() / math.log(2.0) + tokens_per_byte = val_token_count.item() / val_byte_count.item() + model.train() + return float(val_loss.item()), float(bits_per_token * tokens_per_byte) + + +def eval_val_sliding( + args: Hyperparameters, + base_model: nn.Module, + rank: int, + world_size: int, + device: torch.device, + val_tokens: Tensor, + base_bytes_lut: Tensor, + has_leading_space_lut: Tensor, + is_boundary_token_lut: Tensor, + use_hedge: bool = False, + hedge_eta: float = 0.1, +) -> tuple[float, float]: + """Sliding window eval with batching. Windows of train_seq_len advance by eval_stride. + Only the last stride tokens per window are scored (first window scores all). + Optional Hedge Mixer: online n-gram ensemble over scored tokens.""" + seq_len = args.eval_seq_len + stride = args.eval_stride + batch_seqs = args.eval_batch_seqs + total_tokens = val_tokens.numel() - 1 + + window_starts = [ws for ws in range(0, total_tokens, stride) + if min(ws + seq_len, total_tokens) - ws >= 1] + + # With Hedge Mixer: process ALL windows on each rank (sequential, n-gram tables need full context) + # Without: distribute windows across ranks + if use_hedge: + my_windows = window_starts + else: + total_windows = len(window_starts) + my_s = (total_windows * rank) // world_size + my_e = (total_windows * (rank + 1)) // world_size + my_windows = window_starts[my_s:my_e] + + mixer = HedgeMixer(vocab_size=args.vocab_size, device=device, eta=hedge_eta) if use_hedge else None + + val_loss_sum = torch.zeros((), device=device, dtype=torch.float64) + val_token_count = torch.zeros((), device=device, dtype=torch.float64) + val_byte_count = torch.zeros((), device=device, dtype=torch.float64) + + base_model.eval() + with torch.inference_mode(): + for bi in range(0, len(my_windows), batch_seqs): + batch_ws = my_windows[bi : bi + batch_seqs] + bsz = len(batch_ws) + x_batch = torch.zeros(bsz, seq_len, dtype=torch.int64, device=device) + y_batch = torch.zeros(bsz, seq_len, dtype=torch.int64, device=device) + wlens: list[int] = [] + + for i, ws in enumerate(batch_ws): + end = min(ws + seq_len, total_tokens) + wlen = end - ws + wlens.append(wlen) + chunk = val_tokens[ws : end + 1].to(dtype=torch.int64, device=device) + x_batch[i, :wlen] = chunk[:-1] + y_batch[i, :wlen] = chunk[1:] + + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + logits = base_model.forward_logits(x_batch) + + if mixer is not None: + nll = mixer.mix_and_score(logits.float(), x_batch, y_batch, wlens) + else: + nll = F.cross_entropy( + logits.reshape(-1, logits.size(-1)).float(), + y_batch.reshape(-1), + reduction="none", + ).reshape(bsz, seq_len) + + for i, ws in enumerate(batch_ws): + wlen = wlens[i] + s = 0 if ws == 0 else max(wlen - stride, 0) + scored_nll = nll[i, s:wlen].to(torch.float64) + val_loss_sum += scored_nll.sum() + val_token_count += float(wlen - s) + prev_ids = x_batch[i, s:wlen] + tgt_ids = y_batch[i, s:wlen] + token_bytes = base_bytes_lut[tgt_ids].to(dtype=torch.int16) + token_bytes += (has_leading_space_lut[tgt_ids] & ~is_boundary_token_lut[prev_ids]).to(dtype=torch.int16) + val_byte_count += token_bytes.to(torch.float64).sum() + + # Update n-gram tables with scored tokens + if mixer is not None: + for i, ws in enumerate(batch_ws): + wlen = wlens[i] + s = 0 if ws == 0 else max(wlen - stride, 0) + mixer.update(y_batch[i, s:wlen]) + + if not use_hedge and dist.is_available() and dist.is_initialized(): + dist.all_reduce(val_loss_sum, op=dist.ReduceOp.SUM) + dist.all_reduce(val_token_count, op=dist.ReduceOp.SUM) + dist.all_reduce(val_byte_count, op=dist.ReduceOp.SUM) + + val_loss = val_loss_sum / val_token_count + bits_per_token = val_loss.item() / math.log(2.0) + tokens_per_byte = val_token_count.item() / val_byte_count.item() + base_model.train() + return float(val_loss.item()), float(bits_per_token * tokens_per_byte) + + + +# POST-TRAINING QUANTIZATION +# +# It's silly to export our model, which is trained in bf16 and fp32, at that same precision. +# Instead, we get approximately the same model (with a small hit) by quantizing and zstd compressing. + +CONTROL_TENSOR_NAME_PATTERNS = tuple( + pattern + for pattern in os.environ.get( + "CONTROL_TENSOR_NAME_PATTERNS", + "attn_scale,attn_scales,mlp_scale,mlp_scales,resid_mix,resid_mixes,q_gain,skip_weight,skip_weights", + ).split(",") + if pattern +) +INT8_KEEP_FLOAT_FP32_NAME_PATTERNS = tuple( + pattern + for pattern in os.environ.get( + "INT8_KEEP_FLOAT_FP32_NAME_PATTERNS", + ",".join(CONTROL_TENSOR_NAME_PATTERNS), + ).split(",") + if pattern +) +INT8_KEEP_FLOAT_MAX_NUMEL = 65_536 +INT8_KEEP_FLOAT_STORE_DTYPE = torch.float16 +INT8_PER_ROW_SCALE_DTYPE = torch.float16 +INT8_CLIP_PERCENTILE = 99.99984 +INT8_CLIP_Q = INT8_CLIP_PERCENTILE / 100.0 +# Quantization levels: 127 = int8, 31 = int6, 16 = int5. Per-tensor override via MLP_QUANT_LEVELS. +QUANT_LEVELS = int(os.environ.get("QUANT_LEVELS", 127)) +MLP_QUANT_LEVELS = int(os.environ.get("MLP_QUANT_LEVELS", 0)) # 0 = same as QUANT_LEVELS +MLP_TENSOR_PATTERNS = ("mlp.fc.", "mlp.proj.", "fc.weight", "mlp.proj.weight") + +def tensor_nbytes(t: Tensor) -> int: + return int(t.numel()) * int(t.element_size()) + +def keep_float_tensor(name: str, t: Tensor, passthrough_orig_dtypes: dict[str, str]) -> Tensor: + if any(pattern in name for pattern in INT8_KEEP_FLOAT_FP32_NAME_PATTERNS): + return t.float().contiguous() + if t.dtype in {torch.float32, torch.bfloat16}: + passthrough_orig_dtypes[name] = str(t.dtype).removeprefix("torch.") + return t.to(dtype=INT8_KEEP_FLOAT_STORE_DTYPE).contiguous() + return t + +GPTQ_LITE_PERCENTILES = [0.9999, 0.99995, 0.99999, 0.999995, 0.999999] + +def quantize_float_tensor(t: Tensor, ql: int = 0) -> tuple[Tensor, Tensor]: + if ql <= 0: + ql = QUANT_LEVELS + t32 = t.float() + if t32.ndim == 2: + # GPTQ-lite: try multiple clip percentiles per row, pick best MSE + abs_t = t32.abs() + best_q = None + best_scale = None + best_mse = None + for pct in GPTQ_LITE_PERCENTILES: + clip_abs = ( + torch.quantile(abs_t, pct, dim=1) + if t32.numel() + else torch.empty((t32.shape[0],), dtype=torch.float32) + ) + clipped = torch.maximum(torch.minimum(t32, clip_abs[:, None]), -clip_abs[:, None]) + s = (clip_abs / ql).clamp_min(1e-12) + q = torch.clamp(torch.round(clipped / s[:, None]), -ql, ql) + # Reconstruction error per row + recon = q * s[:, None] + mse = (t32 - recon).square().sum(dim=1) + if best_mse is None: + best_mse = mse + best_q = q + best_scale = s + else: + better = mse < best_mse + best_mse = torch.where(better, mse, best_mse) + best_q = torch.where(better[:, None], q, best_q) + best_scale = torch.where(better, s, best_scale) + return best_q.to(torch.int8).contiguous(), best_scale.to(dtype=INT8_PER_ROW_SCALE_DTYPE).contiguous() + + clip_abs = float(torch.quantile(t32.abs().flatten(), INT8_CLIP_Q).item()) if t32.numel() else 0.0 + scale = torch.tensor(clip_abs / ql if clip_abs > 0 else 1.0, dtype=torch.float32) + q = torch.clamp(torch.round(torch.clamp(t32, -clip_abs, clip_abs) / scale), -ql, ql).to(torch.int8).contiguous() + return q, scale + +def quantize_state_dict_int8(state_dict: dict[str, Tensor]): + # Single supported clean-script export format: + # - per-row int8 for 2D float tensors + # - per-tensor int8 for other float tensors + # - exact passthrough for non-floats + # - passthrough for small float tensors, stored as fp16 to save bytes + quantized: dict[str, Tensor] = {} + scales: dict[str, Tensor] = {} + dtypes: dict[str, str] = {} + passthrough: dict[str, Tensor] = {} + passthrough_orig_dtypes: dict[str, str] = {} + qmeta: dict[str, dict[str, object]] = {} + stats = dict.fromkeys( + ("param_count", "num_tensors", "num_float_tensors", "num_nonfloat_tensors", "baseline_tensor_bytes", "int8_payload_bytes"), + 0, + ) + + for name, tensor in state_dict.items(): + t = tensor.detach().to("cpu").contiguous() + stats["param_count"] += int(t.numel()) + stats["num_tensors"] += 1 + stats["baseline_tensor_bytes"] += tensor_nbytes(t) + + if not t.is_floating_point(): + stats["num_nonfloat_tensors"] += 1 + passthrough[name] = t + stats["int8_payload_bytes"] += tensor_nbytes(t) + continue + + # Small float tensors are cheap enough to keep directly. We still downcast + # fp32/bf16 passthrough tensors to fp16 so metadata does not dominate size. + if t.numel() <= INT8_KEEP_FLOAT_MAX_NUMEL: + kept = keep_float_tensor(name, t, passthrough_orig_dtypes) + passthrough[name] = kept + stats["int8_payload_bytes"] += tensor_nbytes(kept) + continue + + stats["num_float_tensors"] += 1 + mlp_ql = MLP_QUANT_LEVELS if MLP_QUANT_LEVELS > 0 else QUANT_LEVELS + ql = mlp_ql if any(p in name for p in MLP_TENSOR_PATTERNS) else QUANT_LEVELS + q, s = quantize_float_tensor(t, ql=ql) + meta: dict[str, object] = {} + if s.ndim > 0: + meta["scheme"] = "per_row" + meta["axis"] = 0 + if ql != QUANT_LEVELS: + meta["ql"] = ql + if meta: + qmeta[name] = meta + quantized[name] = q + scales[name] = s + dtypes[name] = str(t.dtype).removeprefix("torch.") + stats["int8_payload_bytes"] += tensor_nbytes(q) + tensor_nbytes(s) + + obj: dict[str, object] = { + "__quant_format__": "int8_clean_per_row_v1", + "quantized": quantized, + "scales": scales, + "dtypes": dtypes, + "passthrough": passthrough, + } + if qmeta: + obj["qmeta"] = qmeta + if passthrough_orig_dtypes: + obj["passthrough_orig_dtypes"] = passthrough_orig_dtypes + return obj, stats + +def dequantize_state_dict_int8(obj: dict[str, object]) -> dict[str, Tensor]: + out: dict[str, Tensor] = {} + qmeta = obj.get("qmeta", {}) + passthrough_orig_dtypes = obj.get("passthrough_orig_dtypes", {}) + for name, q in obj["quantized"].items(): + dtype = getattr(torch, obj["dtypes"][name]) + s = obj["scales"][name] + if qmeta.get(name, {}).get("scheme") == "per_row" or s.ndim > 0: + s = s.to(dtype=torch.float32) + # Broadcast the saved row scale back across trailing dimensions. + out[name] = (q.float() * s.view(q.shape[0], *([1] * (q.ndim - 1)))).to(dtype=dtype).contiguous() + else: + scale = float(s.item()) + out[name] = (q.float() * scale).to(dtype=dtype).contiguous() + for name, t in obj["passthrough"].items(): + # Restore small tensors, undoing the temporary fp16 storage cast if needed. + out_t = t.detach().to("cpu").contiguous() + orig_dtype = passthrough_orig_dtypes.get(name) + if isinstance(orig_dtype, str): + out_t = out_t.to(dtype=getattr(torch, orig_dtype)).contiguous() + out[name] = out_t + return out + + +# DATA LOADING + +def load_data_shard(file: Path) -> Tensor: + header_bytes = 256 * np.dtype(" None: + self.file_idx = (self.file_idx + 1) % len(self.files) + self.tokens = load_data_shard(self.files[self.file_idx]) + self.pos = 0 + + def take(self, n: int) -> Tensor: + chunks: list[Tensor] = [] + remaining = n + while remaining > 0: + avail = self.tokens.numel() - self.pos + if avail <= 0: + self._advance_file() + continue + k = min(remaining, avail) + chunks.append(self.tokens[self.pos : self.pos + k]) + self.pos += k + remaining -= k + return chunks[0] if len(chunks) == 1 else torch.cat(chunks) + + +class DistributedTokenLoader: + # Each call consumes a contiguous chunk from the shared token stream, then slices out + # one disjoint span per rank. The extra "+1" token lets us build (x, y) by shifting. + def __init__(self, pattern: str, rank: int, world_size: int, device: torch.device): + self.rank = rank + self.world_size = world_size + self.device = device + self.stream = TokenStream(pattern) + + def next_batch(self, global_tokens: int, seq_len: int, grad_accum_steps: int) -> tuple[Tensor, Tensor]: + local_tokens = global_tokens // (self.world_size * grad_accum_steps) + per_rank_span = local_tokens + 1 + chunk = self.stream.take(per_rank_span * self.world_size) + start = self.rank * per_rank_span + local = chunk[start : start + per_rank_span].to(dtype=torch.int64) + x = local[:-1].reshape(-1, seq_len) + y = local[1:].reshape(-1, seq_len) + return x.to(self.device, non_blocking=True), y.to(self.device, non_blocking=True) + +# TRANSFORMER MODULES + +class RMSNorm(nn.Module): + def __init__(self, eps: float | None = None): + super().__init__() + self.eps = eps + + def forward(self, x: Tensor) -> Tensor: + return F.rms_norm(x, (x.size(-1),), eps=self.eps) + + +class CastedLinear(nn.Linear): + # Keep weights in fp32 for optimizer/state quality, cast at matmul time for bf16 compute. + def forward(self, x: Tensor) -> Tensor: + bias = self.bias.to(x.dtype) if self.bias is not None else None + return F.linear(x, self.weight.to(x.dtype), bias) + + +def restore_low_dim_params_to_fp32(module: nn.Module) -> None: + # Keep small/control parameters in fp32 even when the model body runs in bf16. + with torch.no_grad(): + for name, param in module.named_parameters(): + if (param.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS)) and param.dtype != torch.float32: + param.data = param.data.float() + + +class Rotary(nn.Module): + # Caches cos/sin tables per sequence length on the current device. + def __init__(self, dim: int, base: float = 10000.0): + super().__init__() + inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.float32) / dim)) + self.register_buffer("inv_freq", inv_freq, persistent=False) + self._seq_len_cached = 0 + self._cos_cached: Tensor | None = None + self._sin_cached: Tensor | None = None + + def forward(self, seq_len: int, device: torch.device, dtype: torch.dtype) -> tuple[Tensor, Tensor]: + if ( + self._cos_cached is None + or self._sin_cached is None + or self._seq_len_cached != seq_len + or self._cos_cached.device != device + ): + t = torch.arange(seq_len, device=device, dtype=self.inv_freq.dtype) + freqs = torch.outer(t, self.inv_freq.to(device)) + self._cos_cached = freqs.cos()[None, None, :, :] + self._sin_cached = freqs.sin()[None, None, :, :] + self._seq_len_cached = seq_len + return self._cos_cached.to(dtype=dtype), self._sin_cached.to(dtype=dtype) + + +def apply_rotary_emb(x: Tensor, cos: Tensor, sin: Tensor) -> Tensor: + half = x.size(-1) // 2 + x1, x2 = x[..., :half], x[..., half:] + return torch.cat((x1 * cos + x2 * sin, x1 * (-sin) + x2 * cos), dim=-1) + + +class CausalSelfAttention(nn.Module): + def __init__( + self, + dim: int, + num_heads: int, + num_kv_heads: int, + rope_base: float, + qk_gain_init: float, + ): + super().__init__() + if dim % num_heads != 0: + raise ValueError("model_dim must be divisible by num_heads") + if num_heads % num_kv_heads != 0: + raise ValueError("num_heads must be divisible by num_kv_heads") + self.num_heads = num_heads + self.num_kv_heads = num_kv_heads + self.head_dim = dim // num_heads + if self.head_dim % 2 != 0: + raise ValueError("head_dim must be even for RoPE") + kv_dim = self.num_kv_heads * self.head_dim + self.c_q = CastedLinear(dim, dim, bias=False) + self.c_k = CastedLinear(dim, kv_dim, bias=False) + self.c_v = CastedLinear(dim, kv_dim, bias=False) + self.proj = CastedLinear(dim, dim, bias=False) + self.proj._zero_init = True + self.q_gain = nn.Parameter(torch.full((num_heads,), qk_gain_init, dtype=torch.float32)) + self.rotary = Rotary(self.head_dim, base=rope_base) + + def _xsa(self, y: Tensor, v: Tensor) -> Tensor: + """Subtract self-value projection from attention output (GQA-aware).""" + B, T, H, D = y.shape + Hkv = v.size(2) + group = H // Hkv + y_g = y.reshape(B, T, Hkv, group, D) + vn = F.normalize(v, dim=-1).unsqueeze(3) # [B, T, Hkv, 1, D] + proj = (y_g * vn).sum(dim=-1, keepdim=True) * vn + return (y_g - proj).reshape(B, T, H, D) + + def forward(self, x: Tensor, use_xsa: bool = False) -> Tensor: + bsz, seqlen, dim = x.shape + q = self.c_q(x).reshape(bsz, seqlen, self.num_heads, self.head_dim).transpose(1, 2) + k = self.c_k(x).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim).transpose(1, 2) + v = self.c_v(x).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim).transpose(1, 2) + q = F.rms_norm(q, (q.size(-1),)) + k = F.rms_norm(k, (k.size(-1),)) + cos, sin = self.rotary(seqlen, x.device, q.dtype) + q = apply_rotary_emb(q, cos, sin) + k = apply_rotary_emb(k, cos, sin) + q = q * self.q_gain.to(dtype=q.dtype)[None, :, None, None] + y = F.scaled_dot_product_attention( + q, + k, + v, + attn_mask=None, + is_causal=True, + enable_gqa=(self.num_kv_heads != self.num_heads), + ) + # XSA: remove self-value bias from attention output + if use_xsa: + y = y.transpose(1, 2).contiguous() # [B, T, H, D] + v_for_xsa = v.transpose(1, 2) # [B, T, Hkv, D] + y = self._xsa(y, v_for_xsa) + y = y.reshape(bsz, seqlen, dim) + else: + y = y.transpose(1, 2).contiguous().reshape(bsz, seqlen, dim) + return self.proj(y) + + +class MLP(nn.Module): + # leaky_relu(0.5)^2 MLP — better gradient flow than relu^2 for deep/recurrent models + def __init__(self, dim: int, mlp_mult: int): + super().__init__() + hidden = mlp_mult * dim + self.fc = CastedLinear(dim, hidden, bias=False) + self.proj = CastedLinear(hidden, dim, bias=False) + self.proj._zero_init = True + + def forward(self, x: Tensor) -> Tensor: + x = F.leaky_relu(self.fc(x), negative_slope=0.5) + return self.proj(x.square()) + + +class Block(nn.Module): + def __init__( + self, + dim: int, + num_heads: int, + num_kv_heads: int, + mlp_mult: int, + rope_base: float, + qk_gain_init: float, + ): + super().__init__() + self.attn_norm = RMSNorm() + self.mlp_norm = RMSNorm() + self.attn = CausalSelfAttention(dim, num_heads, num_kv_heads, rope_base, qk_gain_init) + self.mlp = MLP(dim, mlp_mult) + self.attn_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + self.mlp_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + self.resid_mix = nn.Parameter(torch.stack((torch.ones(dim), torch.zeros(dim))).float()) + + def forward(self, x: Tensor, x0: Tensor, use_xsa: bool = False) -> Tensor: + mix = self.resid_mix.to(dtype=x.dtype) + x = mix[0][None, None, :] * x + mix[1][None, None, :] * x0 + attn_out = self.attn(self.attn_norm(x), use_xsa=use_xsa) + x = x + self.attn_scale.to(dtype=x.dtype)[None, None, :] * attn_out + x = x + self.mlp_scale.to(dtype=x.dtype)[None, None, :] * self.mlp(self.mlp_norm(x)) + return x + + +class GPT(nn.Module): + def __init__( + self, + vocab_size: int, + num_layers: int, + num_repeats: int, + model_dim: int, + num_heads: int, + num_kv_heads: int, + mlp_mult: int, + num_value_embeds: int, + tie_embeddings: bool, + tied_embed_init_std: float, + logit_softcap: float, + rope_base: float, + qk_gain_init: float, + xsa_last_n: int = 0, + ): + super().__init__() + if logit_softcap <= 0.0: + raise ValueError(f"logit_softcap must be positive, got {logit_softcap}") + self.tie_embeddings = tie_embeddings + self.tied_embed_init_std = tied_embed_init_std + self.logit_softcap = logit_softcap + self.num_repeats = num_repeats + self.xsa_last_n = xsa_last_n + effective_depth = num_layers * num_repeats + self.tok_emb = nn.Embedding(vocab_size, model_dim) + # Value embeddings: extra embedding tables mixed into each effective layer + self.num_value_embeds = num_value_embeds + if num_value_embeds > 0: + self.value_embeds = nn.ModuleList([nn.Embedding(vocab_size, model_dim) for _ in range(num_value_embeds)]) + self.value_scales = nn.Parameter(torch.zeros(effective_depth, num_value_embeds, model_dim, dtype=torch.float32)) + self.blocks = nn.ModuleList( + [ + Block( + model_dim, + num_heads, + num_kv_heads, + mlp_mult, + rope_base, + qk_gain_init, + ) + for i in range(num_layers) + ] + ) + # Loop embedding: tells the model which effective layer it's at + self.loop_embed = nn.Parameter(torch.zeros(effective_depth, model_dim, dtype=torch.float32)) + # Cross-repeat skip: each block receives its own output from previous repeat + self.cross_repeat_scales = nn.Parameter(torch.zeros(num_layers, num_repeats - 1, model_dim, dtype=torch.float32)) + self.final_norm = RMSNorm() + self.lm_head = None if tie_embeddings else CastedLinear(model_dim, vocab_size, bias=False) + if self.lm_head is not None: + self.lm_head._zero_init = True + self._init_weights() + + def _init_weights(self) -> None: + if self.tie_embeddings: + nn.init.normal_(self.tok_emb.weight, mean=0.0, std=self.tied_embed_init_std) + for module in self.modules(): + if isinstance(module, nn.Linear) and getattr(module, "_zero_init", False): + nn.init.zeros_(module.weight) + + def forward_logits(self, input_ids: Tensor) -> Tensor: + x = self.tok_emb(input_ids) + x = F.rms_norm(x, (x.size(-1),)) + x0 = x + + # Pre-compute value embeddings once + ve_list: list[Tensor] = [] + if self.num_value_embeds > 0: + for ve in self.value_embeds: + ve_list.append(ve(input_ids)) # (bsz, seq, dim) + + cur_repeats = self.cur_repeats if hasattr(self, "cur_repeats") else self.num_repeats + cur_depth = len(self.blocks) * cur_repeats + xsa_start = max(0, cur_depth - self.xsa_last_n) if self.xsa_last_n > 0 else cur_depth + + num_blocks = len(self.blocks) + prev_block_outputs: list[Tensor | None] = [None] * num_blocks + layer_idx = 0 + for repeat in range(cur_repeats): + for block_idx, block in enumerate(self.blocks): + x = x + self.loop_embed[layer_idx].to(dtype=x.dtype) + # Value embeddings: add weighted extra embeddings at each layer + if layer_idx < self.value_scales.size(0): + for ve_idx, ve_out in enumerate(ve_list): + vs = self.value_scales[layer_idx, ve_idx].to(dtype=x.dtype) + x = x + vs[None, None, :] * ve_out + # Cross-repeat skip: mix in this block's output from previous repeat + if repeat > 0 and prev_block_outputs[block_idx] is not None: + rep_idx = min(repeat - 1, self.cross_repeat_scales.size(1) - 1) + scale = self.cross_repeat_scales[block_idx, rep_idx].to(dtype=x.dtype) + x = x + scale[None, None, :] * prev_block_outputs[block_idx] + x = block(x, x0, use_xsa=(layer_idx >= xsa_start)) + prev_block_outputs[block_idx] = x.detach() if not self.training else x + layer_idx += 1 + + x = self.final_norm(x) + if self.tie_embeddings: + logits_proj = F.linear(x, self.tok_emb.weight) + else: + if self.lm_head is None: + raise RuntimeError("lm_head is required when tie_embeddings=False") + logits_proj = self.lm_head(x) + logits = self.logit_softcap * torch.tanh(logits_proj / self.logit_softcap) + return logits + + def forward(self, input_ids: Tensor, target_ids: Tensor) -> Tensor: + logits = self.forward_logits(input_ids) + logits = logits.reshape(-1, logits.size(-1)) + targets = target_ids.reshape(-1) + return F.cross_entropy(logits.float(), targets, reduction="mean") + + +# TRAINING + +def main() -> None: + global zeropower_via_newtonschulz5 + + code = Path(__file__).read_text(encoding="utf-8") + args = Hyperparameters() + zeropower_via_newtonschulz5 = torch.compile(zeropower_via_newtonschulz5) + + # DISTRIBUTED + CUDA SETUP + + distributed = "RANK" in os.environ and "WORLD_SIZE" in os.environ + rank = int(os.environ.get("RANK", "0")) + world_size = int(os.environ.get("WORLD_SIZE", "1")) + local_rank = int(os.environ.get("LOCAL_RANK", "0")) + if world_size <= 0: + raise ValueError(f"WORLD_SIZE must be positive, got {world_size}") + grad_accum_steps = max(1, 8 // world_size) + grad_scale = 1.0 / grad_accum_steps + if not torch.cuda.is_available(): + raise RuntimeError("CUDA is required") + device = torch.device("cuda", local_rank) + torch.cuda.set_device(device) + if distributed: + dist.init_process_group(backend="nccl", device_id=device) + dist.barrier() + master_process = rank == 0 + + # Fast math knobs + torch.backends.cuda.matmul.allow_tf32 = True + torch.backends.cudnn.allow_tf32 = True + from torch.backends.cuda import enable_cudnn_sdp, enable_flash_sdp, enable_math_sdp, enable_mem_efficient_sdp + + enable_cudnn_sdp(False) + enable_flash_sdp(True) + enable_mem_efficient_sdp(False) + enable_math_sdp(False) + + logfile = None + if master_process: + os.makedirs("logs", exist_ok=True) + logfile = f"logs/{args.run_id}.txt" + print(logfile) + + def log0(msg: str, console: bool = True) -> None: + if not master_process: + return + if console: + print(msg) + if logfile is not None: + with open(logfile, "a", encoding="utf-8") as f: + print(msg, file=f) + + log0(code, console=False) + log0("=" * 100, console=False) + log0(f"Running Python {sys.version}", console=False) + log0(f"Running PyTorch {torch.__version__}", console=False) + log0( + subprocess.run(["nvidia-smi"], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True, check=False).stdout, + console=False, + ) + log0("=" * 100, console=False) + + # TOKENIZER + VALIDATION METRIC SETUP + + random.seed(args.seed) + np.random.seed(args.seed) + torch.manual_seed(args.seed) + torch.cuda.manual_seed_all(args.seed) + + if not args.tokenizer_path.endswith(".model"): + raise ValueError(f"Script only setup for SentencePiece .model file: {args.tokenizer_path}") + sp = spm.SentencePieceProcessor(model_file=args.tokenizer_path) + if int(sp.vocab_size()) != args.vocab_size: + raise ValueError( + f"VOCAB_SIZE={args.vocab_size} does not match tokenizer vocab_size={int(sp.vocab_size())}" + ) + dataset_dir = Path(args.data_path).resolve() + actual_train_files = len(list(dataset_dir.glob("fineweb_train_*.bin"))) + val_tokens = load_validation_tokens(args.val_files, args.train_seq_len) + base_bytes_lut, has_leading_space_lut, is_boundary_token_lut = build_sentencepiece_luts( + sp, args.vocab_size, device + ) + log0(f"val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path={args.tokenizer_path}") + log0(f"train_loader:dataset:{dataset_dir.name} train_shards:{actual_train_files}") + log0(f"val_loader:shards pattern={args.val_files} tokens:{val_tokens.numel() - 1}") + + # MODEL + OPTIMIZER SETUP + + base_model = GPT( + vocab_size=args.vocab_size, + num_layers=args.num_layers, + num_repeats=args.num_repeats, + model_dim=args.model_dim, + num_heads=args.num_heads, + num_kv_heads=args.num_kv_heads, + mlp_mult=args.mlp_mult, + num_value_embeds=args.num_value_embeds, + tie_embeddings=args.tie_embeddings, + tied_embed_init_std=args.tied_embed_init_std, + logit_softcap=args.logit_softcap, + rope_base=args.rope_base, + qk_gain_init=args.qk_gain_init, + xsa_last_n=args.xsa_last_n, + ).to(device).bfloat16() + for module in base_model.modules(): + if isinstance(module, CastedLinear): + module.float() + restore_low_dim_params_to_fp32(base_model) + compiled_model = torch.compile(base_model, dynamic=False, fullgraph=True) + model: nn.Module = DDP(compiled_model, device_ids=[local_rank], broadcast_buffers=False) if distributed else compiled_model + + # Optimizer split: + # - token embedding (Adam) uses EMBED_LR + # - untied lm_head (Adam) uses HEAD_LR + # - matrix params in transformer blocks use MATRIX_LR via Muon + # - vectors/scalars use SCALAR_LR via Adam + block_named_params = list(base_model.blocks.named_parameters()) + matrix_params = [ + p + for name, p in block_named_params + if p.ndim == 2 and not any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) + ] + scalar_params = [ + p + for name, p in block_named_params + if p.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) + ] + scalar_params.append(base_model.loop_embed) + scalar_params.append(base_model.cross_repeat_scales) + if base_model.num_value_embeds > 0: + scalar_params.append(base_model.value_scales) + token_lr = args.tied_embed_lr if args.tie_embeddings else args.embed_lr + embed_params = [base_model.tok_emb.weight] + if base_model.num_value_embeds > 0: + embed_params.extend(ve.weight for ve in base_model.value_embeds) + optimizer_tok = torch.optim.Adam( + [{"params": embed_params, "lr": token_lr, "base_lr": token_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + fused=True, + ) + optimizer_muon = Muon( + matrix_params, + lr=args.matrix_lr, + momentum=args.muon_momentum, + backend_steps=args.muon_backend_steps, + weight_decay=args.muon_wd, + ) + for group in optimizer_muon.param_groups: + group["base_lr"] = args.matrix_lr + optimizer_scalar = torch.optim.Adam( + [{"params": scalar_params, "lr": args.scalar_lr, "base_lr": args.scalar_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + fused=True, + ) + optimizers: list[torch.optim.Optimizer] = [optimizer_tok, optimizer_muon, optimizer_scalar] + if base_model.lm_head is not None: + optimizer_head = torch.optim.Adam( + [{"params": [base_model.lm_head.weight], "lr": args.head_lr, "base_lr": args.head_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + fused=True, + ) + optimizers.insert(1, optimizer_head) + + n_params = sum(p.numel() for p in base_model.parameters()) + log0(f"model_params:{n_params}") + log0(f"world_size:{world_size} grad_accum_steps:{grad_accum_steps}") + log0("sdp_backends:cudnn=False flash=True mem_efficient=False math=False") + log0(f"attention_mode:gqa num_heads:{args.num_heads} num_kv_heads:{args.num_kv_heads}") + log0( + f"tie_embeddings:{args.tie_embeddings} embed_lr:{token_lr} " + f"head_lr:{args.head_lr if base_model.lm_head is not None else 0.0} " + f"matrix_lr:{args.matrix_lr} scalar_lr:{args.scalar_lr}" + ) + log0( + f"train_batch_tokens:{args.train_batch_tokens} train_seq_len:{args.train_seq_len} " + f"iterations:{args.iterations} warmup_steps:{args.warmup_steps} " + f"max_wallclock_seconds:{args.max_wallclock_seconds:.3f}" + ) + log0(f"seed:{args.seed}") + + # DATA LOADER & MODEL WARMUP + + train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) + + def zero_grad_all() -> None: + for opt in optimizers: + opt.zero_grad(set_to_none=True) + + max_wallclock_ms = 1000.0 * args.max_wallclock_seconds if args.max_wallclock_seconds > 0 else None + + def lr_mul(step: int, elapsed_ms: float) -> float: + if args.warmdown_iters <= 0: + return 1.0 + if max_wallclock_ms is None: + warmdown_start = max(args.iterations - args.warmdown_iters, 0) + return max((args.iterations - step) / max(args.warmdown_iters, 1), 0.0) if warmdown_start <= step < args.iterations else 1.0 + step_ms = elapsed_ms / max(step, 1) + warmdown_ms = args.warmdown_iters * step_ms + remaining_ms = max(max_wallclock_ms - elapsed_ms, 0.0) + return remaining_ms / max(warmdown_ms, 1e-9) if remaining_ms <= warmdown_ms else 1.0 + + # Warmup primes the compiled forward/backward/optimizer paths, then we restore the + # initial weights/optimizer state so measured training starts from the true init. + if args.warmup_steps > 0: + initial_model_state = {name: tensor.detach().cpu().clone() for name, tensor in base_model.state_dict().items()} + initial_optimizer_states = [copy.deepcopy(opt.state_dict()) for opt in optimizers] + model.train() + for warmup_step in range(args.warmup_steps): + zero_grad_all() + for micro_step in range(grad_accum_steps): + if distributed: + model.require_backward_grad_sync = micro_step == grad_accum_steps - 1 + x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + warmup_loss = model(x, y) + (warmup_loss * grad_scale).backward() + for opt in optimizers: + opt.step() + zero_grad_all() + if args.warmup_steps <= 20 or (warmup_step + 1) % 10 == 0 or warmup_step + 1 == args.warmup_steps: + log0(f"warmup_step:{warmup_step + 1}/{args.warmup_steps}") + base_model.load_state_dict(initial_model_state, strict=True) + for opt, state in zip(optimizers, initial_optimizer_states, strict=True): + opt.load_state_dict(state) + zero_grad_all() + if distributed: + model.require_backward_grad_sync = True + train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) + + # MAIN TRAINING LOOP + + training_time_ms = 0.0 + stop_after_step: int | None = None + swa_state: dict[str, Tensor] | None = None + swa_count = 0 + + # Progressive depth schedule: parse "frac:repeats,..." and sort + prog_phases: list[tuple[float, int]] = [] + for entry in args.prog_depth_schedule.split(","): + frac_s, rep_s = entry.strip().split(":") + prog_phases.append((float(frac_s), int(rep_s))) + prog_phases.sort() + current_phase_repeats = prog_phases[0][1] if prog_phases else args.num_repeats + base_model.cur_repeats = current_phase_repeats + # Recompile with initial phase depth + compiled_model = torch.compile(base_model, dynamic=False, fullgraph=True) + model = DDP(compiled_model, device_ids=[local_rank], broadcast_buffers=False) if distributed else compiled_model + log0(f"prog_depth: schedule={prog_phases} starting_repeats={current_phase_repeats}") + + torch.cuda.synchronize() + t0 = time.perf_counter() + + step = 0 + while True: + last_step = step == args.iterations or (stop_after_step is not None and step >= stop_after_step) + + should_validate = last_step or (args.val_loss_every > 0 and step % args.val_loss_every == 0) + if should_validate: + torch.cuda.synchronize() + training_time_ms += 1000.0 * (time.perf_counter() - t0) + val_loss, val_bpb = eval_val( + args, + model, + rank, + world_size, + device, + grad_accum_steps, + val_tokens, + base_bytes_lut, + has_leading_space_lut, + is_boundary_token_lut, + ) + log0( + f"step:{step}/{args.iterations} val_loss:{val_loss:.4f} val_bpb:{val_bpb:.4f} " + f"train_time:{training_time_ms:.0f}ms step_avg:{training_time_ms / max(step, 1):.2f}ms" + ) + torch.cuda.synchronize() + t0 = time.perf_counter() + + if last_step: + if stop_after_step is not None and step < args.iterations: + log0( + f"stopping_early: wallclock_cap train_time:{training_time_ms:.0f}ms " + f"step:{step}/{args.iterations}" + ) + break + + elapsed_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) + # Progressive depth: check if we need to switch phase + # Use synchronized elapsed time (max across ranks) to avoid race conditions + if max_wallclock_ms is not None and prog_phases: + if distributed: + elapsed_tensor = torch.tensor(elapsed_ms, device=device) + dist.all_reduce(elapsed_tensor, op=dist.ReduceOp.MAX) + frac = elapsed_tensor.item() / max_wallclock_ms + else: + frac = elapsed_ms / max_wallclock_ms + new_repeats = prog_phases[-1][1] # default to last + for phase_frac, phase_rep in prog_phases: + if frac < phase_frac: + new_repeats = phase_rep + break + if new_repeats != current_phase_repeats: + current_phase_repeats = new_repeats + base_model.cur_repeats = new_repeats + torch._dynamo.reset() + compiled_model = torch.compile(base_model, dynamic=False, fullgraph=True) + model = DDP(compiled_model, device_ids=[local_rank], broadcast_buffers=False) if distributed else compiled_model + log0(f"prog_depth: switched to {new_repeats} repeats at step:{step} frac:{frac:.2f}") + scale = lr_mul(step, elapsed_ms) + zero_grad_all() + train_loss = torch.zeros((), device=device) + for micro_step in range(grad_accum_steps): + if distributed: + model.require_backward_grad_sync = micro_step == grad_accum_steps - 1 + x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + loss = model(x, y) + train_loss += loss.detach() + (loss * grad_scale).backward() + train_loss /= grad_accum_steps + + frac = min(step / args.muon_momentum_warmup_steps, 1.0) if args.muon_momentum_warmup_steps > 0 else 1.0 + muon_momentum = (1 - frac) * args.muon_momentum_warmup_start + frac * args.muon_momentum + for group in optimizer_muon.param_groups: + group["momentum"] = muon_momentum + + for opt in optimizers: + for group in opt.param_groups: + group["lr"] = group["base_lr"] * scale + + if args.grad_clip_norm > 0: + torch.nn.utils.clip_grad_norm_(base_model.parameters(), args.grad_clip_norm) + for opt in optimizers: + opt.step() + zero_grad_all() + + step += 1 + approx_training_time_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) + + # SWA: collect checkpoints during warmdown (only at full depth to avoid mixing phases) + at_full_depth = current_phase_repeats == args.num_repeats + if args.swa_enabled and at_full_depth and scale < args.swa_start_frac and step % args.swa_every == 0: + if swa_state is None: + swa_state = {name: t.detach().cpu().float() for name, t in base_model.state_dict().items()} + swa_count = 1 + log0(f"swa:start step:{step}") + else: + for name, t in base_model.state_dict().items(): + swa_state[name] += t.detach().cpu().float() + swa_count += 1 + + should_log_train = ( + args.train_log_every > 0 + and (step <= 10 or step % args.train_log_every == 0 or stop_after_step is not None) + ) + if should_log_train: + log0( + f"step:{step}/{args.iterations} train_loss:{train_loss.item():.4f} " + f"train_time:{approx_training_time_ms:.0f}ms step_avg:{approx_training_time_ms / step:.2f}ms" + ) + + # Needed to sync whether we've reached the wallclock cap. + reached_cap = max_wallclock_ms is not None and approx_training_time_ms >= max_wallclock_ms + if distributed and max_wallclock_ms is not None: + reached_cap_tensor = torch.tensor(int(reached_cap), device=device) + dist.all_reduce(reached_cap_tensor, op=dist.ReduceOp.MAX) + reached_cap = bool(reached_cap_tensor.item()) + if stop_after_step is None and reached_cap: + stop_after_step = step + + log0( + f"peak memory allocated: {torch.cuda.max_memory_allocated() // 1024 // 1024} MiB " + f"reserved: {torch.cuda.max_memory_reserved() // 1024 // 1024} MiB" + ) + + # Restore full depth for eval/export + base_model.cur_repeats = args.num_repeats + torch._dynamo.reset() + compiled_model = torch.compile(base_model, dynamic=False, fullgraph=True) + model = DDP(compiled_model, device_ids=[local_rank], broadcast_buffers=False) if distributed else compiled_model + + # Apply SWA if collected + if args.swa_enabled and swa_state is not None: + # Include final weights (may not have landed on swa_every boundary) + for name, t in base_model.state_dict().items(): + swa_state[name] += t.detach().cpu().float() + swa_count += 1 + log0(f"swa: averaging {swa_count} checkpoints") + current_state = base_model.state_dict() + avg_state = { + name: (tensor / swa_count).to(dtype=current_state[name].dtype) + for name, tensor in swa_state.items() + } + base_model.load_state_dict(avg_state, strict=True) + + # SERIALIZATION + ROUNDTRIP VALIDATION + # Save the raw state (useful for debugging/loading in PyTorch directly), then always produce + # the compressed quantized+zstd artifact and validate the round-tripped weights. + + if master_process: + torch.save(base_model.state_dict(), "final_model.pt") + model_bytes = os.path.getsize("final_model.pt") + code_bytes = len(code.encode("utf-8")) + log0(f"Serialized model: {model_bytes} bytes") + log0(f"Code size: {code_bytes} bytes") + log0(f"Total submission size: {model_bytes + code_bytes} bytes") + + quant_obj, quant_stats = quantize_state_dict_int8(base_model.state_dict()) + quant_buf = io.BytesIO() + torch.save(quant_obj, quant_buf) + quant_raw = quant_buf.getvalue() + zstd_level = int(os.environ.get("ZSTD_LEVEL", 22)) + cctx = zstd.ZstdCompressor(level=zstd_level) + quant_blob = cctx.compress(quant_raw) + quant_raw_bytes = len(quant_raw) + if master_process: + with open("final_model.int8.ptz", "wb") as f: + f.write(quant_blob) + quant_file_bytes = os.path.getsize("final_model.int8.ptz") + code_bytes = len(code.encode("utf-8")) + ratio = quant_stats["baseline_tensor_bytes"] / max(quant_stats["int8_payload_bytes"], 1) + log0( + f"Serialized model int8+zstd{zstd_level}: {quant_file_bytes} bytes " + f"(payload:{quant_stats['int8_payload_bytes']} raw_torch:{quant_raw_bytes} payload_ratio:{ratio:.2f}x)" + ) + log0(f"Total submission size int8+zstd{zstd_level}: {quant_file_bytes + code_bytes} bytes") + + if distributed: + dist.barrier() + with open("final_model.int8.ptz", "rb") as f: + quant_blob_disk = f.read() + dctx = zstd.ZstdDecompressor() + quant_state = torch.load(io.BytesIO(dctx.decompress(quant_blob_disk)), map_location="cpu") + base_model.load_state_dict(dequantize_state_dict_int8(quant_state), strict=True) + torch.cuda.synchronize() + t_qeval = time.perf_counter() + q_val_loss, q_val_bpb = eval_val( + args, + model, + rank, + world_size, + device, + grad_accum_steps, + val_tokens, + base_bytes_lut, + has_leading_space_lut, + is_boundary_token_lut, + ) + torch.cuda.synchronize() + log0( + f"final_roundtrip val_loss:{q_val_loss:.4f} val_bpb:{q_val_bpb:.4f} " + f"eval_time:{1000.0 * (time.perf_counter() - t_qeval):.0f}ms" + ) + log0(f"final_roundtrip_exact val_loss:{q_val_loss:.8f} val_bpb:{q_val_bpb:.8f}") + + # Sliding window eval + if args.eval_stride > 0: + base_model.load_state_dict(dequantize_state_dict_int8(quant_state), strict=True) + torch.cuda.synchronize() + t_sw = time.perf_counter() + sw_val_loss, sw_val_bpb = eval_val_sliding( + args, + base_model, + rank, + world_size, + device, + val_tokens, + base_bytes_lut, + has_leading_space_lut, + is_boundary_token_lut, + ) + torch.cuda.synchronize() + log0( + f"final_sliding_window val_loss:{sw_val_loss:.4f} val_bpb:{sw_val_bpb:.4f} " + f"window:{args.eval_seq_len} stride:{args.eval_stride} " + f"eval_time:{1000.0 * (time.perf_counter() - t_sw):.0f}ms" + ) + log0(f"final_sliding_window_exact val_loss:{sw_val_loss:.8f} val_bpb:{sw_val_bpb:.8f}") + + # Hedge Mixer eval (n-gram ensemble) + if args.use_hedge: + base_model.load_state_dict(dequantize_state_dict_int8(quant_state), strict=True) + torch.cuda.synchronize() + t_hm = time.perf_counter() + hm_val_loss, hm_val_bpb = eval_val_sliding( + args, base_model, rank, world_size, device, + val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, + use_hedge=True, hedge_eta=args.hedge_eta, + ) + torch.cuda.synchronize() + log0( + f"final_hedge_mixer val_loss:{hm_val_loss:.4f} val_bpb:{hm_val_bpb:.4f} " + f"eval_time:{1000.0 * (time.perf_counter() - t_hm):.0f}ms" + ) + log0(f"final_hedge_mixer_exact val_loss:{hm_val_loss:.8f} val_bpb:{hm_val_bpb:.8f}") + + if distributed: + dist.destroy_process_group() + + +if __name__ == "__main__": + main() diff --git a/records/track_10min_16mb/2026-04-05_ProgressiveDepth_HedgeMixer/train_seed1337.log b/records/track_10min_16mb/2026-04-05_ProgressiveDepth_HedgeMixer/train_seed1337.log new file mode 100644 index 0000000000..05841c2ea2 --- /dev/null +++ b/records/track_10min_16mb/2026-04-05_ProgressiveDepth_HedgeMixer/train_seed1337.log @@ -0,0 +1,114 @@ +W0405 12:59:14.349000 2645 torch/distributed/run.py:793] +W0405 12:59:14.349000 2645 torch/distributed/run.py:793] ***************************************** +W0405 12:59:14.349000 2645 torch/distributed/run.py:793] Setting OMP_NUM_THREADS environment variable for each process to be 1 in default, to avoid your system being overloaded, please further tune the variable for optimal performance in your application as needed. +W0405 12:59:14.349000 2645 torch/distributed/run.py:793] ***************************************** +logs/7257b942-e1af-4f1b-aaf7-dfd9eec348e9.txt +val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path=./data/tokenizers/fineweb_1024_bpe.model +train_loader:dataset:fineweb10B_sp1024 train_shards:80 +val_loader:shards pattern=./data/datasets/fineweb10B_sp1024/fineweb_val_*.bin tokens:62021632 +model_params:17140056 +world_size:8 grad_accum_steps:1 +sdp_backends:cudnn=False flash=True mem_efficient=False math=False +attention_mode:gqa num_heads:8 num_kv_heads:4 +tie_embeddings:True embed_lr:0.021 head_lr:0.0 matrix_lr:0.018 scalar_lr:0.018 +train_batch_tokens:524288 train_seq_len:1024 iterations:20000 warmup_steps:20 max_wallclock_seconds:600.000 +seed:1337 +warmup_step:1/20 +warmup_step:2/20 +warmup_step:3/20 +warmup_step:4/20 +warmup_step:5/20 +warmup_step:6/20 +warmup_step:7/20 +warmup_step:8/20 +warmup_step:9/20 +warmup_step:10/20 +warmup_step:11/20 +warmup_step:12/20 +warmup_step:13/20 +warmup_step:14/20 +warmup_step:15/20 +warmup_step:16/20 +warmup_step:17/20 +warmup_step:18/20 +warmup_step:19/20 +warmup_step:20/20 +prog_depth: schedule=[(0.4, 2), (0.65, 3), (1.0, 4)] starting_repeats=2 +step:0/20000 val_loss:6.9300 val_bpb:4.1043 train_time:0ms step_avg:0.01ms +step:1/20000 train_loss:6.9769 train_time:20842ms step_avg:20841.80ms +step:2/20000 train_loss:9.6250 train_time:20866ms step_avg:10432.84ms +step:3/20000 train_loss:9.4926 train_time:20931ms step_avg:6976.94ms +step:4/20000 train_loss:9.1982 train_time:21001ms step_avg:5250.24ms +step:5/20000 train_loss:8.6464 train_time:21072ms step_avg:4214.35ms +step:6/20000 train_loss:8.1760 train_time:21144ms step_avg:3524.06ms +step:7/20000 train_loss:7.3005 train_time:21216ms step_avg:3030.88ms +step:8/20000 train_loss:6.6961 train_time:21289ms step_avg:2661.12ms +step:9/20000 train_loss:6.1797 train_time:21363ms step_avg:2373.62ms +step:10/20000 train_loss:5.8335 train_time:21435ms step_avg:2143.47ms +step:200/20000 train_loss:2.7565 train_time:35072ms step_avg:175.36ms +step:400/20000 train_loss:2.3068 train_time:49433ms step_avg:123.58ms +step:600/20000 train_loss:2.5314 train_time:63837ms step_avg:106.39ms +step:800/20000 train_loss:2.2938 train_time:78283ms step_avg:97.85ms +step:1000/20000 train_loss:2.3860 train_time:92761ms step_avg:92.76ms +step:1000/20000 val_loss:2.3453 val_bpb:1.3890 train_time:92802ms step_avg:92.80ms +step:1200/20000 train_loss:2.4039 train_time:107258ms step_avg:89.38ms +step:1400/20000 train_loss:2.4559 train_time:121732ms step_avg:86.95ms +step:1600/20000 train_loss:2.1222 train_time:136212ms step_avg:85.13ms +step:1800/20000 train_loss:2.2289 train_time:150681ms step_avg:83.71ms +step:2000/20000 train_loss:2.2846 train_time:165145ms step_avg:82.57ms +step:2000/20000 val_loss:2.2662 val_bpb:1.3422 train_time:165187ms step_avg:82.59ms +step:2200/20000 train_loss:2.1079 train_time:179607ms step_avg:81.64ms +step:2400/20000 train_loss:2.2331 train_time:194053ms step_avg:80.86ms +step:2600/20000 train_loss:2.4447 train_time:208500ms step_avg:80.19ms +step:2800/20000 train_loss:2.2801 train_time:222954ms step_avg:79.63ms +step:3000/20000 train_loss:2.2653 train_time:237405ms step_avg:79.13ms +step:3000/20000 val_loss:2.2366 val_bpb:1.3246 train_time:237445ms step_avg:79.15ms +prog_depth: switched to 3 repeats at step:3036 frac:0.40 +step:3200/20000 train_loss:2.2264 train_time:278845ms step_avg:87.14ms +step:3400/20000 train_loss:2.1915 train_time:300029ms step_avg:88.24ms +step:3600/20000 train_loss:2.1505 train_time:321247ms step_avg:89.24ms +step:3800/20000 train_loss:2.2496 train_time:342483ms step_avg:90.13ms +step:4000/20000 train_loss:2.1943 train_time:363733ms step_avg:90.93ms +step:4000/20000 val_loss:2.1992 val_bpb:1.3025 train_time:363798ms step_avg:90.95ms +step:4200/20000 train_loss:2.2021 train_time:384941ms step_avg:91.65ms +prog_depth: switched to 4 repeats at step:4248 frac:0.65 +step:4400/20000 train_loss:2.1387 train_time:422163ms step_avg:95.95ms +step:4600/20000 train_loss:1.9754 train_time:450165ms step_avg:97.86ms +step:4800/20000 train_loss:2.2500 train_time:478305ms step_avg:99.65ms +step:5000/20000 train_loss:2.0019 train_time:506362ms step_avg:101.27ms +step:5000/20000 val_loss:2.1254 val_bpb:1.2588 train_time:506445ms step_avg:101.29ms +swa:start step:5100 +step:5200/20000 train_loss:2.1291 train_time:534302ms step_avg:102.75ms +step:5400/20000 train_loss:2.1268 train_time:562387ms step_avg:104.15ms +step:5600/20000 train_loss:2.1052 train_time:590526ms step_avg:105.45ms +step:5668/20000 val_loss:2.0735 val_bpb:1.2280 train_time:600174ms step_avg:105.89ms +stopping_early: wallclock_cap train_time:600174ms step:5668/20000 +peak memory allocated: 25696 MiB reserved: 27322 MiB +swa: averaging 13 checkpoints +Serialized model: 63386762 bytes +Code size: 65854 bytes +Total submission size: 63452616 bytes +Serialized model int8+zstd22: 15816645 bytes (payload:17243616 raw_torch:17260843 payload_ratio:3.68x) +Total submission size int8+zstd22: 15882499 bytes +/workspace/parameter-golf/train_gpt.py:1429: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature. + quant_blob_disk = f.read() +/workspace/parameter-golf/train_gpt.py:1429: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature. + quant_blob_disk = f.read() +/workspace/parameter-golf/train_gpt.py:1429: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature. + quant_blob_disk = f.read() +/workspace/parameter-golf/train_gpt.py:1429: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature. + quant_blob_disk = f.read() +/workspace/parameter-golf/train_gpt.py:1429: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature. + quant_blob_disk = f.read() +/workspace/parameter-golf/train_gpt.py:1429: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature. + quant_blob_disk = f.read() +/workspace/parameter-golf/train_gpt.py:1429: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature. + quant_blob_disk = f.read() +/workspace/parameter-golf/train_gpt.py:1429: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature. + quant_blob_disk = f.read() +final_roundtrip val_loss:2.0772 val_bpb:1.2302 eval_time:14267ms +final_roundtrip_exact val_loss:2.07715950 val_bpb:1.23021086 +final_sliding_window val_loss:2.0203 val_bpb:1.1965 window:1024 stride:256 eval_time:67006ms +final_sliding_window_exact val_loss:2.02030248 val_bpb:1.19653642 +final_hedge_mixer val_loss:1.9318 val_bpb:1.1441 eval_time:579784ms +final_hedge_mixer_exact val_loss:1.93178604 val_bpb:1.14411202 diff --git a/records/track_10min_16mb/2026-04-05_ProgressiveDepth_HedgeMixer/train_seed42.log b/records/track_10min_16mb/2026-04-05_ProgressiveDepth_HedgeMixer/train_seed42.log new file mode 100644 index 0000000000..aebba2b33b --- /dev/null +++ b/records/track_10min_16mb/2026-04-05_ProgressiveDepth_HedgeMixer/train_seed42.log @@ -0,0 +1,111 @@ +W0405 13:26:41.139000 24833 torch/distributed/run.py:793] +W0405 13:26:41.139000 24833 torch/distributed/run.py:793] ***************************************** +W0405 13:26:41.139000 24833 torch/distributed/run.py:793] Setting OMP_NUM_THREADS environment variable for each process to be 1 in default, to avoid your system being overloaded, please further tune the variable for optimal performance in your application as needed. +W0405 13:26:41.139000 24833 torch/distributed/run.py:793] ***************************************** +logs/e58c05d6-b408-4c7d-a63d-0d3ab73372ea.txt +val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path=./data/tokenizers/fineweb_1024_bpe.model +train_loader:dataset:fineweb10B_sp1024 train_shards:80 +val_loader:shards pattern=./data/datasets/fineweb10B_sp1024/fineweb_val_*.bin tokens:62021632 +model_params:17160024 +world_size:8 grad_accum_steps:1 +sdp_backends:cudnn=False flash=True mem_efficient=False math=False +attention_mode:gqa num_heads:8 num_kv_heads:4 +tie_embeddings:True embed_lr:0.021 head_lr:0.0 matrix_lr:0.018 scalar_lr:0.018 +train_batch_tokens:524288 train_seq_len:1024 iterations:20000 warmup_steps:20 max_wallclock_seconds:600.000 +seed:42 +warmup_step:1/20 +warmup_step:2/20 +warmup_step:3/20 +warmup_step:4/20 +warmup_step:5/20 +warmup_step:6/20 +warmup_step:7/20 +warmup_step:8/20 +warmup_step:9/20 +warmup_step:10/20 +warmup_step:11/20 +warmup_step:12/20 +warmup_step:13/20 +warmup_step:14/20 +warmup_step:15/20 +warmup_step:16/20 +warmup_step:17/20 +warmup_step:18/20 +warmup_step:19/20 +warmup_step:20/20 +prog_depth: schedule=[(0.3, 2), (0.5, 3), (1.0, 4)] starting_repeats=2 +step:0/20000 val_loss:6.9314 val_bpb:4.1051 train_time:0ms step_avg:0.01ms +step:1/20000 train_loss:6.9785 train_time:18540ms step_avg:18540.29ms +step:2/20000 train_loss:9.6574 train_time:18561ms step_avg:9280.57ms +step:3/20000 train_loss:9.5568 train_time:18630ms step_avg:6210.11ms +step:4/20000 train_loss:9.2758 train_time:18702ms step_avg:4675.48ms +step:5/20000 train_loss:8.7871 train_time:18774ms step_avg:3754.88ms +step:6/20000 train_loss:8.3631 train_time:18848ms step_avg:3141.25ms +step:7/20000 train_loss:7.5153 train_time:18923ms step_avg:2703.26ms +step:8/20000 train_loss:6.8908 train_time:18995ms step_avg:2374.39ms +step:9/20000 train_loss:6.3561 train_time:19069ms step_avg:2118.76ms +step:10/20000 train_loss:5.9503 train_time:19143ms step_avg:1914.35ms +step:200/20000 train_loss:2.7800 train_time:32995ms step_avg:164.97ms +step:400/20000 train_loss:2.3140 train_time:47604ms step_avg:119.01ms +step:600/20000 train_loss:2.5364 train_time:62259ms step_avg:103.77ms +step:800/20000 train_loss:2.2982 train_time:76954ms step_avg:96.19ms +step:1000/20000 train_loss:2.3847 train_time:91682ms step_avg:91.68ms +step:1000/20000 val_loss:2.3444 val_bpb:1.3885 train_time:91723ms step_avg:91.72ms +step:1200/20000 train_loss:2.4045 train_time:106493ms step_avg:88.74ms +step:1400/20000 train_loss:2.4545 train_time:121226ms step_avg:86.59ms +step:1600/20000 train_loss:2.1201 train_time:135954ms step_avg:84.97ms +step:1800/20000 train_loss:2.2256 train_time:150682ms step_avg:83.71ms +step:2000/20000 train_loss:2.2853 train_time:165393ms step_avg:82.70ms +step:2000/20000 val_loss:2.2653 val_bpb:1.3416 train_time:165435ms step_avg:82.72ms +prog_depth: switched to 3 repeats at step:2199 frac:0.30 +step:2200/20000 train_loss:4.4422 train_time:198710ms step_avg:90.32ms +step:2400/20000 train_loss:2.2240 train_time:220156ms step_avg:91.73ms +step:2600/20000 train_loss:2.4322 train_time:241730ms step_avg:92.97ms +step:2800/20000 train_loss:2.2620 train_time:263333ms step_avg:94.05ms +step:3000/20000 train_loss:2.2471 train_time:284957ms step_avg:94.99ms +step:3000/20000 val_loss:2.2151 val_bpb:1.3119 train_time:285026ms step_avg:95.01ms +prog_depth: switched to 4 repeats at step:3140 frac:0.50 +step:3200/20000 train_loss:2.2169 train_time:319527ms step_avg:99.85ms +step:3400/20000 train_loss:2.1751 train_time:347907ms step_avg:102.33ms +step:3600/20000 train_loss:2.1254 train_time:376350ms step_avg:104.54ms +step:3800/20000 train_loss:2.2115 train_time:404821ms step_avg:106.53ms +step:4000/20000 train_loss:2.1412 train_time:433293ms step_avg:108.32ms +step:4000/20000 val_loss:2.1510 val_bpb:1.2740 train_time:433371ms step_avg:108.34ms +step:4200/20000 train_loss:2.1448 train_time:461699ms step_avg:109.93ms +swa:start step:4400 +step:4400/20000 train_loss:2.0696 train_time:490129ms step_avg:111.39ms +step:4600/20000 train_loss:1.9144 train_time:518692ms step_avg:112.76ms +step:4800/20000 train_loss:2.2006 train_time:547234ms step_avg:114.01ms +step:5000/20000 train_loss:1.9558 train_time:575768ms step_avg:115.15ms +step:5000/20000 val_loss:2.0792 val_bpb:1.2314 train_time:575877ms step_avg:115.18ms +step:5170/20000 val_loss:2.0720 val_bpb:1.2271 train_time:600152ms step_avg:116.08ms +stopping_early: wallclock_cap train_time:600152ms step:5170/20000 +peak memory allocated: 25539 MiB reserved: 26360 MiB +swa: averaging 17 checkpoints +Serialized model: 63427019 bytes +Code size: 65888 bytes +Total submission size: 63492907 bytes +Serialized model int8+zstd22: 15837812 bytes (payload:17283552 raw_torch:17301160 payload_ratio:3.67x) +Total submission size int8+zstd22: 15903700 bytes +/workspace/parameter-golf/train_gpt.py:1431: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature. + quant_state = torch.load(io.BytesIO(dctx.decompress(quant_blob_disk)), map_location="cpu") +/workspace/parameter-golf/train_gpt.py:1431: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature. + quant_state = torch.load(io.BytesIO(dctx.decompress(quant_blob_disk)), map_location="cpu") +/workspace/parameter-golf/train_gpt.py:1431: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature. + quant_state = torch.load(io.BytesIO(dctx.decompress(quant_blob_disk)), map_location="cpu") +/workspace/parameter-golf/train_gpt.py:1431: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature. + quant_state = torch.load(io.BytesIO(dctx.decompress(quant_blob_disk)), map_location="cpu") +/workspace/parameter-golf/train_gpt.py:1431: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature. + quant_state = torch.load(io.BytesIO(dctx.decompress(quant_blob_disk)), map_location="cpu") +/workspace/parameter-golf/train_gpt.py:1431: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature. + quant_state = torch.load(io.BytesIO(dctx.decompress(quant_blob_disk)), map_location="cpu") +/workspace/parameter-golf/train_gpt.py:1431: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature. + quant_state = torch.load(io.BytesIO(dctx.decompress(quant_blob_disk)), map_location="cpu") +/workspace/parameter-golf/train_gpt.py:1431: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature. + quant_state = torch.load(io.BytesIO(dctx.decompress(quant_blob_disk)), map_location="cpu") +final_roundtrip val_loss:2.0765 val_bpb:1.2298 eval_time:14472ms +final_roundtrip_exact val_loss:2.07646450 val_bpb:1.22979924 +final_sliding_window val_loss:2.0197 val_bpb:1.1962 window:1024 stride:256 eval_time:68373ms +final_sliding_window_exact val_loss:2.01969483 val_bpb:1.19617654 +final_hedge_mixer val_loss:1.9403 val_bpb:1.1491 eval_time:587385ms +final_hedge_mixer_exact val_loss:1.94028891 val_bpb:1.14914789 diff --git a/records/track_10min_16mb/2026-04-05_ProgressiveDepth_HedgeMixer/train_seed7.log b/records/track_10min_16mb/2026-04-05_ProgressiveDepth_HedgeMixer/train_seed7.log new file mode 100644 index 0000000000..46476bf1e0 --- /dev/null +++ b/records/track_10min_16mb/2026-04-05_ProgressiveDepth_HedgeMixer/train_seed7.log @@ -0,0 +1,113 @@ +W0405 13:52:01.099000 27624 torch/distributed/run.py:793] +W0405 13:52:01.099000 27624 torch/distributed/run.py:793] ***************************************** +W0405 13:52:01.099000 27624 torch/distributed/run.py:793] Setting OMP_NUM_THREADS environment variable for each process to be 1 in default, to avoid your system being overloaded, please further tune the variable for optimal performance in your application as needed. +W0405 13:52:01.099000 27624 torch/distributed/run.py:793] ***************************************** +logs/67f1f3ea-6f94-4768-9dde-1184d41cb7ad.txt +val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path=./data/tokenizers/fineweb_1024_bpe.model +train_loader:dataset:fineweb10B_sp1024 train_shards:80 +val_loader:shards pattern=./data/datasets/fineweb10B_sp1024/fineweb_val_*.bin tokens:62021632 +model_params:17160024 +world_size:8 grad_accum_steps:1 +sdp_backends:cudnn=False flash=True mem_efficient=False math=False +attention_mode:gqa num_heads:8 num_kv_heads:4 +tie_embeddings:True embed_lr:0.021 head_lr:0.0 matrix_lr:0.018 scalar_lr:0.018 +train_batch_tokens:524288 train_seq_len:1024 iterations:20000 warmup_steps:20 max_wallclock_seconds:600.000 +seed:7 +warmup_step:1/20 +warmup_step:2/20 +warmup_step:3/20 +warmup_step:4/20 +warmup_step:5/20 +warmup_step:6/20 +warmup_step:7/20 +warmup_step:8/20 +warmup_step:9/20 +warmup_step:10/20 +warmup_step:11/20 +warmup_step:12/20 +warmup_step:13/20 +warmup_step:14/20 +warmup_step:15/20 +warmup_step:16/20 +warmup_step:17/20 +warmup_step:18/20 +warmup_step:19/20 +warmup_step:20/20 +prog_depth: schedule=[(0.3, 2), (0.5, 3), (1.0, 4)] starting_repeats=2 +step:0/20000 val_loss:6.9329 val_bpb:4.1060 train_time:0ms step_avg:0.01ms +step:1/20000 train_loss:6.9803 train_time:6821ms step_avg:6821.17ms +step:2/20000 train_loss:9.8042 train_time:6839ms step_avg:3419.68ms +step:3/20000 train_loss:9.5164 train_time:6911ms step_avg:2303.74ms +step:4/20000 train_loss:8.7997 train_time:6984ms step_avg:1745.94ms +step:5/20000 train_loss:7.6852 train_time:7058ms step_avg:1411.58ms +step:6/20000 train_loss:6.8925 train_time:7129ms step_avg:1188.15ms +step:7/20000 train_loss:5.9917 train_time:7203ms step_avg:1028.98ms +step:8/20000 train_loss:5.7313 train_time:7279ms step_avg:909.87ms +step:9/20000 train_loss:5.6082 train_time:7353ms step_avg:817.00ms +step:10/20000 train_loss:5.4280 train_time:7427ms step_avg:742.74ms +step:200/20000 train_loss:2.7335 train_time:21311ms step_avg:106.56ms +step:400/20000 train_loss:2.3165 train_time:35956ms step_avg:89.89ms +step:600/20000 train_loss:2.5395 train_time:50658ms step_avg:84.43ms +step:800/20000 train_loss:2.3015 train_time:65400ms step_avg:81.75ms +step:1000/20000 train_loss:2.3923 train_time:80157ms step_avg:80.16ms +step:1000/20000 val_loss:2.3513 val_bpb:1.3926 train_time:80199ms step_avg:80.20ms +step:1200/20000 train_loss:2.4034 train_time:94922ms step_avg:79.10ms +step:1400/20000 train_loss:2.4556 train_time:109688ms step_avg:78.35ms +step:1600/20000 train_loss:2.1259 train_time:124454ms step_avg:77.78ms +step:1800/20000 train_loss:2.2303 train_time:139206ms step_avg:77.34ms +step:2000/20000 train_loss:2.2890 train_time:153957ms step_avg:76.98ms +step:2000/20000 val_loss:2.2703 val_bpb:1.3446 train_time:153998ms step_avg:77.00ms +step:2200/20000 train_loss:2.1112 train_time:168695ms step_avg:76.68ms +prog_depth: switched to 3 repeats at step:2353 frac:0.30 +step:2400/20000 train_loss:2.2695 train_time:194610ms step_avg:81.09ms +step:2600/20000 train_loss:2.4463 train_time:216165ms step_avg:83.14ms +step:2800/20000 train_loss:2.2693 train_time:237754ms step_avg:84.91ms +step:3000/20000 train_loss:2.2529 train_time:259346ms step_avg:86.45ms +step:3000/20000 val_loss:2.2211 val_bpb:1.3154 train_time:259415ms step_avg:86.47ms +step:3200/20000 train_loss:2.2133 train_time:280908ms step_avg:87.78ms +prog_depth: switched to 4 repeats at step:3377 frac:0.50 +step:3400/20000 train_loss:2.2141 train_time:314033ms step_avg:92.36ms +step:3600/20000 train_loss:2.1395 train_time:342468ms step_avg:95.13ms +step:3800/20000 train_loss:2.2343 train_time:370944ms step_avg:97.62ms +step:4000/20000 train_loss:2.1677 train_time:399441ms step_avg:99.86ms +step:4000/20000 val_loss:2.1719 val_bpb:1.2863 train_time:399520ms step_avg:99.88ms +step:4200/20000 train_loss:2.1676 train_time:427927ms step_avg:101.89ms +step:4400/20000 train_loss:2.0909 train_time:456410ms step_avg:103.73ms +step:4600/20000 train_loss:1.9321 train_time:484923ms step_avg:105.42ms +swa:start step:4700 +step:4800/20000 train_loss:2.2172 train_time:513500ms step_avg:106.98ms +step:5000/20000 train_loss:1.9730 train_time:542084ms step_avg:108.42ms +step:5000/20000 val_loss:2.0940 val_bpb:1.2402 train_time:542194ms step_avg:108.44ms +step:5200/20000 train_loss:2.0999 train_time:570672ms step_avg:109.74ms +step:5400/20000 train_loss:2.1064 train_time:599264ms step_avg:110.97ms +step:5405/20000 val_loss:2.0706 val_bpb:1.2263 train_time:600082ms step_avg:111.02ms +stopping_early: wallclock_cap train_time:600082ms step:5405/20000 +peak memory allocated: 25540 MiB reserved: 26120 MiB +swa: averaging 16 checkpoints +Serialized model: 63427019 bytes +Code size: 65888 bytes +Total submission size: 63492907 bytes +Serialized model int8+zstd22: 15835153 bytes (payload:17283552 raw_torch:17301160 payload_ratio:3.67x) +Total submission size int8+zstd22: 15901041 bytes +/workspace/parameter-golf/train_gpt.py:1431: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature. + quant_state = torch.load(io.BytesIO(dctx.decompress(quant_blob_disk)), map_location="cpu") +/workspace/parameter-golf/train_gpt.py:1431: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature. + quant_state = torch.load(io.BytesIO(dctx.decompress(quant_blob_disk)), map_location="cpu") +/workspace/parameter-golf/train_gpt.py:1431: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature. + quant_state = torch.load(io.BytesIO(dctx.decompress(quant_blob_disk)), map_location="cpu") +/workspace/parameter-golf/train_gpt.py:1431: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature. + quant_state = torch.load(io.BytesIO(dctx.decompress(quant_blob_disk)), map_location="cpu") +/workspace/parameter-golf/train_gpt.py:1431: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature. + quant_state = torch.load(io.BytesIO(dctx.decompress(quant_blob_disk)), map_location="cpu") +/workspace/parameter-golf/train_gpt.py:1431: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature. + quant_state = torch.load(io.BytesIO(dctx.decompress(quant_blob_disk)), map_location="cpu") +/workspace/parameter-golf/train_gpt.py:1431: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature. + quant_state = torch.load(io.BytesIO(dctx.decompress(quant_blob_disk)), map_location="cpu") +/workspace/parameter-golf/train_gpt.py:1431: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature. + quant_state = torch.load(io.BytesIO(dctx.decompress(quant_blob_disk)), map_location="cpu") +final_roundtrip val_loss:2.0744 val_bpb:1.2286 eval_time:13681ms +final_roundtrip_exact val_loss:2.07441014 val_bpb:1.22858253 +final_sliding_window val_loss:2.0180 val_bpb:1.1952 window:1024 stride:256 eval_time:68597ms +final_sliding_window_exact val_loss:2.01801934 val_bpb:1.19518422 +final_hedge_mixer val_loss:1.9231 val_bpb:1.1390 eval_time:590269ms +final_hedge_mixer_exact val_loss:1.92307754 val_bpb:1.13895436