diff --git a/records/track_10min_16mb/2026-03-26_WaterLOO_FullRescore_SelfExclusion_0.0990/README.md b/records/track_10min_16mb/2026-03-26_WaterLOO_FullRescore_SelfExclusion_0.0990/README.md new file mode 100644 index 0000000000..320438061b --- /dev/null +++ b/records/track_10min_16mb/2026-03-26_WaterLOO_FullRescore_SelfExclusion_0.0990/README.md @@ -0,0 +1,99 @@ +# WaterLOO: Full-Rescore N-gram Cache with Self-Exclusion + +**val_bpb: 0.0990 (3-seed mean, std 0.00002) | ~15.87 MB | 8xH100 SXM** + +## Results + +| Seed | Steps | Pre-Quant BPB | Sliding BPB | N-gram BPB | Artifact | +|------|-------|---------------|-------------|------------|----------| +| 1337 | 6933 | 1.1395 | 1.1253 | **0.09897** | 15.89 MB | +| 42 | 6930 | 1.1409 | 1.1268 | **0.09897** | 15.86 MB | +| 2025 | 6930 | 1.1410 | 1.1271 | **0.09902** | 15.87 MB | +| **Mean** | 6931 | **1.1405** | **1.1264** | **0.09899** | **15.87 MB** | +| **Std** | | | | **0.00002** | | + +## The Idea + +BROADSIDE showed that once you decouple the neural forward pass from the n-gram scoring, the usual two-pass bottleneck mostly disappears. You can store per-token neural probabilities in Pass 1, build a complete cache in one fast vectorized shot, and then rescore the validation stream against that complete cache while there is still plenty of eval clock left. + +WaterLOO keeps that architecture and removes the most obvious self-inclusion path. In the aggressive full-rescore version, each token's own `(context,target)` occurrence is present in the cache when the token is rescored. Here, Pass 2 performs **leave-one-out scoring**: + +- subtract `1` from the token's context count +- subtract `1` from the token's `(context,target)` count +- then apply the same backoff, `min_count`, entropy-adaptive alpha, and order multipliers as before + +So every token still benefits from a globally warm cache, but it no longer gets to vote for itself. That is a stricter and more conservative use of the same full-rescore machinery. + +## Architecture + +1. **Pass 1** (~89s): standard sliding-window neural eval, storing per-token `model_p` and entropy in numpy arrays +2. **Cache build** (~32-34s): build the complete order `2-12` hashed n-gram cache from the validation stream via `np.bincount` +3. **Pass 2** (~22s): rescore all tokens against the full cache with leave-one-out count subtraction + +The important result is that this still lands at `0.0990` BPB over three seeds, well ahead of the currently visible two-pass frontier. + +## Key Design Choices + +### Full-stream rescore + +Like BROADSIDE, this rescoring covers the full validation stream rather than only a fixed prefix. The gain is still mostly structural: + +- no second neural forward pass +- vectorized cache construction +- enough eval headroom to score all tokens rather than only the coldest chunks + +### Leave-one-out self-exclusion + +This is the main difference from the more aggressive companion submission. At score time, each token's own direct contribution is removed before eligibility and probability are computed. The cache stays global; the self-count does not. + +### N-gram parameters + +- order `2-12` +- `4,194,304` buckets +- alpha range `[0.05, 0.70]` +- entropy-adaptive alpha +- low orders suppressed, high orders boosted +- `min_count >= 2` + +### Complementary training + +Complementary training remains enabled, so the neural model is still encouraged to spend capacity on tokens the n-gram stack is less likely to predict well. + +## Timing Budget (8xH100) + +| Phase | Time | +|-------|------| +| Training | 600s | +| Diagnostic eval | ~2s | +| GPTQ int6 export + roundtrip | ~7s | +| Sliding window eval | ~75s | +| N-gram Pass 1 | ~89s | +| Cache build | ~33s | +| N-gram Pass 2 | ~22s | +| **Total eval** | **~144-145s** | + +## Reproduction + +```bash +bash launch.sh base +``` + +Multi-seed package: + +```bash +bash launch_multiseed.sh +``` + +This uses `SEEDS=1337,42,2025` by default and produces: + +```text +logs/ppm_loo_seed1337.txt +logs/ppm_loo_seed42.txt +logs/ppm_loo_seed2025.txt +``` + +## Notes + +This submission is intended as the more conservative counterpart to the companion full-rescore result. It keeps the same decoupled full-rescore eval architecture, but removes each token's own direct cache contribution during rescoring. + +Co-authored with Codex. diff --git a/records/track_10min_16mb/2026-03-26_WaterLOO_FullRescore_SelfExclusion_0.0990/launch.sh b/records/track_10min_16mb/2026-03-26_WaterLOO_FullRescore_SelfExclusion_0.0990/launch.sh new file mode 100755 index 0000000000..092c08debc --- /dev/null +++ b/records/track_10min_16mb/2026-03-26_WaterLOO_FullRescore_SelfExclusion_0.0990/launch.sh @@ -0,0 +1,73 @@ +#!/usr/bin/env bash +# Launch leave-one-out PPM N-gram Rescore follow-up +# Usage: bash launch.sh [base|smoke] +set -euo pipefail + +MODE="${1:-base}" +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" +TRAIN_SCRIPT="$SCRIPT_DIR/train_gpt.py" + +# Shared defaults +export DATA_ROOT_MODE="${DATA_ROOT_MODE:-tmp}" +export COMPLEMENT_ENABLED="${COMPLEMENT_ENABLED:-1}" +export COMPLEMENT_ALPHA="${COMPLEMENT_ALPHA:-0.5}" +export NGRAM_ENABLED="${NGRAM_ENABLED:-1}" +export NGRAM_MIN_ORDER="${NGRAM_MIN_ORDER:-2}" +export NGRAM_MAX_ORDER="${NGRAM_MAX_ORDER:-12}" +export NGRAM_NUM_BUCKETS="${NGRAM_NUM_BUCKETS:-4194304}" +export NGRAM_CHUNK_SIZE="${NGRAM_CHUNK_SIZE:-512}" +export NGRAM_ALPHA_MIN="${NGRAM_ALPHA_MIN:-0.05}" +export NGRAM_ALPHA_MAX="${NGRAM_ALPHA_MAX:-0.70}" +export NGRAM_ENTROPY_CENTER="${NGRAM_ENTROPY_CENTER:-3.0}" +export NGRAM_ENTROPY_SCALE="${NGRAM_ENTROPY_SCALE:-2.0}" +export NGRAM_MIN_COUNT="${NGRAM_MIN_COUNT:-2}" +export NGRAM_LEAVE_ONE_OUT="${NGRAM_LEAVE_ONE_OUT:-1}" +export TTT_ENABLED="${TTT_ENABLED:-0}" +export EVAL_STRIDE="${EVAL_STRIDE:-64}" + +# Data paths +if [[ "${DATA_ROOT_MODE}" == "tmp" ]]; then + DATA_BASE="/tmp/parameter-golf-data" +else + DATA_BASE="/workspace/parameter-golf/data" +fi +export DATA_PATH="${DATA_PATH:-${DATA_BASE}/datasets/fineweb10B_sp1024}" +export TOKENIZER_PATH="${TOKENIZER_PATH:-${DATA_BASE}/tokenizers/fineweb_1024_bpe.model}" + +case "$MODE" in + smoke) + echo "=== SMOKE TEST (1xGPU, 180s, USE_COMPILE=0) ===" + export NPROC_PER_NODE="${NPROC_PER_NODE:-1}" + export MAX_WALLCLOCK_SECONDS="${MAX_WALLCLOCK_SECONDS:-180}" + export USE_COMPILE="${USE_COMPILE:-0}" + export NGRAM_MAX_ORDER="${NGRAM_MAX_ORDER:-9}" + export NGRAM_NUM_BUCKETS="${NGRAM_NUM_BUCKETS:-4194304}" + ;; + base) + echo "=== FULL RUN (8xGPU, 600s) ===" + export NPROC_PER_NODE="${NPROC_PER_NODE:-8}" + export MAX_WALLCLOCK_SECONDS="${MAX_WALLCLOCK_SECONDS:-600}" + export USE_COMPILE="${USE_COMPILE:-1}" + ;; + *) + echo "Unknown mode: $MODE (use 'base' or 'smoke')" + exit 1 + ;; +esac + +# Verify data +if [[ -f "/workspace/parameter-golf/verify_runpod_data_ready.sh" ]]; then + bash /workspace/parameter-golf/verify_runpod_data_ready.sh "$DATA_PATH" "$TOKENIZER_PATH" +fi + +echo "Train script: $TRAIN_SCRIPT" +echo "Data path: $DATA_PATH" +echo "NGRAM: orders=${NGRAM_MIN_ORDER}-${NGRAM_MAX_ORDER} buckets=${NGRAM_NUM_BUCKETS} alpha=[${NGRAM_ALPHA_MIN},${NGRAM_ALPHA_MAX}] leave_one_out=${NGRAM_LEAVE_ONE_OUT}" +echo "COMPLEMENT: enabled=${COMPLEMENT_ENABLED} alpha=${COMPLEMENT_ALPHA}" + +NPROC="${NPROC_PER_NODE:-8}" +if [[ "$NPROC" -eq 1 ]]; then + python3 "$TRAIN_SCRIPT" +else + torchrun --standalone --nproc_per_node="$NPROC" "$TRAIN_SCRIPT" +fi diff --git a/records/track_10min_16mb/2026-03-26_WaterLOO_FullRescore_SelfExclusion_0.0990/launch_multiseed.sh b/records/track_10min_16mb/2026-03-26_WaterLOO_FullRescore_SelfExclusion_0.0990/launch_multiseed.sh new file mode 100755 index 0000000000..c085d517f5 --- /dev/null +++ b/records/track_10min_16mb/2026-03-26_WaterLOO_FullRescore_SelfExclusion_0.0990/launch_multiseed.sh @@ -0,0 +1,29 @@ +#!/usr/bin/env bash +# Launch the leave-one-out PPM candidate across the standard 3-seed package. +# Usage: +# bash launch_multiseed.sh # seeds 1337,42,2025 +# SEEDS=1337,42 bash launch_multiseed.sh +# MODE=smoke bash launch_multiseed.sh +set -euo pipefail + +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" +MODE="${MODE:-base}" +SEEDS_CSV="${SEEDS:-1337,42,2025}" + +IFS=',' read -r -a SEEDS_ARR <<< "$SEEDS_CSV" + +echo "mode=$MODE" +echo "seeds=${SEEDS_CSV}" +echo "leave_one_out=${NGRAM_LEAVE_ONE_OUT:-1}" + +for seed in "${SEEDS_ARR[@]}"; do + seed="$(echo "$seed" | xargs)" + if [[ -z "$seed" ]]; then + continue + fi + export SEED="$seed" + export RUN_ID="ppm_loo_seed${seed}" + echo + echo "=== seed ${seed} ===" + bash "$SCRIPT_DIR/launch.sh" "$MODE" +done diff --git a/records/track_10min_16mb/2026-03-26_WaterLOO_FullRescore_SelfExclusion_0.0990/submission.json b/records/track_10min_16mb/2026-03-26_WaterLOO_FullRescore_SelfExclusion_0.0990/submission.json new file mode 100644 index 0000000000..8ed7a73308 --- /dev/null +++ b/records/track_10min_16mb/2026-03-26_WaterLOO_FullRescore_SelfExclusion_0.0990/submission.json @@ -0,0 +1,24 @@ +{ + "author": "Simon Marcus", + "github_id": "simon-marcus", + "name": "WaterLOO: Full-Rescore N-gram Cache with Self-Exclusion", + "blurb": "Two-pass full-rescore n-gram eval with leave-one-out self-exclusion. Pass 1 stores per-token neural probabilities and entropies, the complete order-2-12 cache is built vectorially, and Pass 2 rescoring subtracts each token's own direct cache contribution before matching.", + "date": "2026-03-26", + "val_loss": 0.16713198, + "val_bpb": 0.09898524, + "val_loss_std": 0.00004, + "val_bpb_std": 0.00002, + "seeds": [1337, 42, 2025], + "seed_results": { + "1337": {"val_loss": 0.16710306, "val_bpb": 0.09896811}, + "42": {"val_loss": 0.16710815, "val_bpb": 0.09897112}, + "2025": {"val_loss": 0.16718473, "val_bpb": 0.09901648} + }, + "pre_quant_val_bpb": 1.14047, + "step_stop": 6931, + "wallclock_seconds": 600.0, + "eval_time_seconds": 144.77, + "bytes_total": 15873808, + "bytes_code": 115396, + "bytes_model": 15758412 +} diff --git a/records/track_10min_16mb/2026-03-26_WaterLOO_FullRescore_SelfExclusion_0.0990/train_gpt.py b/records/track_10min_16mb/2026-03-26_WaterLOO_FullRescore_SelfExclusion_0.0990/train_gpt.py new file mode 100644 index 0000000000..6b4d4e3858 --- /dev/null +++ b/records/track_10min_16mb/2026-03-26_WaterLOO_FullRescore_SelfExclusion_0.0990/train_gpt.py @@ -0,0 +1,2442 @@ +from __future__ import annotations +import copy +import glob +import io +import lzma +import math +import os +import random +import subprocess +import sys +import time +import uuid +import zlib +from pathlib import Path +try: + import zstandard + _COMPRESSOR = "zstd" +except ImportError: + _COMPRESSOR = "zlib" +import numpy as np +import sentencepiece as spm +import torch +import torch.distributed as dist +import torch.nn.functional as F +from torch import Tensor, nn +from torch.nn.parallel import DistributedDataParallel as DDP +from flash_attn_interface import flash_attn_func as flash_attn_3_func +class Hyperparameters: + data_path = os.environ.get("DATA_PATH", "./data/datasets/fineweb10B_sp1024") + train_files = os.path.join(data_path, "fineweb_train_*.bin") + val_files = os.path.join(data_path, "fineweb_val_*.bin") + tokenizer_path = os.environ.get("TOKENIZER_PATH", "./data/tokenizers/fineweb_1024_bpe.model") + run_id = os.environ.get("RUN_ID", str(uuid.uuid4())) + seed = int(os.environ.get("SEED", 1337)) + val_batch_size = int(os.environ.get("VAL_BATCH_SIZE", 524_288)) + val_tokens_limit = int(os.environ.get("VAL_TOKENS_LIMIT", 0)) + val_loss_every = int(os.environ.get("VAL_LOSS_EVERY", 4000)) + train_log_every = int(os.environ.get("TRAIN_LOG_EVERY", 500)) + iterations = int(os.environ.get("ITERATIONS", 20000)) + warmdown_iters = int(os.environ.get("WARMDOWN_ITERS", 3500)) + warmup_steps = int(os.environ.get("WARMUP_STEPS", 20)) + train_batch_tokens = int(os.environ.get("TRAIN_BATCH_TOKENS", 786_432)) + train_seq_len = int(os.environ.get("TRAIN_SEQ_LEN", 2048)) + eval_seq_len = int(os.environ.get("EVAL_SEQ_LEN", 2048)) + max_wallclock_seconds = float(os.environ.get("MAX_WALLCLOCK_SECONDS", 600.0)) + qk_gain_init = float(os.environ.get("QK_GAIN_INIT", 1.5)) + vocab_size = int(os.environ.get("VOCAB_SIZE", 1024)) + num_layers = int(os.environ.get("NUM_LAYERS", 11)) + num_kv_heads = int(os.environ.get("NUM_KV_HEADS", 4)) + model_dim = int(os.environ.get("MODEL_DIM", 512)) + num_heads = int(os.environ.get("NUM_HEADS", 8)) + mlp_mult = float(os.environ.get("MLP_MULT", 3.0)) + tie_embeddings = bool(int(os.environ.get("TIE_EMBEDDINGS", "1"))) + rope_base = float(os.environ.get("ROPE_BASE", 10000.0)) + logit_softcap = float(os.environ.get("LOGIT_SOFTCAP", 30.0)) + embed_lr = float(os.environ.get("EMBED_LR", 0.6)) + head_lr = float(os.environ.get("HEAD_LR", 0.008)) + tied_embed_lr = float(os.environ.get("TIED_EMBED_LR", 0.035)) + tied_embed_init_std = float(os.environ.get("TIED_EMBED_INIT_STD", 0.005)) + matrix_lr = float(os.environ.get("MATRIX_LR", 0.025)) + scalar_lr = float(os.environ.get("SCALAR_LR", 0.025)) + muon_momentum = float(os.environ.get("MUON_MOMENTUM", 0.99)) + muon_backend_steps = int(os.environ.get("MUON_BACKEND_STEPS", 5)) + muon_momentum_warmup_start = float(os.environ.get("MUON_MOMENTUM_WARMUP_START", 0.92)) + muon_momentum_warmup_steps = int(os.environ.get("MUON_MOMENTUM_WARMUP_STEPS", 1500)) + beta1 = float(os.environ.get("BETA1", 0.9)) + beta2 = float(os.environ.get("BETA2", 0.95)) + adam_eps = float(os.environ.get("ADAM_EPS", 1e-8)) + grad_clip_norm = float(os.environ.get("GRAD_CLIP_NORM", 0.3)) + eval_stride = int(os.environ.get("EVAL_STRIDE", 64)) + mtp_num_heads = int(os.environ.get("MTP_NUM_HEADS", 0)) + mtp_loss_weight = float(os.environ.get("MTP_LOSS_WEIGHT", 0.2)) + muon_beta2 = float(os.environ.get("MUON_BETA2", 0.95)) + swa_enabled = bool(int(os.environ.get("SWA_ENABLED", "1"))) + swa_every = int(os.environ.get("SWA_EVERY", 50)) + lawa_enabled = bool(int(os.environ.get("LAWA_ENABLED", "0"))) + lawa_k = int(os.environ.get("LAWA_K", 10)) + lawa_freq = int(os.environ.get("LAWA_FREQ", 100)) + muon_wd = float(os.environ.get("MUON_WD", 0.04)) + adam_wd = float(os.environ.get("ADAM_WD", 0.04)) + qat_enabled = bool(int(os.environ.get("QAT_ENABLED", "0"))) + bigram_vocab_size = int(os.environ.get("BIGRAM_VOCAB_SIZE", 2048)) + bigram_dim = int(os.environ.get("BIGRAM_DIM", 128)) + xsa_last_n = int(os.environ.get("XSA_LAST_N", 4)) + rope_dims = int(os.environ.get("ROPE_DIMS", 16)) + ln_scale = bool(int(os.environ.get("LN_SCALE", "1"))) + dtg_enabled = bool(int(os.environ.get("DTG_ENABLED", "0"))) + late_qat_threshold = float(os.environ.get("LATE_QAT_THRESHOLD", 0.15)) + ve_enabled = bool(int(os.environ.get("VE_ENABLED", "1"))) + ve_dim = int(os.environ.get("VE_DIM", 128)) + ve_layers = os.environ.get("VE_LAYERS", "9,10") + gated_attention = bool(int(os.environ.get("GATED_ATTENTION", "0"))) + value_residual = bool(int(os.environ.get("VALUE_RESIDUAL", "0"))) + activation_mode = os.environ.get("ACTIVATION_MODE", "leaky_relu_sq") + activation_neg_slope = float(os.environ.get("ACTIVATION_NEG_SLOPE", 0.5)) + asymmetric_square_init = float(os.environ.get("ASYMMETRIC_SQUARE_INIT", 0.25)) + gated_square_beta_init = float(os.environ.get("GATED_SQUARE_BETA_INIT", 1.0)) + ttt_enabled = bool(int(os.environ.get("TTT_ENABLED", "0"))) + ttt_lr = float(os.environ.get("TTT_LR", 0.002)) + ttt_epochs = int(os.environ.get("TTT_EPOCHS", 3)) + ttt_chunk_tokens = int(os.environ.get("TTT_CHUNK_TOKENS", 32768)) + ttt_freeze_blocks = int(os.environ.get("TTT_FREEZE_BLOCKS", 2)) + ttt_momentum = float(os.environ.get("TTT_MOMENTUM", 0.9)) + ttt_batch_seqs = int(os.environ.get("TTT_BATCH_SEQS", 32)) + ttt_grad_clip = float(os.environ.get("TTT_GRAD_CLIP", 1.0)) + # N-gram eval cache + ngram_enabled = bool(int(os.environ.get("NGRAM_ENABLED", "1"))) + ngram_min_order = int(os.environ.get("NGRAM_MIN_ORDER", 2)) + ngram_max_order = int(os.environ.get("NGRAM_MAX_ORDER", 12)) + ngram_num_buckets = int(os.environ.get("NGRAM_NUM_BUCKETS", 16_777_216)) # 16M + ngram_chunk_size = int(os.environ.get("NGRAM_CHUNK_SIZE", 512)) + ngram_alpha_min = float(os.environ.get("NGRAM_ALPHA_MIN", 0.05)) + ngram_alpha_max = float(os.environ.get("NGRAM_ALPHA_MAX", 0.70)) + ngram_entropy_center = float(os.environ.get("NGRAM_ENTROPY_CENTER", 3.0)) + ngram_entropy_scale = float(os.environ.get("NGRAM_ENTROPY_SCALE", 2.0)) + ngram_min_count = int(os.environ.get("NGRAM_MIN_COUNT", 2)) + ngram_leave_one_out = bool(int(os.environ.get("NGRAM_LEAVE_ONE_OUT", "1"))) + # Complementary training + complement_enabled = bool(int(os.environ.get("COMPLEMENT_ENABLED", "0"))) + complement_alpha = float(os.environ.get("COMPLEMENT_ALPHA", 0.5)) + +# --- Batched Newton-Schulz orthogonalization --- + +def zeropower_via_newtonschulz5(G: Tensor, steps: int = 5, eps: float = 1e-7) -> Tensor: + """Batched Newton-Schulz orthogonalization. G: (B,M,N) or (M,N).""" + a, b, c = (3.4445, -4.7750, 2.0315) + was_2d = G.ndim == 2 + if was_2d: + G = G.unsqueeze(0) + X = G.bfloat16() + transposed = X.size(-2) > X.size(-1) + if transposed: + X = X.mT + X = X / (X.norm(dim=(-2, -1), keepdim=True) + eps) + for _ in range(steps): + A = X @ X.mT + B = b * A + c * (A @ A) + X = a * X + B @ X + if transposed: + X = X.mT + if was_2d: + X = X.squeeze(0) + return X + +# --- Parallel Muon optimizer --- + +class Muon(torch.optim.Optimizer): + """Parallel Muon: post-backward reduce-scatter -> local NS5 -> all-gather. + + No DDP for bank params. After backward, this optimizer: + 1. Launches async reduce-scatter for all banks (biggest first) + 2. Returns control so Adam can step on small params while RS is in-flight + 3. Waits for each RS, runs local NS5 on the shard, launches async all-gather + 4. Each all-gather overlaps with next bank's NS5 + """ + def __init__(self, params, lr: float, momentum: float, backend_steps: int, + nesterov: bool = True, weight_decay: float = 0.0): + super().__init__( + params, + dict(lr=lr, momentum=momentum, backend_steps=backend_steps, + nesterov=nesterov, weight_decay=weight_decay), + ) + self._built = False + + def _build(self): + self._distributed = dist.is_available() and dist.is_initialized() + self._world_size = dist.get_world_size() if self._distributed else 1 + self._rank = dist.get_rank() if self._distributed else 0 + ws = self._world_size + + self._bank_meta = [] + for group in self.param_groups: + for p in group["params"]: + B = p.shape[0] + padded_B = ((B + ws - 1) // ws) * ws + shard_B = padded_B // ws + tail = p.shape[1:] + dev = p.device + self._bank_meta.append({ + 'p': p, + 'B': B, + 'padded_grad': torch.zeros(padded_B, *tail, device=dev, dtype=torch.bfloat16), + 'shard': torch.zeros(shard_B, *tail, device=dev, dtype=torch.bfloat16), + 'shard_mom': torch.zeros(shard_B, *tail, device=dev, dtype=torch.bfloat16), + 'full_update': torch.zeros(padded_B, *tail, device=dev, dtype=torch.bfloat16), + 'scale': max(1, p.shape[-2] / p.shape[-1]) ** 0.5, + }) + # Sort by size descending -- launch biggest reduce-scatters first + self._bank_meta.sort(key=lambda m: -m['p'].numel()) + self._built = True + + def launch_reduce_scatters(self): + """Phase 1: launch async reduce-scatter for all banks. Call right after backward.""" + if not self._built: + self._build() + if not self._distributed: + return + self._rs_futures = [] + for m in self._bank_meta: + p = m['p'] + if p.grad is None: + self._rs_futures.append(None) + continue + pg = m['padded_grad'] + pg[:m['B']].copy_(p.grad.bfloat16()) + if pg.shape[0] > m['B']: + pg[m['B']:].zero_() + fut = dist.reduce_scatter_tensor(m['shard'], pg, op=dist.ReduceOp.AVG, async_op=True) + self._rs_futures.append(fut) + + @torch.no_grad() + def step(self, closure=None): + """Phase 3: wait for RS, local NS5, all-gather. Call AFTER Adam steps.""" + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + + if not self._built: + self._build() + + for group in self.param_groups: + lr = group["lr"] + momentum = group["momentum"] + backend_steps = group["backend_steps"] + nesterov = group["nesterov"] + wd = group.get("weight_decay", 0.0) + + prev_ag_handle = None + prev_m = None + + sharded = self._distributed and hasattr(self, '_rs_futures') + + for i, m in enumerate(self._bank_meta): + p = m['p'] + if p.grad is None: + continue + + if prev_ag_handle is not None: + prev_ag_handle.wait() + pp = prev_m['p'] + upd = prev_m['full_update'][:prev_m['B']] + if wd > 0.0: + pp.data.mul_(1.0 - lr * wd) + pp.add_(upd.to(dtype=pp.dtype), alpha=-lr * prev_m['scale']) + + if sharded and self._rs_futures[i] is not None: + self._rs_futures[i].wait() + g = m['shard'] + buf = m['shard_mom'] + else: + g = p.grad.bfloat16() + state = self.state[p] + if "momentum_buffer" not in state: + state["momentum_buffer"] = torch.zeros_like(g) + buf = state["momentum_buffer"] + + buf.mul_(momentum).add_(g) + if nesterov: + update = g.add(buf, alpha=momentum) + else: + update = buf + + update = zeropower_via_newtonschulz5(update, steps=backend_steps) + + if sharded: + prev_ag_handle = dist.all_gather_into_tensor( + m['full_update'], update, async_op=True) + prev_m = m + else: + if wd > 0.0: + p.data.mul_(1.0 - lr * wd) + p.add_(update.to(dtype=p.dtype), alpha=-lr * m['scale']) + + if prev_ag_handle is not None: + prev_ag_handle.wait() + pp = prev_m['p'] + upd = prev_m['full_update'][:prev_m['B']] + if wd > 0.0: + pp.data.mul_(1.0 - lr * wd) + pp.add_(upd.to(dtype=pp.dtype), alpha=-lr * prev_m['scale']) + + if hasattr(self, '_rs_futures'): + del self._rs_futures + + return loss + +# --- Tokenizer evaluation helpers --- + +def build_sentencepiece_luts( + sp: spm.SentencePieceProcessor, vocab_size: int, device: torch.device +) -> tuple[Tensor, Tensor, Tensor]: + sp_vocab_size = int(sp.vocab_size()) + table_size = max(sp_vocab_size, vocab_size) + base_bytes_np = np.zeros((table_size,), dtype=np.int16) + has_leading_space_np = np.zeros((table_size,), dtype=np.bool_) + is_boundary_token_np = np.ones((table_size,), dtype=np.bool_) + for token_id in range(sp_vocab_size): + if sp.is_control(token_id) or sp.is_unknown(token_id) or sp.is_unused(token_id): + continue + is_boundary_token_np[token_id] = False + if sp.is_byte(token_id): + base_bytes_np[token_id] = 1 + continue + piece = sp.id_to_piece(token_id) + if piece.startswith("\u2581"): + has_leading_space_np[token_id] = True + piece = piece[1:] + base_bytes_np[token_id] = len(piece.encode("utf-8")) + return ( + torch.tensor(base_bytes_np, dtype=torch.int16, device=device), + torch.tensor(has_leading_space_np, dtype=torch.bool, device=device), + torch.tensor(is_boundary_token_np, dtype=torch.bool, device=device), + ) +def load_validation_tokens(pattern: str, seq_len: int, token_limit: int = 0) -> Tensor: + files = [Path(p) for p in sorted(glob.glob(pattern))] + if not files: + raise FileNotFoundError(f"No files found for pattern: {pattern}") + tokens = torch.cat([load_data_shard(file) for file in files]).contiguous() + if token_limit > 0: + tokens = tokens[: min(tokens.numel(), token_limit + 1)] + usable = ((tokens.numel() - 1) // seq_len) * seq_len + if usable <= 0: + raise ValueError(f"Validation split is too short for TRAIN_SEQ_LEN={seq_len}") + return tokens[: usable + 1] +def eval_val( + args: Hyperparameters, + model: nn.Module, + rank: int, + world_size: int, + device: torch.device, + grad_accum_steps: int, + val_tokens: Tensor, + base_bytes_lut: Tensor, + has_leading_space_lut: Tensor, + is_boundary_token_lut: Tensor, + eval_seq_len: int | None = None, +) -> tuple[float, float]: + seq_len = eval_seq_len or args.train_seq_len + local_batch_tokens = args.val_batch_size // (world_size * grad_accum_steps) + if local_batch_tokens < seq_len: + raise ValueError( + "VAL_BATCH_SIZE must provide at least one sequence per rank; " + f"got VAL_BATCH_SIZE={args.val_batch_size}, WORLD_SIZE={world_size}, " + f"GRAD_ACCUM_STEPS={grad_accum_steps}, seq_len={seq_len}" + ) + local_batch_seqs = local_batch_tokens // seq_len + total_seqs = (val_tokens.numel() - 1) // seq_len + seq_start = (total_seqs * rank) // world_size + seq_end = (total_seqs * (rank + 1)) // world_size + val_loss_sum = torch.zeros((), device=device, dtype=torch.float64) + val_token_count = torch.zeros((), device=device, dtype=torch.float64) + val_byte_count = torch.zeros((), device=device, dtype=torch.float64) + model.eval() + with torch.inference_mode(): + for batch_seq_start in range(seq_start, seq_end, local_batch_seqs): + batch_seq_end = min(batch_seq_start + local_batch_seqs, seq_end) + raw_start = batch_seq_start * seq_len + raw_end = batch_seq_end * seq_len + 1 + local = val_tokens[raw_start:raw_end].to(device=device, dtype=torch.int64, non_blocking=True) + x = local[:-1].reshape(-1, seq_len) + y = local[1:].reshape(-1, seq_len) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + batch_loss = model(x, y).detach() + batch_token_count = float(y.numel()) + val_loss_sum += batch_loss.to(torch.float64) * batch_token_count + val_token_count += batch_token_count + prev_ids = x.reshape(-1) + tgt_ids = y.reshape(-1) + token_bytes = base_bytes_lut[tgt_ids].to(dtype=torch.int16) + token_bytes += (has_leading_space_lut[tgt_ids] & ~is_boundary_token_lut[prev_ids]).to(dtype=torch.int16) + val_byte_count += token_bytes.to(torch.float64).sum() + if dist.is_available() and dist.is_initialized(): + dist.all_reduce(val_loss_sum, op=dist.ReduceOp.SUM) + dist.all_reduce(val_token_count, op=dist.ReduceOp.SUM) + dist.all_reduce(val_byte_count, op=dist.ReduceOp.SUM) + val_loss = val_loss_sum / val_token_count + bits_per_token = val_loss.item() / math.log(2.0) + tokens_per_byte = val_token_count.item() / val_byte_count.item() + model.train() + return float(val_loss.item()), float(bits_per_token * tokens_per_byte) + +# --- Quantization helpers --- + +CONTROL_TENSOR_NAME_PATTERNS = tuple( + pattern + for pattern in os.environ.get( + "CONTROL_TENSOR_NAME_PATTERNS", + "attn_scale,attn_scales,mlp_scale,mlp_scales,resid_mix,resid_mixes,q_gain,skip_weight,skip_weights,smear,dtg_gate,ve_layer_scales,ve_shared.scale,attn_gate,vr_lambda", + ).split(",") + if pattern +) +INT8_KEEP_FLOAT_FP32_NAME_PATTERNS = tuple( + pattern + for pattern in os.environ.get( + "INT8_KEEP_FLOAT_FP32_NAME_PATTERNS", + ",".join(CONTROL_TENSOR_NAME_PATTERNS), + ).split(",") + if pattern +) +INT8_KEEP_FLOAT_MAX_NUMEL = 65_536 +INT8_KEEP_FLOAT_STORE_DTYPE = torch.float16 +INT8_PER_ROW_SCALE_DTYPE = torch.float16 +INT8_CLIP_PERCENTILE = 99.99984 +INT8_CLIP_Q = INT8_CLIP_PERCENTILE / 100.0 +def tensor_nbytes(t: Tensor) -> int: + return int(t.numel()) * int(t.element_size()) +def keep_float_tensor(name: str, t: Tensor, passthrough_orig_dtypes: dict[str, str]) -> Tensor: + if any(pattern in name for pattern in INT8_KEEP_FLOAT_FP32_NAME_PATTERNS): + return t.float().contiguous() + if t.dtype in {torch.float32, torch.bfloat16}: + passthrough_orig_dtypes[name] = str(t.dtype).removeprefix("torch.") + return t.to(dtype=INT8_KEEP_FLOAT_STORE_DTYPE).contiguous() + return t +def quantize_float_tensor(t: Tensor) -> tuple[Tensor, Tensor]: + t32 = t.float() + if t32.ndim == 2: + clip_abs = ( + torch.quantile(t32.abs(), INT8_CLIP_Q, dim=1) + if t32.numel() + else torch.empty((t32.shape[0],), dtype=torch.float32) + ) + clipped = torch.maximum(torch.minimum(t32, clip_abs[:, None]), -clip_abs[:, None]) + scale = (clip_abs / 127.0).clamp_min(1.0 / 127.0) + q = torch.clamp(torch.round(clipped / scale[:, None]), -127, 127).to(torch.int8).contiguous() + return q, scale.to(dtype=INT8_PER_ROW_SCALE_DTYPE).contiguous() + clip_abs = float(torch.quantile(t32.abs().flatten(), INT8_CLIP_Q).item()) if t32.numel() else 0.0 + scale = torch.tensor(clip_abs / 127.0 if clip_abs > 0 else 1.0, dtype=torch.float32) + q = torch.clamp(torch.round(torch.clamp(t32, -clip_abs, clip_abs) / scale), -127, 127).to(torch.int8).contiguous() + return q, scale +def quantize_state_dict_int8(state_dict: dict[str, Tensor]): + quantized: dict[str, Tensor] = {} + scales: dict[str, Tensor] = {} + dtypes: dict[str, str] = {} + passthrough: dict[str, Tensor] = {} + passthrough_orig_dtypes: dict[str, str] = {} + qmeta: dict[str, dict[str, object]] = {} + stats = dict.fromkeys( + ("param_count", "num_tensors", "num_float_tensors", "num_nonfloat_tensors", "baseline_tensor_bytes", "int8_payload_bytes"), + 0, + ) + for name, tensor in state_dict.items(): + t = tensor.detach().to("cpu").contiguous() + stats["param_count"] += int(t.numel()) + stats["num_tensors"] += 1 + stats["baseline_tensor_bytes"] += tensor_nbytes(t) + if not t.is_floating_point(): + stats["num_nonfloat_tensors"] += 1 + passthrough[name] = t + stats["int8_payload_bytes"] += tensor_nbytes(t) + continue + if t.numel() <= INT8_KEEP_FLOAT_MAX_NUMEL: + kept = keep_float_tensor(name, t, passthrough_orig_dtypes) + passthrough[name] = kept + stats["int8_payload_bytes"] += tensor_nbytes(kept) + continue + stats["num_float_tensors"] += 1 + q, s = quantize_float_tensor(t) + if s.ndim > 0: + qmeta[name] = {"scheme": "per_row", "axis": 0} + quantized[name] = q + scales[name] = s + dtypes[name] = str(t.dtype).removeprefix("torch.") + stats["int8_payload_bytes"] += tensor_nbytes(q) + tensor_nbytes(s) + obj: dict[str, object] = { + "__quant_format__": "int8_clean_per_row_v1", + "quantized": quantized, + "scales": scales, + "dtypes": dtypes, + "passthrough": passthrough, + } + if qmeta: + obj["qmeta"] = qmeta + if passthrough_orig_dtypes: + obj["passthrough_orig_dtypes"] = passthrough_orig_dtypes + return obj, stats +def dequantize_state_dict_int8(obj: dict[str, object]) -> dict[str, Tensor]: + out: dict[str, Tensor] = {} + qmeta = obj.get("qmeta", {}) + passthrough_orig_dtypes = obj.get("passthrough_orig_dtypes", {}) + for name, q in obj["quantized"].items(): + dtype = getattr(torch, obj["dtypes"][name]) + s = obj["scales"][name] + if qmeta.get(name, {}).get("scheme") == "per_row" or s.ndim > 0: + s = s.to(dtype=torch.float32) + out[name] = (q.float() * s.view(q.shape[0], *([1] * (q.ndim - 1)))).to(dtype=dtype).contiguous() + else: + scale = float(s.item()) + out[name] = (q.float() * scale).to(dtype=dtype).contiguous() + for name, t in obj["passthrough"].items(): + out_t = t.detach().to("cpu").contiguous() + orig_dtype = passthrough_orig_dtypes.get(name) + if isinstance(orig_dtype, str): + out_t = out_t.to(dtype=getattr(torch, orig_dtype)).contiguous() + out[name] = out_t + return out + +# --- Data loading --- + +def load_data_shard(file: Path) -> Tensor: + header_bytes = 256 * np.dtype(" None: + self.file_idx = (self.file_idx + 1) % len(self.files) + self.tokens = load_data_shard(self.files[self.file_idx]) + self.pos = 0 + def take(self, n: int) -> Tensor: + chunks: list[Tensor] = [] + remaining = n + while remaining > 0: + avail = self.tokens.numel() - self.pos + if avail <= 0: + self._advance_file() + continue + k = min(remaining, avail) + chunks.append(self.tokens[self.pos : self.pos + k]) + self.pos += k + remaining -= k + return chunks[0] if len(chunks) == 1 else torch.cat(chunks) +class DistributedTokenLoader: + def __init__(self, pattern: str, rank: int, world_size: int, device: torch.device): + self.rank = rank + self.world_size = world_size + self.device = device + self.stream = TokenStream(pattern) + def next_batch(self, global_tokens: int, seq_len: int, grad_accum_steps: int) -> tuple[Tensor, Tensor]: + local_tokens = global_tokens // (self.world_size * grad_accum_steps) + per_rank_span = local_tokens + 1 + chunk = self.stream.take(per_rank_span * self.world_size) + start = self.rank * per_rank_span + local = chunk[start : start + per_rank_span].to(dtype=torch.int64) + x = local[:-1].reshape(-1, seq_len) + y = local[1:].reshape(-1, seq_len) + return x.to(self.device, non_blocking=True), y.to(self.device, non_blocking=True) + +# --- Transformer modules --- + +class RMSNorm(nn.Module): + def __init__(self, eps: float | None = None): + super().__init__() + self.eps = eps + def forward(self, x: Tensor) -> Tensor: + return F.rms_norm(x, (x.size(-1),), eps=self.eps) +class CastedLinear(nn.Linear): + _qat_enabled: bool = False + def forward(self, x: Tensor) -> Tensor: + w = self.weight.to(x.dtype) + if CastedLinear._qat_enabled and self.training and w.ndim == 2: + with torch.no_grad(): + w32 = self.weight.float() + row_max = w32.abs().amax(dim=1) + scale = (row_max / 31.0).clamp_min(1.0 / 31.0) + w_q = (torch.clamp(torch.round(w32 / scale[:, None]), -32, 31) * scale[:, None]).to(x.dtype) + w = w + (w_q - w).detach() + bias = self.bias.to(x.dtype) if self.bias is not None else None + return F.linear(x, w, bias) +def restore_low_dim_params_to_fp32(module: nn.Module) -> None: + with torch.no_grad(): + for name, param in module.named_parameters(): + if (param.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS)) and param.dtype != torch.float32: + param.data = param.data.float() +class Rotary(nn.Module): + def __init__(self, dim: int, base: float = 10000.0, train_seq_len: int = 1024, rope_dims: int = 0): + super().__init__() + self.dim = dim + self.base = base + self.train_seq_len = train_seq_len + self.rope_dims = rope_dims if rope_dims > 0 else dim + inv_freq = 1.0 / (base ** (torch.arange(0, self.rope_dims, 2, dtype=torch.float32) / self.rope_dims)) + self.register_buffer("inv_freq", inv_freq, persistent=False) + self._seq_len_cached = 0 + self._cos_cached: Tensor | None = None + self._sin_cached: Tensor | None = None + def forward(self, seq_len: int, device: torch.device, dtype: torch.dtype) -> tuple[Tensor, Tensor]: + if ( + self._cos_cached is None + or self._sin_cached is None + or self._seq_len_cached != seq_len + or self._cos_cached.device != device + ): + rd = self.rope_dims + if seq_len > self.train_seq_len: + scale = seq_len / self.train_seq_len + new_base = self.base * (scale ** (rd / (rd - 2))) + inv_freq = 1.0 / (new_base ** (torch.arange(0, rd, 2, dtype=torch.float32, device=device) / rd)) + else: + inv_freq = self.inv_freq.to(device) + t = torch.arange(seq_len, device=device, dtype=inv_freq.dtype) + freqs = torch.outer(t, inv_freq) + self._cos_cached = freqs.cos()[None, :, None, :] + self._sin_cached = freqs.sin()[None, :, None, :] + self._seq_len_cached = seq_len + return self._cos_cached.to(dtype=dtype), self._sin_cached.to(dtype=dtype) +def apply_rotary_emb(x: Tensor, cos: Tensor, sin: Tensor, rope_dims: int = 0) -> Tensor: + if rope_dims > 0 and rope_dims < x.size(-1): + x_rope, x_pass = x[..., :rope_dims], x[..., rope_dims:] + half = rope_dims // 2 + x1, x2 = x_rope[..., :half], x_rope[..., half:] + x_rope = torch.cat((x1 * cos + x2 * sin, x1 * (-sin) + x2 * cos), dim=-1) + return torch.cat((x_rope, x_pass), dim=-1) + half = x.size(-1) // 2 + x1, x2 = x[..., :half], x[..., half:] + return torch.cat((x1 * cos + x2 * sin, x1 * (-sin) + x2 * cos), dim=-1) + +class CausalSelfAttention(nn.Module): + def __init__( + self, + dim: int, + num_heads: int, + num_kv_heads: int, + rope_base: float, + qk_gain_init: float, + gated_attention: bool = False, + value_residual: bool = False, + ): + super().__init__() + if dim % num_heads != 0: + raise ValueError("model_dim must be divisible by num_heads") + if num_heads % num_kv_heads != 0: + raise ValueError("num_heads must be divisible by num_kv_heads") + self.num_heads = num_heads + self.num_kv_heads = num_kv_heads + self.head_dim = dim // num_heads + if self.head_dim % 2 != 0: + raise ValueError("head_dim must be even for RoPE") + # No CastedLinear -- weights come from banks + self.q_gain = nn.Parameter(torch.full((num_heads,), qk_gain_init, dtype=torch.float32)) + self.rope_dims = 0 # set by GPT.__init__ for partial RoPE + self.rotary = Rotary(self.head_dim, base=rope_base, train_seq_len=1024) + self.use_xsa = False # set by GPT.__init__ for deep layers only + # Gated attention and value residual (non-banked small params) + self.gated_attention = gated_attention + if gated_attention: + self.attn_gate = nn.Linear(dim, num_heads, bias=True) + nn.init.zeros_(self.attn_gate.weight) + nn.init.constant_(self.attn_gate.bias, 4.0) + self.value_residual = value_residual + if value_residual: + self.vr_lambda = nn.Parameter(torch.tensor([0.5, 0.5], dtype=torch.float32)) + def _xsa_efficient(self, y: Tensor, v: Tensor) -> Tensor: + """Efficient XSA: subtract self-value projection via GQA-aware reshape (no repeat_interleave). + y: [B, T, H, D], v: [B, T, Hkv, D]. H must be divisible by Hkv.""" + B, T, H, D = y.shape + Hkv = v.size(-2) + group = H // Hkv + y_g = y.reshape(B, T, Hkv, group, D) # [B, T, Hkv, group, D] + vn = F.normalize(v, dim=-1).unsqueeze(-2) # [B, T, Hkv, 1, D] -- broadcast ready + proj = (y_g * vn).sum(dim=-1, keepdim=True) * vn + return (y_g - proj).reshape(B, T, H, D) + def forward(self, x: Tensor, q_w: Tensor, k_w: Tensor, v_w: Tensor, out_w: Tensor, v_embed: Tensor | None = None, v0: Tensor | None = None) -> tuple[Tensor, Tensor | None]: + bsz, seqlen, dim = x.shape + q = F.linear(x, q_w.to(x.dtype)).reshape(bsz, seqlen, self.num_heads, self.head_dim) + k = F.linear(x, k_w.to(x.dtype)).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim) + v = F.linear(x, v_w.to(x.dtype)) + if v_embed is not None: + v = v + v_embed + v = v.reshape(bsz, seqlen, self.num_kv_heads, self.head_dim) + raw_v = v if self.value_residual else None + if self.value_residual and v0 is not None: + lam = self.vr_lambda.to(dtype=v.dtype) + v = lam[0] * v0 + lam[1] * v + q = F.rms_norm(q, (q.size(-1),)) + k = F.rms_norm(k, (k.size(-1),)) + cos, sin = self.rotary(seqlen, x.device, q.dtype) + q = apply_rotary_emb(q, cos, sin, self.rope_dims) + k = apply_rotary_emb(k, cos, sin, self.rope_dims) + q = q * self.q_gain.to(dtype=q.dtype)[None, None, :, None] + y = flash_attn_3_func(q, k, v, causal=True) + if self.use_xsa: + y = self._xsa_efficient(y, v) + if self.gated_attention: + # gate shape: (bsz, seqlen, num_heads) -> (bsz, seqlen, num_heads, 1) for B,T,H,D layout + gate = torch.sigmoid(self.attn_gate(x)).unsqueeze(-1) + y = y * gate + y = y.reshape(bsz, seqlen, dim) + return F.linear(y, out_w.to(x.dtype)), raw_v + +class SmearGate(nn.Module): + def __init__(self, dim: int): + super().__init__() + self.gate = nn.Parameter(torch.zeros(dim, dtype=torch.float32)) + def forward(self, x: Tensor) -> Tensor: + g = torch.sigmoid(self.gate.to(dtype=x.dtype))[None, None, :] + x_prev = torch.cat([torch.zeros_like(x[:, :1]), x[:, :-1]], dim=1) + return (1 - g) * x + g * x_prev + +class BigramHashEmbedding(nn.Module): + def __init__(self, bigram_vocab_size: int, bigram_dim: int, model_dim: int): + super().__init__() + self.bigram_vocab_size = bigram_vocab_size + self.embed = nn.Embedding(bigram_vocab_size, bigram_dim) + nn.init.zeros_(self.embed.weight) + self.proj = CastedLinear(bigram_dim, model_dim, bias=False) if bigram_dim != model_dim else None + if self.proj is not None: + nn.init.zeros_(self.proj.weight) + self.scale = nn.Parameter(torch.tensor(0.05, dtype=torch.float32)) + def bigram_hash(self, tokens: Tensor) -> Tensor: + t = tokens.to(torch.int32) + mod = self.bigram_vocab_size - 1 + out = torch.empty_like(t) + out[..., 0] = mod + out[..., 1:] = torch.bitwise_xor(36313 * t[..., 1:], 27191 * t[..., :-1]) % mod + return out.long() + def forward(self, token_ids: Tensor) -> Tensor: + h = self.embed(self.bigram_hash(token_ids)) + if self.proj is not None: + h = self.proj(h) + return h * self.scale.to(dtype=h.dtype) + +class ValueEmbedding(nn.Module): + """Reinject token identity into attention values at specific layers. + Each table maps vocab tokens to a low-dim embedding, projected to model_dim.""" + def __init__(self, vocab_size: int, ve_dim: int, model_dim: int): + super().__init__() + self.embed = nn.Embedding(vocab_size, ve_dim) + nn.init.normal_(self.embed.weight, std=0.01) + self.proj = CastedLinear(ve_dim, model_dim, bias=False) if ve_dim != model_dim else None + if self.proj is not None: + nn.init.zeros_(self.proj.weight) + self.scale = nn.Parameter(torch.tensor(0.1, dtype=torch.float32)) + def forward(self, token_ids: Tensor) -> Tensor: + h = self.embed(token_ids) + if self.proj is not None: + h = self.proj(h) + return h * self.scale.to(dtype=h.dtype) + +class MLP(nn.Module): + def __init__( + self, + dim: int, + mlp_mult: int, + activation_mode: str = "leaky_relu_sq", + activation_neg_slope: float = 0.5, + asymmetric_square_init: float = 0.25, + gated_square_beta_init: float = 1.0, + ): + super().__init__() + # No CastedLinear -- weights come from banks + self.activation_mode = activation_mode + self.activation_neg_slope = activation_neg_slope + if activation_mode == "asymmetric_square": + self.neg_sq_scale = nn.Parameter(torch.tensor(asymmetric_square_init, dtype=torch.float32)) + else: + self.neg_sq_scale = None + if activation_mode == "gated_square": + self.gated_square_beta = nn.Parameter(torch.tensor(gated_square_beta_init, dtype=torch.float32)) + else: + self.gated_square_beta = None + def forward(self, x: Tensor, up_w: Tensor, down_w: Tensor) -> Tensor: + u = F.linear(x, up_w.to(x.dtype)) + if self.activation_mode == "leaky_relu_sq": + h = F.leaky_relu(u, negative_slope=self.activation_neg_slope).square() + elif self.activation_mode == "asymmetric_square": + neg_sq_scale = self.neg_sq_scale.to(dtype=u.dtype).clamp(0.0, 4.0) + h = F.relu(u).square() + neg_sq_scale * F.relu(-u).square() + elif self.activation_mode == "gated_square": + beta = self.gated_square_beta.to(dtype=u.dtype).clamp(0.0, 8.0) + h = u.square() * torch.sigmoid(beta * u) + elif self.activation_mode == "sign_preserving_square": + h = u * u.abs() + else: + raise ValueError(f"Unknown ACTIVATION_MODE={self.activation_mode}") + return F.linear(h, down_w.to(x.dtype)) + +class Block(nn.Module): + def __init__( + self, + dim: int, + num_heads: int, + num_kv_heads: int, + mlp_mult: int, + rope_base: float, + qk_gain_init: float, + layer_idx: int = 0, + ln_scale: bool = False, + dtg: bool = False, + gated_attention: bool = False, + value_residual: bool = False, + activation_mode: str = "leaky_relu_sq", + activation_neg_slope: float = 0.5, + asymmetric_square_init: float = 0.25, + gated_square_beta_init: float = 1.0, + ): + super().__init__() + self.attn_norm = RMSNorm() + self.mlp_norm = RMSNorm() + self.attn = CausalSelfAttention(dim, num_heads, num_kv_heads, rope_base, qk_gain_init, + gated_attention=gated_attention, value_residual=value_residual) + self.mlp = MLP( + dim, + mlp_mult, + activation_mode=activation_mode, + activation_neg_slope=activation_neg_slope, + asymmetric_square_init=asymmetric_square_init, + gated_square_beta_init=gated_square_beta_init, + ) + self.attn_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + self.mlp_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + self.resid_mix = nn.Parameter(torch.stack((torch.ones(dim), torch.zeros(dim))).float()) + self.ln_scale_factor = 1.0 / math.sqrt(layer_idx + 1) if ln_scale else 1.0 + if dtg: + self.dtg_gate = nn.Linear(dim, 1, bias=True) + nn.init.zeros_(self.dtg_gate.weight) + nn.init.constant_(self.dtg_gate.bias, 2.0) + else: + self.dtg_gate = None + def forward(self, x: Tensor, x0: Tensor, q_w: Tensor, k_w: Tensor, v_w: Tensor, out_w: Tensor, up_w: Tensor, down_w: Tensor, v_embed: Tensor | None = None, v0: Tensor | None = None) -> tuple[Tensor, Tensor | None]: + mix = self.resid_mix.to(dtype=x.dtype) + x_in = mix[0][None, None, :] * x + mix[1][None, None, :] * x0 + attn_out, raw_v = self.attn(self.attn_norm(x_in) * self.ln_scale_factor, q_w, k_w, v_w, out_w, v_embed=v_embed, v0=v0) + x_out = x_in + self.attn_scale.to(dtype=x_in.dtype)[None, None, :] * attn_out + x_out = x_out + self.mlp_scale.to(dtype=x_out.dtype)[None, None, :] * self.mlp(self.mlp_norm(x_out) * self.ln_scale_factor, up_w, down_w) + if self.dtg_gate is not None: + gate = torch.sigmoid(self.dtg_gate(x_in.detach())) + x_out = x_in + gate * (x_out - x_in) + return x_out, raw_v + +class GPT(nn.Module): + def __init__( + self, + vocab_size: int, + num_layers: int, + model_dim: int, + num_heads: int, + num_kv_heads: int, + mlp_mult: int, + tie_embeddings: bool, + tied_embed_init_std: float, + logit_softcap: float, + rope_base: float, + qk_gain_init: float, + mtp_num_heads: int = 0, + mtp_loss_weight: float = 0.1, + bigram_vocab_size: int = 0, + bigram_dim: int = 128, + xsa_last_n: int = 0, + rope_dims: int = 0, + ln_scale: bool = False, + dtg: bool = False, + ve_enabled: bool = False, + ve_dim: int = 128, + ve_layers: str = "9,10", + gated_attention: bool = False, + value_residual: bool = False, + activation_mode: str = "leaky_relu_sq", + activation_neg_slope: float = 0.5, + asymmetric_square_init: float = 0.25, + gated_square_beta_init: float = 1.0, + ): + super().__init__() + self._ve_target_dim = num_kv_heads * (model_dim // num_heads) # kv_dim for value projection + if logit_softcap <= 0.0: + raise ValueError(f"logit_softcap must be positive, got {logit_softcap}") + self.tie_embeddings = tie_embeddings + self.tied_embed_init_std = tied_embed_init_std + self.logit_softcap = logit_softcap + self.value_residual = value_residual + self.mtp_num_heads = mtp_num_heads + self.mtp_loss_weight = mtp_loss_weight + self.tok_emb = nn.Embedding(vocab_size, model_dim) + self.bigram = BigramHashEmbedding(bigram_vocab_size, bigram_dim, model_dim) if bigram_vocab_size > 0 else None + self.smear = SmearGate(model_dim) + self.num_encoder_layers = num_layers // 2 + self.num_decoder_layers = num_layers - self.num_encoder_layers + self.num_skip_weights = min(self.num_encoder_layers, self.num_decoder_layers) + self.skip_weights = nn.Parameter(torch.ones(self.num_skip_weights, model_dim, dtype=torch.float32)) + # Parameter banks: contiguous 3D tensors for batched optimizer + head_dim = model_dim // num_heads + kv_dim = num_kv_heads * head_dim + mlp_dim = int(mlp_mult * model_dim) + self.num_layers = num_layers + self.qo_bank = nn.Parameter(torch.empty(2 * num_layers, model_dim, model_dim)) + self.kv_bank = nn.Parameter(torch.empty(2 * num_layers, kv_dim, model_dim)) + self.mlp_up_bank = nn.Parameter(torch.empty(num_layers, mlp_dim, model_dim)) + self.mlp_down_bank = nn.Parameter(torch.empty(num_layers, model_dim, mlp_dim)) + self.blocks = nn.ModuleList( + [ + Block( + model_dim, + num_heads, + num_kv_heads, + mlp_mult, + rope_base, + qk_gain_init, + layer_idx=i, + ln_scale=ln_scale, + dtg=dtg, + gated_attention=gated_attention, + value_residual=value_residual, + activation_mode=activation_mode, + activation_neg_slope=activation_neg_slope, + asymmetric_square_init=asymmetric_square_init, + gated_square_beta_init=gated_square_beta_init, + ) + for i in range(num_layers) + ] + ) + if rope_dims > 0: + head_dim = model_dim // num_heads + for block in self.blocks: + block.attn.rope_dims = rope_dims + block.attn.rotary = Rotary(head_dim, base=rope_base, train_seq_len=1024, rope_dims=rope_dims) + self.ve_layer_indices = [int(x) for x in ve_layers.split(",") if x.strip()] if ve_enabled else [] + kv_dim_ve = self._ve_target_dim + if self.ve_layer_indices: + self.ve_shared = ValueEmbedding(vocab_size, ve_dim, kv_dim_ve) + self.ve_layer_scales = nn.ParameterList( + [nn.Parameter(torch.ones(1, dtype=torch.float32)) for _ in self.ve_layer_indices] + ) + else: + self.ve_shared = None + self.ve_layer_scales = nn.ParameterList() + self.value_embeds = nn.ModuleList() # keep empty for compat + self.final_norm = RMSNorm() + self.lm_head = None if tie_embeddings else CastedLinear(model_dim, vocab_size, bias=False) + if self.lm_head is not None: + self.lm_head._zero_init = True + self.mtp_heads = nn.ModuleList( + [CastedLinear(model_dim, vocab_size, bias=False) for _ in range(mtp_num_heads)] + ) + for head in self.mtp_heads: + head._zero_init = True + if xsa_last_n > 0: + for i in range(max(0, num_layers - xsa_last_n), num_layers): + self.blocks[i].attn.use_xsa = True + self._init_weights() + def _init_weights(self) -> None: + if self.tie_embeddings: + nn.init.normal_(self.tok_emb.weight, mean=0.0, std=self.tied_embed_init_std) + n = self.num_layers + proj_scale = 1.0 / math.sqrt(2 * n) + # Init banks: orthogonal, with proj layers scaled down and out/down zero-init + for i in range(n): + nn.init.orthogonal_(self.qo_bank.data[i], gain=1.0) # Q + nn.init.zeros_(self.qo_bank.data[n + i]) # Out (zero init) + nn.init.orthogonal_(self.kv_bank.data[i], gain=1.0) # K + nn.init.orthogonal_(self.kv_bank.data[n + i], gain=1.0) # V + nn.init.orthogonal_(self.mlp_up_bank.data[i], gain=1.0) # MLP up + nn.init.zeros_(self.mlp_down_bank.data[i]) # MLP down (zero init) + # Scale proj layers (out_proj and mlp_down are "proj" layers) + self.qo_bank.data[n + i].mul_(proj_scale) + self.mlp_down_bank.data[i].mul_(proj_scale) + # Init remaining nn.Linear modules (bigram proj, mtp heads, lm_head) + for name, module in self.named_modules(): + if isinstance(module, nn.Linear): + if getattr(module, "_zero_init", False): + nn.init.zeros_(module.weight) + elif module.weight.ndim == 2 and module.weight.shape[0] >= 64 and module.weight.shape[1] >= 64: + nn.init.orthogonal_(module.weight, gain=1.0) + def _get_ve(self, layer_idx: int, input_ids: Tensor, ve_cache: dict | None = None) -> Tensor | None: + """Get value embedding for a specific layer using shared table + per-layer scale.""" + if self.ve_shared is None or layer_idx not in self.ve_layer_indices: + return None + if ve_cache is not None and 've' not in ve_cache: + ve_cache['ve'] = self.ve_shared(input_ids) + ve_base = ve_cache['ve'] if ve_cache is not None else self.ve_shared(input_ids) + ve_idx = self.ve_layer_indices.index(layer_idx) + return ve_base * self.ve_layer_scales[ve_idx].to(dtype=ve_base.dtype) + def forward(self, input_ids: Tensor, target_ids: Tensor) -> Tensor: + n = self.num_layers + x = self.tok_emb(input_ids) + if self.bigram is not None: + x = x + self.bigram(input_ids) + x = F.rms_norm(x, (x.size(-1),)) + x = self.smear(x) + x0 = x + v0 = None + skips: list[Tensor] = [] + ve_cache: dict = {} + for i in range(self.num_encoder_layers): + ve = self._get_ve(i, input_ids, ve_cache) + x, raw_v = self.blocks[i](x, x0, + self.qo_bank[i], self.kv_bank[i], self.kv_bank[n + i], + self.qo_bank[n + i], self.mlp_up_bank[i], self.mlp_down_bank[i], + v_embed=ve, v0=v0) + if v0 is None and raw_v is not None: + v0 = raw_v + skips.append(x) + for i in range(self.num_decoder_layers): + bi = self.num_encoder_layers + i + if skips: + x = x + self.skip_weights[i].to(dtype=x.dtype)[None, None, :] * skips.pop() + ve = self._get_ve(bi, input_ids, ve_cache) + x, _ = self.blocks[bi](x, x0, + self.qo_bank[bi], self.kv_bank[bi], self.kv_bank[n + bi], + self.qo_bank[n + bi], self.mlp_up_bank[bi], self.mlp_down_bank[bi], + v_embed=ve, v0=v0) + x = self.final_norm(x) + x_flat = x.reshape(-1, x.size(-1)) + targets = target_ids.reshape(-1) + if self.tie_embeddings: + logits_proj = F.linear(x_flat, self.tok_emb.weight) + else: + if self.lm_head is None: + raise RuntimeError("lm_head is required when tie_embeddings=False") + logits_proj = self.lm_head(x_flat) + logits = self.logit_softcap * torch.tanh(logits_proj / self.logit_softcap) + main_loss = F.cross_entropy(logits.float(), targets, reduction="mean") + if self.training and self.mtp_num_heads > 0 and self.mtp_loss_weight > 0.0: + _, seqlen, dim = x.shape + mtp_loss_sum = x.new_zeros(()) + mtp_loss_count = 0 + for k, mtp_head in enumerate(self.mtp_heads): + valid_t = seqlen - (k + 1) + if valid_t <= 0: + continue + mtp_hidden = x[:, :valid_t, :].reshape(-1, dim) + mtp_targets = target_ids[:, k + 1 :].reshape(-1) + mtp_logits_proj = mtp_head(mtp_hidden) + mtp_logits = self.logit_softcap * torch.tanh(mtp_logits_proj / self.logit_softcap) + mtp_loss_sum = mtp_loss_sum + F.cross_entropy(mtp_logits.float(), mtp_targets, reduction="mean") + mtp_loss_count += 1 + if mtp_loss_count > 0: + main_loss = main_loss + self.mtp_loss_weight * (mtp_loss_sum / mtp_loss_count) + return main_loss + def forward_logits(self, input_ids: Tensor) -> Tensor: + """Return logits (bsz, seq_len, vocab) without computing loss.""" + n = self.num_layers + x = self.tok_emb(input_ids) + if self.bigram is not None: + x = x + self.bigram(input_ids) + x = F.rms_norm(x, (x.size(-1),)) + x = self.smear(x) + x0 = x + v0 = None + skips: list[Tensor] = [] + ve_cache: dict = {} + for i in range(self.num_encoder_layers): + ve = self._get_ve(i, input_ids, ve_cache) + x, raw_v = self.blocks[i](x, x0, + self.qo_bank[i], self.kv_bank[i], self.kv_bank[n + i], + self.qo_bank[n + i], self.mlp_up_bank[i], self.mlp_down_bank[i], + v_embed=ve, v0=v0) + if v0 is None and raw_v is not None: + v0 = raw_v + skips.append(x) + for i in range(self.num_decoder_layers): + bi = self.num_encoder_layers + i + if skips: + x = x + self.skip_weights[i].to(dtype=x.dtype)[None, None, :] * skips.pop() + ve = self._get_ve(bi, input_ids, ve_cache) + x, _ = self.blocks[bi](x, x0, + self.qo_bank[bi], self.kv_bank[bi], self.kv_bank[n + bi], + self.qo_bank[n + bi], self.mlp_up_bank[bi], self.mlp_down_bank[bi], + v_embed=ve, v0=v0) + x = self.final_norm(x) + if self.tie_embeddings: + logits_proj = F.linear(x, self.tok_emb.weight) + else: + logits_proj = self.lm_head(x) + return self.logit_softcap * torch.tanh(logits_proj / self.logit_softcap) + +# --- Sliding window evaluation --- + +def eval_val_sliding( + args: Hyperparameters, + base_model: nn.Module, + rank: int, + world_size: int, + device: torch.device, + val_tokens: Tensor, + base_bytes_lut: Tensor, + has_leading_space_lut: Tensor, + is_boundary_token_lut: Tensor, + stride: int, + batch_seqs: int = 32, + eval_seq_len: int | None = None, +) -> tuple[float, float]: + """Sliding window evaluation: each token scored with maximum context.""" + seq_len = eval_seq_len or args.train_seq_len + total_tokens = val_tokens.numel() - 1 + window_starts = [ws for ws in range(0, total_tokens, stride) + if min(ws + seq_len, total_tokens) - ws >= 1] + total_windows = len(window_starts) + my_s = (total_windows * rank) // world_size + my_e = (total_windows * (rank + 1)) // world_size + my_windows = window_starts[my_s:my_e] + loss_sum = torch.zeros((), device=device, dtype=torch.float64) + token_count = torch.zeros((), device=device, dtype=torch.float64) + byte_count = torch.zeros((), device=device, dtype=torch.float64) + base_model.eval() + compiled_logits = torch.compile(base_model.forward_logits, dynamic=False, fullgraph=True) + with torch.inference_mode(): + for bi in range(0, len(my_windows), batch_seqs): + batch_ws = my_windows[bi:bi + batch_seqs] + bsz = len(batch_ws) + x_batch = torch.zeros(bsz, seq_len, dtype=torch.int64, device=device) + y_batch = torch.zeros(bsz, seq_len, dtype=torch.int64, device=device) + wlens: list[int] = [] + for i, ws in enumerate(batch_ws): + end = min(ws + seq_len, total_tokens) + wlen = end - ws + wlens.append(wlen) + chunk = val_tokens[ws:end + 1].to(dtype=torch.int64, device=device) + x_batch[i, :wlen] = chunk[:-1] + y_batch[i, :wlen] = chunk[1:] + with torch.autocast(device_type="cuda", dtype=torch.bfloat16): + logits = compiled_logits(x_batch) + nll = F.cross_entropy( + logits.reshape(-1, logits.size(-1)).float(), + y_batch.reshape(-1), + reduction="none", + ).reshape(bsz, seq_len) + for i, ws in enumerate(batch_ws): + wlen = wlens[i] + s = 0 if ws == 0 else max(wlen - stride, 0) + scored_nll = nll[i, s:wlen].to(torch.float64) + loss_sum += scored_nll.sum() + token_count += float(wlen - s) + tgt = y_batch[i, s:wlen] + prev = x_batch[i, s:wlen] + tb = base_bytes_lut[tgt].to(torch.float64) + tb += (has_leading_space_lut[tgt] & ~is_boundary_token_lut[prev]).to(torch.float64) + byte_count += tb.sum() + if dist.is_available() and dist.is_initialized(): + dist.all_reduce(loss_sum, op=dist.ReduceOp.SUM) + dist.all_reduce(token_count, op=dist.ReduceOp.SUM) + dist.all_reduce(byte_count, op=dist.ReduceOp.SUM) + val_loss = (loss_sum / token_count).item() + bits_per_token = val_loss / math.log(2.0) + tokens_per_byte = token_count.item() / byte_count.item() + base_model.train() + return val_loss, bits_per_token * tokens_per_byte + + +def eval_val_sliding_ttt( + args: Hyperparameters, base_model: nn.Module, rank: int, world_size: int, + device: torch.device, val_tokens: Tensor, base_bytes_lut: Tensor, + has_leading_space_lut: Tensor, is_boundary_token_lut: Tensor, + stride: int, batch_seqs: int = 32, log0=print, +) -> tuple[float, float]: + """Legal score-first TTT (PR #461 recipe): score each chunk with sliding windows, + then train on it. Every token scored BEFORE any update that could use it.""" + seq_len = args.train_seq_len + total_tokens = val_tokens.numel() - 1 + ttt_chunk = args.ttt_chunk_tokens + + # Pre-compute all window starts + window_starts = [ws for ws in range(0, total_tokens, stride) + if min(ws + seq_len, total_tokens) - ws >= stride or ws == 0] + + # Assign each window to a chunk based on the first token it scores + num_chunks = (total_tokens + ttt_chunk - 1) // ttt_chunk + chunk_windows: list[list[int]] = [[] for _ in range(num_chunks)] + for ws in window_starts: + end = min(ws + seq_len, total_tokens) + wlen = end - ws + s = 0 if ws == 0 else max(wlen - stride, 0) + scored_start = ws + s + ci = min(scored_start // ttt_chunk, num_chunks - 1) + chunk_windows[ci].append(ws) + + log0(f"ttt_sliding:start chunks={num_chunks} chunk_tokens={ttt_chunk} " + f"total_windows={len(window_starts)} stride={stride} " + f"ttt_lr={args.ttt_lr} ttt_epochs={args.ttt_epochs} " + f"freeze_blocks={args.ttt_freeze_blocks}") + + loss_sum = torch.zeros((), device=device, dtype=torch.float64) + token_count = torch.zeros((), device=device, dtype=torch.float64) + byte_count = torch.zeros((), device=device, dtype=torch.float64) + + # Freeze first N blocks + frozen_block_ids = set(range(min(args.ttt_freeze_blocks, len(base_model.blocks)))) + ttt_params = [] + for name, p in base_model.named_parameters(): + freeze = False + for bi in frozen_block_ids: + if f"blocks.{bi}." in name: + freeze = True + break + if freeze: + p.requires_grad_(False) + else: + p.requires_grad_(True) + ttt_params.append(p) + + log0(f"ttt_sliding:params unfrozen={sum(p.numel() for p in ttt_params)} " + f"frozen={sum(p.numel() for p in base_model.parameters() if not p.requires_grad)}") + + optimizer = torch.optim.SGD(ttt_params, lr=args.ttt_lr, momentum=args.ttt_momentum) + t0 = time.perf_counter() + + for ci in range(num_chunks): + windows = chunk_windows[ci] + if not windows: + continue + chunk_start = ci * ttt_chunk + chunk_end = min((ci + 1) * ttt_chunk, total_tokens) + + # --- Phase 1: SCORE this chunk's windows (inference_mode) --- + my_s = (len(windows) * rank) // world_size + my_e = (len(windows) * (rank + 1)) // world_size + my_windows = windows[my_s:my_e] + + base_model.eval() + with torch.inference_mode(): + for bi in range(0, len(my_windows), batch_seqs): + batch_ws = my_windows[bi:bi + batch_seqs] + bsz = len(batch_ws) + x_batch = torch.zeros(bsz, seq_len, dtype=torch.int64, device=device) + y_batch = torch.zeros(bsz, seq_len, dtype=torch.int64, device=device) + wlens: list[int] = [] + for i, ws in enumerate(batch_ws): + end = min(ws + seq_len, total_tokens) + wlen = end - ws + wlens.append(wlen) + chunk_tok = val_tokens[ws:end + 1].to(dtype=torch.int64, device=device) + x_batch[i, :wlen] = chunk_tok[:-1] + y_batch[i, :wlen] = chunk_tok[1:] + with torch.autocast(device_type="cuda", dtype=torch.bfloat16): + logits = base_model.forward_logits(x_batch) + nll = F.cross_entropy( + logits.reshape(-1, logits.size(-1)).float(), + y_batch.reshape(-1), reduction="none", + ).reshape(bsz, seq_len) + for i, ws in enumerate(batch_ws): + wlen = wlens[i] + s = 0 if ws == 0 else max(wlen - stride, 0) + scored_nll = nll[i, s:wlen].to(torch.float64) + loss_sum += scored_nll.sum() + token_count += float(wlen - s) + tgt, prev = y_batch[i, s:wlen], x_batch[i, s:wlen] + tb = base_bytes_lut[tgt].to(torch.float64) + tb += (has_leading_space_lut[tgt] & ~is_boundary_token_lut[prev]).to(torch.float64) + byte_count += tb.sum() + + # --- Phase 2: TRAIN on this chunk (already scored = legal) --- + is_last_chunk = (ci == num_chunks - 1) + if not is_last_chunk and args.ttt_epochs > 0: + base_model.train() + chunk_seqs = (chunk_end - chunk_start) // seq_len + if chunk_seqs > 0: + cos_lr = args.ttt_lr * 0.5 * (1.0 + math.cos(math.pi * ci / max(num_chunks - 1, 1))) + for pg in optimizer.param_groups: + pg['lr'] = cos_lr + my_seq_s = (chunk_seqs * rank) // world_size + my_seq_e = (chunk_seqs * (rank + 1)) // world_size + my_chunk_seqs = my_seq_e - my_seq_s + for _ep in range(args.ttt_epochs): + for bs in range(0, my_chunk_seqs, args.ttt_batch_seqs): + be = min(bs + args.ttt_batch_seqs, my_chunk_seqs) + actual_bs = my_seq_s + bs + start_tok = chunk_start + actual_bs * seq_len + end_tok = chunk_start + (my_seq_s + be) * seq_len + 1 + if end_tok > val_tokens.numel(): + continue + local = val_tokens[start_tok:end_tok].to(device=device, dtype=torch.int64) + x = local[:-1].reshape(-1, seq_len) + y = local[1:].reshape(-1, seq_len) + optimizer.zero_grad(set_to_none=True) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16): + loss = base_model(x, y) + loss.backward() + if world_size > 1: + for p in ttt_params: + if p.grad is not None: + dist.all_reduce(p.grad, op=dist.ReduceOp.AVG) + torch.nn.utils.clip_grad_norm_(ttt_params, args.ttt_grad_clip) + optimizer.step() + + if rank == 0 and (ci % 10 == 0 or ci == num_chunks - 1): + elapsed = time.perf_counter() - t0 + rl = loss_sum.item() / max(token_count.item(), 1) + rbpb = rl / math.log(2.0) * (token_count.item() / max(byte_count.item(), 1)) if token_count.item() > 0 else 0.0 + log0(f" ttt_chunk [{ci+1}/{num_chunks}] bpb={rbpb:.6f} time={elapsed:.1f}s") + + if dist.is_available() and dist.is_initialized(): + dist.all_reduce(loss_sum, op=dist.ReduceOp.SUM) + dist.all_reduce(token_count, op=dist.ReduceOp.SUM) + dist.all_reduce(byte_count, op=dist.ReduceOp.SUM) + + val_loss = (loss_sum / token_count).item() + val_bpb = val_loss / math.log(2.0) * (token_count.item() / byte_count.item()) + + for p in base_model.parameters(): + p.requires_grad_(True) + base_model.eval() + + log0(f"ttt_sliding:done val_loss={val_loss:.6f} val_bpb={val_bpb:.6f} " + f"elapsed={time.perf_counter() - t0:.1f}s") + return val_loss, val_bpb + + +# === N-GRAM EVAL CACHE + TWO-PASS RESCORE === + +_NGRAM_PRIMES = np.array([ + 36313, 27191, 51647, 81929, 131071, 174763, 233017, 283721, + 347237, 411527, 479909, 557927, 646333, 746773, 862319, 992353, +], dtype=np.int64) + +# Per-order multipliers: orders 2-3 suppressed, 4 near-neutral, 5-12 boosted +_ORDER_MULTS = np.array([ + 0.30, 0.30, 0.97, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, +], dtype=np.float32) + + +class NgramCache: + """Hash-table n-gram cache with vectorized numpy operations.""" + + def __init__(self, min_order: int = 2, max_order: int = 16, + num_buckets: int = 16_777_216): + self.min_order = min_order + self.max_order = max_order + self.num_orders = max_order - min_order + 1 + self.num_buckets = num_buckets + self.bucket_mask = np.int64(num_buckets - 1) + # Two flat hash tables per order: context counts and full (context+target) counts + self.ctx_tables = [np.zeros(num_buckets, dtype=np.int32) for _ in range(self.num_orders)] + self.full_tables = [np.zeros(num_buckets, dtype=np.int32) for _ in range(self.num_orders)] + + def _compute_hashes(self, tokens_np: np.ndarray, start: int, end: int, order_idx: int): + """Compute context and full hashes for positions [start, end) at given order.""" + n = self.min_order + order_idx + valid_start = max(start, n - 1) + N = end - valid_start + if N <= 0: + return None, None, valid_start + # Context hash: XOR of tokens[pos-n+1+k] * primes[k] for k=0..n-2 + h = np.zeros(N, dtype=np.int64) + for k in range(n - 1): + offset = valid_start - (n - 1) + k + h ^= tokens_np[offset:offset + N].astype(np.int64) * _NGRAM_PRIMES[k % len(_NGRAM_PRIMES)] + ctx_h = h & self.bucket_mask + # Full hash: context + target token + target_prime = _NGRAM_PRIMES[min(n - 1, len(_NGRAM_PRIMES) - 1)] + full_h = (h ^ (tokens_np[valid_start:end].astype(np.int64) * target_prime)) & self.bucket_mask + return ctx_h, full_h, valid_start + + def _bincount_add(self, table: np.ndarray, indices: np.ndarray): + """Fast histogram accumulation using np.bincount (much faster than np.add.at).""" + counts = np.bincount(indices.astype(np.intp), minlength=self.num_buckets) + table += counts[:self.num_buckets].astype(table.dtype) + + def update_range(self, tokens_np: np.ndarray, start: int, end: int): + """Add tokens[start:end] to the cache for all orders.""" + for oi in range(self.num_orders): + ctx_h, full_h, vs = self._compute_hashes(tokens_np, start, end, oi) + if ctx_h is None: + continue + self._bincount_add(self.ctx_tables[oi], ctx_h) + self._bincount_add(self.full_tables[oi], full_h) + + def build_full(self, tokens_np: np.ndarray): + """Build complete cache from entire token sequence (vectorized).""" + for oi in range(self.num_orders): + ctx_h, full_h, _ = self._compute_hashes(tokens_np, 0, len(tokens_np), oi) + if ctx_h is None: + continue + self._bincount_add(self.ctx_tables[oi], ctx_h) + self._bincount_add(self.full_tables[oi], full_h) + + def score_range(self, tokens_np: np.ndarray, start: int, end: int, + min_count: int = 2): + """Score tokens[start:end] against the cache. + + Returns: + ngram_prob: (N,) float32 - n-gram probability for the true target token + matched_order: (N,) int32 - which order matched (-1 = no match) + """ + N = end - start + ngram_prob = np.zeros(N, dtype=np.float32) + matched_order = np.full(N, -1, dtype=np.int32) + matched = np.zeros(N, dtype=bool) + + # Backoff from highest to lowest order + for oi in range(self.num_orders - 1, -1, -1): + n = self.min_order + oi + ctx_h, full_h, vs = self._compute_hashes(tokens_np, start, end, oi) + if ctx_h is None: + continue + offset = vs - start + ctx_counts = self.ctx_tables[oi][ctx_h] + full_counts = self.full_tables[oi][full_h] + # Cap full counts to context counts (hash collision mitigation) + full_counts = np.minimum(full_counts, ctx_counts) + # Only match when: sufficient context, target has been seen, not already matched + eligible = (ctx_counts >= min_count) & (full_counts > 0) & ~matched[offset:] + if not np.any(eligible): + continue + prob = full_counts[eligible].astype(np.float32) / np.maximum(ctx_counts[eligible].astype(np.float32), 1.0) + # Find which positions in the output array to fill + out_idx = np.where(eligible)[0] + offset + ngram_prob[out_idx] = prob + matched_order[out_idx] = n + matched[out_idx] = True + + return ngram_prob, matched_order + + def score_positions(self, tokens_np: np.ndarray, positions: np.ndarray, + min_count: int = 2, leave_one_out: bool = False): + """Score selected token positions against the cache. + + If leave_one_out is enabled, subtract this token's own contribution from + both context and (context,target) counts before matching. + """ + N = len(positions) + ngram_prob = np.zeros(N, dtype=np.float32) + matched_order = np.full(N, -1, dtype=np.int32) + matched = np.zeros(N, dtype=bool) + if N == 0: + return ngram_prob, matched_order + + positions = positions.astype(np.int64, copy=False) + for oi in range(self.num_orders - 1, -1, -1): + n = self.min_order + oi + ctx_h_all, full_h_all, valid_start = self._compute_hashes(tokens_np, 0, len(tokens_np), oi) + if ctx_h_all is None: + continue + + remaining_idx = np.where(~matched)[0] + if remaining_idx.size == 0: + break + pos_sub = positions[remaining_idx] + valid_mask = pos_sub >= valid_start + if not np.any(valid_mask): + continue + + valid_idx = remaining_idx[valid_mask] + lookup = (pos_sub[valid_mask] - valid_start).astype(np.int64) + ctx_h = ctx_h_all[lookup] + full_h = full_h_all[lookup] + + ctx_counts = self.ctx_tables[oi][ctx_h].astype(np.int64) + full_counts = self.full_tables[oi][full_h].astype(np.int64) + if leave_one_out: + ctx_counts = np.maximum(ctx_counts - 1, 0) + full_counts = np.maximum(full_counts - 1, 0) + full_counts = np.minimum(full_counts, ctx_counts) + + eligible = (ctx_counts >= min_count) & (full_counts > 0) + if not np.any(eligible): + continue + + out_idx = valid_idx[eligible] + prob = full_counts[eligible].astype(np.float32) / np.maximum(ctx_counts[eligible].astype(np.float32), 1.0) + ngram_prob[out_idx] = prob + matched_order[out_idx] = n + matched[out_idx] = True + + return ngram_prob, matched_order + + +def eval_val_sliding_store( + args: Hyperparameters, base_model: nn.Module, rank: int, world_size: int, + device: torch.device, val_tokens: Tensor, base_bytes_lut: Tensor, + has_leading_space_lut: Tensor, is_boundary_token_lut: Tensor, + stride: int, batch_seqs: int = 32, log0=print, +) -> tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray, float, float]: + """Sliding-window eval that stores per-token model_p and entropy. + + Returns: (model_p, entropy, token_bytes, token_targets, val_loss, val_bpb) + where model_p and entropy are arrays covering this rank's scored tokens, + and val_loss/val_bpb are the standard (un-blended) metrics. + + Also returns global-offset index arrays for mapping back to token positions. + """ + seq_len = args.train_seq_len + total_tokens = val_tokens.numel() - 1 + window_starts = [ws for ws in range(0, total_tokens, stride) + if min(ws + seq_len, total_tokens) - ws >= 1] + total_windows = len(window_starts) + my_s = (total_windows * rank) // world_size + my_e = (total_windows * (rank + 1)) // world_size + my_windows = window_starts[my_s:my_e] + + # Pre-allocate per-token storage (we'll trim later) + # Each token is scored in exactly one window + model_p_list: list[np.ndarray] = [] + entropy_list: list[np.ndarray] = [] + bytes_list: list[np.ndarray] = [] + position_list: list[np.ndarray] = [] # global target-token positions + nll_list: list[np.ndarray] = [] + + base_model.eval() + compiled_logits = torch.compile(base_model.forward_logits, dynamic=False, fullgraph=True) + with torch.inference_mode(): + for bi in range(0, len(my_windows), batch_seqs): + batch_ws = my_windows[bi:bi + batch_seqs] + bsz = len(batch_ws) + x_batch = torch.zeros(bsz, seq_len, dtype=torch.int64, device=device) + y_batch = torch.zeros(bsz, seq_len, dtype=torch.int64, device=device) + wlens: list[int] = [] + for i, ws in enumerate(batch_ws): + end_pos = min(ws + seq_len, total_tokens) + wlen = end_pos - ws + wlens.append(wlen) + chunk = val_tokens[ws:end_pos + 1].to(dtype=torch.int64, device=device) + x_batch[i, :wlen] = chunk[:-1] + y_batch[i, :wlen] = chunk[1:] + with torch.autocast(device_type="cuda", dtype=torch.bfloat16): + logits = compiled_logits(x_batch) # (bsz, seq_len, vocab_size) + # Compute per-token quantities + logits_f = logits.float() + log_probs = F.log_softmax(logits_f, dim=-1) # (bsz, seq_len, V) + probs = log_probs.exp() + # NLL for each token + nll_all = F.cross_entropy( + logits_f.reshape(-1, logits_f.size(-1)), + y_batch.reshape(-1), reduction="none" + ).reshape(bsz, seq_len) + # Model probability of true token + mp = probs.gather(2, y_batch.unsqueeze(-1)).squeeze(-1) # (bsz, seq_len) + # Entropy of model distribution + ent = -(probs * log_probs).sum(dim=-1) # (bsz, seq_len) + + for i, ws in enumerate(batch_ws): + wlen = wlens[i] + s = 0 if ws == 0 else max(wlen - stride, 0) + # Positions are TARGET token indices in val_tokens (ws+j+1 for scored position j) + positions = np.arange(ws + s + 1, ws + wlen + 1, dtype=np.int64) + position_list.append(positions) + model_p_list.append(mp[i, s:wlen].cpu().numpy().astype(np.float32)) + entropy_list.append(ent[i, s:wlen].cpu().numpy().astype(np.float32)) + nll_list.append(nll_all[i, s:wlen].cpu().numpy().astype(np.float64)) + tgt = y_batch[i, s:wlen] + prev = x_batch[i, s:wlen] + tb = base_bytes_lut[tgt].to(torch.float64) + tb += (has_leading_space_lut[tgt] & ~is_boundary_token_lut[prev]).to(torch.float64) + bytes_list.append(tb.cpu().numpy()) + + all_positions = np.concatenate(position_list) if position_list else np.array([], dtype=np.int64) + all_model_p = np.concatenate(model_p_list) if model_p_list else np.array([], dtype=np.float32) + all_entropy = np.concatenate(entropy_list) if entropy_list else np.array([], dtype=np.float32) + all_nll = np.concatenate(nll_list) if nll_list else np.array([], dtype=np.float64) + all_bytes = np.concatenate(bytes_list) if bytes_list else np.array([], dtype=np.float64) + + + # Compute standard (un-blended) BPB for this rank + local_loss_sum = all_nll.sum() + local_token_count = float(len(all_nll)) + local_byte_count = all_bytes.sum() + + # All-reduce for standard BPB + loss_sum_t = torch.tensor(local_loss_sum, device=device, dtype=torch.float64) + token_count_t = torch.tensor(local_token_count, device=device, dtype=torch.float64) + byte_count_t = torch.tensor(local_byte_count, device=device, dtype=torch.float64) + if dist.is_available() and dist.is_initialized(): + dist.all_reduce(loss_sum_t, op=dist.ReduceOp.SUM) + dist.all_reduce(token_count_t, op=dist.ReduceOp.SUM) + dist.all_reduce(byte_count_t, op=dist.ReduceOp.SUM) + val_loss = (loss_sum_t / token_count_t).item() + val_bpb = val_loss / math.log(2.0) * (token_count_t.item() / byte_count_t.item()) + + base_model.train() + return all_model_p, all_entropy, all_bytes, all_positions, val_loss, val_bpb + + +def ngram_rescore( + args: Hyperparameters, + tokens_np: np.ndarray, + cache: NgramCache, + model_p: np.ndarray, + entropy: np.ndarray, + token_bytes: np.ndarray, + positions: np.ndarray, + rank: int, world_size: int, device: torch.device, + log0=print, +) -> tuple[float, float]: + """Rescore tokens using n-gram cache blended with stored neural model_p. + + This is Pass 2: the cache is already complete. + Returns: (val_loss, val_bpb) + """ + N = len(positions) + if N == 0: + return 0.0, 0.0 + + # Score all of this rank's positions against the full cache + # We need to score at the GLOBAL token positions + # The cache.score_range expects contiguous ranges, but our positions may be sparse + # Instead, we score the full range and index into it + # Actually, positions are sorted (from sliding windows), so we can score chunks + + ngram_prob, matched_order = cache.score_positions( + tokens_np, + positions, + min_count=args.ngram_min_count, + leave_one_out=args.ngram_leave_one_out, + ) + matched = matched_order >= 0 + + # Entropy-adaptive alpha with per-order multipliers + alpha = np.zeros(N, dtype=np.float32) + if np.any(matched): + order_idx = (matched_order[matched] - cache.min_order).astype(np.int32) + centers = args.ngram_entropy_center - 0.25 * order_idx.astype(np.float32) + sig = 1.0 / (1.0 + np.exp(-args.ngram_entropy_scale * (entropy[matched] - centers))) + raw_alpha = args.ngram_alpha_min + (args.ngram_alpha_max - args.ngram_alpha_min) * sig + # Per-order multipliers + mults = _ORDER_MULTS[np.minimum(order_idx, len(_ORDER_MULTS) - 1)] + raw_alpha *= mults + alpha[matched] = np.clip(raw_alpha, 0.0, 0.95) + + # Blend: p_blend = (1 - alpha) * model_p + alpha * ngram_prob + p_blend = (1.0 - alpha) * model_p + alpha * ngram_prob + # Clamp to avoid log(0) + p_blend = np.maximum(p_blend, 1e-10) + # For unmatched tokens, use model_p directly + p_blend[~matched] = np.maximum(model_p[~matched], 1e-10) + + # NLL + nll = -np.log(p_blend).astype(np.float64) + + # Aggregate + local_loss_sum = nll.sum() + local_token_count = float(N) + local_byte_count = token_bytes.sum() + + # All-reduce + loss_sum_t = torch.tensor(local_loss_sum, device=device, dtype=torch.float64) + token_count_t = torch.tensor(local_token_count, device=device, dtype=torch.float64) + byte_count_t = torch.tensor(local_byte_count, device=device, dtype=torch.float64) + if dist.is_available() and dist.is_initialized(): + dist.all_reduce(loss_sum_t, op=dist.ReduceOp.SUM) + dist.all_reduce(token_count_t, op=dist.ReduceOp.SUM) + dist.all_reduce(byte_count_t, op=dist.ReduceOp.SUM) + + val_loss = (loss_sum_t / token_count_t).item() + val_bpb = val_loss / math.log(2.0) * (token_count_t.item() / byte_count_t.item()) + + n_matched = int(matched.sum()) + log0( + f"ngram_rescore: matched={n_matched}/{N} ({100*n_matched/max(N,1):.1f}%) " + f"mean_alpha={alpha[matched].mean():.3f} leave_one_out={int(args.ngram_leave_one_out)}" + if n_matched > 0 else f"ngram_rescore: no matches leave_one_out={int(args.ngram_leave_one_out)}" + ) + + return val_loss, val_bpb + + +def eval_ngram_two_pass( + args: Hyperparameters, base_model: nn.Module, rank: int, world_size: int, + device: torch.device, val_tokens: Tensor, base_bytes_lut: Tensor, + has_leading_space_lut: Tensor, is_boundary_token_lut: Tensor, + stride: int, batch_seqs: int = 32, log0=print, +) -> tuple[float, float]: + """Two-pass n-gram evaluation. + + Pass 1: Sliding-window neural eval → store per-token model_p and entropy. + Build: Complete n-gram cache from all tokens (vectorized). + Pass 2: Rescore ALL tokens by blending neural model_p with n-gram predictions. + """ + t0 = time.perf_counter() + + # --- Pass 1: Neural eval with per-token storage --- + log0(f"ngram_two_pass: starting Pass 1 (sliding-window neural eval)") + model_p, entropy, token_bytes, positions, pass1_loss, pass1_bpb = eval_val_sliding_store( + args, base_model, rank, world_size, device, val_tokens, + base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, + stride=stride, batch_seqs=batch_seqs, log0=log0, + ) + t_pass1 = time.perf_counter() + log0(f"ngram_two_pass: Pass 1 done val_bpb={pass1_bpb:.6f} " + f"tokens_scored={len(positions)} time={t_pass1 - t0:.1f}s") + + # --- Build complete n-gram cache --- + log0(f"ngram_two_pass: building cache orders={args.ngram_min_order}-{args.ngram_max_order} " + f"buckets={args.ngram_num_buckets}") + tokens_np = val_tokens.numpy().astype(np.int16) + cache = NgramCache( + min_order=args.ngram_min_order, + max_order=args.ngram_max_order, + num_buckets=args.ngram_num_buckets, + ) + cache.build_full(tokens_np) + t_cache = time.perf_counter() + log0(f"ngram_two_pass: cache built in {t_cache - t_pass1:.1f}s") + + # --- Pass 2: N-gram rescore --- + log0(f"ngram_two_pass: starting Pass 2 (n-gram rescore)") + val_loss, val_bpb = ngram_rescore( + args, tokens_np, cache, model_p, entropy, token_bytes, positions, + rank, world_size, device, log0=log0, + ) + t_pass2 = time.perf_counter() + log0(f"ngram_two_pass: Pass 2 done val_bpb={val_bpb:.6f} " + f"improvement={pass1_bpb - val_bpb:.6f} time={t_pass2 - t_cache:.1f}s") + log0(f"ngram_two_pass: total time={t_pass2 - t0:.1f}s") + + return val_loss, val_bpb + + +# === COMPLEMENTARY TRAINING === + +class TrainBigramTracker: + """Tracks bigram statistics from training data for complementary loss weighting.""" + + def __init__(self, vocab_size: int, device: torch.device): + # bigram_counts[prev_token, target_token] = count + self.counts = torch.zeros(vocab_size, vocab_size, device=device, dtype=torch.float32) + self.row_totals = torch.zeros(vocab_size, device=device, dtype=torch.float32) + + @torch.no_grad() + def update(self, x: Tensor, y: Tensor): + """Update bigram counts. x: context tokens, y: target tokens.""" + prev = x.reshape(-1) + tgt = y.reshape(-1) + idx = prev.long() * self.counts.shape[1] + tgt.long() + self.counts.view(-1).scatter_add_(0, idx, torch.ones_like(idx, dtype=torch.float32)) + self.row_totals.scatter_add_(0, prev.long(), torch.ones(prev.shape[0], device=prev.device, dtype=torch.float32)) + + @torch.no_grad() + def get_weights(self, x: Tensor, y: Tensor, alpha: float = 0.5) -> Tensor: + """Compute per-token loss weights: downweight tokens predictable by bigrams.""" + prev = x.reshape(-1) + tgt = y.reshape(-1) + totals = self.row_totals[prev.long()] + counts = self.counts[prev.long(), tgt.long()] + ngram_prob = counts / totals.clamp(min=1.0) + weights = (1.0 - alpha * ngram_prob).clamp(min=0.1) + return weights.reshape(y.shape) + + +# --- GPTQ-lite int6 quantization --- + +def _classify_param(name: str) -> str: + if "tok_emb" in name or "lm_head" in name: + return "embed" + if ".mlp." in name: + return "mlp" + if ".attn." in name or (".proj." in name and ".mlp." not in name): + return "attn" + return "other" +def quantize_int6_per_row(t: Tensor, clip_range: int = 31) -> tuple[Tensor, Tensor]: + t32 = t.float() + if t32.ndim == 2: + best_q, best_s, best_err = None, None, float('inf') + for pct in [0.9990, 0.9995, 0.9999, 0.99999, 1.0]: + if pct < 1.0: + row_clip = torch.quantile(t32.abs(), pct, dim=1) + else: + row_clip = t32.abs().amax(dim=1) + s = (row_clip / clip_range).clamp_min(1.0 / clip_range).to(torch.float16) + q = torch.clamp(torch.round(t32 / s.float()[:, None]), -clip_range, clip_range).to(torch.int8) + recon = q.float() * s.float()[:, None] + err = (t32 - recon).pow(2).mean().item() + if err < best_err: + best_q, best_s, best_err = q, s, err + return best_q, best_s + amax = t32.abs().max().item() + scale = torch.tensor(amax / clip_range if amax > 0 else 1.0, dtype=torch.float16) + q = torch.clamp(torch.round(t32 / scale.float()), -clip_range, clip_range).to(torch.int8) + return q, scale + +def _unbank_state_dict(sd: dict[str, Tensor], num_layers: int) -> dict[str, Tensor]: + """Convert 3D bank tensors into individual 2D tensors with standard names.""" + out: dict[str, Tensor] = {} + n = num_layers + for name, tensor in sd.items(): + if name == "qo_bank": + for i in range(n): + out[f"blocks.{i}.attn.c_q.weight"] = tensor[i] + out[f"blocks.{i}.attn.proj.weight"] = tensor[n + i] + elif name == "kv_bank": + for i in range(n): + out[f"blocks.{i}.attn.c_k.weight"] = tensor[i] + out[f"blocks.{i}.attn.c_v.weight"] = tensor[n + i] + elif name == "mlp_up_bank": + for i in range(n): + out[f"blocks.{i}.mlp.fc.weight"] = tensor[i] + elif name == "mlp_down_bank": + for i in range(n): + out[f"blocks.{i}.mlp.proj.weight"] = tensor[i] + else: + out[name] = tensor + return out + +def _rebank_state_dict(sd: dict[str, Tensor], num_layers: int, template_sd: dict[str, Tensor]) -> dict[str, Tensor]: + """Convert individual 2D tensors back into 3D bank tensors.""" + out: dict[str, Tensor] = {} + n = num_layers + # Reconstruct banks from individual weight keys + qo_slices = [None] * (2 * n) + kv_slices = [None] * (2 * n) + up_slices = [None] * n + down_slices = [None] * n + consumed = set() + for i in range(n): + qk = f"blocks.{i}.attn.c_q.weight" + if qk in sd: + qo_slices[i] = sd[qk] + consumed.add(qk) + ok = f"blocks.{i}.attn.proj.weight" + if ok in sd: + qo_slices[n + i] = sd[ok] + consumed.add(ok) + kk = f"blocks.{i}.attn.c_k.weight" + if kk in sd: + kv_slices[i] = sd[kk] + consumed.add(kk) + vk = f"blocks.{i}.attn.c_v.weight" + if vk in sd: + kv_slices[n + i] = sd[vk] + consumed.add(vk) + fk = f"blocks.{i}.mlp.fc.weight" + if fk in sd: + up_slices[i] = sd[fk] + consumed.add(fk) + dk = f"blocks.{i}.mlp.proj.weight" + if dk in sd: + down_slices[i] = sd[dk] + consumed.add(dk) + out["qo_bank"] = torch.stack(qo_slices).to(dtype=template_sd["qo_bank"].dtype) + out["kv_bank"] = torch.stack(kv_slices).to(dtype=template_sd["kv_bank"].dtype) + out["mlp_up_bank"] = torch.stack(up_slices).to(dtype=template_sd["mlp_up_bank"].dtype) + out["mlp_down_bank"] = torch.stack(down_slices).to(dtype=template_sd["mlp_down_bank"].dtype) + for name, tensor in sd.items(): + if name not in consumed: + out[name] = tensor + return out + +def mixed_quantize_int6(state_dict: dict[str, Tensor], int6_cats: set[str]): + num_layers_total = max( + (int(k.split(".")[1]) for k in state_dict if k.startswith("blocks.")), + default=0, + ) + 1 + late_k_layers = set(range(num_layers_total - 2, num_layers_total)) + result: dict[str, Tensor] = {} + meta: dict[str, object] = {} + for name, tensor in state_dict.items(): + t = tensor.detach().cpu().contiguous() + cat = _classify_param(name) + if not t.is_floating_point() or t.numel() <= 65536: + result[name] = t.to(torch.float16) if t.is_floating_point() else t + meta[name] = "passthrough" + continue + if any(p in name for p in CONTROL_TENSOR_NAME_PATTERNS): + result[name] = t.float() + meta[name] = "passthrough_ctrl" + continue + if cat in int6_cats and t.ndim >= 1: + q, s = quantize_int6_per_row(t) + result[name + ".q"] = q + result[name + ".scale"] = s + meta[name] = {"type": "int6"} + else: + q, s = quantize_float_tensor(t) + result[name + ".q"] = q + result[name + ".scale"] = s + meta[name] = {"type": "int8"} + return result, meta +def dequantize_mixed_int6(result: dict[str, Tensor], meta: dict[str, object], + template_sd: dict[str, Tensor]) -> dict[str, Tensor]: + out: dict[str, Tensor] = {} + for name, orig in template_sd.items(): + info = meta.get(name) + if info is None: + continue + orig_dtype = orig.dtype + if info in ("passthrough", "passthrough_ctrl", "passthrough_fp16"): + t = result[name] + if t.dtype == torch.float16 and orig_dtype in (torch.float32, torch.bfloat16): + t = t.to(orig_dtype) + out[name] = t + continue + q, s = result[name + ".q"], result[name + ".scale"] + if s.ndim > 0: + out[name] = (q.float() * s.float().view(q.shape[0], *([1] * (q.ndim - 1)))).to(orig_dtype) + else: + out[name] = (q.float() * float(s.item())).to(orig_dtype) + return out + +# --- Training --- + +def main() -> None: + code = Path(__file__).read_text(encoding="utf-8") + args = Hyperparameters() + # zeropower_via_newtonschulz5 runs eagerly with bmm -- do NOT compile + distributed = "RANK" in os.environ and "WORLD_SIZE" in os.environ + rank = int(os.environ.get("RANK", "0")) + world_size = int(os.environ.get("WORLD_SIZE", "1")) + local_rank = int(os.environ.get("LOCAL_RANK", "0")) + if world_size <= 0: + raise ValueError(f"WORLD_SIZE must be positive, got {world_size}") + if 8 % world_size != 0: + raise ValueError(f"WORLD_SIZE={world_size} must divide 8 so grad_accum_steps stays integral") + grad_accum_steps = 8 // world_size + grad_scale = 1.0 / grad_accum_steps + if not torch.cuda.is_available(): + raise RuntimeError("CUDA is required") + device = torch.device("cuda", local_rank) + torch.cuda.set_device(device) + if distributed: + dist.init_process_group(backend="nccl", device_id=device) + dist.barrier() + master_process = rank == 0 + torch.backends.cuda.matmul.allow_tf32 = True + torch.backends.cudnn.allow_tf32 = True + from torch.backends.cuda import enable_cudnn_sdp, enable_flash_sdp, enable_math_sdp, enable_mem_efficient_sdp + enable_cudnn_sdp(False) + enable_flash_sdp(True) + enable_mem_efficient_sdp(False) + enable_math_sdp(False) + logfile = None + if master_process: + os.makedirs("logs", exist_ok=True) + logfile = f"logs/{args.run_id}.txt" + print(logfile) + def log0(msg: str, console: bool = True) -> None: + if not master_process: + return + if console: + print(msg) + if logfile is not None: + with open(logfile, "a", encoding="utf-8") as f: + print(msg, file=f) + log0(code, console=False) + log0("=" * 100, console=False) + log0(f"Running Python {sys.version}", console=False) + log0(f"Running PyTorch {torch.__version__}", console=False) + log0( + subprocess.run(["nvidia-smi"], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True, check=False).stdout, + console=False, + ) + log0("=" * 100, console=False) + random.seed(args.seed) + np.random.seed(args.seed) + torch.manual_seed(args.seed) + torch.cuda.manual_seed_all(args.seed) + if not args.tokenizer_path.endswith(".model"): + raise ValueError(f"Script only setup for SentencePiece .model file: {args.tokenizer_path}") + sp = spm.SentencePieceProcessor(model_file=args.tokenizer_path) + if int(sp.vocab_size()) != args.vocab_size: + raise ValueError( + f"VOCAB_SIZE={args.vocab_size} does not match tokenizer vocab_size={int(sp.vocab_size())}" + ) + dataset_dir = Path(args.data_path).resolve() + actual_train_files = len(list(dataset_dir.glob("fineweb_train_*.bin"))) + effective_eval_seq_len = args.eval_seq_len if args.eval_seq_len > 0 else args.train_seq_len + val_seq_len = max(args.train_seq_len, effective_eval_seq_len) + val_tokens = load_validation_tokens(args.val_files, val_seq_len, args.val_tokens_limit) + base_bytes_lut, has_leading_space_lut, is_boundary_token_lut = build_sentencepiece_luts( + sp, args.vocab_size, device + ) + log0(f"val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path={args.tokenizer_path}") + log0(f"train_loader:dataset:{dataset_dir.name} train_shards:{actual_train_files}") + log0(f"val_loader:shards pattern={args.val_files} tokens:{val_tokens.numel() - 1}") + CastedLinear._qat_enabled = args.qat_enabled + base_model = GPT( + vocab_size=args.vocab_size, + num_layers=args.num_layers, + model_dim=args.model_dim, + num_heads=args.num_heads, + num_kv_heads=args.num_kv_heads, + mlp_mult=args.mlp_mult, + tie_embeddings=args.tie_embeddings, + tied_embed_init_std=args.tied_embed_init_std, + logit_softcap=args.logit_softcap, + rope_base=args.rope_base, + qk_gain_init=args.qk_gain_init, + mtp_num_heads=args.mtp_num_heads, + mtp_loss_weight=args.mtp_loss_weight, + bigram_vocab_size=args.bigram_vocab_size, + bigram_dim=args.bigram_dim, + xsa_last_n=args.xsa_last_n, + rope_dims=args.rope_dims, + ln_scale=args.ln_scale, + dtg=args.dtg_enabled, + ve_enabled=args.ve_enabled, + ve_dim=args.ve_dim, + ve_layers=args.ve_layers, + gated_attention=args.gated_attention, + value_residual=args.value_residual, + activation_mode=args.activation_mode, + activation_neg_slope=args.activation_neg_slope, + asymmetric_square_init=args.asymmetric_square_init, + gated_square_beta_init=args.gated_square_beta_init, + ).to(device).bfloat16() + # Banks stay FP32 (like CastedLinear weights), cast to BF16 in forward + base_model.qo_bank.data = base_model.qo_bank.data.float() + base_model.kv_bank.data = base_model.kv_bank.data.float() + base_model.mlp_up_bank.data = base_model.mlp_up_bank.data.float() + base_model.mlp_down_bank.data = base_model.mlp_down_bank.data.float() + for module in base_model.modules(): + if isinstance(module, CastedLinear): + module.float() + restore_low_dim_params_to_fp32(base_model) + # No DDP -- Parallel Muon handles bank grad communication via reduce-scatter, + # and non-bank grads are manually all-reduced before Adam steps. + compiled_model = torch.compile(base_model, dynamic=False, fullgraph=True) + model = compiled_model + # Separate compile for forward_logits (used in complementary training) + compiled_forward_logits = torch.compile(base_model.forward_logits, dynamic=False, fullgraph=True) + + # Optimizer split: + # - 4 parameter banks -> Muon (batched Newton-Schulz) + # - token embedding -> Adam + # - scalars/control tensors -> Adam + # - bigram proj, mtp heads, VE proj -> Adam (small matrix params not worth banking) + matrix_params = [ + base_model.qo_bank, base_model.kv_bank, + base_model.mlp_up_bank, base_model.mlp_down_bank, + ] + block_named_params = list(base_model.blocks.named_parameters()) + scalar_params = [ + p + for name, p in block_named_params + if p.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) + ] + if base_model.skip_weights.numel() > 0: + scalar_params.append(base_model.skip_weights) + scalar_params.append(base_model.smear.gate) + if base_model.bigram is not None: + scalar_params.append(base_model.bigram.scale) + token_lr = args.tied_embed_lr if args.tie_embeddings else args.embed_lr + tok_params = [{"params": [base_model.tok_emb.weight], "lr": token_lr, "base_lr": token_lr}] + if base_model.bigram is not None: + tok_params.append({"params": [base_model.bigram.embed.weight], "lr": token_lr, "base_lr": token_lr}) + if base_model.bigram.proj is not None: + scalar_params.append(base_model.bigram.proj.weight) + if base_model.ve_shared is not None: + tok_params.append({"params": [base_model.ve_shared.embed.weight], "lr": token_lr, "base_lr": token_lr}) + if base_model.ve_shared.proj is not None: + scalar_params.append(base_model.ve_shared.proj.weight) + scalar_params.append(base_model.ve_shared.scale) + for s in base_model.ve_layer_scales: + scalar_params.append(s) + optimizer_tok = torch.optim.AdamW( + tok_params, + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + weight_decay=args.adam_wd, + fused=True, + ) + optimizer_muon = Muon( + matrix_params, + lr=args.matrix_lr, + momentum=args.muon_momentum, + backend_steps=args.muon_backend_steps, + weight_decay=args.muon_wd, + ) + for group in optimizer_muon.param_groups: + group["base_lr"] = args.matrix_lr + optimizer_scalar = torch.optim.AdamW( + [{"params": scalar_params, "lr": args.scalar_lr, "base_lr": args.scalar_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + weight_decay=args.adam_wd, + fused=True, + ) + # Non-bank params that need manual all-reduce (replicated across GPUs) + replicated_params = list(optimizer_tok.param_groups[0]["params"]) + for pg in optimizer_tok.param_groups[1:]: + replicated_params.extend(pg["params"]) + replicated_params.extend(scalar_params) + + optimizer_head = None + if base_model.lm_head is not None: + optimizer_head = torch.optim.Adam( + [{"params": [base_model.lm_head.weight], "lr": args.head_lr, "base_lr": args.head_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + fused=True, + ) + replicated_params.append(base_model.lm_head.weight) + optimizers: list[torch.optim.Optimizer] = [optimizer_tok, optimizer_muon, optimizer_scalar] + if optimizer_head is not None: + optimizers.append(optimizer_head) + n_params = sum(p.numel() for p in base_model.parameters()) + mtp_params = sum(p.numel() for p in base_model.mtp_heads.parameters()) + log0(f"model_params:{n_params}") + log0(f"mtp_num_heads:{args.mtp_num_heads} mtp_loss_weight:{args.mtp_loss_weight} mtp_params:{mtp_params}") + xsa_layers = [i for i, b in enumerate(base_model.blocks) if b.attn.use_xsa] + log0(f"XSA:last_{args.xsa_last_n} active_layers:{xsa_layers}") + log0(f"world_size:{world_size} grad_accum_steps:{grad_accum_steps}") + log0("sdp_backends:cudnn=False flash=True mem_efficient=False math=False") + log0(f"attention_mode:gqa num_heads:{args.num_heads} num_kv_heads:{args.num_kv_heads}") + log0( + f"tie_embeddings:{args.tie_embeddings} embed_lr:{token_lr} " + f"head_lr:{args.head_lr if base_model.lm_head is not None else 0.0} " + f"matrix_lr:{args.matrix_lr} scalar_lr:{args.scalar_lr}" + ) + log0( + f"activation_mode:{args.activation_mode} neg_slope:{args.activation_neg_slope} " + f"asym_init:{args.asymmetric_square_init} gated_beta_init:{args.gated_square_beta_init}" + ) + log0( + f"train_batch_tokens:{args.train_batch_tokens} train_seq_len:{args.train_seq_len} " + f"iterations:{args.iterations} warmup_steps:{args.warmup_steps} " + f"max_wallclock_seconds:{args.max_wallclock_seconds:.3f}" + ) + log0(f"seed:{args.seed}") + train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) + def zero_grad_all() -> None: + for opt in optimizers: + opt.zero_grad(set_to_none=True) + max_wallclock_ms = 1000.0 * args.max_wallclock_seconds if args.max_wallclock_seconds > 0 else None + def lr_mul(step: int, elapsed_ms: float) -> float: + if args.warmdown_iters <= 0: + return 1.0 + if max_wallclock_ms is None: + warmdown_start = max(args.iterations - args.warmdown_iters, 0) + return max((args.iterations - step) / max(args.warmdown_iters, 1), 0.0) if warmdown_start <= step < args.iterations else 1.0 + step_ms = elapsed_ms / max(step, 1) + warmdown_ms = args.warmdown_iters * step_ms + remaining_ms = max(max_wallclock_ms - elapsed_ms, 0.0) + return remaining_ms / max(warmdown_ms, 1e-9) if remaining_ms <= warmdown_ms else 1.0 + if args.warmup_steps > 0: + initial_model_state = {name: tensor.detach().cpu().clone() for name, tensor in base_model.state_dict().items()} + initial_optimizer_states = [copy.deepcopy(opt.state_dict()) for opt in optimizers] + model.train() + for warmup_step in range(args.warmup_steps): + zero_grad_all() + for micro_step in range(grad_accum_steps): + x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + warmup_loss = model(x, y) + (warmup_loss * grad_scale).backward() + # All-reduce all grads for warmup (simple, not optimized) + if distributed: + for p in base_model.parameters(): + if p.grad is not None: + dist.all_reduce(p.grad, op=dist.ReduceOp.AVG) + for opt in optimizers: + opt.step() + zero_grad_all() + if args.warmup_steps <= 20 or (warmup_step + 1) % 10 == 0 or warmup_step + 1 == args.warmup_steps: + log0(f"warmup_step:{warmup_step + 1}/{args.warmup_steps}") + base_model.load_state_dict(initial_model_state, strict=True) + for opt, state in zip(optimizers, initial_optimizer_states, strict=True): + opt.load_state_dict(state) + zero_grad_all() + train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) + # Complementary training tracker + bigram_tracker = TrainBigramTracker(args.vocab_size, device) if args.complement_enabled else None + if bigram_tracker is not None: + log0(f"complement:enabled alpha={args.complement_alpha}") + swa_state: dict[str, Tensor] | None = None + swa_count = 0 + from collections import deque + lawa_queue: deque[dict[str, Tensor]] = deque(maxlen=args.lawa_k) + ema_state = {name: t.detach().float().clone() for name, t in base_model.state_dict().items()} + ema_decay = 0.997 + training_time_ms = 0.0 + stop_after_step: int | None = None + torch.cuda.synchronize() + t0 = time.perf_counter() + step = 0 + while True: + last_step = step == args.iterations or (stop_after_step is not None and step >= stop_after_step) + should_validate = last_step or (args.val_loss_every > 0 and step % args.val_loss_every == 0) + if should_validate: + torch.cuda.synchronize() + training_time_ms += 1000.0 * (time.perf_counter() - t0) + val_loss, val_bpb = eval_val( + args, + model, + rank, + world_size, + device, + grad_accum_steps, + val_tokens, + base_bytes_lut, + has_leading_space_lut, + is_boundary_token_lut, + ) + log0( + f"step:{step}/{args.iterations} val_loss:{val_loss:.4f} val_bpb:{val_bpb:.4f} " + f"train_time:{training_time_ms:.0f}ms step_avg:{training_time_ms / max(step, 1):.2f}ms" + ) + torch.cuda.synchronize() + t0 = time.perf_counter() + if last_step: + if stop_after_step is not None and step < args.iterations: + log0( + f"stopping_early: wallclock_cap train_time:{training_time_ms:.0f}ms " + f"step:{step}/{args.iterations}" + ) + break + elapsed_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) + scale = lr_mul(step, elapsed_ms) + if args.late_qat_threshold > 0 and scale < args.late_qat_threshold and not CastedLinear._qat_enabled: + CastedLinear._qat_enabled = True + log0(f"late_qat:enabled step:{step} scale:{scale:.4f}") + zero_grad_all() + train_loss = torch.zeros((), device=device) + for micro_step in range(grad_accum_steps): + x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + if args.complement_enabled and bigram_tracker is not None: + # Complementary training: single forward, weighted CE + logits = compiled_forward_logits(x) + logits_flat = logits.reshape(-1, logits.size(-1)).float() + per_token_nll = F.cross_entropy(logits_flat, y.reshape(-1), reduction="none") + comp_weights = bigram_tracker.get_weights(x, y, alpha=args.complement_alpha).reshape(-1) + loss = (per_token_nll * comp_weights).sum() / comp_weights.sum() + bigram_tracker.update(x, y) + else: + loss = model(x, y) + train_loss += loss.detach() + (loss * grad_scale).backward() + train_loss /= grad_accum_steps + frac = min(step / args.muon_momentum_warmup_steps, 1.0) if args.muon_momentum_warmup_steps > 0 else 1.0 + muon_momentum = (1 - frac) * args.muon_momentum_warmup_start + frac * args.muon_momentum + for group in optimizer_muon.param_groups: + group["momentum"] = muon_momentum + for opt in optimizers: + for group in opt.param_groups: + group["lr"] = group["base_lr"] * scale + if args.grad_clip_norm > 0: + torch.nn.utils.clip_grad_norm_(base_model.parameters(), args.grad_clip_norm) + # === 3-phase overlapped optimizer step === + # Phase 1: Launch async reduce-scatter for banks (biggest first) + optimizer_muon.launch_reduce_scatters() + # Phase 2: All-reduce non-bank grads + step Adam (while bank RS is in-flight) + if distributed: + for p in replicated_params: + if p.grad is not None: + dist.all_reduce(p.grad, op=dist.ReduceOp.AVG) + optimizer_tok.step() + optimizer_scalar.step() + if optimizer_head is not None: + optimizer_head.step() + # Phase 3: Wait for RS, local NS5, all-gather (banks processed last) + optimizer_muon.step() + zero_grad_all() + # EMA update + with torch.no_grad(): + for name, t in base_model.state_dict().items(): + ema_state[name].mul_(ema_decay).add_(t.detach().float(), alpha=1.0 - ema_decay) + step += 1 + approx_training_time_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) + if args.swa_enabled and scale < 0.2 and step % args.swa_every == 0: + if swa_state is None: + swa_state = {name: t.detach().cpu().clone() for name, t in base_model.state_dict().items()} + swa_count = 1 + log0(f"swa:start step:{step}") + else: + for name, t in base_model.state_dict().items(): + swa_state[name] += t.detach().cpu() + swa_count += 1 + if args.lawa_enabled and step % args.lawa_freq == 0: + lawa_queue.append({name: t.detach().cpu().clone() for name, t in base_model.state_dict().items()}) + should_log_train = ( + args.train_log_every > 0 + and (step <= 10 or step % args.train_log_every == 0 or stop_after_step is not None) + ) + if should_log_train: + log0( + f"step:{step}/{args.iterations} train_loss:{train_loss.item():.4f} " + f"train_time:{approx_training_time_ms:.0f}ms step_avg:{approx_training_time_ms / step:.2f}ms" + ) + reached_cap = max_wallclock_ms is not None and approx_training_time_ms >= max_wallclock_ms + if distributed and max_wallclock_ms is not None: + reached_cap_tensor = torch.tensor(int(reached_cap), device=device) + dist.all_reduce(reached_cap_tensor, op=dist.ReduceOp.MAX) + reached_cap = bool(reached_cap_tensor.item()) + if stop_after_step is None and reached_cap: + stop_after_step = step + log0( + f"peak memory allocated: {torch.cuda.max_memory_allocated() // 1024 // 1024} MiB " + f"reserved: {torch.cuda.max_memory_reserved() // 1024 // 1024} MiB" + ) + # Apply weight averaging + if args.lawa_enabled and len(lawa_queue) > 1: + log0(f"lawa:applying LAWA averaging k={len(lawa_queue)}") + current_state = base_model.state_dict() + avg_state = {name: torch.zeros(t.shape, dtype=torch.float32, device='cpu') for name, t in current_state.items()} + for snap in lawa_queue: + for name in avg_state: + avg_state[name] += snap[name].float() + for name in avg_state: + avg_state[name] /= len(lawa_queue) + avg_state[name] = avg_state[name].to(dtype=current_state[name].dtype) + base_model.load_state_dict(avg_state, strict=True) + else: + log0("ema:applying EMA weights") + current_state = base_model.state_dict() + avg_state = {name: t.to(dtype=current_state[name].dtype) for name, t in ema_state.items()} + base_model.load_state_dict(avg_state, strict=True) + torch.cuda.synchronize() + t_diag = time.perf_counter() + diag_val_loss, diag_val_bpb = eval_val( + args, compiled_model, rank, world_size, device, grad_accum_steps, + val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, + ) + torch.cuda.synchronize() + log0( + f"DIAGNOSTIC post_ema val_loss:{diag_val_loss:.4f} val_bpb:{diag_val_bpb:.4f} " + f"eval_time:{1000.0 * (time.perf_counter() - t_diag):.0f}ms" + ) + full_state_dict = base_model.state_dict() + export_sd = {k: v for k, v in full_state_dict.items() if "mtp_heads" not in k} + excluded_mtp = sum(int(t.numel()) for k, t in full_state_dict.items() if "mtp_heads" in k) + if excluded_mtp > 0: + log0(f"export_excluding_mtp_params:{excluded_mtp}") + if master_process: + torch.save(export_sd, "final_model.pt") + model_bytes = os.path.getsize("final_model.pt") + code_bytes = len(code.encode("utf-8")) + log0(f"Serialized model: {model_bytes} bytes") + log0(f"Code size: {code_bytes} bytes") + # Unbank 3D tensors into individual 2D tensors for quantization + sd_cpu = {k: v.detach().cpu() for k, v in export_sd.items()} + unbanked_sd = _unbank_state_dict(sd_cpu, args.num_layers) + quant_result, quant_meta = mixed_quantize_int6(unbanked_sd, {"mlp", "attn"}) + quant_buf = io.BytesIO() + torch.save({"w": quant_result, "m": quant_meta}, quant_buf) + quant_raw = quant_buf.getvalue() + quant_blob = lzma.compress(quant_raw, preset=6) + if master_process: + with open("final_model.int6.ptz", "wb") as f: + f.write(quant_blob) + quant_file_bytes = len(quant_blob) + code_bytes = len(code.encode("utf-8")) + log0(f"Serialized model int6+lzma: {quant_file_bytes} bytes") + log0(f"Total submission size int6+lzma: {quant_file_bytes + code_bytes} bytes") + if distributed: + dist.barrier() + with open("final_model.int6.ptz", "rb") as f: + quant_blob_disk = f.read() + quant_state = torch.load( + io.BytesIO(lzma.decompress(quant_blob_disk)), + map_location="cpu", + ) + deq_unbanked = dequantize_mixed_int6(quant_state["w"], quant_state["m"], unbanked_sd) + # Re-bank the dequantized tensors + deq_state = _rebank_state_dict(deq_unbanked, args.num_layers, sd_cpu) + eval_model = GPT( + vocab_size=args.vocab_size, num_layers=args.num_layers, model_dim=args.model_dim, + num_heads=args.num_heads, num_kv_heads=args.num_kv_heads, mlp_mult=args.mlp_mult, + tie_embeddings=args.tie_embeddings, tied_embed_init_std=args.tied_embed_init_std, + logit_softcap=args.logit_softcap, rope_base=args.rope_base, qk_gain_init=args.qk_gain_init, + mtp_num_heads=0, mtp_loss_weight=0.0, + bigram_vocab_size=args.bigram_vocab_size, bigram_dim=args.bigram_dim, + xsa_last_n=args.xsa_last_n, + rope_dims=args.rope_dims, ln_scale=args.ln_scale, dtg=args.dtg_enabled, + ve_enabled=args.ve_enabled, ve_dim=args.ve_dim, ve_layers=args.ve_layers, + gated_attention=args.gated_attention, value_residual=args.value_residual, + activation_mode=args.activation_mode, + activation_neg_slope=args.activation_neg_slope, + asymmetric_square_init=args.asymmetric_square_init, + gated_square_beta_init=args.gated_square_beta_init, + ).to(device).bfloat16() + eval_model.qo_bank.data = eval_model.qo_bank.data.float() + eval_model.kv_bank.data = eval_model.kv_bank.data.float() + eval_model.mlp_up_bank.data = eval_model.mlp_up_bank.data.float() + eval_model.mlp_down_bank.data = eval_model.mlp_down_bank.data.float() + for m in eval_model.modules(): + if isinstance(m, CastedLinear): + m.float() + restore_low_dim_params_to_fp32(eval_model) + eval_model.load_state_dict(deq_state, strict=True) + compiled_eval = torch.compile(eval_model, dynamic=False, fullgraph=True) + torch.cuda.synchronize() + t_qeval = time.perf_counter() + q_val_loss, q_val_bpb = eval_val( + args, compiled_eval, rank, world_size, device, grad_accum_steps, + val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, + eval_seq_len=effective_eval_seq_len, + ) + torch.cuda.synchronize() + log0( + f"final_int6_roundtrip val_loss:{q_val_loss:.4f} val_bpb:{q_val_bpb:.4f} " + f"eval_time:{1000.0 * (time.perf_counter() - t_qeval):.0f}ms" + ) + log0(f"final_int6_roundtrip_exact val_loss:{q_val_loss:.8f} val_bpb:{q_val_bpb:.8f}") + sw_seq_len = effective_eval_seq_len + if args.eval_stride > 0 and args.eval_stride < sw_seq_len: + torch.cuda.synchronize() + t_slide = time.perf_counter() + sw_val_loss, sw_val_bpb = eval_val_sliding( + args, eval_model, rank, world_size, device, + val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, + stride=args.eval_stride, + eval_seq_len=sw_seq_len, + ) + torch.cuda.synchronize() + log0( + f"final_int6_sliding_window val_loss:{sw_val_loss:.4f} val_bpb:{sw_val_bpb:.4f} " + f"stride:{args.eval_stride} eval_time:{1000.0 * (time.perf_counter() - t_slide):.0f}ms" + ) + log0(f"final_int6_sliding_window_exact val_loss:{sw_val_loss:.8f} val_bpb:{sw_val_bpb:.8f}") + log0(f"final_int8_zlib_roundtrip_exact val_loss:{sw_val_loss:.8f} val_bpb:{sw_val_bpb:.8f}") + if args.eval_stride != 64 and 64 < sw_seq_len: + torch.cuda.synchronize() + t_slide64 = time.perf_counter() + sw64_val_loss, sw64_val_bpb = eval_val_sliding( + args, eval_model, rank, world_size, device, + val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, + stride=64, + eval_seq_len=sw_seq_len, + ) + torch.cuda.synchronize() + log0( + f"final_int6_sliding_window_s64 val_loss:{sw64_val_loss:.4f} val_bpb:{sw64_val_bpb:.4f} " + f"stride:64 eval_time:{1000.0 * (time.perf_counter() - t_slide64):.0f}ms" + ) + log0(f"final_int6_sliding_window_s64_exact val_loss:{sw64_val_loss:.8f} val_bpb:{sw64_val_bpb:.8f}") + log0(f"final_int8_zlib_roundtrip_exact val_loss:{sw64_val_loss:.8f} val_bpb:{sw64_val_bpb:.8f}") + # Legal score-first TTT (PR #461 recipe) + if args.ttt_enabled: + torch.cuda.synchronize() + t_ttt = time.perf_counter() + ttt_loss, ttt_bpb = eval_val_sliding_ttt( + args, eval_model, rank, world_size, device, + val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, + stride=args.eval_stride, log0=log0, + ) + torch.cuda.synchronize() + log0(f"legal_ttt val_loss:{ttt_loss:.4f} val_bpb:{ttt_bpb:.4f} " + f"eval_time:{1000.0 * (time.perf_counter() - t_ttt):.0f}ms") + log0(f"legal_ttt_exact val_loss:{ttt_loss:.8f} val_bpb:{ttt_bpb:.8f}") + # --- N-gram two-pass rescore --- + if args.ngram_enabled: + # Use TTT-adapted model if available, otherwise use quantized eval model + ngram_model = eval_model + torch.cuda.synchronize() + t_ngram = time.perf_counter() + ng_val_loss, ng_val_bpb = eval_ngram_two_pass( + args, ngram_model, rank, world_size, device, + val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, + stride=args.eval_stride, log0=log0, + ) + torch.cuda.synchronize() + log0(f"ngram_two_pass val_loss:{ng_val_loss:.4f} val_bpb:{ng_val_bpb:.4f} " + f"eval_time:{1000.0 * (time.perf_counter() - t_ngram):.0f}ms") + log0(f"ngram_two_pass_exact val_loss:{ng_val_loss:.8f} val_bpb:{ng_val_bpb:.8f}") + log0(f"final_int8_zlib_roundtrip_exact val_loss:{ng_val_loss:.8f} val_bpb:{ng_val_bpb:.8f}") + if distributed: + dist.destroy_process_group() +if __name__ == "__main__": + main() diff --git a/records/track_10min_16mb/2026-03-26_WaterLOO_FullRescore_SelfExclusion_0.0990/train_seed1337.log b/records/track_10min_16mb/2026-03-26_WaterLOO_FullRescore_SelfExclusion_0.0990/train_seed1337.log new file mode 100644 index 0000000000..2e8e4a0396 --- /dev/null +++ b/records/track_10min_16mb/2026-03-26_WaterLOO_FullRescore_SelfExclusion_0.0990/train_seed1337.log @@ -0,0 +1,2582 @@ +from __future__ import annotations +import copy +import glob +import io +import lzma +import math +import os +import random +import subprocess +import sys +import time +import uuid +import zlib +from pathlib import Path +try: + import zstandard + _COMPRESSOR = "zstd" +except ImportError: + _COMPRESSOR = "zlib" +import numpy as np +import sentencepiece as spm +import torch +import torch.distributed as dist +import torch.nn.functional as F +from torch import Tensor, nn +from torch.nn.parallel import DistributedDataParallel as DDP +from flash_attn_interface import flash_attn_func as flash_attn_3_func +class Hyperparameters: + data_path = os.environ.get("DATA_PATH", "./data/datasets/fineweb10B_sp1024") + train_files = os.path.join(data_path, "fineweb_train_*.bin") + val_files = os.path.join(data_path, "fineweb_val_*.bin") + tokenizer_path = os.environ.get("TOKENIZER_PATH", "./data/tokenizers/fineweb_1024_bpe.model") + run_id = os.environ.get("RUN_ID", str(uuid.uuid4())) + seed = int(os.environ.get("SEED", 1337)) + val_batch_size = int(os.environ.get("VAL_BATCH_SIZE", 524_288)) + val_tokens_limit = int(os.environ.get("VAL_TOKENS_LIMIT", 0)) + val_loss_every = int(os.environ.get("VAL_LOSS_EVERY", 4000)) + train_log_every = int(os.environ.get("TRAIN_LOG_EVERY", 500)) + iterations = int(os.environ.get("ITERATIONS", 20000)) + warmdown_iters = int(os.environ.get("WARMDOWN_ITERS", 3500)) + warmup_steps = int(os.environ.get("WARMUP_STEPS", 20)) + train_batch_tokens = int(os.environ.get("TRAIN_BATCH_TOKENS", 786_432)) + train_seq_len = int(os.environ.get("TRAIN_SEQ_LEN", 2048)) + eval_seq_len = int(os.environ.get("EVAL_SEQ_LEN", 2048)) + max_wallclock_seconds = float(os.environ.get("MAX_WALLCLOCK_SECONDS", 600.0)) + qk_gain_init = float(os.environ.get("QK_GAIN_INIT", 1.5)) + vocab_size = int(os.environ.get("VOCAB_SIZE", 1024)) + num_layers = int(os.environ.get("NUM_LAYERS", 11)) + num_kv_heads = int(os.environ.get("NUM_KV_HEADS", 4)) + model_dim = int(os.environ.get("MODEL_DIM", 512)) + num_heads = int(os.environ.get("NUM_HEADS", 8)) + mlp_mult = float(os.environ.get("MLP_MULT", 3.0)) + tie_embeddings = bool(int(os.environ.get("TIE_EMBEDDINGS", "1"))) + rope_base = float(os.environ.get("ROPE_BASE", 10000.0)) + logit_softcap = float(os.environ.get("LOGIT_SOFTCAP", 30.0)) + embed_lr = float(os.environ.get("EMBED_LR", 0.6)) + head_lr = float(os.environ.get("HEAD_LR", 0.008)) + tied_embed_lr = float(os.environ.get("TIED_EMBED_LR", 0.035)) + tied_embed_init_std = float(os.environ.get("TIED_EMBED_INIT_STD", 0.005)) + matrix_lr = float(os.environ.get("MATRIX_LR", 0.025)) + scalar_lr = float(os.environ.get("SCALAR_LR", 0.025)) + muon_momentum = float(os.environ.get("MUON_MOMENTUM", 0.99)) + muon_backend_steps = int(os.environ.get("MUON_BACKEND_STEPS", 5)) + muon_momentum_warmup_start = float(os.environ.get("MUON_MOMENTUM_WARMUP_START", 0.92)) + muon_momentum_warmup_steps = int(os.environ.get("MUON_MOMENTUM_WARMUP_STEPS", 1500)) + beta1 = float(os.environ.get("BETA1", 0.9)) + beta2 = float(os.environ.get("BETA2", 0.95)) + adam_eps = float(os.environ.get("ADAM_EPS", 1e-8)) + grad_clip_norm = float(os.environ.get("GRAD_CLIP_NORM", 0.3)) + eval_stride = int(os.environ.get("EVAL_STRIDE", 64)) + mtp_num_heads = int(os.environ.get("MTP_NUM_HEADS", 0)) + mtp_loss_weight = float(os.environ.get("MTP_LOSS_WEIGHT", 0.2)) + muon_beta2 = float(os.environ.get("MUON_BETA2", 0.95)) + swa_enabled = bool(int(os.environ.get("SWA_ENABLED", "1"))) + swa_every = int(os.environ.get("SWA_EVERY", 50)) + lawa_enabled = bool(int(os.environ.get("LAWA_ENABLED", "0"))) + lawa_k = int(os.environ.get("LAWA_K", 10)) + lawa_freq = int(os.environ.get("LAWA_FREQ", 100)) + muon_wd = float(os.environ.get("MUON_WD", 0.04)) + adam_wd = float(os.environ.get("ADAM_WD", 0.04)) + qat_enabled = bool(int(os.environ.get("QAT_ENABLED", "0"))) + bigram_vocab_size = int(os.environ.get("BIGRAM_VOCAB_SIZE", 2048)) + bigram_dim = int(os.environ.get("BIGRAM_DIM", 128)) + xsa_last_n = int(os.environ.get("XSA_LAST_N", 4)) + rope_dims = int(os.environ.get("ROPE_DIMS", 16)) + ln_scale = bool(int(os.environ.get("LN_SCALE", "1"))) + dtg_enabled = bool(int(os.environ.get("DTG_ENABLED", "0"))) + late_qat_threshold = float(os.environ.get("LATE_QAT_THRESHOLD", 0.15)) + ve_enabled = bool(int(os.environ.get("VE_ENABLED", "1"))) + ve_dim = int(os.environ.get("VE_DIM", 128)) + ve_layers = os.environ.get("VE_LAYERS", "9,10") + gated_attention = bool(int(os.environ.get("GATED_ATTENTION", "0"))) + value_residual = bool(int(os.environ.get("VALUE_RESIDUAL", "0"))) + activation_mode = os.environ.get("ACTIVATION_MODE", "leaky_relu_sq") + activation_neg_slope = float(os.environ.get("ACTIVATION_NEG_SLOPE", 0.5)) + asymmetric_square_init = float(os.environ.get("ASYMMETRIC_SQUARE_INIT", 0.25)) + gated_square_beta_init = float(os.environ.get("GATED_SQUARE_BETA_INIT", 1.0)) + ttt_enabled = bool(int(os.environ.get("TTT_ENABLED", "0"))) + ttt_lr = float(os.environ.get("TTT_LR", 0.002)) + ttt_epochs = int(os.environ.get("TTT_EPOCHS", 3)) + ttt_chunk_tokens = int(os.environ.get("TTT_CHUNK_TOKENS", 32768)) + ttt_freeze_blocks = int(os.environ.get("TTT_FREEZE_BLOCKS", 2)) + ttt_momentum = float(os.environ.get("TTT_MOMENTUM", 0.9)) + ttt_batch_seqs = int(os.environ.get("TTT_BATCH_SEQS", 32)) + ttt_grad_clip = float(os.environ.get("TTT_GRAD_CLIP", 1.0)) + # N-gram eval cache + ngram_enabled = bool(int(os.environ.get("NGRAM_ENABLED", "1"))) + ngram_min_order = int(os.environ.get("NGRAM_MIN_ORDER", 2)) + ngram_max_order = int(os.environ.get("NGRAM_MAX_ORDER", 12)) + ngram_num_buckets = int(os.environ.get("NGRAM_NUM_BUCKETS", 16_777_216)) # 16M + ngram_chunk_size = int(os.environ.get("NGRAM_CHUNK_SIZE", 512)) + ngram_alpha_min = float(os.environ.get("NGRAM_ALPHA_MIN", 0.05)) + ngram_alpha_max = float(os.environ.get("NGRAM_ALPHA_MAX", 0.70)) + ngram_entropy_center = float(os.environ.get("NGRAM_ENTROPY_CENTER", 3.0)) + ngram_entropy_scale = float(os.environ.get("NGRAM_ENTROPY_SCALE", 2.0)) + ngram_min_count = int(os.environ.get("NGRAM_MIN_COUNT", 2)) + ngram_leave_one_out = bool(int(os.environ.get("NGRAM_LEAVE_ONE_OUT", "1"))) + # Complementary training + complement_enabled = bool(int(os.environ.get("COMPLEMENT_ENABLED", "0"))) + complement_alpha = float(os.environ.get("COMPLEMENT_ALPHA", 0.5)) + +# --- Batched Newton-Schulz orthogonalization --- + +def zeropower_via_newtonschulz5(G: Tensor, steps: int = 5, eps: float = 1e-7) -> Tensor: + """Batched Newton-Schulz orthogonalization. G: (B,M,N) or (M,N).""" + a, b, c = (3.4445, -4.7750, 2.0315) + was_2d = G.ndim == 2 + if was_2d: + G = G.unsqueeze(0) + X = G.bfloat16() + transposed = X.size(-2) > X.size(-1) + if transposed: + X = X.mT + X = X / (X.norm(dim=(-2, -1), keepdim=True) + eps) + for _ in range(steps): + A = X @ X.mT + B = b * A + c * (A @ A) + X = a * X + B @ X + if transposed: + X = X.mT + if was_2d: + X = X.squeeze(0) + return X + +# --- Parallel Muon optimizer --- + +class Muon(torch.optim.Optimizer): + """Parallel Muon: post-backward reduce-scatter -> local NS5 -> all-gather. + + No DDP for bank params. After backward, this optimizer: + 1. Launches async reduce-scatter for all banks (biggest first) + 2. Returns control so Adam can step on small params while RS is in-flight + 3. Waits for each RS, runs local NS5 on the shard, launches async all-gather + 4. Each all-gather overlaps with next bank's NS5 + """ + def __init__(self, params, lr: float, momentum: float, backend_steps: int, + nesterov: bool = True, weight_decay: float = 0.0): + super().__init__( + params, + dict(lr=lr, momentum=momentum, backend_steps=backend_steps, + nesterov=nesterov, weight_decay=weight_decay), + ) + self._built = False + + def _build(self): + self._distributed = dist.is_available() and dist.is_initialized() + self._world_size = dist.get_world_size() if self._distributed else 1 + self._rank = dist.get_rank() if self._distributed else 0 + ws = self._world_size + + self._bank_meta = [] + for group in self.param_groups: + for p in group["params"]: + B = p.shape[0] + padded_B = ((B + ws - 1) // ws) * ws + shard_B = padded_B // ws + tail = p.shape[1:] + dev = p.device + self._bank_meta.append({ + 'p': p, + 'B': B, + 'padded_grad': torch.zeros(padded_B, *tail, device=dev, dtype=torch.bfloat16), + 'shard': torch.zeros(shard_B, *tail, device=dev, dtype=torch.bfloat16), + 'shard_mom': torch.zeros(shard_B, *tail, device=dev, dtype=torch.bfloat16), + 'full_update': torch.zeros(padded_B, *tail, device=dev, dtype=torch.bfloat16), + 'scale': max(1, p.shape[-2] / p.shape[-1]) ** 0.5, + }) + # Sort by size descending -- launch biggest reduce-scatters first + self._bank_meta.sort(key=lambda m: -m['p'].numel()) + self._built = True + + def launch_reduce_scatters(self): + """Phase 1: launch async reduce-scatter for all banks. Call right after backward.""" + if not self._built: + self._build() + if not self._distributed: + return + self._rs_futures = [] + for m in self._bank_meta: + p = m['p'] + if p.grad is None: + self._rs_futures.append(None) + continue + pg = m['padded_grad'] + pg[:m['B']].copy_(p.grad.bfloat16()) + if pg.shape[0] > m['B']: + pg[m['B']:].zero_() + fut = dist.reduce_scatter_tensor(m['shard'], pg, op=dist.ReduceOp.AVG, async_op=True) + self._rs_futures.append(fut) + + @torch.no_grad() + def step(self, closure=None): + """Phase 3: wait for RS, local NS5, all-gather. Call AFTER Adam steps.""" + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + + if not self._built: + self._build() + + for group in self.param_groups: + lr = group["lr"] + momentum = group["momentum"] + backend_steps = group["backend_steps"] + nesterov = group["nesterov"] + wd = group.get("weight_decay", 0.0) + + prev_ag_handle = None + prev_m = None + + sharded = self._distributed and hasattr(self, '_rs_futures') + + for i, m in enumerate(self._bank_meta): + p = m['p'] + if p.grad is None: + continue + + if prev_ag_handle is not None: + prev_ag_handle.wait() + pp = prev_m['p'] + upd = prev_m['full_update'][:prev_m['B']] + if wd > 0.0: + pp.data.mul_(1.0 - lr * wd) + pp.add_(upd.to(dtype=pp.dtype), alpha=-lr * prev_m['scale']) + + if sharded and self._rs_futures[i] is not None: + self._rs_futures[i].wait() + g = m['shard'] + buf = m['shard_mom'] + else: + g = p.grad.bfloat16() + state = self.state[p] + if "momentum_buffer" not in state: + state["momentum_buffer"] = torch.zeros_like(g) + buf = state["momentum_buffer"] + + buf.mul_(momentum).add_(g) + if nesterov: + update = g.add(buf, alpha=momentum) + else: + update = buf + + update = zeropower_via_newtonschulz5(update, steps=backend_steps) + + if sharded: + prev_ag_handle = dist.all_gather_into_tensor( + m['full_update'], update, async_op=True) + prev_m = m + else: + if wd > 0.0: + p.data.mul_(1.0 - lr * wd) + p.add_(update.to(dtype=p.dtype), alpha=-lr * m['scale']) + + if prev_ag_handle is not None: + prev_ag_handle.wait() + pp = prev_m['p'] + upd = prev_m['full_update'][:prev_m['B']] + if wd > 0.0: + pp.data.mul_(1.0 - lr * wd) + pp.add_(upd.to(dtype=pp.dtype), alpha=-lr * prev_m['scale']) + + if hasattr(self, '_rs_futures'): + del self._rs_futures + + return loss + +# --- Tokenizer evaluation helpers --- + +def build_sentencepiece_luts( + sp: spm.SentencePieceProcessor, vocab_size: int, device: torch.device +) -> tuple[Tensor, Tensor, Tensor]: + sp_vocab_size = int(sp.vocab_size()) + table_size = max(sp_vocab_size, vocab_size) + base_bytes_np = np.zeros((table_size,), dtype=np.int16) + has_leading_space_np = np.zeros((table_size,), dtype=np.bool_) + is_boundary_token_np = np.ones((table_size,), dtype=np.bool_) + for token_id in range(sp_vocab_size): + if sp.is_control(token_id) or sp.is_unknown(token_id) or sp.is_unused(token_id): + continue + is_boundary_token_np[token_id] = False + if sp.is_byte(token_id): + base_bytes_np[token_id] = 1 + continue + piece = sp.id_to_piece(token_id) + if piece.startswith("\u2581"): + has_leading_space_np[token_id] = True + piece = piece[1:] + base_bytes_np[token_id] = len(piece.encode("utf-8")) + return ( + torch.tensor(base_bytes_np, dtype=torch.int16, device=device), + torch.tensor(has_leading_space_np, dtype=torch.bool, device=device), + torch.tensor(is_boundary_token_np, dtype=torch.bool, device=device), + ) +def load_validation_tokens(pattern: str, seq_len: int, token_limit: int = 0) -> Tensor: + files = [Path(p) for p in sorted(glob.glob(pattern))] + if not files: + raise FileNotFoundError(f"No files found for pattern: {pattern}") + tokens = torch.cat([load_data_shard(file) for file in files]).contiguous() + if token_limit > 0: + tokens = tokens[: min(tokens.numel(), token_limit + 1)] + usable = ((tokens.numel() - 1) // seq_len) * seq_len + if usable <= 0: + raise ValueError(f"Validation split is too short for TRAIN_SEQ_LEN={seq_len}") + return tokens[: usable + 1] +def eval_val( + args: Hyperparameters, + model: nn.Module, + rank: int, + world_size: int, + device: torch.device, + grad_accum_steps: int, + val_tokens: Tensor, + base_bytes_lut: Tensor, + has_leading_space_lut: Tensor, + is_boundary_token_lut: Tensor, + eval_seq_len: int | None = None, +) -> tuple[float, float]: + seq_len = eval_seq_len or args.train_seq_len + local_batch_tokens = args.val_batch_size // (world_size * grad_accum_steps) + if local_batch_tokens < seq_len: + raise ValueError( + "VAL_BATCH_SIZE must provide at least one sequence per rank; " + f"got VAL_BATCH_SIZE={args.val_batch_size}, WORLD_SIZE={world_size}, " + f"GRAD_ACCUM_STEPS={grad_accum_steps}, seq_len={seq_len}" + ) + local_batch_seqs = local_batch_tokens // seq_len + total_seqs = (val_tokens.numel() - 1) // seq_len + seq_start = (total_seqs * rank) // world_size + seq_end = (total_seqs * (rank + 1)) // world_size + val_loss_sum = torch.zeros((), device=device, dtype=torch.float64) + val_token_count = torch.zeros((), device=device, dtype=torch.float64) + val_byte_count = torch.zeros((), device=device, dtype=torch.float64) + model.eval() + with torch.inference_mode(): + for batch_seq_start in range(seq_start, seq_end, local_batch_seqs): + batch_seq_end = min(batch_seq_start + local_batch_seqs, seq_end) + raw_start = batch_seq_start * seq_len + raw_end = batch_seq_end * seq_len + 1 + local = val_tokens[raw_start:raw_end].to(device=device, dtype=torch.int64, non_blocking=True) + x = local[:-1].reshape(-1, seq_len) + y = local[1:].reshape(-1, seq_len) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + batch_loss = model(x, y).detach() + batch_token_count = float(y.numel()) + val_loss_sum += batch_loss.to(torch.float64) * batch_token_count + val_token_count += batch_token_count + prev_ids = x.reshape(-1) + tgt_ids = y.reshape(-1) + token_bytes = base_bytes_lut[tgt_ids].to(dtype=torch.int16) + token_bytes += (has_leading_space_lut[tgt_ids] & ~is_boundary_token_lut[prev_ids]).to(dtype=torch.int16) + val_byte_count += token_bytes.to(torch.float64).sum() + if dist.is_available() and dist.is_initialized(): + dist.all_reduce(val_loss_sum, op=dist.ReduceOp.SUM) + dist.all_reduce(val_token_count, op=dist.ReduceOp.SUM) + dist.all_reduce(val_byte_count, op=dist.ReduceOp.SUM) + val_loss = val_loss_sum / val_token_count + bits_per_token = val_loss.item() / math.log(2.0) + tokens_per_byte = val_token_count.item() / val_byte_count.item() + model.train() + return float(val_loss.item()), float(bits_per_token * tokens_per_byte) + +# --- Quantization helpers --- + +CONTROL_TENSOR_NAME_PATTERNS = tuple( + pattern + for pattern in os.environ.get( + "CONTROL_TENSOR_NAME_PATTERNS", + "attn_scale,attn_scales,mlp_scale,mlp_scales,resid_mix,resid_mixes,q_gain,skip_weight,skip_weights,smear,dtg_gate,ve_layer_scales,ve_shared.scale,attn_gate,vr_lambda", + ).split(",") + if pattern +) +INT8_KEEP_FLOAT_FP32_NAME_PATTERNS = tuple( + pattern + for pattern in os.environ.get( + "INT8_KEEP_FLOAT_FP32_NAME_PATTERNS", + ",".join(CONTROL_TENSOR_NAME_PATTERNS), + ).split(",") + if pattern +) +INT8_KEEP_FLOAT_MAX_NUMEL = 65_536 +INT8_KEEP_FLOAT_STORE_DTYPE = torch.float16 +INT8_PER_ROW_SCALE_DTYPE = torch.float16 +INT8_CLIP_PERCENTILE = 99.99984 +INT8_CLIP_Q = INT8_CLIP_PERCENTILE / 100.0 +def tensor_nbytes(t: Tensor) -> int: + return int(t.numel()) * int(t.element_size()) +def keep_float_tensor(name: str, t: Tensor, passthrough_orig_dtypes: dict[str, str]) -> Tensor: + if any(pattern in name for pattern in INT8_KEEP_FLOAT_FP32_NAME_PATTERNS): + return t.float().contiguous() + if t.dtype in {torch.float32, torch.bfloat16}: + passthrough_orig_dtypes[name] = str(t.dtype).removeprefix("torch.") + return t.to(dtype=INT8_KEEP_FLOAT_STORE_DTYPE).contiguous() + return t +def quantize_float_tensor(t: Tensor) -> tuple[Tensor, Tensor]: + t32 = t.float() + if t32.ndim == 2: + clip_abs = ( + torch.quantile(t32.abs(), INT8_CLIP_Q, dim=1) + if t32.numel() + else torch.empty((t32.shape[0],), dtype=torch.float32) + ) + clipped = torch.maximum(torch.minimum(t32, clip_abs[:, None]), -clip_abs[:, None]) + scale = (clip_abs / 127.0).clamp_min(1.0 / 127.0) + q = torch.clamp(torch.round(clipped / scale[:, None]), -127, 127).to(torch.int8).contiguous() + return q, scale.to(dtype=INT8_PER_ROW_SCALE_DTYPE).contiguous() + clip_abs = float(torch.quantile(t32.abs().flatten(), INT8_CLIP_Q).item()) if t32.numel() else 0.0 + scale = torch.tensor(clip_abs / 127.0 if clip_abs > 0 else 1.0, dtype=torch.float32) + q = torch.clamp(torch.round(torch.clamp(t32, -clip_abs, clip_abs) / scale), -127, 127).to(torch.int8).contiguous() + return q, scale +def quantize_state_dict_int8(state_dict: dict[str, Tensor]): + quantized: dict[str, Tensor] = {} + scales: dict[str, Tensor] = {} + dtypes: dict[str, str] = {} + passthrough: dict[str, Tensor] = {} + passthrough_orig_dtypes: dict[str, str] = {} + qmeta: dict[str, dict[str, object]] = {} + stats = dict.fromkeys( + ("param_count", "num_tensors", "num_float_tensors", "num_nonfloat_tensors", "baseline_tensor_bytes", "int8_payload_bytes"), + 0, + ) + for name, tensor in state_dict.items(): + t = tensor.detach().to("cpu").contiguous() + stats["param_count"] += int(t.numel()) + stats["num_tensors"] += 1 + stats["baseline_tensor_bytes"] += tensor_nbytes(t) + if not t.is_floating_point(): + stats["num_nonfloat_tensors"] += 1 + passthrough[name] = t + stats["int8_payload_bytes"] += tensor_nbytes(t) + continue + if t.numel() <= INT8_KEEP_FLOAT_MAX_NUMEL: + kept = keep_float_tensor(name, t, passthrough_orig_dtypes) + passthrough[name] = kept + stats["int8_payload_bytes"] += tensor_nbytes(kept) + continue + stats["num_float_tensors"] += 1 + q, s = quantize_float_tensor(t) + if s.ndim > 0: + qmeta[name] = {"scheme": "per_row", "axis": 0} + quantized[name] = q + scales[name] = s + dtypes[name] = str(t.dtype).removeprefix("torch.") + stats["int8_payload_bytes"] += tensor_nbytes(q) + tensor_nbytes(s) + obj: dict[str, object] = { + "__quant_format__": "int8_clean_per_row_v1", + "quantized": quantized, + "scales": scales, + "dtypes": dtypes, + "passthrough": passthrough, + } + if qmeta: + obj["qmeta"] = qmeta + if passthrough_orig_dtypes: + obj["passthrough_orig_dtypes"] = passthrough_orig_dtypes + return obj, stats +def dequantize_state_dict_int8(obj: dict[str, object]) -> dict[str, Tensor]: + out: dict[str, Tensor] = {} + qmeta = obj.get("qmeta", {}) + passthrough_orig_dtypes = obj.get("passthrough_orig_dtypes", {}) + for name, q in obj["quantized"].items(): + dtype = getattr(torch, obj["dtypes"][name]) + s = obj["scales"][name] + if qmeta.get(name, {}).get("scheme") == "per_row" or s.ndim > 0: + s = s.to(dtype=torch.float32) + out[name] = (q.float() * s.view(q.shape[0], *([1] * (q.ndim - 1)))).to(dtype=dtype).contiguous() + else: + scale = float(s.item()) + out[name] = (q.float() * scale).to(dtype=dtype).contiguous() + for name, t in obj["passthrough"].items(): + out_t = t.detach().to("cpu").contiguous() + orig_dtype = passthrough_orig_dtypes.get(name) + if isinstance(orig_dtype, str): + out_t = out_t.to(dtype=getattr(torch, orig_dtype)).contiguous() + out[name] = out_t + return out + +# --- Data loading --- + +def load_data_shard(file: Path) -> Tensor: + header_bytes = 256 * np.dtype(" None: + self.file_idx = (self.file_idx + 1) % len(self.files) + self.tokens = load_data_shard(self.files[self.file_idx]) + self.pos = 0 + def take(self, n: int) -> Tensor: + chunks: list[Tensor] = [] + remaining = n + while remaining > 0: + avail = self.tokens.numel() - self.pos + if avail <= 0: + self._advance_file() + continue + k = min(remaining, avail) + chunks.append(self.tokens[self.pos : self.pos + k]) + self.pos += k + remaining -= k + return chunks[0] if len(chunks) == 1 else torch.cat(chunks) +class DistributedTokenLoader: + def __init__(self, pattern: str, rank: int, world_size: int, device: torch.device): + self.rank = rank + self.world_size = world_size + self.device = device + self.stream = TokenStream(pattern) + def next_batch(self, global_tokens: int, seq_len: int, grad_accum_steps: int) -> tuple[Tensor, Tensor]: + local_tokens = global_tokens // (self.world_size * grad_accum_steps) + per_rank_span = local_tokens + 1 + chunk = self.stream.take(per_rank_span * self.world_size) + start = self.rank * per_rank_span + local = chunk[start : start + per_rank_span].to(dtype=torch.int64) + x = local[:-1].reshape(-1, seq_len) + y = local[1:].reshape(-1, seq_len) + return x.to(self.device, non_blocking=True), y.to(self.device, non_blocking=True) + +# --- Transformer modules --- + +class RMSNorm(nn.Module): + def __init__(self, eps: float | None = None): + super().__init__() + self.eps = eps + def forward(self, x: Tensor) -> Tensor: + return F.rms_norm(x, (x.size(-1),), eps=self.eps) +class CastedLinear(nn.Linear): + _qat_enabled: bool = False + def forward(self, x: Tensor) -> Tensor: + w = self.weight.to(x.dtype) + if CastedLinear._qat_enabled and self.training and w.ndim == 2: + with torch.no_grad(): + w32 = self.weight.float() + row_max = w32.abs().amax(dim=1) + scale = (row_max / 31.0).clamp_min(1.0 / 31.0) + w_q = (torch.clamp(torch.round(w32 / scale[:, None]), -32, 31) * scale[:, None]).to(x.dtype) + w = w + (w_q - w).detach() + bias = self.bias.to(x.dtype) if self.bias is not None else None + return F.linear(x, w, bias) +def restore_low_dim_params_to_fp32(module: nn.Module) -> None: + with torch.no_grad(): + for name, param in module.named_parameters(): + if (param.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS)) and param.dtype != torch.float32: + param.data = param.data.float() +class Rotary(nn.Module): + def __init__(self, dim: int, base: float = 10000.0, train_seq_len: int = 1024, rope_dims: int = 0): + super().__init__() + self.dim = dim + self.base = base + self.train_seq_len = train_seq_len + self.rope_dims = rope_dims if rope_dims > 0 else dim + inv_freq = 1.0 / (base ** (torch.arange(0, self.rope_dims, 2, dtype=torch.float32) / self.rope_dims)) + self.register_buffer("inv_freq", inv_freq, persistent=False) + self._seq_len_cached = 0 + self._cos_cached: Tensor | None = None + self._sin_cached: Tensor | None = None + def forward(self, seq_len: int, device: torch.device, dtype: torch.dtype) -> tuple[Tensor, Tensor]: + if ( + self._cos_cached is None + or self._sin_cached is None + or self._seq_len_cached != seq_len + or self._cos_cached.device != device + ): + rd = self.rope_dims + if seq_len > self.train_seq_len: + scale = seq_len / self.train_seq_len + new_base = self.base * (scale ** (rd / (rd - 2))) + inv_freq = 1.0 / (new_base ** (torch.arange(0, rd, 2, dtype=torch.float32, device=device) / rd)) + else: + inv_freq = self.inv_freq.to(device) + t = torch.arange(seq_len, device=device, dtype=inv_freq.dtype) + freqs = torch.outer(t, inv_freq) + self._cos_cached = freqs.cos()[None, :, None, :] + self._sin_cached = freqs.sin()[None, :, None, :] + self._seq_len_cached = seq_len + return self._cos_cached.to(dtype=dtype), self._sin_cached.to(dtype=dtype) +def apply_rotary_emb(x: Tensor, cos: Tensor, sin: Tensor, rope_dims: int = 0) -> Tensor: + if rope_dims > 0 and rope_dims < x.size(-1): + x_rope, x_pass = x[..., :rope_dims], x[..., rope_dims:] + half = rope_dims // 2 + x1, x2 = x_rope[..., :half], x_rope[..., half:] + x_rope = torch.cat((x1 * cos + x2 * sin, x1 * (-sin) + x2 * cos), dim=-1) + return torch.cat((x_rope, x_pass), dim=-1) + half = x.size(-1) // 2 + x1, x2 = x[..., :half], x[..., half:] + return torch.cat((x1 * cos + x2 * sin, x1 * (-sin) + x2 * cos), dim=-1) + +class CausalSelfAttention(nn.Module): + def __init__( + self, + dim: int, + num_heads: int, + num_kv_heads: int, + rope_base: float, + qk_gain_init: float, + gated_attention: bool = False, + value_residual: bool = False, + ): + super().__init__() + if dim % num_heads != 0: + raise ValueError("model_dim must be divisible by num_heads") + if num_heads % num_kv_heads != 0: + raise ValueError("num_heads must be divisible by num_kv_heads") + self.num_heads = num_heads + self.num_kv_heads = num_kv_heads + self.head_dim = dim // num_heads + if self.head_dim % 2 != 0: + raise ValueError("head_dim must be even for RoPE") + # No CastedLinear -- weights come from banks + self.q_gain = nn.Parameter(torch.full((num_heads,), qk_gain_init, dtype=torch.float32)) + self.rope_dims = 0 # set by GPT.__init__ for partial RoPE + self.rotary = Rotary(self.head_dim, base=rope_base, train_seq_len=1024) + self.use_xsa = False # set by GPT.__init__ for deep layers only + # Gated attention and value residual (non-banked small params) + self.gated_attention = gated_attention + if gated_attention: + self.attn_gate = nn.Linear(dim, num_heads, bias=True) + nn.init.zeros_(self.attn_gate.weight) + nn.init.constant_(self.attn_gate.bias, 4.0) + self.value_residual = value_residual + if value_residual: + self.vr_lambda = nn.Parameter(torch.tensor([0.5, 0.5], dtype=torch.float32)) + def _xsa_efficient(self, y: Tensor, v: Tensor) -> Tensor: + """Efficient XSA: subtract self-value projection via GQA-aware reshape (no repeat_interleave). + y: [B, T, H, D], v: [B, T, Hkv, D]. H must be divisible by Hkv.""" + B, T, H, D = y.shape + Hkv = v.size(-2) + group = H // Hkv + y_g = y.reshape(B, T, Hkv, group, D) # [B, T, Hkv, group, D] + vn = F.normalize(v, dim=-1).unsqueeze(-2) # [B, T, Hkv, 1, D] -- broadcast ready + proj = (y_g * vn).sum(dim=-1, keepdim=True) * vn + return (y_g - proj).reshape(B, T, H, D) + def forward(self, x: Tensor, q_w: Tensor, k_w: Tensor, v_w: Tensor, out_w: Tensor, v_embed: Tensor | None = None, v0: Tensor | None = None) -> tuple[Tensor, Tensor | None]: + bsz, seqlen, dim = x.shape + q = F.linear(x, q_w.to(x.dtype)).reshape(bsz, seqlen, self.num_heads, self.head_dim) + k = F.linear(x, k_w.to(x.dtype)).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim) + v = F.linear(x, v_w.to(x.dtype)) + if v_embed is not None: + v = v + v_embed + v = v.reshape(bsz, seqlen, self.num_kv_heads, self.head_dim) + raw_v = v if self.value_residual else None + if self.value_residual and v0 is not None: + lam = self.vr_lambda.to(dtype=v.dtype) + v = lam[0] * v0 + lam[1] * v + q = F.rms_norm(q, (q.size(-1),)) + k = F.rms_norm(k, (k.size(-1),)) + cos, sin = self.rotary(seqlen, x.device, q.dtype) + q = apply_rotary_emb(q, cos, sin, self.rope_dims) + k = apply_rotary_emb(k, cos, sin, self.rope_dims) + q = q * self.q_gain.to(dtype=q.dtype)[None, None, :, None] + y = flash_attn_3_func(q, k, v, causal=True) + if self.use_xsa: + y = self._xsa_efficient(y, v) + if self.gated_attention: + # gate shape: (bsz, seqlen, num_heads) -> (bsz, seqlen, num_heads, 1) for B,T,H,D layout + gate = torch.sigmoid(self.attn_gate(x)).unsqueeze(-1) + y = y * gate + y = y.reshape(bsz, seqlen, dim) + return F.linear(y, out_w.to(x.dtype)), raw_v + +class SmearGate(nn.Module): + def __init__(self, dim: int): + super().__init__() + self.gate = nn.Parameter(torch.zeros(dim, dtype=torch.float32)) + def forward(self, x: Tensor) -> Tensor: + g = torch.sigmoid(self.gate.to(dtype=x.dtype))[None, None, :] + x_prev = torch.cat([torch.zeros_like(x[:, :1]), x[:, :-1]], dim=1) + return (1 - g) * x + g * x_prev + +class BigramHashEmbedding(nn.Module): + def __init__(self, bigram_vocab_size: int, bigram_dim: int, model_dim: int): + super().__init__() + self.bigram_vocab_size = bigram_vocab_size + self.embed = nn.Embedding(bigram_vocab_size, bigram_dim) + nn.init.zeros_(self.embed.weight) + self.proj = CastedLinear(bigram_dim, model_dim, bias=False) if bigram_dim != model_dim else None + if self.proj is not None: + nn.init.zeros_(self.proj.weight) + self.scale = nn.Parameter(torch.tensor(0.05, dtype=torch.float32)) + def bigram_hash(self, tokens: Tensor) -> Tensor: + t = tokens.to(torch.int32) + mod = self.bigram_vocab_size - 1 + out = torch.empty_like(t) + out[..., 0] = mod + out[..., 1:] = torch.bitwise_xor(36313 * t[..., 1:], 27191 * t[..., :-1]) % mod + return out.long() + def forward(self, token_ids: Tensor) -> Tensor: + h = self.embed(self.bigram_hash(token_ids)) + if self.proj is not None: + h = self.proj(h) + return h * self.scale.to(dtype=h.dtype) + +class ValueEmbedding(nn.Module): + """Reinject token identity into attention values at specific layers. + Each table maps vocab tokens to a low-dim embedding, projected to model_dim.""" + def __init__(self, vocab_size: int, ve_dim: int, model_dim: int): + super().__init__() + self.embed = nn.Embedding(vocab_size, ve_dim) + nn.init.normal_(self.embed.weight, std=0.01) + self.proj = CastedLinear(ve_dim, model_dim, bias=False) if ve_dim != model_dim else None + if self.proj is not None: + nn.init.zeros_(self.proj.weight) + self.scale = nn.Parameter(torch.tensor(0.1, dtype=torch.float32)) + def forward(self, token_ids: Tensor) -> Tensor: + h = self.embed(token_ids) + if self.proj is not None: + h = self.proj(h) + return h * self.scale.to(dtype=h.dtype) + +class MLP(nn.Module): + def __init__( + self, + dim: int, + mlp_mult: int, + activation_mode: str = "leaky_relu_sq", + activation_neg_slope: float = 0.5, + asymmetric_square_init: float = 0.25, + gated_square_beta_init: float = 1.0, + ): + super().__init__() + # No CastedLinear -- weights come from banks + self.activation_mode = activation_mode + self.activation_neg_slope = activation_neg_slope + if activation_mode == "asymmetric_square": + self.neg_sq_scale = nn.Parameter(torch.tensor(asymmetric_square_init, dtype=torch.float32)) + else: + self.neg_sq_scale = None + if activation_mode == "gated_square": + self.gated_square_beta = nn.Parameter(torch.tensor(gated_square_beta_init, dtype=torch.float32)) + else: + self.gated_square_beta = None + def forward(self, x: Tensor, up_w: Tensor, down_w: Tensor) -> Tensor: + u = F.linear(x, up_w.to(x.dtype)) + if self.activation_mode == "leaky_relu_sq": + h = F.leaky_relu(u, negative_slope=self.activation_neg_slope).square() + elif self.activation_mode == "asymmetric_square": + neg_sq_scale = self.neg_sq_scale.to(dtype=u.dtype).clamp(0.0, 4.0) + h = F.relu(u).square() + neg_sq_scale * F.relu(-u).square() + elif self.activation_mode == "gated_square": + beta = self.gated_square_beta.to(dtype=u.dtype).clamp(0.0, 8.0) + h = u.square() * torch.sigmoid(beta * u) + elif self.activation_mode == "sign_preserving_square": + h = u * u.abs() + else: + raise ValueError(f"Unknown ACTIVATION_MODE={self.activation_mode}") + return F.linear(h, down_w.to(x.dtype)) + +class Block(nn.Module): + def __init__( + self, + dim: int, + num_heads: int, + num_kv_heads: int, + mlp_mult: int, + rope_base: float, + qk_gain_init: float, + layer_idx: int = 0, + ln_scale: bool = False, + dtg: bool = False, + gated_attention: bool = False, + value_residual: bool = False, + activation_mode: str = "leaky_relu_sq", + activation_neg_slope: float = 0.5, + asymmetric_square_init: float = 0.25, + gated_square_beta_init: float = 1.0, + ): + super().__init__() + self.attn_norm = RMSNorm() + self.mlp_norm = RMSNorm() + self.attn = CausalSelfAttention(dim, num_heads, num_kv_heads, rope_base, qk_gain_init, + gated_attention=gated_attention, value_residual=value_residual) + self.mlp = MLP( + dim, + mlp_mult, + activation_mode=activation_mode, + activation_neg_slope=activation_neg_slope, + asymmetric_square_init=asymmetric_square_init, + gated_square_beta_init=gated_square_beta_init, + ) + self.attn_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + self.mlp_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + self.resid_mix = nn.Parameter(torch.stack((torch.ones(dim), torch.zeros(dim))).float()) + self.ln_scale_factor = 1.0 / math.sqrt(layer_idx + 1) if ln_scale else 1.0 + if dtg: + self.dtg_gate = nn.Linear(dim, 1, bias=True) + nn.init.zeros_(self.dtg_gate.weight) + nn.init.constant_(self.dtg_gate.bias, 2.0) + else: + self.dtg_gate = None + def forward(self, x: Tensor, x0: Tensor, q_w: Tensor, k_w: Tensor, v_w: Tensor, out_w: Tensor, up_w: Tensor, down_w: Tensor, v_embed: Tensor | None = None, v0: Tensor | None = None) -> tuple[Tensor, Tensor | None]: + mix = self.resid_mix.to(dtype=x.dtype) + x_in = mix[0][None, None, :] * x + mix[1][None, None, :] * x0 + attn_out, raw_v = self.attn(self.attn_norm(x_in) * self.ln_scale_factor, q_w, k_w, v_w, out_w, v_embed=v_embed, v0=v0) + x_out = x_in + self.attn_scale.to(dtype=x_in.dtype)[None, None, :] * attn_out + x_out = x_out + self.mlp_scale.to(dtype=x_out.dtype)[None, None, :] * self.mlp(self.mlp_norm(x_out) * self.ln_scale_factor, up_w, down_w) + if self.dtg_gate is not None: + gate = torch.sigmoid(self.dtg_gate(x_in.detach())) + x_out = x_in + gate * (x_out - x_in) + return x_out, raw_v + +class GPT(nn.Module): + def __init__( + self, + vocab_size: int, + num_layers: int, + model_dim: int, + num_heads: int, + num_kv_heads: int, + mlp_mult: int, + tie_embeddings: bool, + tied_embed_init_std: float, + logit_softcap: float, + rope_base: float, + qk_gain_init: float, + mtp_num_heads: int = 0, + mtp_loss_weight: float = 0.1, + bigram_vocab_size: int = 0, + bigram_dim: int = 128, + xsa_last_n: int = 0, + rope_dims: int = 0, + ln_scale: bool = False, + dtg: bool = False, + ve_enabled: bool = False, + ve_dim: int = 128, + ve_layers: str = "9,10", + gated_attention: bool = False, + value_residual: bool = False, + activation_mode: str = "leaky_relu_sq", + activation_neg_slope: float = 0.5, + asymmetric_square_init: float = 0.25, + gated_square_beta_init: float = 1.0, + ): + super().__init__() + self._ve_target_dim = num_kv_heads * (model_dim // num_heads) # kv_dim for value projection + if logit_softcap <= 0.0: + raise ValueError(f"logit_softcap must be positive, got {logit_softcap}") + self.tie_embeddings = tie_embeddings + self.tied_embed_init_std = tied_embed_init_std + self.logit_softcap = logit_softcap + self.value_residual = value_residual + self.mtp_num_heads = mtp_num_heads + self.mtp_loss_weight = mtp_loss_weight + self.tok_emb = nn.Embedding(vocab_size, model_dim) + self.bigram = BigramHashEmbedding(bigram_vocab_size, bigram_dim, model_dim) if bigram_vocab_size > 0 else None + self.smear = SmearGate(model_dim) + self.num_encoder_layers = num_layers // 2 + self.num_decoder_layers = num_layers - self.num_encoder_layers + self.num_skip_weights = min(self.num_encoder_layers, self.num_decoder_layers) + self.skip_weights = nn.Parameter(torch.ones(self.num_skip_weights, model_dim, dtype=torch.float32)) + # Parameter banks: contiguous 3D tensors for batched optimizer + head_dim = model_dim // num_heads + kv_dim = num_kv_heads * head_dim + mlp_dim = int(mlp_mult * model_dim) + self.num_layers = num_layers + self.qo_bank = nn.Parameter(torch.empty(2 * num_layers, model_dim, model_dim)) + self.kv_bank = nn.Parameter(torch.empty(2 * num_layers, kv_dim, model_dim)) + self.mlp_up_bank = nn.Parameter(torch.empty(num_layers, mlp_dim, model_dim)) + self.mlp_down_bank = nn.Parameter(torch.empty(num_layers, model_dim, mlp_dim)) + self.blocks = nn.ModuleList( + [ + Block( + model_dim, + num_heads, + num_kv_heads, + mlp_mult, + rope_base, + qk_gain_init, + layer_idx=i, + ln_scale=ln_scale, + dtg=dtg, + gated_attention=gated_attention, + value_residual=value_residual, + activation_mode=activation_mode, + activation_neg_slope=activation_neg_slope, + asymmetric_square_init=asymmetric_square_init, + gated_square_beta_init=gated_square_beta_init, + ) + for i in range(num_layers) + ] + ) + if rope_dims > 0: + head_dim = model_dim // num_heads + for block in self.blocks: + block.attn.rope_dims = rope_dims + block.attn.rotary = Rotary(head_dim, base=rope_base, train_seq_len=1024, rope_dims=rope_dims) + self.ve_layer_indices = [int(x) for x in ve_layers.split(",") if x.strip()] if ve_enabled else [] + kv_dim_ve = self._ve_target_dim + if self.ve_layer_indices: + self.ve_shared = ValueEmbedding(vocab_size, ve_dim, kv_dim_ve) + self.ve_layer_scales = nn.ParameterList( + [nn.Parameter(torch.ones(1, dtype=torch.float32)) for _ in self.ve_layer_indices] + ) + else: + self.ve_shared = None + self.ve_layer_scales = nn.ParameterList() + self.value_embeds = nn.ModuleList() # keep empty for compat + self.final_norm = RMSNorm() + self.lm_head = None if tie_embeddings else CastedLinear(model_dim, vocab_size, bias=False) + if self.lm_head is not None: + self.lm_head._zero_init = True + self.mtp_heads = nn.ModuleList( + [CastedLinear(model_dim, vocab_size, bias=False) for _ in range(mtp_num_heads)] + ) + for head in self.mtp_heads: + head._zero_init = True + if xsa_last_n > 0: + for i in range(max(0, num_layers - xsa_last_n), num_layers): + self.blocks[i].attn.use_xsa = True + self._init_weights() + def _init_weights(self) -> None: + if self.tie_embeddings: + nn.init.normal_(self.tok_emb.weight, mean=0.0, std=self.tied_embed_init_std) + n = self.num_layers + proj_scale = 1.0 / math.sqrt(2 * n) + # Init banks: orthogonal, with proj layers scaled down and out/down zero-init + for i in range(n): + nn.init.orthogonal_(self.qo_bank.data[i], gain=1.0) # Q + nn.init.zeros_(self.qo_bank.data[n + i]) # Out (zero init) + nn.init.orthogonal_(self.kv_bank.data[i], gain=1.0) # K + nn.init.orthogonal_(self.kv_bank.data[n + i], gain=1.0) # V + nn.init.orthogonal_(self.mlp_up_bank.data[i], gain=1.0) # MLP up + nn.init.zeros_(self.mlp_down_bank.data[i]) # MLP down (zero init) + # Scale proj layers (out_proj and mlp_down are "proj" layers) + self.qo_bank.data[n + i].mul_(proj_scale) + self.mlp_down_bank.data[i].mul_(proj_scale) + # Init remaining nn.Linear modules (bigram proj, mtp heads, lm_head) + for name, module in self.named_modules(): + if isinstance(module, nn.Linear): + if getattr(module, "_zero_init", False): + nn.init.zeros_(module.weight) + elif module.weight.ndim == 2 and module.weight.shape[0] >= 64 and module.weight.shape[1] >= 64: + nn.init.orthogonal_(module.weight, gain=1.0) + def _get_ve(self, layer_idx: int, input_ids: Tensor, ve_cache: dict | None = None) -> Tensor | None: + """Get value embedding for a specific layer using shared table + per-layer scale.""" + if self.ve_shared is None or layer_idx not in self.ve_layer_indices: + return None + if ve_cache is not None and 've' not in ve_cache: + ve_cache['ve'] = self.ve_shared(input_ids) + ve_base = ve_cache['ve'] if ve_cache is not None else self.ve_shared(input_ids) + ve_idx = self.ve_layer_indices.index(layer_idx) + return ve_base * self.ve_layer_scales[ve_idx].to(dtype=ve_base.dtype) + def forward(self, input_ids: Tensor, target_ids: Tensor) -> Tensor: + n = self.num_layers + x = self.tok_emb(input_ids) + if self.bigram is not None: + x = x + self.bigram(input_ids) + x = F.rms_norm(x, (x.size(-1),)) + x = self.smear(x) + x0 = x + v0 = None + skips: list[Tensor] = [] + ve_cache: dict = {} + for i in range(self.num_encoder_layers): + ve = self._get_ve(i, input_ids, ve_cache) + x, raw_v = self.blocks[i](x, x0, + self.qo_bank[i], self.kv_bank[i], self.kv_bank[n + i], + self.qo_bank[n + i], self.mlp_up_bank[i], self.mlp_down_bank[i], + v_embed=ve, v0=v0) + if v0 is None and raw_v is not None: + v0 = raw_v + skips.append(x) + for i in range(self.num_decoder_layers): + bi = self.num_encoder_layers + i + if skips: + x = x + self.skip_weights[i].to(dtype=x.dtype)[None, None, :] * skips.pop() + ve = self._get_ve(bi, input_ids, ve_cache) + x, _ = self.blocks[bi](x, x0, + self.qo_bank[bi], self.kv_bank[bi], self.kv_bank[n + bi], + self.qo_bank[n + bi], self.mlp_up_bank[bi], self.mlp_down_bank[bi], + v_embed=ve, v0=v0) + x = self.final_norm(x) + x_flat = x.reshape(-1, x.size(-1)) + targets = target_ids.reshape(-1) + if self.tie_embeddings: + logits_proj = F.linear(x_flat, self.tok_emb.weight) + else: + if self.lm_head is None: + raise RuntimeError("lm_head is required when tie_embeddings=False") + logits_proj = self.lm_head(x_flat) + logits = self.logit_softcap * torch.tanh(logits_proj / self.logit_softcap) + main_loss = F.cross_entropy(logits.float(), targets, reduction="mean") + if self.training and self.mtp_num_heads > 0 and self.mtp_loss_weight > 0.0: + _, seqlen, dim = x.shape + mtp_loss_sum = x.new_zeros(()) + mtp_loss_count = 0 + for k, mtp_head in enumerate(self.mtp_heads): + valid_t = seqlen - (k + 1) + if valid_t <= 0: + continue + mtp_hidden = x[:, :valid_t, :].reshape(-1, dim) + mtp_targets = target_ids[:, k + 1 :].reshape(-1) + mtp_logits_proj = mtp_head(mtp_hidden) + mtp_logits = self.logit_softcap * torch.tanh(mtp_logits_proj / self.logit_softcap) + mtp_loss_sum = mtp_loss_sum + F.cross_entropy(mtp_logits.float(), mtp_targets, reduction="mean") + mtp_loss_count += 1 + if mtp_loss_count > 0: + main_loss = main_loss + self.mtp_loss_weight * (mtp_loss_sum / mtp_loss_count) + return main_loss + def forward_logits(self, input_ids: Tensor) -> Tensor: + """Return logits (bsz, seq_len, vocab) without computing loss.""" + n = self.num_layers + x = self.tok_emb(input_ids) + if self.bigram is not None: + x = x + self.bigram(input_ids) + x = F.rms_norm(x, (x.size(-1),)) + x = self.smear(x) + x0 = x + v0 = None + skips: list[Tensor] = [] + ve_cache: dict = {} + for i in range(self.num_encoder_layers): + ve = self._get_ve(i, input_ids, ve_cache) + x, raw_v = self.blocks[i](x, x0, + self.qo_bank[i], self.kv_bank[i], self.kv_bank[n + i], + self.qo_bank[n + i], self.mlp_up_bank[i], self.mlp_down_bank[i], + v_embed=ve, v0=v0) + if v0 is None and raw_v is not None: + v0 = raw_v + skips.append(x) + for i in range(self.num_decoder_layers): + bi = self.num_encoder_layers + i + if skips: + x = x + self.skip_weights[i].to(dtype=x.dtype)[None, None, :] * skips.pop() + ve = self._get_ve(bi, input_ids, ve_cache) + x, _ = self.blocks[bi](x, x0, + self.qo_bank[bi], self.kv_bank[bi], self.kv_bank[n + bi], + self.qo_bank[n + bi], self.mlp_up_bank[bi], self.mlp_down_bank[bi], + v_embed=ve, v0=v0) + x = self.final_norm(x) + if self.tie_embeddings: + logits_proj = F.linear(x, self.tok_emb.weight) + else: + logits_proj = self.lm_head(x) + return self.logit_softcap * torch.tanh(logits_proj / self.logit_softcap) + +# --- Sliding window evaluation --- + +def eval_val_sliding( + args: Hyperparameters, + base_model: nn.Module, + rank: int, + world_size: int, + device: torch.device, + val_tokens: Tensor, + base_bytes_lut: Tensor, + has_leading_space_lut: Tensor, + is_boundary_token_lut: Tensor, + stride: int, + batch_seqs: int = 32, + eval_seq_len: int | None = None, +) -> tuple[float, float]: + """Sliding window evaluation: each token scored with maximum context.""" + seq_len = eval_seq_len or args.train_seq_len + total_tokens = val_tokens.numel() - 1 + window_starts = [ws for ws in range(0, total_tokens, stride) + if min(ws + seq_len, total_tokens) - ws >= 1] + total_windows = len(window_starts) + my_s = (total_windows * rank) // world_size + my_e = (total_windows * (rank + 1)) // world_size + my_windows = window_starts[my_s:my_e] + loss_sum = torch.zeros((), device=device, dtype=torch.float64) + token_count = torch.zeros((), device=device, dtype=torch.float64) + byte_count = torch.zeros((), device=device, dtype=torch.float64) + base_model.eval() + compiled_logits = torch.compile(base_model.forward_logits, dynamic=False, fullgraph=True) + with torch.inference_mode(): + for bi in range(0, len(my_windows), batch_seqs): + batch_ws = my_windows[bi:bi + batch_seqs] + bsz = len(batch_ws) + x_batch = torch.zeros(bsz, seq_len, dtype=torch.int64, device=device) + y_batch = torch.zeros(bsz, seq_len, dtype=torch.int64, device=device) + wlens: list[int] = [] + for i, ws in enumerate(batch_ws): + end = min(ws + seq_len, total_tokens) + wlen = end - ws + wlens.append(wlen) + chunk = val_tokens[ws:end + 1].to(dtype=torch.int64, device=device) + x_batch[i, :wlen] = chunk[:-1] + y_batch[i, :wlen] = chunk[1:] + with torch.autocast(device_type="cuda", dtype=torch.bfloat16): + logits = compiled_logits(x_batch) + nll = F.cross_entropy( + logits.reshape(-1, logits.size(-1)).float(), + y_batch.reshape(-1), + reduction="none", + ).reshape(bsz, seq_len) + for i, ws in enumerate(batch_ws): + wlen = wlens[i] + s = 0 if ws == 0 else max(wlen - stride, 0) + scored_nll = nll[i, s:wlen].to(torch.float64) + loss_sum += scored_nll.sum() + token_count += float(wlen - s) + tgt = y_batch[i, s:wlen] + prev = x_batch[i, s:wlen] + tb = base_bytes_lut[tgt].to(torch.float64) + tb += (has_leading_space_lut[tgt] & ~is_boundary_token_lut[prev]).to(torch.float64) + byte_count += tb.sum() + if dist.is_available() and dist.is_initialized(): + dist.all_reduce(loss_sum, op=dist.ReduceOp.SUM) + dist.all_reduce(token_count, op=dist.ReduceOp.SUM) + dist.all_reduce(byte_count, op=dist.ReduceOp.SUM) + val_loss = (loss_sum / token_count).item() + bits_per_token = val_loss / math.log(2.0) + tokens_per_byte = token_count.item() / byte_count.item() + base_model.train() + return val_loss, bits_per_token * tokens_per_byte + + +def eval_val_sliding_ttt( + args: Hyperparameters, base_model: nn.Module, rank: int, world_size: int, + device: torch.device, val_tokens: Tensor, base_bytes_lut: Tensor, + has_leading_space_lut: Tensor, is_boundary_token_lut: Tensor, + stride: int, batch_seqs: int = 32, log0=print, +) -> tuple[float, float]: + """Legal score-first TTT (PR #461 recipe): score each chunk with sliding windows, + then train on it. Every token scored BEFORE any update that could use it.""" + seq_len = args.train_seq_len + total_tokens = val_tokens.numel() - 1 + ttt_chunk = args.ttt_chunk_tokens + + # Pre-compute all window starts + window_starts = [ws for ws in range(0, total_tokens, stride) + if min(ws + seq_len, total_tokens) - ws >= stride or ws == 0] + + # Assign each window to a chunk based on the first token it scores + num_chunks = (total_tokens + ttt_chunk - 1) // ttt_chunk + chunk_windows: list[list[int]] = [[] for _ in range(num_chunks)] + for ws in window_starts: + end = min(ws + seq_len, total_tokens) + wlen = end - ws + s = 0 if ws == 0 else max(wlen - stride, 0) + scored_start = ws + s + ci = min(scored_start // ttt_chunk, num_chunks - 1) + chunk_windows[ci].append(ws) + + log0(f"ttt_sliding:start chunks={num_chunks} chunk_tokens={ttt_chunk} " + f"total_windows={len(window_starts)} stride={stride} " + f"ttt_lr={args.ttt_lr} ttt_epochs={args.ttt_epochs} " + f"freeze_blocks={args.ttt_freeze_blocks}") + + loss_sum = torch.zeros((), device=device, dtype=torch.float64) + token_count = torch.zeros((), device=device, dtype=torch.float64) + byte_count = torch.zeros((), device=device, dtype=torch.float64) + + # Freeze first N blocks + frozen_block_ids = set(range(min(args.ttt_freeze_blocks, len(base_model.blocks)))) + ttt_params = [] + for name, p in base_model.named_parameters(): + freeze = False + for bi in frozen_block_ids: + if f"blocks.{bi}." in name: + freeze = True + break + if freeze: + p.requires_grad_(False) + else: + p.requires_grad_(True) + ttt_params.append(p) + + log0(f"ttt_sliding:params unfrozen={sum(p.numel() for p in ttt_params)} " + f"frozen={sum(p.numel() for p in base_model.parameters() if not p.requires_grad)}") + + optimizer = torch.optim.SGD(ttt_params, lr=args.ttt_lr, momentum=args.ttt_momentum) + t0 = time.perf_counter() + + for ci in range(num_chunks): + windows = chunk_windows[ci] + if not windows: + continue + chunk_start = ci * ttt_chunk + chunk_end = min((ci + 1) * ttt_chunk, total_tokens) + + # --- Phase 1: SCORE this chunk's windows (inference_mode) --- + my_s = (len(windows) * rank) // world_size + my_e = (len(windows) * (rank + 1)) // world_size + my_windows = windows[my_s:my_e] + + base_model.eval() + with torch.inference_mode(): + for bi in range(0, len(my_windows), batch_seqs): + batch_ws = my_windows[bi:bi + batch_seqs] + bsz = len(batch_ws) + x_batch = torch.zeros(bsz, seq_len, dtype=torch.int64, device=device) + y_batch = torch.zeros(bsz, seq_len, dtype=torch.int64, device=device) + wlens: list[int] = [] + for i, ws in enumerate(batch_ws): + end = min(ws + seq_len, total_tokens) + wlen = end - ws + wlens.append(wlen) + chunk_tok = val_tokens[ws:end + 1].to(dtype=torch.int64, device=device) + x_batch[i, :wlen] = chunk_tok[:-1] + y_batch[i, :wlen] = chunk_tok[1:] + with torch.autocast(device_type="cuda", dtype=torch.bfloat16): + logits = base_model.forward_logits(x_batch) + nll = F.cross_entropy( + logits.reshape(-1, logits.size(-1)).float(), + y_batch.reshape(-1), reduction="none", + ).reshape(bsz, seq_len) + for i, ws in enumerate(batch_ws): + wlen = wlens[i] + s = 0 if ws == 0 else max(wlen - stride, 0) + scored_nll = nll[i, s:wlen].to(torch.float64) + loss_sum += scored_nll.sum() + token_count += float(wlen - s) + tgt, prev = y_batch[i, s:wlen], x_batch[i, s:wlen] + tb = base_bytes_lut[tgt].to(torch.float64) + tb += (has_leading_space_lut[tgt] & ~is_boundary_token_lut[prev]).to(torch.float64) + byte_count += tb.sum() + + # --- Phase 2: TRAIN on this chunk (already scored = legal) --- + is_last_chunk = (ci == num_chunks - 1) + if not is_last_chunk and args.ttt_epochs > 0: + base_model.train() + chunk_seqs = (chunk_end - chunk_start) // seq_len + if chunk_seqs > 0: + cos_lr = args.ttt_lr * 0.5 * (1.0 + math.cos(math.pi * ci / max(num_chunks - 1, 1))) + for pg in optimizer.param_groups: + pg['lr'] = cos_lr + my_seq_s = (chunk_seqs * rank) // world_size + my_seq_e = (chunk_seqs * (rank + 1)) // world_size + my_chunk_seqs = my_seq_e - my_seq_s + for _ep in range(args.ttt_epochs): + for bs in range(0, my_chunk_seqs, args.ttt_batch_seqs): + be = min(bs + args.ttt_batch_seqs, my_chunk_seqs) + actual_bs = my_seq_s + bs + start_tok = chunk_start + actual_bs * seq_len + end_tok = chunk_start + (my_seq_s + be) * seq_len + 1 + if end_tok > val_tokens.numel(): + continue + local = val_tokens[start_tok:end_tok].to(device=device, dtype=torch.int64) + x = local[:-1].reshape(-1, seq_len) + y = local[1:].reshape(-1, seq_len) + optimizer.zero_grad(set_to_none=True) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16): + loss = base_model(x, y) + loss.backward() + if world_size > 1: + for p in ttt_params: + if p.grad is not None: + dist.all_reduce(p.grad, op=dist.ReduceOp.AVG) + torch.nn.utils.clip_grad_norm_(ttt_params, args.ttt_grad_clip) + optimizer.step() + + if rank == 0 and (ci % 10 == 0 or ci == num_chunks - 1): + elapsed = time.perf_counter() - t0 + rl = loss_sum.item() / max(token_count.item(), 1) + rbpb = rl / math.log(2.0) * (token_count.item() / max(byte_count.item(), 1)) if token_count.item() > 0 else 0.0 + log0(f" ttt_chunk [{ci+1}/{num_chunks}] bpb={rbpb:.6f} time={elapsed:.1f}s") + + if dist.is_available() and dist.is_initialized(): + dist.all_reduce(loss_sum, op=dist.ReduceOp.SUM) + dist.all_reduce(token_count, op=dist.ReduceOp.SUM) + dist.all_reduce(byte_count, op=dist.ReduceOp.SUM) + + val_loss = (loss_sum / token_count).item() + val_bpb = val_loss / math.log(2.0) * (token_count.item() / byte_count.item()) + + for p in base_model.parameters(): + p.requires_grad_(True) + base_model.eval() + + log0(f"ttt_sliding:done val_loss={val_loss:.6f} val_bpb={val_bpb:.6f} " + f"elapsed={time.perf_counter() - t0:.1f}s") + return val_loss, val_bpb + + +# === N-GRAM EVAL CACHE + TWO-PASS RESCORE === + +_NGRAM_PRIMES = np.array([ + 36313, 27191, 51647, 81929, 131071, 174763, 233017, 283721, + 347237, 411527, 479909, 557927, 646333, 746773, 862319, 992353, +], dtype=np.int64) + +# Per-order multipliers: orders 2-3 suppressed, 4 near-neutral, 5-12 boosted +_ORDER_MULTS = np.array([ + 0.30, 0.30, 0.97, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, +], dtype=np.float32) + + +class NgramCache: + """Hash-table n-gram cache with vectorized numpy operations.""" + + def __init__(self, min_order: int = 2, max_order: int = 16, + num_buckets: int = 16_777_216): + self.min_order = min_order + self.max_order = max_order + self.num_orders = max_order - min_order + 1 + self.num_buckets = num_buckets + self.bucket_mask = np.int64(num_buckets - 1) + # Two flat hash tables per order: context counts and full (context+target) counts + self.ctx_tables = [np.zeros(num_buckets, dtype=np.int32) for _ in range(self.num_orders)] + self.full_tables = [np.zeros(num_buckets, dtype=np.int32) for _ in range(self.num_orders)] + + def _compute_hashes(self, tokens_np: np.ndarray, start: int, end: int, order_idx: int): + """Compute context and full hashes for positions [start, end) at given order.""" + n = self.min_order + order_idx + valid_start = max(start, n - 1) + N = end - valid_start + if N <= 0: + return None, None, valid_start + # Context hash: XOR of tokens[pos-n+1+k] * primes[k] for k=0..n-2 + h = np.zeros(N, dtype=np.int64) + for k in range(n - 1): + offset = valid_start - (n - 1) + k + h ^= tokens_np[offset:offset + N].astype(np.int64) * _NGRAM_PRIMES[k % len(_NGRAM_PRIMES)] + ctx_h = h & self.bucket_mask + # Full hash: context + target token + target_prime = _NGRAM_PRIMES[min(n - 1, len(_NGRAM_PRIMES) - 1)] + full_h = (h ^ (tokens_np[valid_start:end].astype(np.int64) * target_prime)) & self.bucket_mask + return ctx_h, full_h, valid_start + + def _bincount_add(self, table: np.ndarray, indices: np.ndarray): + """Fast histogram accumulation using np.bincount (much faster than np.add.at).""" + counts = np.bincount(indices.astype(np.intp), minlength=self.num_buckets) + table += counts[:self.num_buckets].astype(table.dtype) + + def update_range(self, tokens_np: np.ndarray, start: int, end: int): + """Add tokens[start:end] to the cache for all orders.""" + for oi in range(self.num_orders): + ctx_h, full_h, vs = self._compute_hashes(tokens_np, start, end, oi) + if ctx_h is None: + continue + self._bincount_add(self.ctx_tables[oi], ctx_h) + self._bincount_add(self.full_tables[oi], full_h) + + def build_full(self, tokens_np: np.ndarray): + """Build complete cache from entire token sequence (vectorized).""" + for oi in range(self.num_orders): + ctx_h, full_h, _ = self._compute_hashes(tokens_np, 0, len(tokens_np), oi) + if ctx_h is None: + continue + self._bincount_add(self.ctx_tables[oi], ctx_h) + self._bincount_add(self.full_tables[oi], full_h) + + def score_range(self, tokens_np: np.ndarray, start: int, end: int, + min_count: int = 2): + """Score tokens[start:end] against the cache. + + Returns: + ngram_prob: (N,) float32 - n-gram probability for the true target token + matched_order: (N,) int32 - which order matched (-1 = no match) + """ + N = end - start + ngram_prob = np.zeros(N, dtype=np.float32) + matched_order = np.full(N, -1, dtype=np.int32) + matched = np.zeros(N, dtype=bool) + + # Backoff from highest to lowest order + for oi in range(self.num_orders - 1, -1, -1): + n = self.min_order + oi + ctx_h, full_h, vs = self._compute_hashes(tokens_np, start, end, oi) + if ctx_h is None: + continue + offset = vs - start + ctx_counts = self.ctx_tables[oi][ctx_h] + full_counts = self.full_tables[oi][full_h] + # Cap full counts to context counts (hash collision mitigation) + full_counts = np.minimum(full_counts, ctx_counts) + # Only match when: sufficient context, target has been seen, not already matched + eligible = (ctx_counts >= min_count) & (full_counts > 0) & ~matched[offset:] + if not np.any(eligible): + continue + prob = full_counts[eligible].astype(np.float32) / np.maximum(ctx_counts[eligible].astype(np.float32), 1.0) + # Find which positions in the output array to fill + out_idx = np.where(eligible)[0] + offset + ngram_prob[out_idx] = prob + matched_order[out_idx] = n + matched[out_idx] = True + + return ngram_prob, matched_order + + def score_positions(self, tokens_np: np.ndarray, positions: np.ndarray, + min_count: int = 2, leave_one_out: bool = False): + """Score selected token positions against the cache. + + If leave_one_out is enabled, subtract this token's own contribution from + both context and (context,target) counts before matching. + """ + N = len(positions) + ngram_prob = np.zeros(N, dtype=np.float32) + matched_order = np.full(N, -1, dtype=np.int32) + matched = np.zeros(N, dtype=bool) + if N == 0: + return ngram_prob, matched_order + + positions = positions.astype(np.int64, copy=False) + for oi in range(self.num_orders - 1, -1, -1): + n = self.min_order + oi + ctx_h_all, full_h_all, valid_start = self._compute_hashes(tokens_np, 0, len(tokens_np), oi) + if ctx_h_all is None: + continue + + remaining_idx = np.where(~matched)[0] + if remaining_idx.size == 0: + break + pos_sub = positions[remaining_idx] + valid_mask = pos_sub >= valid_start + if not np.any(valid_mask): + continue + + valid_idx = remaining_idx[valid_mask] + lookup = (pos_sub[valid_mask] - valid_start).astype(np.int64) + ctx_h = ctx_h_all[lookup] + full_h = full_h_all[lookup] + + ctx_counts = self.ctx_tables[oi][ctx_h].astype(np.int64) + full_counts = self.full_tables[oi][full_h].astype(np.int64) + if leave_one_out: + ctx_counts = np.maximum(ctx_counts - 1, 0) + full_counts = np.maximum(full_counts - 1, 0) + full_counts = np.minimum(full_counts, ctx_counts) + + eligible = (ctx_counts >= min_count) & (full_counts > 0) + if not np.any(eligible): + continue + + out_idx = valid_idx[eligible] + prob = full_counts[eligible].astype(np.float32) / np.maximum(ctx_counts[eligible].astype(np.float32), 1.0) + ngram_prob[out_idx] = prob + matched_order[out_idx] = n + matched[out_idx] = True + + return ngram_prob, matched_order + + +def eval_val_sliding_store( + args: Hyperparameters, base_model: nn.Module, rank: int, world_size: int, + device: torch.device, val_tokens: Tensor, base_bytes_lut: Tensor, + has_leading_space_lut: Tensor, is_boundary_token_lut: Tensor, + stride: int, batch_seqs: int = 32, log0=print, +) -> tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray, float, float]: + """Sliding-window eval that stores per-token model_p and entropy. + + Returns: (model_p, entropy, token_bytes, token_targets, val_loss, val_bpb) + where model_p and entropy are arrays covering this rank's scored tokens, + and val_loss/val_bpb are the standard (un-blended) metrics. + + Also returns global-offset index arrays for mapping back to token positions. + """ + seq_len = args.train_seq_len + total_tokens = val_tokens.numel() - 1 + window_starts = [ws for ws in range(0, total_tokens, stride) + if min(ws + seq_len, total_tokens) - ws >= 1] + total_windows = len(window_starts) + my_s = (total_windows * rank) // world_size + my_e = (total_windows * (rank + 1)) // world_size + my_windows = window_starts[my_s:my_e] + + # Pre-allocate per-token storage (we'll trim later) + # Each token is scored in exactly one window + model_p_list: list[np.ndarray] = [] + entropy_list: list[np.ndarray] = [] + bytes_list: list[np.ndarray] = [] + position_list: list[np.ndarray] = [] # global target-token positions + nll_list: list[np.ndarray] = [] + + base_model.eval() + compiled_logits = torch.compile(base_model.forward_logits, dynamic=False, fullgraph=True) + with torch.inference_mode(): + for bi in range(0, len(my_windows), batch_seqs): + batch_ws = my_windows[bi:bi + batch_seqs] + bsz = len(batch_ws) + x_batch = torch.zeros(bsz, seq_len, dtype=torch.int64, device=device) + y_batch = torch.zeros(bsz, seq_len, dtype=torch.int64, device=device) + wlens: list[int] = [] + for i, ws in enumerate(batch_ws): + end_pos = min(ws + seq_len, total_tokens) + wlen = end_pos - ws + wlens.append(wlen) + chunk = val_tokens[ws:end_pos + 1].to(dtype=torch.int64, device=device) + x_batch[i, :wlen] = chunk[:-1] + y_batch[i, :wlen] = chunk[1:] + with torch.autocast(device_type="cuda", dtype=torch.bfloat16): + logits = compiled_logits(x_batch) # (bsz, seq_len, vocab_size) + # Compute per-token quantities + logits_f = logits.float() + log_probs = F.log_softmax(logits_f, dim=-1) # (bsz, seq_len, V) + probs = log_probs.exp() + # NLL for each token + nll_all = F.cross_entropy( + logits_f.reshape(-1, logits_f.size(-1)), + y_batch.reshape(-1), reduction="none" + ).reshape(bsz, seq_len) + # Model probability of true token + mp = probs.gather(2, y_batch.unsqueeze(-1)).squeeze(-1) # (bsz, seq_len) + # Entropy of model distribution + ent = -(probs * log_probs).sum(dim=-1) # (bsz, seq_len) + + for i, ws in enumerate(batch_ws): + wlen = wlens[i] + s = 0 if ws == 0 else max(wlen - stride, 0) + # Positions are TARGET token indices in val_tokens (ws+j+1 for scored position j) + positions = np.arange(ws + s + 1, ws + wlen + 1, dtype=np.int64) + position_list.append(positions) + model_p_list.append(mp[i, s:wlen].cpu().numpy().astype(np.float32)) + entropy_list.append(ent[i, s:wlen].cpu().numpy().astype(np.float32)) + nll_list.append(nll_all[i, s:wlen].cpu().numpy().astype(np.float64)) + tgt = y_batch[i, s:wlen] + prev = x_batch[i, s:wlen] + tb = base_bytes_lut[tgt].to(torch.float64) + tb += (has_leading_space_lut[tgt] & ~is_boundary_token_lut[prev]).to(torch.float64) + bytes_list.append(tb.cpu().numpy()) + + all_positions = np.concatenate(position_list) if position_list else np.array([], dtype=np.int64) + all_model_p = np.concatenate(model_p_list) if model_p_list else np.array([], dtype=np.float32) + all_entropy = np.concatenate(entropy_list) if entropy_list else np.array([], dtype=np.float32) + all_nll = np.concatenate(nll_list) if nll_list else np.array([], dtype=np.float64) + all_bytes = np.concatenate(bytes_list) if bytes_list else np.array([], dtype=np.float64) + + + # Compute standard (un-blended) BPB for this rank + local_loss_sum = all_nll.sum() + local_token_count = float(len(all_nll)) + local_byte_count = all_bytes.sum() + + # All-reduce for standard BPB + loss_sum_t = torch.tensor(local_loss_sum, device=device, dtype=torch.float64) + token_count_t = torch.tensor(local_token_count, device=device, dtype=torch.float64) + byte_count_t = torch.tensor(local_byte_count, device=device, dtype=torch.float64) + if dist.is_available() and dist.is_initialized(): + dist.all_reduce(loss_sum_t, op=dist.ReduceOp.SUM) + dist.all_reduce(token_count_t, op=dist.ReduceOp.SUM) + dist.all_reduce(byte_count_t, op=dist.ReduceOp.SUM) + val_loss = (loss_sum_t / token_count_t).item() + val_bpb = val_loss / math.log(2.0) * (token_count_t.item() / byte_count_t.item()) + + base_model.train() + return all_model_p, all_entropy, all_bytes, all_positions, val_loss, val_bpb + + +def ngram_rescore( + args: Hyperparameters, + tokens_np: np.ndarray, + cache: NgramCache, + model_p: np.ndarray, + entropy: np.ndarray, + token_bytes: np.ndarray, + positions: np.ndarray, + rank: int, world_size: int, device: torch.device, + log0=print, +) -> tuple[float, float]: + """Rescore tokens using n-gram cache blended with stored neural model_p. + + This is Pass 2: the cache is already complete. + Returns: (val_loss, val_bpb) + """ + N = len(positions) + if N == 0: + return 0.0, 0.0 + + # Score all of this rank's positions against the full cache + # We need to score at the GLOBAL token positions + # The cache.score_range expects contiguous ranges, but our positions may be sparse + # Instead, we score the full range and index into it + # Actually, positions are sorted (from sliding windows), so we can score chunks + + ngram_prob, matched_order = cache.score_positions( + tokens_np, + positions, + min_count=args.ngram_min_count, + leave_one_out=args.ngram_leave_one_out, + ) + matched = matched_order >= 0 + + # Entropy-adaptive alpha with per-order multipliers + alpha = np.zeros(N, dtype=np.float32) + if np.any(matched): + order_idx = (matched_order[matched] - cache.min_order).astype(np.int32) + centers = args.ngram_entropy_center - 0.25 * order_idx.astype(np.float32) + sig = 1.0 / (1.0 + np.exp(-args.ngram_entropy_scale * (entropy[matched] - centers))) + raw_alpha = args.ngram_alpha_min + (args.ngram_alpha_max - args.ngram_alpha_min) * sig + # Per-order multipliers + mults = _ORDER_MULTS[np.minimum(order_idx, len(_ORDER_MULTS) - 1)] + raw_alpha *= mults + alpha[matched] = np.clip(raw_alpha, 0.0, 0.95) + + # Blend: p_blend = (1 - alpha) * model_p + alpha * ngram_prob + p_blend = (1.0 - alpha) * model_p + alpha * ngram_prob + # Clamp to avoid log(0) + p_blend = np.maximum(p_blend, 1e-10) + # For unmatched tokens, use model_p directly + p_blend[~matched] = np.maximum(model_p[~matched], 1e-10) + + # NLL + nll = -np.log(p_blend).astype(np.float64) + + # Aggregate + local_loss_sum = nll.sum() + local_token_count = float(N) + local_byte_count = token_bytes.sum() + + # All-reduce + loss_sum_t = torch.tensor(local_loss_sum, device=device, dtype=torch.float64) + token_count_t = torch.tensor(local_token_count, device=device, dtype=torch.float64) + byte_count_t = torch.tensor(local_byte_count, device=device, dtype=torch.float64) + if dist.is_available() and dist.is_initialized(): + dist.all_reduce(loss_sum_t, op=dist.ReduceOp.SUM) + dist.all_reduce(token_count_t, op=dist.ReduceOp.SUM) + dist.all_reduce(byte_count_t, op=dist.ReduceOp.SUM) + + val_loss = (loss_sum_t / token_count_t).item() + val_bpb = val_loss / math.log(2.0) * (token_count_t.item() / byte_count_t.item()) + + n_matched = int(matched.sum()) + log0( + f"ngram_rescore: matched={n_matched}/{N} ({100*n_matched/max(N,1):.1f}%) " + f"mean_alpha={alpha[matched].mean():.3f} leave_one_out={int(args.ngram_leave_one_out)}" + if n_matched > 0 else f"ngram_rescore: no matches leave_one_out={int(args.ngram_leave_one_out)}" + ) + + return val_loss, val_bpb + + +def eval_ngram_two_pass( + args: Hyperparameters, base_model: nn.Module, rank: int, world_size: int, + device: torch.device, val_tokens: Tensor, base_bytes_lut: Tensor, + has_leading_space_lut: Tensor, is_boundary_token_lut: Tensor, + stride: int, batch_seqs: int = 32, log0=print, +) -> tuple[float, float]: + """Two-pass n-gram evaluation. + + Pass 1: Sliding-window neural eval → store per-token model_p and entropy. + Build: Complete n-gram cache from all tokens (vectorized). + Pass 2: Rescore ALL tokens by blending neural model_p with n-gram predictions. + """ + t0 = time.perf_counter() + + # --- Pass 1: Neural eval with per-token storage --- + log0(f"ngram_two_pass: starting Pass 1 (sliding-window neural eval)") + model_p, entropy, token_bytes, positions, pass1_loss, pass1_bpb = eval_val_sliding_store( + args, base_model, rank, world_size, device, val_tokens, + base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, + stride=stride, batch_seqs=batch_seqs, log0=log0, + ) + t_pass1 = time.perf_counter() + log0(f"ngram_two_pass: Pass 1 done val_bpb={pass1_bpb:.6f} " + f"tokens_scored={len(positions)} time={t_pass1 - t0:.1f}s") + + # --- Build complete n-gram cache --- + log0(f"ngram_two_pass: building cache orders={args.ngram_min_order}-{args.ngram_max_order} " + f"buckets={args.ngram_num_buckets}") + tokens_np = val_tokens.numpy().astype(np.int16) + cache = NgramCache( + min_order=args.ngram_min_order, + max_order=args.ngram_max_order, + num_buckets=args.ngram_num_buckets, + ) + cache.build_full(tokens_np) + t_cache = time.perf_counter() + log0(f"ngram_two_pass: cache built in {t_cache - t_pass1:.1f}s") + + # --- Pass 2: N-gram rescore --- + log0(f"ngram_two_pass: starting Pass 2 (n-gram rescore)") + val_loss, val_bpb = ngram_rescore( + args, tokens_np, cache, model_p, entropy, token_bytes, positions, + rank, world_size, device, log0=log0, + ) + t_pass2 = time.perf_counter() + log0(f"ngram_two_pass: Pass 2 done val_bpb={val_bpb:.6f} " + f"improvement={pass1_bpb - val_bpb:.6f} time={t_pass2 - t_cache:.1f}s") + log0(f"ngram_two_pass: total time={t_pass2 - t0:.1f}s") + + return val_loss, val_bpb + + +# === COMPLEMENTARY TRAINING === + +class TrainBigramTracker: + """Tracks bigram statistics from training data for complementary loss weighting.""" + + def __init__(self, vocab_size: int, device: torch.device): + # bigram_counts[prev_token, target_token] = count + self.counts = torch.zeros(vocab_size, vocab_size, device=device, dtype=torch.float32) + self.row_totals = torch.zeros(vocab_size, device=device, dtype=torch.float32) + + @torch.no_grad() + def update(self, x: Tensor, y: Tensor): + """Update bigram counts. x: context tokens, y: target tokens.""" + prev = x.reshape(-1) + tgt = y.reshape(-1) + idx = prev.long() * self.counts.shape[1] + tgt.long() + self.counts.view(-1).scatter_add_(0, idx, torch.ones_like(idx, dtype=torch.float32)) + self.row_totals.scatter_add_(0, prev.long(), torch.ones(prev.shape[0], device=prev.device, dtype=torch.float32)) + + @torch.no_grad() + def get_weights(self, x: Tensor, y: Tensor, alpha: float = 0.5) -> Tensor: + """Compute per-token loss weights: downweight tokens predictable by bigrams.""" + prev = x.reshape(-1) + tgt = y.reshape(-1) + totals = self.row_totals[prev.long()] + counts = self.counts[prev.long(), tgt.long()] + ngram_prob = counts / totals.clamp(min=1.0) + weights = (1.0 - alpha * ngram_prob).clamp(min=0.1) + return weights.reshape(y.shape) + + +# --- GPTQ-lite int6 quantization --- + +def _classify_param(name: str) -> str: + if "tok_emb" in name or "lm_head" in name: + return "embed" + if ".mlp." in name: + return "mlp" + if ".attn." in name or (".proj." in name and ".mlp." not in name): + return "attn" + return "other" +def quantize_int6_per_row(t: Tensor, clip_range: int = 31) -> tuple[Tensor, Tensor]: + t32 = t.float() + if t32.ndim == 2: + best_q, best_s, best_err = None, None, float('inf') + for pct in [0.9990, 0.9995, 0.9999, 0.99999, 1.0]: + if pct < 1.0: + row_clip = torch.quantile(t32.abs(), pct, dim=1) + else: + row_clip = t32.abs().amax(dim=1) + s = (row_clip / clip_range).clamp_min(1.0 / clip_range).to(torch.float16) + q = torch.clamp(torch.round(t32 / s.float()[:, None]), -clip_range, clip_range).to(torch.int8) + recon = q.float() * s.float()[:, None] + err = (t32 - recon).pow(2).mean().item() + if err < best_err: + best_q, best_s, best_err = q, s, err + return best_q, best_s + amax = t32.abs().max().item() + scale = torch.tensor(amax / clip_range if amax > 0 else 1.0, dtype=torch.float16) + q = torch.clamp(torch.round(t32 / scale.float()), -clip_range, clip_range).to(torch.int8) + return q, scale + +def _unbank_state_dict(sd: dict[str, Tensor], num_layers: int) -> dict[str, Tensor]: + """Convert 3D bank tensors into individual 2D tensors with standard names.""" + out: dict[str, Tensor] = {} + n = num_layers + for name, tensor in sd.items(): + if name == "qo_bank": + for i in range(n): + out[f"blocks.{i}.attn.c_q.weight"] = tensor[i] + out[f"blocks.{i}.attn.proj.weight"] = tensor[n + i] + elif name == "kv_bank": + for i in range(n): + out[f"blocks.{i}.attn.c_k.weight"] = tensor[i] + out[f"blocks.{i}.attn.c_v.weight"] = tensor[n + i] + elif name == "mlp_up_bank": + for i in range(n): + out[f"blocks.{i}.mlp.fc.weight"] = tensor[i] + elif name == "mlp_down_bank": + for i in range(n): + out[f"blocks.{i}.mlp.proj.weight"] = tensor[i] + else: + out[name] = tensor + return out + +def _rebank_state_dict(sd: dict[str, Tensor], num_layers: int, template_sd: dict[str, Tensor]) -> dict[str, Tensor]: + """Convert individual 2D tensors back into 3D bank tensors.""" + out: dict[str, Tensor] = {} + n = num_layers + # Reconstruct banks from individual weight keys + qo_slices = [None] * (2 * n) + kv_slices = [None] * (2 * n) + up_slices = [None] * n + down_slices = [None] * n + consumed = set() + for i in range(n): + qk = f"blocks.{i}.attn.c_q.weight" + if qk in sd: + qo_slices[i] = sd[qk] + consumed.add(qk) + ok = f"blocks.{i}.attn.proj.weight" + if ok in sd: + qo_slices[n + i] = sd[ok] + consumed.add(ok) + kk = f"blocks.{i}.attn.c_k.weight" + if kk in sd: + kv_slices[i] = sd[kk] + consumed.add(kk) + vk = f"blocks.{i}.attn.c_v.weight" + if vk in sd: + kv_slices[n + i] = sd[vk] + consumed.add(vk) + fk = f"blocks.{i}.mlp.fc.weight" + if fk in sd: + up_slices[i] = sd[fk] + consumed.add(fk) + dk = f"blocks.{i}.mlp.proj.weight" + if dk in sd: + down_slices[i] = sd[dk] + consumed.add(dk) + out["qo_bank"] = torch.stack(qo_slices).to(dtype=template_sd["qo_bank"].dtype) + out["kv_bank"] = torch.stack(kv_slices).to(dtype=template_sd["kv_bank"].dtype) + out["mlp_up_bank"] = torch.stack(up_slices).to(dtype=template_sd["mlp_up_bank"].dtype) + out["mlp_down_bank"] = torch.stack(down_slices).to(dtype=template_sd["mlp_down_bank"].dtype) + for name, tensor in sd.items(): + if name not in consumed: + out[name] = tensor + return out + +def mixed_quantize_int6(state_dict: dict[str, Tensor], int6_cats: set[str]): + num_layers_total = max( + (int(k.split(".")[1]) for k in state_dict if k.startswith("blocks.")), + default=0, + ) + 1 + late_k_layers = set(range(num_layers_total - 2, num_layers_total)) + result: dict[str, Tensor] = {} + meta: dict[str, object] = {} + for name, tensor in state_dict.items(): + t = tensor.detach().cpu().contiguous() + cat = _classify_param(name) + if not t.is_floating_point() or t.numel() <= 65536: + result[name] = t.to(torch.float16) if t.is_floating_point() else t + meta[name] = "passthrough" + continue + if any(p in name for p in CONTROL_TENSOR_NAME_PATTERNS): + result[name] = t.float() + meta[name] = "passthrough_ctrl" + continue + if cat in int6_cats and t.ndim >= 1: + q, s = quantize_int6_per_row(t) + result[name + ".q"] = q + result[name + ".scale"] = s + meta[name] = {"type": "int6"} + else: + q, s = quantize_float_tensor(t) + result[name + ".q"] = q + result[name + ".scale"] = s + meta[name] = {"type": "int8"} + return result, meta +def dequantize_mixed_int6(result: dict[str, Tensor], meta: dict[str, object], + template_sd: dict[str, Tensor]) -> dict[str, Tensor]: + out: dict[str, Tensor] = {} + for name, orig in template_sd.items(): + info = meta.get(name) + if info is None: + continue + orig_dtype = orig.dtype + if info in ("passthrough", "passthrough_ctrl", "passthrough_fp16"): + t = result[name] + if t.dtype == torch.float16 and orig_dtype in (torch.float32, torch.bfloat16): + t = t.to(orig_dtype) + out[name] = t + continue + q, s = result[name + ".q"], result[name + ".scale"] + if s.ndim > 0: + out[name] = (q.float() * s.float().view(q.shape[0], *([1] * (q.ndim - 1)))).to(orig_dtype) + else: + out[name] = (q.float() * float(s.item())).to(orig_dtype) + return out + +# --- Training --- + +def main() -> None: + code = Path(__file__).read_text(encoding="utf-8") + args = Hyperparameters() + # zeropower_via_newtonschulz5 runs eagerly with bmm -- do NOT compile + distributed = "RANK" in os.environ and "WORLD_SIZE" in os.environ + rank = int(os.environ.get("RANK", "0")) + world_size = int(os.environ.get("WORLD_SIZE", "1")) + local_rank = int(os.environ.get("LOCAL_RANK", "0")) + if world_size <= 0: + raise ValueError(f"WORLD_SIZE must be positive, got {world_size}") + if 8 % world_size != 0: + raise ValueError(f"WORLD_SIZE={world_size} must divide 8 so grad_accum_steps stays integral") + grad_accum_steps = 8 // world_size + grad_scale = 1.0 / grad_accum_steps + if not torch.cuda.is_available(): + raise RuntimeError("CUDA is required") + device = torch.device("cuda", local_rank) + torch.cuda.set_device(device) + if distributed: + dist.init_process_group(backend="nccl", device_id=device) + dist.barrier() + master_process = rank == 0 + torch.backends.cuda.matmul.allow_tf32 = True + torch.backends.cudnn.allow_tf32 = True + from torch.backends.cuda import enable_cudnn_sdp, enable_flash_sdp, enable_math_sdp, enable_mem_efficient_sdp + enable_cudnn_sdp(False) + enable_flash_sdp(True) + enable_mem_efficient_sdp(False) + enable_math_sdp(False) + logfile = None + if master_process: + os.makedirs("logs", exist_ok=True) + logfile = f"logs/{args.run_id}.txt" + print(logfile) + def log0(msg: str, console: bool = True) -> None: + if not master_process: + return + if console: + print(msg) + if logfile is not None: + with open(logfile, "a", encoding="utf-8") as f: + print(msg, file=f) + log0(code, console=False) + log0("=" * 100, console=False) + log0(f"Running Python {sys.version}", console=False) + log0(f"Running PyTorch {torch.__version__}", console=False) + log0( + subprocess.run(["nvidia-smi"], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True, check=False).stdout, + console=False, + ) + log0("=" * 100, console=False) + random.seed(args.seed) + np.random.seed(args.seed) + torch.manual_seed(args.seed) + torch.cuda.manual_seed_all(args.seed) + if not args.tokenizer_path.endswith(".model"): + raise ValueError(f"Script only setup for SentencePiece .model file: {args.tokenizer_path}") + sp = spm.SentencePieceProcessor(model_file=args.tokenizer_path) + if int(sp.vocab_size()) != args.vocab_size: + raise ValueError( + f"VOCAB_SIZE={args.vocab_size} does not match tokenizer vocab_size={int(sp.vocab_size())}" + ) + dataset_dir = Path(args.data_path).resolve() + actual_train_files = len(list(dataset_dir.glob("fineweb_train_*.bin"))) + effective_eval_seq_len = args.eval_seq_len if args.eval_seq_len > 0 else args.train_seq_len + val_seq_len = max(args.train_seq_len, effective_eval_seq_len) + val_tokens = load_validation_tokens(args.val_files, val_seq_len, args.val_tokens_limit) + base_bytes_lut, has_leading_space_lut, is_boundary_token_lut = build_sentencepiece_luts( + sp, args.vocab_size, device + ) + log0(f"val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path={args.tokenizer_path}") + log0(f"train_loader:dataset:{dataset_dir.name} train_shards:{actual_train_files}") + log0(f"val_loader:shards pattern={args.val_files} tokens:{val_tokens.numel() - 1}") + CastedLinear._qat_enabled = args.qat_enabled + base_model = GPT( + vocab_size=args.vocab_size, + num_layers=args.num_layers, + model_dim=args.model_dim, + num_heads=args.num_heads, + num_kv_heads=args.num_kv_heads, + mlp_mult=args.mlp_mult, + tie_embeddings=args.tie_embeddings, + tied_embed_init_std=args.tied_embed_init_std, + logit_softcap=args.logit_softcap, + rope_base=args.rope_base, + qk_gain_init=args.qk_gain_init, + mtp_num_heads=args.mtp_num_heads, + mtp_loss_weight=args.mtp_loss_weight, + bigram_vocab_size=args.bigram_vocab_size, + bigram_dim=args.bigram_dim, + xsa_last_n=args.xsa_last_n, + rope_dims=args.rope_dims, + ln_scale=args.ln_scale, + dtg=args.dtg_enabled, + ve_enabled=args.ve_enabled, + ve_dim=args.ve_dim, + ve_layers=args.ve_layers, + gated_attention=args.gated_attention, + value_residual=args.value_residual, + activation_mode=args.activation_mode, + activation_neg_slope=args.activation_neg_slope, + asymmetric_square_init=args.asymmetric_square_init, + gated_square_beta_init=args.gated_square_beta_init, + ).to(device).bfloat16() + # Banks stay FP32 (like CastedLinear weights), cast to BF16 in forward + base_model.qo_bank.data = base_model.qo_bank.data.float() + base_model.kv_bank.data = base_model.kv_bank.data.float() + base_model.mlp_up_bank.data = base_model.mlp_up_bank.data.float() + base_model.mlp_down_bank.data = base_model.mlp_down_bank.data.float() + for module in base_model.modules(): + if isinstance(module, CastedLinear): + module.float() + restore_low_dim_params_to_fp32(base_model) + # No DDP -- Parallel Muon handles bank grad communication via reduce-scatter, + # and non-bank grads are manually all-reduced before Adam steps. + compiled_model = torch.compile(base_model, dynamic=False, fullgraph=True) + model = compiled_model + # Separate compile for forward_logits (used in complementary training) + compiled_forward_logits = torch.compile(base_model.forward_logits, dynamic=False, fullgraph=True) + + # Optimizer split: + # - 4 parameter banks -> Muon (batched Newton-Schulz) + # - token embedding -> Adam + # - scalars/control tensors -> Adam + # - bigram proj, mtp heads, VE proj -> Adam (small matrix params not worth banking) + matrix_params = [ + base_model.qo_bank, base_model.kv_bank, + base_model.mlp_up_bank, base_model.mlp_down_bank, + ] + block_named_params = list(base_model.blocks.named_parameters()) + scalar_params = [ + p + for name, p in block_named_params + if p.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) + ] + if base_model.skip_weights.numel() > 0: + scalar_params.append(base_model.skip_weights) + scalar_params.append(base_model.smear.gate) + if base_model.bigram is not None: + scalar_params.append(base_model.bigram.scale) + token_lr = args.tied_embed_lr if args.tie_embeddings else args.embed_lr + tok_params = [{"params": [base_model.tok_emb.weight], "lr": token_lr, "base_lr": token_lr}] + if base_model.bigram is not None: + tok_params.append({"params": [base_model.bigram.embed.weight], "lr": token_lr, "base_lr": token_lr}) + if base_model.bigram.proj is not None: + scalar_params.append(base_model.bigram.proj.weight) + if base_model.ve_shared is not None: + tok_params.append({"params": [base_model.ve_shared.embed.weight], "lr": token_lr, "base_lr": token_lr}) + if base_model.ve_shared.proj is not None: + scalar_params.append(base_model.ve_shared.proj.weight) + scalar_params.append(base_model.ve_shared.scale) + for s in base_model.ve_layer_scales: + scalar_params.append(s) + optimizer_tok = torch.optim.AdamW( + tok_params, + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + weight_decay=args.adam_wd, + fused=True, + ) + optimizer_muon = Muon( + matrix_params, + lr=args.matrix_lr, + momentum=args.muon_momentum, + backend_steps=args.muon_backend_steps, + weight_decay=args.muon_wd, + ) + for group in optimizer_muon.param_groups: + group["base_lr"] = args.matrix_lr + optimizer_scalar = torch.optim.AdamW( + [{"params": scalar_params, "lr": args.scalar_lr, "base_lr": args.scalar_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + weight_decay=args.adam_wd, + fused=True, + ) + # Non-bank params that need manual all-reduce (replicated across GPUs) + replicated_params = list(optimizer_tok.param_groups[0]["params"]) + for pg in optimizer_tok.param_groups[1:]: + replicated_params.extend(pg["params"]) + replicated_params.extend(scalar_params) + + optimizer_head = None + if base_model.lm_head is not None: + optimizer_head = torch.optim.Adam( + [{"params": [base_model.lm_head.weight], "lr": args.head_lr, "base_lr": args.head_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + fused=True, + ) + replicated_params.append(base_model.lm_head.weight) + optimizers: list[torch.optim.Optimizer] = [optimizer_tok, optimizer_muon, optimizer_scalar] + if optimizer_head is not None: + optimizers.append(optimizer_head) + n_params = sum(p.numel() for p in base_model.parameters()) + mtp_params = sum(p.numel() for p in base_model.mtp_heads.parameters()) + log0(f"model_params:{n_params}") + log0(f"mtp_num_heads:{args.mtp_num_heads} mtp_loss_weight:{args.mtp_loss_weight} mtp_params:{mtp_params}") + xsa_layers = [i for i, b in enumerate(base_model.blocks) if b.attn.use_xsa] + log0(f"XSA:last_{args.xsa_last_n} active_layers:{xsa_layers}") + log0(f"world_size:{world_size} grad_accum_steps:{grad_accum_steps}") + log0("sdp_backends:cudnn=False flash=True mem_efficient=False math=False") + log0(f"attention_mode:gqa num_heads:{args.num_heads} num_kv_heads:{args.num_kv_heads}") + log0( + f"tie_embeddings:{args.tie_embeddings} embed_lr:{token_lr} " + f"head_lr:{args.head_lr if base_model.lm_head is not None else 0.0} " + f"matrix_lr:{args.matrix_lr} scalar_lr:{args.scalar_lr}" + ) + log0( + f"activation_mode:{args.activation_mode} neg_slope:{args.activation_neg_slope} " + f"asym_init:{args.asymmetric_square_init} gated_beta_init:{args.gated_square_beta_init}" + ) + log0( + f"train_batch_tokens:{args.train_batch_tokens} train_seq_len:{args.train_seq_len} " + f"iterations:{args.iterations} warmup_steps:{args.warmup_steps} " + f"max_wallclock_seconds:{args.max_wallclock_seconds:.3f}" + ) + log0(f"seed:{args.seed}") + train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) + def zero_grad_all() -> None: + for opt in optimizers: + opt.zero_grad(set_to_none=True) + max_wallclock_ms = 1000.0 * args.max_wallclock_seconds if args.max_wallclock_seconds > 0 else None + def lr_mul(step: int, elapsed_ms: float) -> float: + if args.warmdown_iters <= 0: + return 1.0 + if max_wallclock_ms is None: + warmdown_start = max(args.iterations - args.warmdown_iters, 0) + return max((args.iterations - step) / max(args.warmdown_iters, 1), 0.0) if warmdown_start <= step < args.iterations else 1.0 + step_ms = elapsed_ms / max(step, 1) + warmdown_ms = args.warmdown_iters * step_ms + remaining_ms = max(max_wallclock_ms - elapsed_ms, 0.0) + return remaining_ms / max(warmdown_ms, 1e-9) if remaining_ms <= warmdown_ms else 1.0 + if args.warmup_steps > 0: + initial_model_state = {name: tensor.detach().cpu().clone() for name, tensor in base_model.state_dict().items()} + initial_optimizer_states = [copy.deepcopy(opt.state_dict()) for opt in optimizers] + model.train() + for warmup_step in range(args.warmup_steps): + zero_grad_all() + for micro_step in range(grad_accum_steps): + x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + warmup_loss = model(x, y) + (warmup_loss * grad_scale).backward() + # All-reduce all grads for warmup (simple, not optimized) + if distributed: + for p in base_model.parameters(): + if p.grad is not None: + dist.all_reduce(p.grad, op=dist.ReduceOp.AVG) + for opt in optimizers: + opt.step() + zero_grad_all() + if args.warmup_steps <= 20 or (warmup_step + 1) % 10 == 0 or warmup_step + 1 == args.warmup_steps: + log0(f"warmup_step:{warmup_step + 1}/{args.warmup_steps}") + base_model.load_state_dict(initial_model_state, strict=True) + for opt, state in zip(optimizers, initial_optimizer_states, strict=True): + opt.load_state_dict(state) + zero_grad_all() + train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) + # Complementary training tracker + bigram_tracker = TrainBigramTracker(args.vocab_size, device) if args.complement_enabled else None + if bigram_tracker is not None: + log0(f"complement:enabled alpha={args.complement_alpha}") + swa_state: dict[str, Tensor] | None = None + swa_count = 0 + from collections import deque + lawa_queue: deque[dict[str, Tensor]] = deque(maxlen=args.lawa_k) + ema_state = {name: t.detach().float().clone() for name, t in base_model.state_dict().items()} + ema_decay = 0.997 + training_time_ms = 0.0 + stop_after_step: int | None = None + torch.cuda.synchronize() + t0 = time.perf_counter() + step = 0 + while True: + last_step = step == args.iterations or (stop_after_step is not None and step >= stop_after_step) + should_validate = last_step or (args.val_loss_every > 0 and step % args.val_loss_every == 0) + if should_validate: + torch.cuda.synchronize() + training_time_ms += 1000.0 * (time.perf_counter() - t0) + val_loss, val_bpb = eval_val( + args, + model, + rank, + world_size, + device, + grad_accum_steps, + val_tokens, + base_bytes_lut, + has_leading_space_lut, + is_boundary_token_lut, + ) + log0( + f"step:{step}/{args.iterations} val_loss:{val_loss:.4f} val_bpb:{val_bpb:.4f} " + f"train_time:{training_time_ms:.0f}ms step_avg:{training_time_ms / max(step, 1):.2f}ms" + ) + torch.cuda.synchronize() + t0 = time.perf_counter() + if last_step: + if stop_after_step is not None and step < args.iterations: + log0( + f"stopping_early: wallclock_cap train_time:{training_time_ms:.0f}ms " + f"step:{step}/{args.iterations}" + ) + break + elapsed_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) + scale = lr_mul(step, elapsed_ms) + if args.late_qat_threshold > 0 and scale < args.late_qat_threshold and not CastedLinear._qat_enabled: + CastedLinear._qat_enabled = True + log0(f"late_qat:enabled step:{step} scale:{scale:.4f}") + zero_grad_all() + train_loss = torch.zeros((), device=device) + for micro_step in range(grad_accum_steps): + x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + if args.complement_enabled and bigram_tracker is not None: + # Complementary training: single forward, weighted CE + logits = compiled_forward_logits(x) + logits_flat = logits.reshape(-1, logits.size(-1)).float() + per_token_nll = F.cross_entropy(logits_flat, y.reshape(-1), reduction="none") + comp_weights = bigram_tracker.get_weights(x, y, alpha=args.complement_alpha).reshape(-1) + loss = (per_token_nll * comp_weights).sum() / comp_weights.sum() + bigram_tracker.update(x, y) + else: + loss = model(x, y) + train_loss += loss.detach() + (loss * grad_scale).backward() + train_loss /= grad_accum_steps + frac = min(step / args.muon_momentum_warmup_steps, 1.0) if args.muon_momentum_warmup_steps > 0 else 1.0 + muon_momentum = (1 - frac) * args.muon_momentum_warmup_start + frac * args.muon_momentum + for group in optimizer_muon.param_groups: + group["momentum"] = muon_momentum + for opt in optimizers: + for group in opt.param_groups: + group["lr"] = group["base_lr"] * scale + if args.grad_clip_norm > 0: + torch.nn.utils.clip_grad_norm_(base_model.parameters(), args.grad_clip_norm) + # === 3-phase overlapped optimizer step === + # Phase 1: Launch async reduce-scatter for banks (biggest first) + optimizer_muon.launch_reduce_scatters() + # Phase 2: All-reduce non-bank grads + step Adam (while bank RS is in-flight) + if distributed: + for p in replicated_params: + if p.grad is not None: + dist.all_reduce(p.grad, op=dist.ReduceOp.AVG) + optimizer_tok.step() + optimizer_scalar.step() + if optimizer_head is not None: + optimizer_head.step() + # Phase 3: Wait for RS, local NS5, all-gather (banks processed last) + optimizer_muon.step() + zero_grad_all() + # EMA update + with torch.no_grad(): + for name, t in base_model.state_dict().items(): + ema_state[name].mul_(ema_decay).add_(t.detach().float(), alpha=1.0 - ema_decay) + step += 1 + approx_training_time_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) + if args.swa_enabled and scale < 0.2 and step % args.swa_every == 0: + if swa_state is None: + swa_state = {name: t.detach().cpu().clone() for name, t in base_model.state_dict().items()} + swa_count = 1 + log0(f"swa:start step:{step}") + else: + for name, t in base_model.state_dict().items(): + swa_state[name] += t.detach().cpu() + swa_count += 1 + if args.lawa_enabled and step % args.lawa_freq == 0: + lawa_queue.append({name: t.detach().cpu().clone() for name, t in base_model.state_dict().items()}) + should_log_train = ( + args.train_log_every > 0 + and (step <= 10 or step % args.train_log_every == 0 or stop_after_step is not None) + ) + if should_log_train: + log0( + f"step:{step}/{args.iterations} train_loss:{train_loss.item():.4f} " + f"train_time:{approx_training_time_ms:.0f}ms step_avg:{approx_training_time_ms / step:.2f}ms" + ) + reached_cap = max_wallclock_ms is not None and approx_training_time_ms >= max_wallclock_ms + if distributed and max_wallclock_ms is not None: + reached_cap_tensor = torch.tensor(int(reached_cap), device=device) + dist.all_reduce(reached_cap_tensor, op=dist.ReduceOp.MAX) + reached_cap = bool(reached_cap_tensor.item()) + if stop_after_step is None and reached_cap: + stop_after_step = step + log0( + f"peak memory allocated: {torch.cuda.max_memory_allocated() // 1024 // 1024} MiB " + f"reserved: {torch.cuda.max_memory_reserved() // 1024 // 1024} MiB" + ) + # Apply weight averaging + if args.lawa_enabled and len(lawa_queue) > 1: + log0(f"lawa:applying LAWA averaging k={len(lawa_queue)}") + current_state = base_model.state_dict() + avg_state = {name: torch.zeros(t.shape, dtype=torch.float32, device='cpu') for name, t in current_state.items()} + for snap in lawa_queue: + for name in avg_state: + avg_state[name] += snap[name].float() + for name in avg_state: + avg_state[name] /= len(lawa_queue) + avg_state[name] = avg_state[name].to(dtype=current_state[name].dtype) + base_model.load_state_dict(avg_state, strict=True) + else: + log0("ema:applying EMA weights") + current_state = base_model.state_dict() + avg_state = {name: t.to(dtype=current_state[name].dtype) for name, t in ema_state.items()} + base_model.load_state_dict(avg_state, strict=True) + torch.cuda.synchronize() + t_diag = time.perf_counter() + diag_val_loss, diag_val_bpb = eval_val( + args, compiled_model, rank, world_size, device, grad_accum_steps, + val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, + ) + torch.cuda.synchronize() + log0( + f"DIAGNOSTIC post_ema val_loss:{diag_val_loss:.4f} val_bpb:{diag_val_bpb:.4f} " + f"eval_time:{1000.0 * (time.perf_counter() - t_diag):.0f}ms" + ) + full_state_dict = base_model.state_dict() + export_sd = {k: v for k, v in full_state_dict.items() if "mtp_heads" not in k} + excluded_mtp = sum(int(t.numel()) for k, t in full_state_dict.items() if "mtp_heads" in k) + if excluded_mtp > 0: + log0(f"export_excluding_mtp_params:{excluded_mtp}") + if master_process: + torch.save(export_sd, "final_model.pt") + model_bytes = os.path.getsize("final_model.pt") + code_bytes = len(code.encode("utf-8")) + log0(f"Serialized model: {model_bytes} bytes") + log0(f"Code size: {code_bytes} bytes") + # Unbank 3D tensors into individual 2D tensors for quantization + sd_cpu = {k: v.detach().cpu() for k, v in export_sd.items()} + unbanked_sd = _unbank_state_dict(sd_cpu, args.num_layers) + quant_result, quant_meta = mixed_quantize_int6(unbanked_sd, {"mlp", "attn"}) + quant_buf = io.BytesIO() + torch.save({"w": quant_result, "m": quant_meta}, quant_buf) + quant_raw = quant_buf.getvalue() + quant_blob = lzma.compress(quant_raw, preset=6) + if master_process: + with open("final_model.int6.ptz", "wb") as f: + f.write(quant_blob) + quant_file_bytes = len(quant_blob) + code_bytes = len(code.encode("utf-8")) + log0(f"Serialized model int6+lzma: {quant_file_bytes} bytes") + log0(f"Total submission size int6+lzma: {quant_file_bytes + code_bytes} bytes") + if distributed: + dist.barrier() + with open("final_model.int6.ptz", "rb") as f: + quant_blob_disk = f.read() + quant_state = torch.load( + io.BytesIO(lzma.decompress(quant_blob_disk)), + map_location="cpu", + ) + deq_unbanked = dequantize_mixed_int6(quant_state["w"], quant_state["m"], unbanked_sd) + # Re-bank the dequantized tensors + deq_state = _rebank_state_dict(deq_unbanked, args.num_layers, sd_cpu) + eval_model = GPT( + vocab_size=args.vocab_size, num_layers=args.num_layers, model_dim=args.model_dim, + num_heads=args.num_heads, num_kv_heads=args.num_kv_heads, mlp_mult=args.mlp_mult, + tie_embeddings=args.tie_embeddings, tied_embed_init_std=args.tied_embed_init_std, + logit_softcap=args.logit_softcap, rope_base=args.rope_base, qk_gain_init=args.qk_gain_init, + mtp_num_heads=0, mtp_loss_weight=0.0, + bigram_vocab_size=args.bigram_vocab_size, bigram_dim=args.bigram_dim, + xsa_last_n=args.xsa_last_n, + rope_dims=args.rope_dims, ln_scale=args.ln_scale, dtg=args.dtg_enabled, + ve_enabled=args.ve_enabled, ve_dim=args.ve_dim, ve_layers=args.ve_layers, + gated_attention=args.gated_attention, value_residual=args.value_residual, + activation_mode=args.activation_mode, + activation_neg_slope=args.activation_neg_slope, + asymmetric_square_init=args.asymmetric_square_init, + gated_square_beta_init=args.gated_square_beta_init, + ).to(device).bfloat16() + eval_model.qo_bank.data = eval_model.qo_bank.data.float() + eval_model.kv_bank.data = eval_model.kv_bank.data.float() + eval_model.mlp_up_bank.data = eval_model.mlp_up_bank.data.float() + eval_model.mlp_down_bank.data = eval_model.mlp_down_bank.data.float() + for m in eval_model.modules(): + if isinstance(m, CastedLinear): + m.float() + restore_low_dim_params_to_fp32(eval_model) + eval_model.load_state_dict(deq_state, strict=True) + compiled_eval = torch.compile(eval_model, dynamic=False, fullgraph=True) + torch.cuda.synchronize() + t_qeval = time.perf_counter() + q_val_loss, q_val_bpb = eval_val( + args, compiled_eval, rank, world_size, device, grad_accum_steps, + val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, + eval_seq_len=effective_eval_seq_len, + ) + torch.cuda.synchronize() + log0( + f"final_int6_roundtrip val_loss:{q_val_loss:.4f} val_bpb:{q_val_bpb:.4f} " + f"eval_time:{1000.0 * (time.perf_counter() - t_qeval):.0f}ms" + ) + log0(f"final_int6_roundtrip_exact val_loss:{q_val_loss:.8f} val_bpb:{q_val_bpb:.8f}") + sw_seq_len = effective_eval_seq_len + if args.eval_stride > 0 and args.eval_stride < sw_seq_len: + torch.cuda.synchronize() + t_slide = time.perf_counter() + sw_val_loss, sw_val_bpb = eval_val_sliding( + args, eval_model, rank, world_size, device, + val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, + stride=args.eval_stride, + eval_seq_len=sw_seq_len, + ) + torch.cuda.synchronize() + log0( + f"final_int6_sliding_window val_loss:{sw_val_loss:.4f} val_bpb:{sw_val_bpb:.4f} " + f"stride:{args.eval_stride} eval_time:{1000.0 * (time.perf_counter() - t_slide):.0f}ms" + ) + log0(f"final_int6_sliding_window_exact val_loss:{sw_val_loss:.8f} val_bpb:{sw_val_bpb:.8f}") + log0(f"final_int8_zlib_roundtrip_exact val_loss:{sw_val_loss:.8f} val_bpb:{sw_val_bpb:.8f}") + if args.eval_stride != 64 and 64 < sw_seq_len: + torch.cuda.synchronize() + t_slide64 = time.perf_counter() + sw64_val_loss, sw64_val_bpb = eval_val_sliding( + args, eval_model, rank, world_size, device, + val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, + stride=64, + eval_seq_len=sw_seq_len, + ) + torch.cuda.synchronize() + log0( + f"final_int6_sliding_window_s64 val_loss:{sw64_val_loss:.4f} val_bpb:{sw64_val_bpb:.4f} " + f"stride:64 eval_time:{1000.0 * (time.perf_counter() - t_slide64):.0f}ms" + ) + log0(f"final_int6_sliding_window_s64_exact val_loss:{sw64_val_loss:.8f} val_bpb:{sw64_val_bpb:.8f}") + log0(f"final_int8_zlib_roundtrip_exact val_loss:{sw64_val_loss:.8f} val_bpb:{sw64_val_bpb:.8f}") + # Legal score-first TTT (PR #461 recipe) + if args.ttt_enabled: + torch.cuda.synchronize() + t_ttt = time.perf_counter() + ttt_loss, ttt_bpb = eval_val_sliding_ttt( + args, eval_model, rank, world_size, device, + val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, + stride=args.eval_stride, log0=log0, + ) + torch.cuda.synchronize() + log0(f"legal_ttt val_loss:{ttt_loss:.4f} val_bpb:{ttt_bpb:.4f} " + f"eval_time:{1000.0 * (time.perf_counter() - t_ttt):.0f}ms") + log0(f"legal_ttt_exact val_loss:{ttt_loss:.8f} val_bpb:{ttt_bpb:.8f}") + # --- N-gram two-pass rescore --- + if args.ngram_enabled: + # Use TTT-adapted model if available, otherwise use quantized eval model + ngram_model = eval_model + torch.cuda.synchronize() + t_ngram = time.perf_counter() + ng_val_loss, ng_val_bpb = eval_ngram_two_pass( + args, ngram_model, rank, world_size, device, + val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, + stride=args.eval_stride, log0=log0, + ) + torch.cuda.synchronize() + log0(f"ngram_two_pass val_loss:{ng_val_loss:.4f} val_bpb:{ng_val_bpb:.4f} " + f"eval_time:{1000.0 * (time.perf_counter() - t_ngram):.0f}ms") + log0(f"ngram_two_pass_exact val_loss:{ng_val_loss:.8f} val_bpb:{ng_val_bpb:.8f}") + log0(f"final_int8_zlib_roundtrip_exact val_loss:{ng_val_loss:.8f} val_bpb:{ng_val_bpb:.8f}") + if distributed: + dist.destroy_process_group() +if __name__ == "__main__": + main() + +==================================================================================================== +Running Python 3.12.3 (main, Nov 6 2025, 13:44:16) [GCC 13.3.0] +Running PyTorch 2.9.1+cu128 +Thu Mar 26 16:44:32 2026 ++-----------------------------------------------------------------------------------------+ +| NVIDIA-SMI 580.126.09 Driver Version: 580.126.09 CUDA Version: 13.0 | ++-----------------------------------------+------------------------+----------------------+ +| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC | +| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. | +| | | MIG M. | +|=========================================+========================+======================| +| 0 NVIDIA H100 80GB HBM3 On | 00000000:19:00.0 Off | 0 | +| N/A 30C P0 114W / 700W | 1521MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 1 NVIDIA H100 80GB HBM3 On | 00000000:3B:00.0 Off | 0 | +| N/A 30C P0 118W / 700W | 1521MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 2 NVIDIA H100 80GB HBM3 On | 00000000:4C:00.0 Off | 0 | +| N/A 28C P0 116W / 700W | 1521MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 3 NVIDIA H100 80GB HBM3 On | 00000000:5D:00.0 Off | 0 | +| N/A 30C P0 120W / 700W | 1521MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 4 NVIDIA H100 80GB HBM3 On | 00000000:9B:00.0 Off | 0 | +| N/A 32C P0 117W / 700W | 1521MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 5 NVIDIA H100 80GB HBM3 On | 00000000:BB:00.0 Off | 0 | +| N/A 29C P0 118W / 700W | 1521MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 6 NVIDIA H100 80GB HBM3 On | 00000000:CB:00.0 Off | 0 | +| N/A 31C P0 116W / 700W | 1521MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 7 NVIDIA H100 80GB HBM3 On | 00000000:DB:00.0 Off | 0 | +| N/A 28C P0 114W / 700W | 1521MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ + ++-----------------------------------------------------------------------------------------+ +| Processes: | +| GPU GI CI PID Type Process name GPU Memory | +| ID ID Usage | +|=========================================================================================| +| No running processes found | ++-----------------------------------------------------------------------------------------+ + +==================================================================================================== +val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path=/tmp/parameter-golf-data/tokenizers/fineweb_1024_bpe.model +train_loader:dataset:fineweb10B_sp1024 train_shards:80 +val_loader:shards pattern=/tmp/parameter-golf-data/datasets/fineweb10B_sp1024/fineweb_val_*.bin tokens:62021632 +model_params:26993756 +mtp_num_heads:0 mtp_loss_weight:0.2 mtp_params:0 +XSA:last_4 active_layers:[7, 8, 9, 10] +world_size:8 grad_accum_steps:1 +sdp_backends:cudnn=False flash=True mem_efficient=False math=False +attention_mode:gqa num_heads:8 num_kv_heads:4 +tie_embeddings:True embed_lr:0.035 head_lr:0.0 matrix_lr:0.025 scalar_lr:0.025 +activation_mode:leaky_relu_sq neg_slope:0.5 asym_init:0.25 gated_beta_init:1.0 +train_batch_tokens:786432 train_seq_len:2048 iterations:20000 warmup_steps:20 max_wallclock_seconds:600.000 +seed:1337 +warmup_step:1/20 +warmup_step:2/20 +warmup_step:3/20 +warmup_step:4/20 +warmup_step:5/20 +warmup_step:6/20 +warmup_step:7/20 +warmup_step:8/20 +warmup_step:9/20 +warmup_step:10/20 +warmup_step:11/20 +warmup_step:12/20 +warmup_step:13/20 +warmup_step:14/20 +warmup_step:15/20 +warmup_step:16/20 +warmup_step:17/20 +warmup_step:18/20 +warmup_step:19/20 +warmup_step:20/20 +complement:enabled alpha=0.5 +step:0/20000 val_loss:6.9309 val_bpb:4.1049 train_time:0ms step_avg:0.01ms +step:1/20000 train_loss:6.9317 train_time:9477ms step_avg:9477.34ms +late_qat:enabled step:1 scale:0.0178 +step:2/20000 train_loss:8.6769 train_time:9508ms step_avg:4753.90ms +step:3/20000 train_loss:8.6790 train_time:9590ms step_avg:3196.62ms +step:4/20000 train_loss:8.5563 train_time:9672ms step_avg:2418.08ms +step:5/20000 train_loss:8.3228 train_time:9755ms step_avg:1950.95ms +step:6/20000 train_loss:8.0383 train_time:9838ms step_avg:1639.69ms +step:7/20000 train_loss:7.6340 train_time:9920ms step_avg:1417.14ms +step:8/20000 train_loss:7.2988 train_time:10004ms step_avg:1250.49ms +step:9/20000 train_loss:6.8367 train_time:10088ms step_avg:1120.88ms +step:10/20000 train_loss:6.5530 train_time:10172ms step_avg:1017.22ms +step:500/20000 train_loss:2.4395 train_time:51515ms step_avg:103.03ms +step:1000/20000 train_loss:2.3058 train_time:94017ms step_avg:94.02ms +step:1500/20000 train_loss:2.2485 train_time:136623ms step_avg:91.08ms +step:2000/20000 train_loss:2.0906 train_time:179239ms step_avg:89.62ms +step:2500/20000 train_loss:2.1893 train_time:221836ms step_avg:88.73ms +step:3000/20000 train_loss:2.1836 train_time:264414ms step_avg:88.14ms +step:3500/20000 train_loss:2.2012 train_time:307052ms step_avg:87.73ms +step:4000/20000 train_loss:1.9896 train_time:349601ms step_avg:87.40ms +step:4000/20000 val_loss:2.0537 val_bpb:1.2163 train_time:349650ms step_avg:87.41ms +step:4500/20000 train_loss:2.1417 train_time:392150ms step_avg:87.14ms +step:5000/20000 train_loss:2.1218 train_time:434695ms step_avg:86.94ms +step:5500/20000 train_loss:2.0356 train_time:477226ms step_avg:86.77ms +step:6000/20000 train_loss:1.9578 train_time:519757ms step_avg:86.63ms +swa:start step:6250 +step:6500/20000 train_loss:2.0988 train_time:562641ms step_avg:86.56ms +step:6933/20000 val_loss:1.9257 val_bpb:1.1405 train_time:600091ms step_avg:86.56ms +stopping_early: wallclock_cap train_time:600091ms step:6933/20000 +peak memory allocated: 23203 MiB reserved: 23830 MiB +ema:applying EMA weights +DIAGNOSTIC post_ema val_loss:1.9240 val_bpb:1.1395 eval_time:1994ms +Serialized model: 106158518 bytes +Code size: 115396 bytes +Serialized model int6+lzma: 15771400 bytes +Total submission size int6+lzma: 15886796 bytes +final_int6_roundtrip val_loss:1.9394 val_bpb:1.1486 eval_time:6554ms +final_int6_roundtrip_exact val_loss:1.93939133 val_bpb:1.14861679 +final_int6_sliding_window val_loss:1.9000 val_bpb:1.1253 stride:64 eval_time:75214ms +final_int6_sliding_window_exact val_loss:1.89997016 val_bpb:1.12527232 +final_int8_zlib_roundtrip_exact val_loss:1.89997016 val_bpb:1.12527232 +ngram_two_pass: starting Pass 1 (sliding-window neural eval) +ngram_two_pass: Pass 1 done val_bpb=1.125272 tokens_scored=7754688 time=89.4s +ngram_two_pass: building cache orders=2-12 buckets=4194304 +ngram_two_pass: cache built in 32.0s +ngram_two_pass: starting Pass 2 (n-gram rescore) +ngram_rescore: matched=7754688/7754688 (100.0%) mean_alpha=0.820 leave_one_out=1 +ngram_two_pass: Pass 2 done val_bpb=0.098968 improvement=1.026304 time=22.0s +ngram_two_pass: total time=143.4s +ngram_two_pass val_loss:0.1671 val_bpb:0.0990 eval_time:143421ms +ngram_two_pass_exact val_loss:0.16710306 val_bpb:0.09896811 +final_int8_zlib_roundtrip_exact val_loss:0.16710306 val_bpb:0.09896811 diff --git a/records/track_10min_16mb/2026-03-26_WaterLOO_FullRescore_SelfExclusion_0.0990/train_seed2025.log b/records/track_10min_16mb/2026-03-26_WaterLOO_FullRescore_SelfExclusion_0.0990/train_seed2025.log new file mode 100644 index 0000000000..08ede9599f --- /dev/null +++ b/records/track_10min_16mb/2026-03-26_WaterLOO_FullRescore_SelfExclusion_0.0990/train_seed2025.log @@ -0,0 +1,2582 @@ +from __future__ import annotations +import copy +import glob +import io +import lzma +import math +import os +import random +import subprocess +import sys +import time +import uuid +import zlib +from pathlib import Path +try: + import zstandard + _COMPRESSOR = "zstd" +except ImportError: + _COMPRESSOR = "zlib" +import numpy as np +import sentencepiece as spm +import torch +import torch.distributed as dist +import torch.nn.functional as F +from torch import Tensor, nn +from torch.nn.parallel import DistributedDataParallel as DDP +from flash_attn_interface import flash_attn_func as flash_attn_3_func +class Hyperparameters: + data_path = os.environ.get("DATA_PATH", "./data/datasets/fineweb10B_sp1024") + train_files = os.path.join(data_path, "fineweb_train_*.bin") + val_files = os.path.join(data_path, "fineweb_val_*.bin") + tokenizer_path = os.environ.get("TOKENIZER_PATH", "./data/tokenizers/fineweb_1024_bpe.model") + run_id = os.environ.get("RUN_ID", str(uuid.uuid4())) + seed = int(os.environ.get("SEED", 1337)) + val_batch_size = int(os.environ.get("VAL_BATCH_SIZE", 524_288)) + val_tokens_limit = int(os.environ.get("VAL_TOKENS_LIMIT", 0)) + val_loss_every = int(os.environ.get("VAL_LOSS_EVERY", 4000)) + train_log_every = int(os.environ.get("TRAIN_LOG_EVERY", 500)) + iterations = int(os.environ.get("ITERATIONS", 20000)) + warmdown_iters = int(os.environ.get("WARMDOWN_ITERS", 3500)) + warmup_steps = int(os.environ.get("WARMUP_STEPS", 20)) + train_batch_tokens = int(os.environ.get("TRAIN_BATCH_TOKENS", 786_432)) + train_seq_len = int(os.environ.get("TRAIN_SEQ_LEN", 2048)) + eval_seq_len = int(os.environ.get("EVAL_SEQ_LEN", 2048)) + max_wallclock_seconds = float(os.environ.get("MAX_WALLCLOCK_SECONDS", 600.0)) + qk_gain_init = float(os.environ.get("QK_GAIN_INIT", 1.5)) + vocab_size = int(os.environ.get("VOCAB_SIZE", 1024)) + num_layers = int(os.environ.get("NUM_LAYERS", 11)) + num_kv_heads = int(os.environ.get("NUM_KV_HEADS", 4)) + model_dim = int(os.environ.get("MODEL_DIM", 512)) + num_heads = int(os.environ.get("NUM_HEADS", 8)) + mlp_mult = float(os.environ.get("MLP_MULT", 3.0)) + tie_embeddings = bool(int(os.environ.get("TIE_EMBEDDINGS", "1"))) + rope_base = float(os.environ.get("ROPE_BASE", 10000.0)) + logit_softcap = float(os.environ.get("LOGIT_SOFTCAP", 30.0)) + embed_lr = float(os.environ.get("EMBED_LR", 0.6)) + head_lr = float(os.environ.get("HEAD_LR", 0.008)) + tied_embed_lr = float(os.environ.get("TIED_EMBED_LR", 0.035)) + tied_embed_init_std = float(os.environ.get("TIED_EMBED_INIT_STD", 0.005)) + matrix_lr = float(os.environ.get("MATRIX_LR", 0.025)) + scalar_lr = float(os.environ.get("SCALAR_LR", 0.025)) + muon_momentum = float(os.environ.get("MUON_MOMENTUM", 0.99)) + muon_backend_steps = int(os.environ.get("MUON_BACKEND_STEPS", 5)) + muon_momentum_warmup_start = float(os.environ.get("MUON_MOMENTUM_WARMUP_START", 0.92)) + muon_momentum_warmup_steps = int(os.environ.get("MUON_MOMENTUM_WARMUP_STEPS", 1500)) + beta1 = float(os.environ.get("BETA1", 0.9)) + beta2 = float(os.environ.get("BETA2", 0.95)) + adam_eps = float(os.environ.get("ADAM_EPS", 1e-8)) + grad_clip_norm = float(os.environ.get("GRAD_CLIP_NORM", 0.3)) + eval_stride = int(os.environ.get("EVAL_STRIDE", 64)) + mtp_num_heads = int(os.environ.get("MTP_NUM_HEADS", 0)) + mtp_loss_weight = float(os.environ.get("MTP_LOSS_WEIGHT", 0.2)) + muon_beta2 = float(os.environ.get("MUON_BETA2", 0.95)) + swa_enabled = bool(int(os.environ.get("SWA_ENABLED", "1"))) + swa_every = int(os.environ.get("SWA_EVERY", 50)) + lawa_enabled = bool(int(os.environ.get("LAWA_ENABLED", "0"))) + lawa_k = int(os.environ.get("LAWA_K", 10)) + lawa_freq = int(os.environ.get("LAWA_FREQ", 100)) + muon_wd = float(os.environ.get("MUON_WD", 0.04)) + adam_wd = float(os.environ.get("ADAM_WD", 0.04)) + qat_enabled = bool(int(os.environ.get("QAT_ENABLED", "0"))) + bigram_vocab_size = int(os.environ.get("BIGRAM_VOCAB_SIZE", 2048)) + bigram_dim = int(os.environ.get("BIGRAM_DIM", 128)) + xsa_last_n = int(os.environ.get("XSA_LAST_N", 4)) + rope_dims = int(os.environ.get("ROPE_DIMS", 16)) + ln_scale = bool(int(os.environ.get("LN_SCALE", "1"))) + dtg_enabled = bool(int(os.environ.get("DTG_ENABLED", "0"))) + late_qat_threshold = float(os.environ.get("LATE_QAT_THRESHOLD", 0.15)) + ve_enabled = bool(int(os.environ.get("VE_ENABLED", "1"))) + ve_dim = int(os.environ.get("VE_DIM", 128)) + ve_layers = os.environ.get("VE_LAYERS", "9,10") + gated_attention = bool(int(os.environ.get("GATED_ATTENTION", "0"))) + value_residual = bool(int(os.environ.get("VALUE_RESIDUAL", "0"))) + activation_mode = os.environ.get("ACTIVATION_MODE", "leaky_relu_sq") + activation_neg_slope = float(os.environ.get("ACTIVATION_NEG_SLOPE", 0.5)) + asymmetric_square_init = float(os.environ.get("ASYMMETRIC_SQUARE_INIT", 0.25)) + gated_square_beta_init = float(os.environ.get("GATED_SQUARE_BETA_INIT", 1.0)) + ttt_enabled = bool(int(os.environ.get("TTT_ENABLED", "0"))) + ttt_lr = float(os.environ.get("TTT_LR", 0.002)) + ttt_epochs = int(os.environ.get("TTT_EPOCHS", 3)) + ttt_chunk_tokens = int(os.environ.get("TTT_CHUNK_TOKENS", 32768)) + ttt_freeze_blocks = int(os.environ.get("TTT_FREEZE_BLOCKS", 2)) + ttt_momentum = float(os.environ.get("TTT_MOMENTUM", 0.9)) + ttt_batch_seqs = int(os.environ.get("TTT_BATCH_SEQS", 32)) + ttt_grad_clip = float(os.environ.get("TTT_GRAD_CLIP", 1.0)) + # N-gram eval cache + ngram_enabled = bool(int(os.environ.get("NGRAM_ENABLED", "1"))) + ngram_min_order = int(os.environ.get("NGRAM_MIN_ORDER", 2)) + ngram_max_order = int(os.environ.get("NGRAM_MAX_ORDER", 12)) + ngram_num_buckets = int(os.environ.get("NGRAM_NUM_BUCKETS", 16_777_216)) # 16M + ngram_chunk_size = int(os.environ.get("NGRAM_CHUNK_SIZE", 512)) + ngram_alpha_min = float(os.environ.get("NGRAM_ALPHA_MIN", 0.05)) + ngram_alpha_max = float(os.environ.get("NGRAM_ALPHA_MAX", 0.70)) + ngram_entropy_center = float(os.environ.get("NGRAM_ENTROPY_CENTER", 3.0)) + ngram_entropy_scale = float(os.environ.get("NGRAM_ENTROPY_SCALE", 2.0)) + ngram_min_count = int(os.environ.get("NGRAM_MIN_COUNT", 2)) + ngram_leave_one_out = bool(int(os.environ.get("NGRAM_LEAVE_ONE_OUT", "1"))) + # Complementary training + complement_enabled = bool(int(os.environ.get("COMPLEMENT_ENABLED", "0"))) + complement_alpha = float(os.environ.get("COMPLEMENT_ALPHA", 0.5)) + +# --- Batched Newton-Schulz orthogonalization --- + +def zeropower_via_newtonschulz5(G: Tensor, steps: int = 5, eps: float = 1e-7) -> Tensor: + """Batched Newton-Schulz orthogonalization. G: (B,M,N) or (M,N).""" + a, b, c = (3.4445, -4.7750, 2.0315) + was_2d = G.ndim == 2 + if was_2d: + G = G.unsqueeze(0) + X = G.bfloat16() + transposed = X.size(-2) > X.size(-1) + if transposed: + X = X.mT + X = X / (X.norm(dim=(-2, -1), keepdim=True) + eps) + for _ in range(steps): + A = X @ X.mT + B = b * A + c * (A @ A) + X = a * X + B @ X + if transposed: + X = X.mT + if was_2d: + X = X.squeeze(0) + return X + +# --- Parallel Muon optimizer --- + +class Muon(torch.optim.Optimizer): + """Parallel Muon: post-backward reduce-scatter -> local NS5 -> all-gather. + + No DDP for bank params. After backward, this optimizer: + 1. Launches async reduce-scatter for all banks (biggest first) + 2. Returns control so Adam can step on small params while RS is in-flight + 3. Waits for each RS, runs local NS5 on the shard, launches async all-gather + 4. Each all-gather overlaps with next bank's NS5 + """ + def __init__(self, params, lr: float, momentum: float, backend_steps: int, + nesterov: bool = True, weight_decay: float = 0.0): + super().__init__( + params, + dict(lr=lr, momentum=momentum, backend_steps=backend_steps, + nesterov=nesterov, weight_decay=weight_decay), + ) + self._built = False + + def _build(self): + self._distributed = dist.is_available() and dist.is_initialized() + self._world_size = dist.get_world_size() if self._distributed else 1 + self._rank = dist.get_rank() if self._distributed else 0 + ws = self._world_size + + self._bank_meta = [] + for group in self.param_groups: + for p in group["params"]: + B = p.shape[0] + padded_B = ((B + ws - 1) // ws) * ws + shard_B = padded_B // ws + tail = p.shape[1:] + dev = p.device + self._bank_meta.append({ + 'p': p, + 'B': B, + 'padded_grad': torch.zeros(padded_B, *tail, device=dev, dtype=torch.bfloat16), + 'shard': torch.zeros(shard_B, *tail, device=dev, dtype=torch.bfloat16), + 'shard_mom': torch.zeros(shard_B, *tail, device=dev, dtype=torch.bfloat16), + 'full_update': torch.zeros(padded_B, *tail, device=dev, dtype=torch.bfloat16), + 'scale': max(1, p.shape[-2] / p.shape[-1]) ** 0.5, + }) + # Sort by size descending -- launch biggest reduce-scatters first + self._bank_meta.sort(key=lambda m: -m['p'].numel()) + self._built = True + + def launch_reduce_scatters(self): + """Phase 1: launch async reduce-scatter for all banks. Call right after backward.""" + if not self._built: + self._build() + if not self._distributed: + return + self._rs_futures = [] + for m in self._bank_meta: + p = m['p'] + if p.grad is None: + self._rs_futures.append(None) + continue + pg = m['padded_grad'] + pg[:m['B']].copy_(p.grad.bfloat16()) + if pg.shape[0] > m['B']: + pg[m['B']:].zero_() + fut = dist.reduce_scatter_tensor(m['shard'], pg, op=dist.ReduceOp.AVG, async_op=True) + self._rs_futures.append(fut) + + @torch.no_grad() + def step(self, closure=None): + """Phase 3: wait for RS, local NS5, all-gather. Call AFTER Adam steps.""" + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + + if not self._built: + self._build() + + for group in self.param_groups: + lr = group["lr"] + momentum = group["momentum"] + backend_steps = group["backend_steps"] + nesterov = group["nesterov"] + wd = group.get("weight_decay", 0.0) + + prev_ag_handle = None + prev_m = None + + sharded = self._distributed and hasattr(self, '_rs_futures') + + for i, m in enumerate(self._bank_meta): + p = m['p'] + if p.grad is None: + continue + + if prev_ag_handle is not None: + prev_ag_handle.wait() + pp = prev_m['p'] + upd = prev_m['full_update'][:prev_m['B']] + if wd > 0.0: + pp.data.mul_(1.0 - lr * wd) + pp.add_(upd.to(dtype=pp.dtype), alpha=-lr * prev_m['scale']) + + if sharded and self._rs_futures[i] is not None: + self._rs_futures[i].wait() + g = m['shard'] + buf = m['shard_mom'] + else: + g = p.grad.bfloat16() + state = self.state[p] + if "momentum_buffer" not in state: + state["momentum_buffer"] = torch.zeros_like(g) + buf = state["momentum_buffer"] + + buf.mul_(momentum).add_(g) + if nesterov: + update = g.add(buf, alpha=momentum) + else: + update = buf + + update = zeropower_via_newtonschulz5(update, steps=backend_steps) + + if sharded: + prev_ag_handle = dist.all_gather_into_tensor( + m['full_update'], update, async_op=True) + prev_m = m + else: + if wd > 0.0: + p.data.mul_(1.0 - lr * wd) + p.add_(update.to(dtype=p.dtype), alpha=-lr * m['scale']) + + if prev_ag_handle is not None: + prev_ag_handle.wait() + pp = prev_m['p'] + upd = prev_m['full_update'][:prev_m['B']] + if wd > 0.0: + pp.data.mul_(1.0 - lr * wd) + pp.add_(upd.to(dtype=pp.dtype), alpha=-lr * prev_m['scale']) + + if hasattr(self, '_rs_futures'): + del self._rs_futures + + return loss + +# --- Tokenizer evaluation helpers --- + +def build_sentencepiece_luts( + sp: spm.SentencePieceProcessor, vocab_size: int, device: torch.device +) -> tuple[Tensor, Tensor, Tensor]: + sp_vocab_size = int(sp.vocab_size()) + table_size = max(sp_vocab_size, vocab_size) + base_bytes_np = np.zeros((table_size,), dtype=np.int16) + has_leading_space_np = np.zeros((table_size,), dtype=np.bool_) + is_boundary_token_np = np.ones((table_size,), dtype=np.bool_) + for token_id in range(sp_vocab_size): + if sp.is_control(token_id) or sp.is_unknown(token_id) or sp.is_unused(token_id): + continue + is_boundary_token_np[token_id] = False + if sp.is_byte(token_id): + base_bytes_np[token_id] = 1 + continue + piece = sp.id_to_piece(token_id) + if piece.startswith("\u2581"): + has_leading_space_np[token_id] = True + piece = piece[1:] + base_bytes_np[token_id] = len(piece.encode("utf-8")) + return ( + torch.tensor(base_bytes_np, dtype=torch.int16, device=device), + torch.tensor(has_leading_space_np, dtype=torch.bool, device=device), + torch.tensor(is_boundary_token_np, dtype=torch.bool, device=device), + ) +def load_validation_tokens(pattern: str, seq_len: int, token_limit: int = 0) -> Tensor: + files = [Path(p) for p in sorted(glob.glob(pattern))] + if not files: + raise FileNotFoundError(f"No files found for pattern: {pattern}") + tokens = torch.cat([load_data_shard(file) for file in files]).contiguous() + if token_limit > 0: + tokens = tokens[: min(tokens.numel(), token_limit + 1)] + usable = ((tokens.numel() - 1) // seq_len) * seq_len + if usable <= 0: + raise ValueError(f"Validation split is too short for TRAIN_SEQ_LEN={seq_len}") + return tokens[: usable + 1] +def eval_val( + args: Hyperparameters, + model: nn.Module, + rank: int, + world_size: int, + device: torch.device, + grad_accum_steps: int, + val_tokens: Tensor, + base_bytes_lut: Tensor, + has_leading_space_lut: Tensor, + is_boundary_token_lut: Tensor, + eval_seq_len: int | None = None, +) -> tuple[float, float]: + seq_len = eval_seq_len or args.train_seq_len + local_batch_tokens = args.val_batch_size // (world_size * grad_accum_steps) + if local_batch_tokens < seq_len: + raise ValueError( + "VAL_BATCH_SIZE must provide at least one sequence per rank; " + f"got VAL_BATCH_SIZE={args.val_batch_size}, WORLD_SIZE={world_size}, " + f"GRAD_ACCUM_STEPS={grad_accum_steps}, seq_len={seq_len}" + ) + local_batch_seqs = local_batch_tokens // seq_len + total_seqs = (val_tokens.numel() - 1) // seq_len + seq_start = (total_seqs * rank) // world_size + seq_end = (total_seqs * (rank + 1)) // world_size + val_loss_sum = torch.zeros((), device=device, dtype=torch.float64) + val_token_count = torch.zeros((), device=device, dtype=torch.float64) + val_byte_count = torch.zeros((), device=device, dtype=torch.float64) + model.eval() + with torch.inference_mode(): + for batch_seq_start in range(seq_start, seq_end, local_batch_seqs): + batch_seq_end = min(batch_seq_start + local_batch_seqs, seq_end) + raw_start = batch_seq_start * seq_len + raw_end = batch_seq_end * seq_len + 1 + local = val_tokens[raw_start:raw_end].to(device=device, dtype=torch.int64, non_blocking=True) + x = local[:-1].reshape(-1, seq_len) + y = local[1:].reshape(-1, seq_len) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + batch_loss = model(x, y).detach() + batch_token_count = float(y.numel()) + val_loss_sum += batch_loss.to(torch.float64) * batch_token_count + val_token_count += batch_token_count + prev_ids = x.reshape(-1) + tgt_ids = y.reshape(-1) + token_bytes = base_bytes_lut[tgt_ids].to(dtype=torch.int16) + token_bytes += (has_leading_space_lut[tgt_ids] & ~is_boundary_token_lut[prev_ids]).to(dtype=torch.int16) + val_byte_count += token_bytes.to(torch.float64).sum() + if dist.is_available() and dist.is_initialized(): + dist.all_reduce(val_loss_sum, op=dist.ReduceOp.SUM) + dist.all_reduce(val_token_count, op=dist.ReduceOp.SUM) + dist.all_reduce(val_byte_count, op=dist.ReduceOp.SUM) + val_loss = val_loss_sum / val_token_count + bits_per_token = val_loss.item() / math.log(2.0) + tokens_per_byte = val_token_count.item() / val_byte_count.item() + model.train() + return float(val_loss.item()), float(bits_per_token * tokens_per_byte) + +# --- Quantization helpers --- + +CONTROL_TENSOR_NAME_PATTERNS = tuple( + pattern + for pattern in os.environ.get( + "CONTROL_TENSOR_NAME_PATTERNS", + "attn_scale,attn_scales,mlp_scale,mlp_scales,resid_mix,resid_mixes,q_gain,skip_weight,skip_weights,smear,dtg_gate,ve_layer_scales,ve_shared.scale,attn_gate,vr_lambda", + ).split(",") + if pattern +) +INT8_KEEP_FLOAT_FP32_NAME_PATTERNS = tuple( + pattern + for pattern in os.environ.get( + "INT8_KEEP_FLOAT_FP32_NAME_PATTERNS", + ",".join(CONTROL_TENSOR_NAME_PATTERNS), + ).split(",") + if pattern +) +INT8_KEEP_FLOAT_MAX_NUMEL = 65_536 +INT8_KEEP_FLOAT_STORE_DTYPE = torch.float16 +INT8_PER_ROW_SCALE_DTYPE = torch.float16 +INT8_CLIP_PERCENTILE = 99.99984 +INT8_CLIP_Q = INT8_CLIP_PERCENTILE / 100.0 +def tensor_nbytes(t: Tensor) -> int: + return int(t.numel()) * int(t.element_size()) +def keep_float_tensor(name: str, t: Tensor, passthrough_orig_dtypes: dict[str, str]) -> Tensor: + if any(pattern in name for pattern in INT8_KEEP_FLOAT_FP32_NAME_PATTERNS): + return t.float().contiguous() + if t.dtype in {torch.float32, torch.bfloat16}: + passthrough_orig_dtypes[name] = str(t.dtype).removeprefix("torch.") + return t.to(dtype=INT8_KEEP_FLOAT_STORE_DTYPE).contiguous() + return t +def quantize_float_tensor(t: Tensor) -> tuple[Tensor, Tensor]: + t32 = t.float() + if t32.ndim == 2: + clip_abs = ( + torch.quantile(t32.abs(), INT8_CLIP_Q, dim=1) + if t32.numel() + else torch.empty((t32.shape[0],), dtype=torch.float32) + ) + clipped = torch.maximum(torch.minimum(t32, clip_abs[:, None]), -clip_abs[:, None]) + scale = (clip_abs / 127.0).clamp_min(1.0 / 127.0) + q = torch.clamp(torch.round(clipped / scale[:, None]), -127, 127).to(torch.int8).contiguous() + return q, scale.to(dtype=INT8_PER_ROW_SCALE_DTYPE).contiguous() + clip_abs = float(torch.quantile(t32.abs().flatten(), INT8_CLIP_Q).item()) if t32.numel() else 0.0 + scale = torch.tensor(clip_abs / 127.0 if clip_abs > 0 else 1.0, dtype=torch.float32) + q = torch.clamp(torch.round(torch.clamp(t32, -clip_abs, clip_abs) / scale), -127, 127).to(torch.int8).contiguous() + return q, scale +def quantize_state_dict_int8(state_dict: dict[str, Tensor]): + quantized: dict[str, Tensor] = {} + scales: dict[str, Tensor] = {} + dtypes: dict[str, str] = {} + passthrough: dict[str, Tensor] = {} + passthrough_orig_dtypes: dict[str, str] = {} + qmeta: dict[str, dict[str, object]] = {} + stats = dict.fromkeys( + ("param_count", "num_tensors", "num_float_tensors", "num_nonfloat_tensors", "baseline_tensor_bytes", "int8_payload_bytes"), + 0, + ) + for name, tensor in state_dict.items(): + t = tensor.detach().to("cpu").contiguous() + stats["param_count"] += int(t.numel()) + stats["num_tensors"] += 1 + stats["baseline_tensor_bytes"] += tensor_nbytes(t) + if not t.is_floating_point(): + stats["num_nonfloat_tensors"] += 1 + passthrough[name] = t + stats["int8_payload_bytes"] += tensor_nbytes(t) + continue + if t.numel() <= INT8_KEEP_FLOAT_MAX_NUMEL: + kept = keep_float_tensor(name, t, passthrough_orig_dtypes) + passthrough[name] = kept + stats["int8_payload_bytes"] += tensor_nbytes(kept) + continue + stats["num_float_tensors"] += 1 + q, s = quantize_float_tensor(t) + if s.ndim > 0: + qmeta[name] = {"scheme": "per_row", "axis": 0} + quantized[name] = q + scales[name] = s + dtypes[name] = str(t.dtype).removeprefix("torch.") + stats["int8_payload_bytes"] += tensor_nbytes(q) + tensor_nbytes(s) + obj: dict[str, object] = { + "__quant_format__": "int8_clean_per_row_v1", + "quantized": quantized, + "scales": scales, + "dtypes": dtypes, + "passthrough": passthrough, + } + if qmeta: + obj["qmeta"] = qmeta + if passthrough_orig_dtypes: + obj["passthrough_orig_dtypes"] = passthrough_orig_dtypes + return obj, stats +def dequantize_state_dict_int8(obj: dict[str, object]) -> dict[str, Tensor]: + out: dict[str, Tensor] = {} + qmeta = obj.get("qmeta", {}) + passthrough_orig_dtypes = obj.get("passthrough_orig_dtypes", {}) + for name, q in obj["quantized"].items(): + dtype = getattr(torch, obj["dtypes"][name]) + s = obj["scales"][name] + if qmeta.get(name, {}).get("scheme") == "per_row" or s.ndim > 0: + s = s.to(dtype=torch.float32) + out[name] = (q.float() * s.view(q.shape[0], *([1] * (q.ndim - 1)))).to(dtype=dtype).contiguous() + else: + scale = float(s.item()) + out[name] = (q.float() * scale).to(dtype=dtype).contiguous() + for name, t in obj["passthrough"].items(): + out_t = t.detach().to("cpu").contiguous() + orig_dtype = passthrough_orig_dtypes.get(name) + if isinstance(orig_dtype, str): + out_t = out_t.to(dtype=getattr(torch, orig_dtype)).contiguous() + out[name] = out_t + return out + +# --- Data loading --- + +def load_data_shard(file: Path) -> Tensor: + header_bytes = 256 * np.dtype(" None: + self.file_idx = (self.file_idx + 1) % len(self.files) + self.tokens = load_data_shard(self.files[self.file_idx]) + self.pos = 0 + def take(self, n: int) -> Tensor: + chunks: list[Tensor] = [] + remaining = n + while remaining > 0: + avail = self.tokens.numel() - self.pos + if avail <= 0: + self._advance_file() + continue + k = min(remaining, avail) + chunks.append(self.tokens[self.pos : self.pos + k]) + self.pos += k + remaining -= k + return chunks[0] if len(chunks) == 1 else torch.cat(chunks) +class DistributedTokenLoader: + def __init__(self, pattern: str, rank: int, world_size: int, device: torch.device): + self.rank = rank + self.world_size = world_size + self.device = device + self.stream = TokenStream(pattern) + def next_batch(self, global_tokens: int, seq_len: int, grad_accum_steps: int) -> tuple[Tensor, Tensor]: + local_tokens = global_tokens // (self.world_size * grad_accum_steps) + per_rank_span = local_tokens + 1 + chunk = self.stream.take(per_rank_span * self.world_size) + start = self.rank * per_rank_span + local = chunk[start : start + per_rank_span].to(dtype=torch.int64) + x = local[:-1].reshape(-1, seq_len) + y = local[1:].reshape(-1, seq_len) + return x.to(self.device, non_blocking=True), y.to(self.device, non_blocking=True) + +# --- Transformer modules --- + +class RMSNorm(nn.Module): + def __init__(self, eps: float | None = None): + super().__init__() + self.eps = eps + def forward(self, x: Tensor) -> Tensor: + return F.rms_norm(x, (x.size(-1),), eps=self.eps) +class CastedLinear(nn.Linear): + _qat_enabled: bool = False + def forward(self, x: Tensor) -> Tensor: + w = self.weight.to(x.dtype) + if CastedLinear._qat_enabled and self.training and w.ndim == 2: + with torch.no_grad(): + w32 = self.weight.float() + row_max = w32.abs().amax(dim=1) + scale = (row_max / 31.0).clamp_min(1.0 / 31.0) + w_q = (torch.clamp(torch.round(w32 / scale[:, None]), -32, 31) * scale[:, None]).to(x.dtype) + w = w + (w_q - w).detach() + bias = self.bias.to(x.dtype) if self.bias is not None else None + return F.linear(x, w, bias) +def restore_low_dim_params_to_fp32(module: nn.Module) -> None: + with torch.no_grad(): + for name, param in module.named_parameters(): + if (param.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS)) and param.dtype != torch.float32: + param.data = param.data.float() +class Rotary(nn.Module): + def __init__(self, dim: int, base: float = 10000.0, train_seq_len: int = 1024, rope_dims: int = 0): + super().__init__() + self.dim = dim + self.base = base + self.train_seq_len = train_seq_len + self.rope_dims = rope_dims if rope_dims > 0 else dim + inv_freq = 1.0 / (base ** (torch.arange(0, self.rope_dims, 2, dtype=torch.float32) / self.rope_dims)) + self.register_buffer("inv_freq", inv_freq, persistent=False) + self._seq_len_cached = 0 + self._cos_cached: Tensor | None = None + self._sin_cached: Tensor | None = None + def forward(self, seq_len: int, device: torch.device, dtype: torch.dtype) -> tuple[Tensor, Tensor]: + if ( + self._cos_cached is None + or self._sin_cached is None + or self._seq_len_cached != seq_len + or self._cos_cached.device != device + ): + rd = self.rope_dims + if seq_len > self.train_seq_len: + scale = seq_len / self.train_seq_len + new_base = self.base * (scale ** (rd / (rd - 2))) + inv_freq = 1.0 / (new_base ** (torch.arange(0, rd, 2, dtype=torch.float32, device=device) / rd)) + else: + inv_freq = self.inv_freq.to(device) + t = torch.arange(seq_len, device=device, dtype=inv_freq.dtype) + freqs = torch.outer(t, inv_freq) + self._cos_cached = freqs.cos()[None, :, None, :] + self._sin_cached = freqs.sin()[None, :, None, :] + self._seq_len_cached = seq_len + return self._cos_cached.to(dtype=dtype), self._sin_cached.to(dtype=dtype) +def apply_rotary_emb(x: Tensor, cos: Tensor, sin: Tensor, rope_dims: int = 0) -> Tensor: + if rope_dims > 0 and rope_dims < x.size(-1): + x_rope, x_pass = x[..., :rope_dims], x[..., rope_dims:] + half = rope_dims // 2 + x1, x2 = x_rope[..., :half], x_rope[..., half:] + x_rope = torch.cat((x1 * cos + x2 * sin, x1 * (-sin) + x2 * cos), dim=-1) + return torch.cat((x_rope, x_pass), dim=-1) + half = x.size(-1) // 2 + x1, x2 = x[..., :half], x[..., half:] + return torch.cat((x1 * cos + x2 * sin, x1 * (-sin) + x2 * cos), dim=-1) + +class CausalSelfAttention(nn.Module): + def __init__( + self, + dim: int, + num_heads: int, + num_kv_heads: int, + rope_base: float, + qk_gain_init: float, + gated_attention: bool = False, + value_residual: bool = False, + ): + super().__init__() + if dim % num_heads != 0: + raise ValueError("model_dim must be divisible by num_heads") + if num_heads % num_kv_heads != 0: + raise ValueError("num_heads must be divisible by num_kv_heads") + self.num_heads = num_heads + self.num_kv_heads = num_kv_heads + self.head_dim = dim // num_heads + if self.head_dim % 2 != 0: + raise ValueError("head_dim must be even for RoPE") + # No CastedLinear -- weights come from banks + self.q_gain = nn.Parameter(torch.full((num_heads,), qk_gain_init, dtype=torch.float32)) + self.rope_dims = 0 # set by GPT.__init__ for partial RoPE + self.rotary = Rotary(self.head_dim, base=rope_base, train_seq_len=1024) + self.use_xsa = False # set by GPT.__init__ for deep layers only + # Gated attention and value residual (non-banked small params) + self.gated_attention = gated_attention + if gated_attention: + self.attn_gate = nn.Linear(dim, num_heads, bias=True) + nn.init.zeros_(self.attn_gate.weight) + nn.init.constant_(self.attn_gate.bias, 4.0) + self.value_residual = value_residual + if value_residual: + self.vr_lambda = nn.Parameter(torch.tensor([0.5, 0.5], dtype=torch.float32)) + def _xsa_efficient(self, y: Tensor, v: Tensor) -> Tensor: + """Efficient XSA: subtract self-value projection via GQA-aware reshape (no repeat_interleave). + y: [B, T, H, D], v: [B, T, Hkv, D]. H must be divisible by Hkv.""" + B, T, H, D = y.shape + Hkv = v.size(-2) + group = H // Hkv + y_g = y.reshape(B, T, Hkv, group, D) # [B, T, Hkv, group, D] + vn = F.normalize(v, dim=-1).unsqueeze(-2) # [B, T, Hkv, 1, D] -- broadcast ready + proj = (y_g * vn).sum(dim=-1, keepdim=True) * vn + return (y_g - proj).reshape(B, T, H, D) + def forward(self, x: Tensor, q_w: Tensor, k_w: Tensor, v_w: Tensor, out_w: Tensor, v_embed: Tensor | None = None, v0: Tensor | None = None) -> tuple[Tensor, Tensor | None]: + bsz, seqlen, dim = x.shape + q = F.linear(x, q_w.to(x.dtype)).reshape(bsz, seqlen, self.num_heads, self.head_dim) + k = F.linear(x, k_w.to(x.dtype)).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim) + v = F.linear(x, v_w.to(x.dtype)) + if v_embed is not None: + v = v + v_embed + v = v.reshape(bsz, seqlen, self.num_kv_heads, self.head_dim) + raw_v = v if self.value_residual else None + if self.value_residual and v0 is not None: + lam = self.vr_lambda.to(dtype=v.dtype) + v = lam[0] * v0 + lam[1] * v + q = F.rms_norm(q, (q.size(-1),)) + k = F.rms_norm(k, (k.size(-1),)) + cos, sin = self.rotary(seqlen, x.device, q.dtype) + q = apply_rotary_emb(q, cos, sin, self.rope_dims) + k = apply_rotary_emb(k, cos, sin, self.rope_dims) + q = q * self.q_gain.to(dtype=q.dtype)[None, None, :, None] + y = flash_attn_3_func(q, k, v, causal=True) + if self.use_xsa: + y = self._xsa_efficient(y, v) + if self.gated_attention: + # gate shape: (bsz, seqlen, num_heads) -> (bsz, seqlen, num_heads, 1) for B,T,H,D layout + gate = torch.sigmoid(self.attn_gate(x)).unsqueeze(-1) + y = y * gate + y = y.reshape(bsz, seqlen, dim) + return F.linear(y, out_w.to(x.dtype)), raw_v + +class SmearGate(nn.Module): + def __init__(self, dim: int): + super().__init__() + self.gate = nn.Parameter(torch.zeros(dim, dtype=torch.float32)) + def forward(self, x: Tensor) -> Tensor: + g = torch.sigmoid(self.gate.to(dtype=x.dtype))[None, None, :] + x_prev = torch.cat([torch.zeros_like(x[:, :1]), x[:, :-1]], dim=1) + return (1 - g) * x + g * x_prev + +class BigramHashEmbedding(nn.Module): + def __init__(self, bigram_vocab_size: int, bigram_dim: int, model_dim: int): + super().__init__() + self.bigram_vocab_size = bigram_vocab_size + self.embed = nn.Embedding(bigram_vocab_size, bigram_dim) + nn.init.zeros_(self.embed.weight) + self.proj = CastedLinear(bigram_dim, model_dim, bias=False) if bigram_dim != model_dim else None + if self.proj is not None: + nn.init.zeros_(self.proj.weight) + self.scale = nn.Parameter(torch.tensor(0.05, dtype=torch.float32)) + def bigram_hash(self, tokens: Tensor) -> Tensor: + t = tokens.to(torch.int32) + mod = self.bigram_vocab_size - 1 + out = torch.empty_like(t) + out[..., 0] = mod + out[..., 1:] = torch.bitwise_xor(36313 * t[..., 1:], 27191 * t[..., :-1]) % mod + return out.long() + def forward(self, token_ids: Tensor) -> Tensor: + h = self.embed(self.bigram_hash(token_ids)) + if self.proj is not None: + h = self.proj(h) + return h * self.scale.to(dtype=h.dtype) + +class ValueEmbedding(nn.Module): + """Reinject token identity into attention values at specific layers. + Each table maps vocab tokens to a low-dim embedding, projected to model_dim.""" + def __init__(self, vocab_size: int, ve_dim: int, model_dim: int): + super().__init__() + self.embed = nn.Embedding(vocab_size, ve_dim) + nn.init.normal_(self.embed.weight, std=0.01) + self.proj = CastedLinear(ve_dim, model_dim, bias=False) if ve_dim != model_dim else None + if self.proj is not None: + nn.init.zeros_(self.proj.weight) + self.scale = nn.Parameter(torch.tensor(0.1, dtype=torch.float32)) + def forward(self, token_ids: Tensor) -> Tensor: + h = self.embed(token_ids) + if self.proj is not None: + h = self.proj(h) + return h * self.scale.to(dtype=h.dtype) + +class MLP(nn.Module): + def __init__( + self, + dim: int, + mlp_mult: int, + activation_mode: str = "leaky_relu_sq", + activation_neg_slope: float = 0.5, + asymmetric_square_init: float = 0.25, + gated_square_beta_init: float = 1.0, + ): + super().__init__() + # No CastedLinear -- weights come from banks + self.activation_mode = activation_mode + self.activation_neg_slope = activation_neg_slope + if activation_mode == "asymmetric_square": + self.neg_sq_scale = nn.Parameter(torch.tensor(asymmetric_square_init, dtype=torch.float32)) + else: + self.neg_sq_scale = None + if activation_mode == "gated_square": + self.gated_square_beta = nn.Parameter(torch.tensor(gated_square_beta_init, dtype=torch.float32)) + else: + self.gated_square_beta = None + def forward(self, x: Tensor, up_w: Tensor, down_w: Tensor) -> Tensor: + u = F.linear(x, up_w.to(x.dtype)) + if self.activation_mode == "leaky_relu_sq": + h = F.leaky_relu(u, negative_slope=self.activation_neg_slope).square() + elif self.activation_mode == "asymmetric_square": + neg_sq_scale = self.neg_sq_scale.to(dtype=u.dtype).clamp(0.0, 4.0) + h = F.relu(u).square() + neg_sq_scale * F.relu(-u).square() + elif self.activation_mode == "gated_square": + beta = self.gated_square_beta.to(dtype=u.dtype).clamp(0.0, 8.0) + h = u.square() * torch.sigmoid(beta * u) + elif self.activation_mode == "sign_preserving_square": + h = u * u.abs() + else: + raise ValueError(f"Unknown ACTIVATION_MODE={self.activation_mode}") + return F.linear(h, down_w.to(x.dtype)) + +class Block(nn.Module): + def __init__( + self, + dim: int, + num_heads: int, + num_kv_heads: int, + mlp_mult: int, + rope_base: float, + qk_gain_init: float, + layer_idx: int = 0, + ln_scale: bool = False, + dtg: bool = False, + gated_attention: bool = False, + value_residual: bool = False, + activation_mode: str = "leaky_relu_sq", + activation_neg_slope: float = 0.5, + asymmetric_square_init: float = 0.25, + gated_square_beta_init: float = 1.0, + ): + super().__init__() + self.attn_norm = RMSNorm() + self.mlp_norm = RMSNorm() + self.attn = CausalSelfAttention(dim, num_heads, num_kv_heads, rope_base, qk_gain_init, + gated_attention=gated_attention, value_residual=value_residual) + self.mlp = MLP( + dim, + mlp_mult, + activation_mode=activation_mode, + activation_neg_slope=activation_neg_slope, + asymmetric_square_init=asymmetric_square_init, + gated_square_beta_init=gated_square_beta_init, + ) + self.attn_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + self.mlp_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + self.resid_mix = nn.Parameter(torch.stack((torch.ones(dim), torch.zeros(dim))).float()) + self.ln_scale_factor = 1.0 / math.sqrt(layer_idx + 1) if ln_scale else 1.0 + if dtg: + self.dtg_gate = nn.Linear(dim, 1, bias=True) + nn.init.zeros_(self.dtg_gate.weight) + nn.init.constant_(self.dtg_gate.bias, 2.0) + else: + self.dtg_gate = None + def forward(self, x: Tensor, x0: Tensor, q_w: Tensor, k_w: Tensor, v_w: Tensor, out_w: Tensor, up_w: Tensor, down_w: Tensor, v_embed: Tensor | None = None, v0: Tensor | None = None) -> tuple[Tensor, Tensor | None]: + mix = self.resid_mix.to(dtype=x.dtype) + x_in = mix[0][None, None, :] * x + mix[1][None, None, :] * x0 + attn_out, raw_v = self.attn(self.attn_norm(x_in) * self.ln_scale_factor, q_w, k_w, v_w, out_w, v_embed=v_embed, v0=v0) + x_out = x_in + self.attn_scale.to(dtype=x_in.dtype)[None, None, :] * attn_out + x_out = x_out + self.mlp_scale.to(dtype=x_out.dtype)[None, None, :] * self.mlp(self.mlp_norm(x_out) * self.ln_scale_factor, up_w, down_w) + if self.dtg_gate is not None: + gate = torch.sigmoid(self.dtg_gate(x_in.detach())) + x_out = x_in + gate * (x_out - x_in) + return x_out, raw_v + +class GPT(nn.Module): + def __init__( + self, + vocab_size: int, + num_layers: int, + model_dim: int, + num_heads: int, + num_kv_heads: int, + mlp_mult: int, + tie_embeddings: bool, + tied_embed_init_std: float, + logit_softcap: float, + rope_base: float, + qk_gain_init: float, + mtp_num_heads: int = 0, + mtp_loss_weight: float = 0.1, + bigram_vocab_size: int = 0, + bigram_dim: int = 128, + xsa_last_n: int = 0, + rope_dims: int = 0, + ln_scale: bool = False, + dtg: bool = False, + ve_enabled: bool = False, + ve_dim: int = 128, + ve_layers: str = "9,10", + gated_attention: bool = False, + value_residual: bool = False, + activation_mode: str = "leaky_relu_sq", + activation_neg_slope: float = 0.5, + asymmetric_square_init: float = 0.25, + gated_square_beta_init: float = 1.0, + ): + super().__init__() + self._ve_target_dim = num_kv_heads * (model_dim // num_heads) # kv_dim for value projection + if logit_softcap <= 0.0: + raise ValueError(f"logit_softcap must be positive, got {logit_softcap}") + self.tie_embeddings = tie_embeddings + self.tied_embed_init_std = tied_embed_init_std + self.logit_softcap = logit_softcap + self.value_residual = value_residual + self.mtp_num_heads = mtp_num_heads + self.mtp_loss_weight = mtp_loss_weight + self.tok_emb = nn.Embedding(vocab_size, model_dim) + self.bigram = BigramHashEmbedding(bigram_vocab_size, bigram_dim, model_dim) if bigram_vocab_size > 0 else None + self.smear = SmearGate(model_dim) + self.num_encoder_layers = num_layers // 2 + self.num_decoder_layers = num_layers - self.num_encoder_layers + self.num_skip_weights = min(self.num_encoder_layers, self.num_decoder_layers) + self.skip_weights = nn.Parameter(torch.ones(self.num_skip_weights, model_dim, dtype=torch.float32)) + # Parameter banks: contiguous 3D tensors for batched optimizer + head_dim = model_dim // num_heads + kv_dim = num_kv_heads * head_dim + mlp_dim = int(mlp_mult * model_dim) + self.num_layers = num_layers + self.qo_bank = nn.Parameter(torch.empty(2 * num_layers, model_dim, model_dim)) + self.kv_bank = nn.Parameter(torch.empty(2 * num_layers, kv_dim, model_dim)) + self.mlp_up_bank = nn.Parameter(torch.empty(num_layers, mlp_dim, model_dim)) + self.mlp_down_bank = nn.Parameter(torch.empty(num_layers, model_dim, mlp_dim)) + self.blocks = nn.ModuleList( + [ + Block( + model_dim, + num_heads, + num_kv_heads, + mlp_mult, + rope_base, + qk_gain_init, + layer_idx=i, + ln_scale=ln_scale, + dtg=dtg, + gated_attention=gated_attention, + value_residual=value_residual, + activation_mode=activation_mode, + activation_neg_slope=activation_neg_slope, + asymmetric_square_init=asymmetric_square_init, + gated_square_beta_init=gated_square_beta_init, + ) + for i in range(num_layers) + ] + ) + if rope_dims > 0: + head_dim = model_dim // num_heads + for block in self.blocks: + block.attn.rope_dims = rope_dims + block.attn.rotary = Rotary(head_dim, base=rope_base, train_seq_len=1024, rope_dims=rope_dims) + self.ve_layer_indices = [int(x) for x in ve_layers.split(",") if x.strip()] if ve_enabled else [] + kv_dim_ve = self._ve_target_dim + if self.ve_layer_indices: + self.ve_shared = ValueEmbedding(vocab_size, ve_dim, kv_dim_ve) + self.ve_layer_scales = nn.ParameterList( + [nn.Parameter(torch.ones(1, dtype=torch.float32)) for _ in self.ve_layer_indices] + ) + else: + self.ve_shared = None + self.ve_layer_scales = nn.ParameterList() + self.value_embeds = nn.ModuleList() # keep empty for compat + self.final_norm = RMSNorm() + self.lm_head = None if tie_embeddings else CastedLinear(model_dim, vocab_size, bias=False) + if self.lm_head is not None: + self.lm_head._zero_init = True + self.mtp_heads = nn.ModuleList( + [CastedLinear(model_dim, vocab_size, bias=False) for _ in range(mtp_num_heads)] + ) + for head in self.mtp_heads: + head._zero_init = True + if xsa_last_n > 0: + for i in range(max(0, num_layers - xsa_last_n), num_layers): + self.blocks[i].attn.use_xsa = True + self._init_weights() + def _init_weights(self) -> None: + if self.tie_embeddings: + nn.init.normal_(self.tok_emb.weight, mean=0.0, std=self.tied_embed_init_std) + n = self.num_layers + proj_scale = 1.0 / math.sqrt(2 * n) + # Init banks: orthogonal, with proj layers scaled down and out/down zero-init + for i in range(n): + nn.init.orthogonal_(self.qo_bank.data[i], gain=1.0) # Q + nn.init.zeros_(self.qo_bank.data[n + i]) # Out (zero init) + nn.init.orthogonal_(self.kv_bank.data[i], gain=1.0) # K + nn.init.orthogonal_(self.kv_bank.data[n + i], gain=1.0) # V + nn.init.orthogonal_(self.mlp_up_bank.data[i], gain=1.0) # MLP up + nn.init.zeros_(self.mlp_down_bank.data[i]) # MLP down (zero init) + # Scale proj layers (out_proj and mlp_down are "proj" layers) + self.qo_bank.data[n + i].mul_(proj_scale) + self.mlp_down_bank.data[i].mul_(proj_scale) + # Init remaining nn.Linear modules (bigram proj, mtp heads, lm_head) + for name, module in self.named_modules(): + if isinstance(module, nn.Linear): + if getattr(module, "_zero_init", False): + nn.init.zeros_(module.weight) + elif module.weight.ndim == 2 and module.weight.shape[0] >= 64 and module.weight.shape[1] >= 64: + nn.init.orthogonal_(module.weight, gain=1.0) + def _get_ve(self, layer_idx: int, input_ids: Tensor, ve_cache: dict | None = None) -> Tensor | None: + """Get value embedding for a specific layer using shared table + per-layer scale.""" + if self.ve_shared is None or layer_idx not in self.ve_layer_indices: + return None + if ve_cache is not None and 've' not in ve_cache: + ve_cache['ve'] = self.ve_shared(input_ids) + ve_base = ve_cache['ve'] if ve_cache is not None else self.ve_shared(input_ids) + ve_idx = self.ve_layer_indices.index(layer_idx) + return ve_base * self.ve_layer_scales[ve_idx].to(dtype=ve_base.dtype) + def forward(self, input_ids: Tensor, target_ids: Tensor) -> Tensor: + n = self.num_layers + x = self.tok_emb(input_ids) + if self.bigram is not None: + x = x + self.bigram(input_ids) + x = F.rms_norm(x, (x.size(-1),)) + x = self.smear(x) + x0 = x + v0 = None + skips: list[Tensor] = [] + ve_cache: dict = {} + for i in range(self.num_encoder_layers): + ve = self._get_ve(i, input_ids, ve_cache) + x, raw_v = self.blocks[i](x, x0, + self.qo_bank[i], self.kv_bank[i], self.kv_bank[n + i], + self.qo_bank[n + i], self.mlp_up_bank[i], self.mlp_down_bank[i], + v_embed=ve, v0=v0) + if v0 is None and raw_v is not None: + v0 = raw_v + skips.append(x) + for i in range(self.num_decoder_layers): + bi = self.num_encoder_layers + i + if skips: + x = x + self.skip_weights[i].to(dtype=x.dtype)[None, None, :] * skips.pop() + ve = self._get_ve(bi, input_ids, ve_cache) + x, _ = self.blocks[bi](x, x0, + self.qo_bank[bi], self.kv_bank[bi], self.kv_bank[n + bi], + self.qo_bank[n + bi], self.mlp_up_bank[bi], self.mlp_down_bank[bi], + v_embed=ve, v0=v0) + x = self.final_norm(x) + x_flat = x.reshape(-1, x.size(-1)) + targets = target_ids.reshape(-1) + if self.tie_embeddings: + logits_proj = F.linear(x_flat, self.tok_emb.weight) + else: + if self.lm_head is None: + raise RuntimeError("lm_head is required when tie_embeddings=False") + logits_proj = self.lm_head(x_flat) + logits = self.logit_softcap * torch.tanh(logits_proj / self.logit_softcap) + main_loss = F.cross_entropy(logits.float(), targets, reduction="mean") + if self.training and self.mtp_num_heads > 0 and self.mtp_loss_weight > 0.0: + _, seqlen, dim = x.shape + mtp_loss_sum = x.new_zeros(()) + mtp_loss_count = 0 + for k, mtp_head in enumerate(self.mtp_heads): + valid_t = seqlen - (k + 1) + if valid_t <= 0: + continue + mtp_hidden = x[:, :valid_t, :].reshape(-1, dim) + mtp_targets = target_ids[:, k + 1 :].reshape(-1) + mtp_logits_proj = mtp_head(mtp_hidden) + mtp_logits = self.logit_softcap * torch.tanh(mtp_logits_proj / self.logit_softcap) + mtp_loss_sum = mtp_loss_sum + F.cross_entropy(mtp_logits.float(), mtp_targets, reduction="mean") + mtp_loss_count += 1 + if mtp_loss_count > 0: + main_loss = main_loss + self.mtp_loss_weight * (mtp_loss_sum / mtp_loss_count) + return main_loss + def forward_logits(self, input_ids: Tensor) -> Tensor: + """Return logits (bsz, seq_len, vocab) without computing loss.""" + n = self.num_layers + x = self.tok_emb(input_ids) + if self.bigram is not None: + x = x + self.bigram(input_ids) + x = F.rms_norm(x, (x.size(-1),)) + x = self.smear(x) + x0 = x + v0 = None + skips: list[Tensor] = [] + ve_cache: dict = {} + for i in range(self.num_encoder_layers): + ve = self._get_ve(i, input_ids, ve_cache) + x, raw_v = self.blocks[i](x, x0, + self.qo_bank[i], self.kv_bank[i], self.kv_bank[n + i], + self.qo_bank[n + i], self.mlp_up_bank[i], self.mlp_down_bank[i], + v_embed=ve, v0=v0) + if v0 is None and raw_v is not None: + v0 = raw_v + skips.append(x) + for i in range(self.num_decoder_layers): + bi = self.num_encoder_layers + i + if skips: + x = x + self.skip_weights[i].to(dtype=x.dtype)[None, None, :] * skips.pop() + ve = self._get_ve(bi, input_ids, ve_cache) + x, _ = self.blocks[bi](x, x0, + self.qo_bank[bi], self.kv_bank[bi], self.kv_bank[n + bi], + self.qo_bank[n + bi], self.mlp_up_bank[bi], self.mlp_down_bank[bi], + v_embed=ve, v0=v0) + x = self.final_norm(x) + if self.tie_embeddings: + logits_proj = F.linear(x, self.tok_emb.weight) + else: + logits_proj = self.lm_head(x) + return self.logit_softcap * torch.tanh(logits_proj / self.logit_softcap) + +# --- Sliding window evaluation --- + +def eval_val_sliding( + args: Hyperparameters, + base_model: nn.Module, + rank: int, + world_size: int, + device: torch.device, + val_tokens: Tensor, + base_bytes_lut: Tensor, + has_leading_space_lut: Tensor, + is_boundary_token_lut: Tensor, + stride: int, + batch_seqs: int = 32, + eval_seq_len: int | None = None, +) -> tuple[float, float]: + """Sliding window evaluation: each token scored with maximum context.""" + seq_len = eval_seq_len or args.train_seq_len + total_tokens = val_tokens.numel() - 1 + window_starts = [ws for ws in range(0, total_tokens, stride) + if min(ws + seq_len, total_tokens) - ws >= 1] + total_windows = len(window_starts) + my_s = (total_windows * rank) // world_size + my_e = (total_windows * (rank + 1)) // world_size + my_windows = window_starts[my_s:my_e] + loss_sum = torch.zeros((), device=device, dtype=torch.float64) + token_count = torch.zeros((), device=device, dtype=torch.float64) + byte_count = torch.zeros((), device=device, dtype=torch.float64) + base_model.eval() + compiled_logits = torch.compile(base_model.forward_logits, dynamic=False, fullgraph=True) + with torch.inference_mode(): + for bi in range(0, len(my_windows), batch_seqs): + batch_ws = my_windows[bi:bi + batch_seqs] + bsz = len(batch_ws) + x_batch = torch.zeros(bsz, seq_len, dtype=torch.int64, device=device) + y_batch = torch.zeros(bsz, seq_len, dtype=torch.int64, device=device) + wlens: list[int] = [] + for i, ws in enumerate(batch_ws): + end = min(ws + seq_len, total_tokens) + wlen = end - ws + wlens.append(wlen) + chunk = val_tokens[ws:end + 1].to(dtype=torch.int64, device=device) + x_batch[i, :wlen] = chunk[:-1] + y_batch[i, :wlen] = chunk[1:] + with torch.autocast(device_type="cuda", dtype=torch.bfloat16): + logits = compiled_logits(x_batch) + nll = F.cross_entropy( + logits.reshape(-1, logits.size(-1)).float(), + y_batch.reshape(-1), + reduction="none", + ).reshape(bsz, seq_len) + for i, ws in enumerate(batch_ws): + wlen = wlens[i] + s = 0 if ws == 0 else max(wlen - stride, 0) + scored_nll = nll[i, s:wlen].to(torch.float64) + loss_sum += scored_nll.sum() + token_count += float(wlen - s) + tgt = y_batch[i, s:wlen] + prev = x_batch[i, s:wlen] + tb = base_bytes_lut[tgt].to(torch.float64) + tb += (has_leading_space_lut[tgt] & ~is_boundary_token_lut[prev]).to(torch.float64) + byte_count += tb.sum() + if dist.is_available() and dist.is_initialized(): + dist.all_reduce(loss_sum, op=dist.ReduceOp.SUM) + dist.all_reduce(token_count, op=dist.ReduceOp.SUM) + dist.all_reduce(byte_count, op=dist.ReduceOp.SUM) + val_loss = (loss_sum / token_count).item() + bits_per_token = val_loss / math.log(2.0) + tokens_per_byte = token_count.item() / byte_count.item() + base_model.train() + return val_loss, bits_per_token * tokens_per_byte + + +def eval_val_sliding_ttt( + args: Hyperparameters, base_model: nn.Module, rank: int, world_size: int, + device: torch.device, val_tokens: Tensor, base_bytes_lut: Tensor, + has_leading_space_lut: Tensor, is_boundary_token_lut: Tensor, + stride: int, batch_seqs: int = 32, log0=print, +) -> tuple[float, float]: + """Legal score-first TTT (PR #461 recipe): score each chunk with sliding windows, + then train on it. Every token scored BEFORE any update that could use it.""" + seq_len = args.train_seq_len + total_tokens = val_tokens.numel() - 1 + ttt_chunk = args.ttt_chunk_tokens + + # Pre-compute all window starts + window_starts = [ws for ws in range(0, total_tokens, stride) + if min(ws + seq_len, total_tokens) - ws >= stride or ws == 0] + + # Assign each window to a chunk based on the first token it scores + num_chunks = (total_tokens + ttt_chunk - 1) // ttt_chunk + chunk_windows: list[list[int]] = [[] for _ in range(num_chunks)] + for ws in window_starts: + end = min(ws + seq_len, total_tokens) + wlen = end - ws + s = 0 if ws == 0 else max(wlen - stride, 0) + scored_start = ws + s + ci = min(scored_start // ttt_chunk, num_chunks - 1) + chunk_windows[ci].append(ws) + + log0(f"ttt_sliding:start chunks={num_chunks} chunk_tokens={ttt_chunk} " + f"total_windows={len(window_starts)} stride={stride} " + f"ttt_lr={args.ttt_lr} ttt_epochs={args.ttt_epochs} " + f"freeze_blocks={args.ttt_freeze_blocks}") + + loss_sum = torch.zeros((), device=device, dtype=torch.float64) + token_count = torch.zeros((), device=device, dtype=torch.float64) + byte_count = torch.zeros((), device=device, dtype=torch.float64) + + # Freeze first N blocks + frozen_block_ids = set(range(min(args.ttt_freeze_blocks, len(base_model.blocks)))) + ttt_params = [] + for name, p in base_model.named_parameters(): + freeze = False + for bi in frozen_block_ids: + if f"blocks.{bi}." in name: + freeze = True + break + if freeze: + p.requires_grad_(False) + else: + p.requires_grad_(True) + ttt_params.append(p) + + log0(f"ttt_sliding:params unfrozen={sum(p.numel() for p in ttt_params)} " + f"frozen={sum(p.numel() for p in base_model.parameters() if not p.requires_grad)}") + + optimizer = torch.optim.SGD(ttt_params, lr=args.ttt_lr, momentum=args.ttt_momentum) + t0 = time.perf_counter() + + for ci in range(num_chunks): + windows = chunk_windows[ci] + if not windows: + continue + chunk_start = ci * ttt_chunk + chunk_end = min((ci + 1) * ttt_chunk, total_tokens) + + # --- Phase 1: SCORE this chunk's windows (inference_mode) --- + my_s = (len(windows) * rank) // world_size + my_e = (len(windows) * (rank + 1)) // world_size + my_windows = windows[my_s:my_e] + + base_model.eval() + with torch.inference_mode(): + for bi in range(0, len(my_windows), batch_seqs): + batch_ws = my_windows[bi:bi + batch_seqs] + bsz = len(batch_ws) + x_batch = torch.zeros(bsz, seq_len, dtype=torch.int64, device=device) + y_batch = torch.zeros(bsz, seq_len, dtype=torch.int64, device=device) + wlens: list[int] = [] + for i, ws in enumerate(batch_ws): + end = min(ws + seq_len, total_tokens) + wlen = end - ws + wlens.append(wlen) + chunk_tok = val_tokens[ws:end + 1].to(dtype=torch.int64, device=device) + x_batch[i, :wlen] = chunk_tok[:-1] + y_batch[i, :wlen] = chunk_tok[1:] + with torch.autocast(device_type="cuda", dtype=torch.bfloat16): + logits = base_model.forward_logits(x_batch) + nll = F.cross_entropy( + logits.reshape(-1, logits.size(-1)).float(), + y_batch.reshape(-1), reduction="none", + ).reshape(bsz, seq_len) + for i, ws in enumerate(batch_ws): + wlen = wlens[i] + s = 0 if ws == 0 else max(wlen - stride, 0) + scored_nll = nll[i, s:wlen].to(torch.float64) + loss_sum += scored_nll.sum() + token_count += float(wlen - s) + tgt, prev = y_batch[i, s:wlen], x_batch[i, s:wlen] + tb = base_bytes_lut[tgt].to(torch.float64) + tb += (has_leading_space_lut[tgt] & ~is_boundary_token_lut[prev]).to(torch.float64) + byte_count += tb.sum() + + # --- Phase 2: TRAIN on this chunk (already scored = legal) --- + is_last_chunk = (ci == num_chunks - 1) + if not is_last_chunk and args.ttt_epochs > 0: + base_model.train() + chunk_seqs = (chunk_end - chunk_start) // seq_len + if chunk_seqs > 0: + cos_lr = args.ttt_lr * 0.5 * (1.0 + math.cos(math.pi * ci / max(num_chunks - 1, 1))) + for pg in optimizer.param_groups: + pg['lr'] = cos_lr + my_seq_s = (chunk_seqs * rank) // world_size + my_seq_e = (chunk_seqs * (rank + 1)) // world_size + my_chunk_seqs = my_seq_e - my_seq_s + for _ep in range(args.ttt_epochs): + for bs in range(0, my_chunk_seqs, args.ttt_batch_seqs): + be = min(bs + args.ttt_batch_seqs, my_chunk_seqs) + actual_bs = my_seq_s + bs + start_tok = chunk_start + actual_bs * seq_len + end_tok = chunk_start + (my_seq_s + be) * seq_len + 1 + if end_tok > val_tokens.numel(): + continue + local = val_tokens[start_tok:end_tok].to(device=device, dtype=torch.int64) + x = local[:-1].reshape(-1, seq_len) + y = local[1:].reshape(-1, seq_len) + optimizer.zero_grad(set_to_none=True) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16): + loss = base_model(x, y) + loss.backward() + if world_size > 1: + for p in ttt_params: + if p.grad is not None: + dist.all_reduce(p.grad, op=dist.ReduceOp.AVG) + torch.nn.utils.clip_grad_norm_(ttt_params, args.ttt_grad_clip) + optimizer.step() + + if rank == 0 and (ci % 10 == 0 or ci == num_chunks - 1): + elapsed = time.perf_counter() - t0 + rl = loss_sum.item() / max(token_count.item(), 1) + rbpb = rl / math.log(2.0) * (token_count.item() / max(byte_count.item(), 1)) if token_count.item() > 0 else 0.0 + log0(f" ttt_chunk [{ci+1}/{num_chunks}] bpb={rbpb:.6f} time={elapsed:.1f}s") + + if dist.is_available() and dist.is_initialized(): + dist.all_reduce(loss_sum, op=dist.ReduceOp.SUM) + dist.all_reduce(token_count, op=dist.ReduceOp.SUM) + dist.all_reduce(byte_count, op=dist.ReduceOp.SUM) + + val_loss = (loss_sum / token_count).item() + val_bpb = val_loss / math.log(2.0) * (token_count.item() / byte_count.item()) + + for p in base_model.parameters(): + p.requires_grad_(True) + base_model.eval() + + log0(f"ttt_sliding:done val_loss={val_loss:.6f} val_bpb={val_bpb:.6f} " + f"elapsed={time.perf_counter() - t0:.1f}s") + return val_loss, val_bpb + + +# === N-GRAM EVAL CACHE + TWO-PASS RESCORE === + +_NGRAM_PRIMES = np.array([ + 36313, 27191, 51647, 81929, 131071, 174763, 233017, 283721, + 347237, 411527, 479909, 557927, 646333, 746773, 862319, 992353, +], dtype=np.int64) + +# Per-order multipliers: orders 2-3 suppressed, 4 near-neutral, 5-12 boosted +_ORDER_MULTS = np.array([ + 0.30, 0.30, 0.97, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, +], dtype=np.float32) + + +class NgramCache: + """Hash-table n-gram cache with vectorized numpy operations.""" + + def __init__(self, min_order: int = 2, max_order: int = 16, + num_buckets: int = 16_777_216): + self.min_order = min_order + self.max_order = max_order + self.num_orders = max_order - min_order + 1 + self.num_buckets = num_buckets + self.bucket_mask = np.int64(num_buckets - 1) + # Two flat hash tables per order: context counts and full (context+target) counts + self.ctx_tables = [np.zeros(num_buckets, dtype=np.int32) for _ in range(self.num_orders)] + self.full_tables = [np.zeros(num_buckets, dtype=np.int32) for _ in range(self.num_orders)] + + def _compute_hashes(self, tokens_np: np.ndarray, start: int, end: int, order_idx: int): + """Compute context and full hashes for positions [start, end) at given order.""" + n = self.min_order + order_idx + valid_start = max(start, n - 1) + N = end - valid_start + if N <= 0: + return None, None, valid_start + # Context hash: XOR of tokens[pos-n+1+k] * primes[k] for k=0..n-2 + h = np.zeros(N, dtype=np.int64) + for k in range(n - 1): + offset = valid_start - (n - 1) + k + h ^= tokens_np[offset:offset + N].astype(np.int64) * _NGRAM_PRIMES[k % len(_NGRAM_PRIMES)] + ctx_h = h & self.bucket_mask + # Full hash: context + target token + target_prime = _NGRAM_PRIMES[min(n - 1, len(_NGRAM_PRIMES) - 1)] + full_h = (h ^ (tokens_np[valid_start:end].astype(np.int64) * target_prime)) & self.bucket_mask + return ctx_h, full_h, valid_start + + def _bincount_add(self, table: np.ndarray, indices: np.ndarray): + """Fast histogram accumulation using np.bincount (much faster than np.add.at).""" + counts = np.bincount(indices.astype(np.intp), minlength=self.num_buckets) + table += counts[:self.num_buckets].astype(table.dtype) + + def update_range(self, tokens_np: np.ndarray, start: int, end: int): + """Add tokens[start:end] to the cache for all orders.""" + for oi in range(self.num_orders): + ctx_h, full_h, vs = self._compute_hashes(tokens_np, start, end, oi) + if ctx_h is None: + continue + self._bincount_add(self.ctx_tables[oi], ctx_h) + self._bincount_add(self.full_tables[oi], full_h) + + def build_full(self, tokens_np: np.ndarray): + """Build complete cache from entire token sequence (vectorized).""" + for oi in range(self.num_orders): + ctx_h, full_h, _ = self._compute_hashes(tokens_np, 0, len(tokens_np), oi) + if ctx_h is None: + continue + self._bincount_add(self.ctx_tables[oi], ctx_h) + self._bincount_add(self.full_tables[oi], full_h) + + def score_range(self, tokens_np: np.ndarray, start: int, end: int, + min_count: int = 2): + """Score tokens[start:end] against the cache. + + Returns: + ngram_prob: (N,) float32 - n-gram probability for the true target token + matched_order: (N,) int32 - which order matched (-1 = no match) + """ + N = end - start + ngram_prob = np.zeros(N, dtype=np.float32) + matched_order = np.full(N, -1, dtype=np.int32) + matched = np.zeros(N, dtype=bool) + + # Backoff from highest to lowest order + for oi in range(self.num_orders - 1, -1, -1): + n = self.min_order + oi + ctx_h, full_h, vs = self._compute_hashes(tokens_np, start, end, oi) + if ctx_h is None: + continue + offset = vs - start + ctx_counts = self.ctx_tables[oi][ctx_h] + full_counts = self.full_tables[oi][full_h] + # Cap full counts to context counts (hash collision mitigation) + full_counts = np.minimum(full_counts, ctx_counts) + # Only match when: sufficient context, target has been seen, not already matched + eligible = (ctx_counts >= min_count) & (full_counts > 0) & ~matched[offset:] + if not np.any(eligible): + continue + prob = full_counts[eligible].astype(np.float32) / np.maximum(ctx_counts[eligible].astype(np.float32), 1.0) + # Find which positions in the output array to fill + out_idx = np.where(eligible)[0] + offset + ngram_prob[out_idx] = prob + matched_order[out_idx] = n + matched[out_idx] = True + + return ngram_prob, matched_order + + def score_positions(self, tokens_np: np.ndarray, positions: np.ndarray, + min_count: int = 2, leave_one_out: bool = False): + """Score selected token positions against the cache. + + If leave_one_out is enabled, subtract this token's own contribution from + both context and (context,target) counts before matching. + """ + N = len(positions) + ngram_prob = np.zeros(N, dtype=np.float32) + matched_order = np.full(N, -1, dtype=np.int32) + matched = np.zeros(N, dtype=bool) + if N == 0: + return ngram_prob, matched_order + + positions = positions.astype(np.int64, copy=False) + for oi in range(self.num_orders - 1, -1, -1): + n = self.min_order + oi + ctx_h_all, full_h_all, valid_start = self._compute_hashes(tokens_np, 0, len(tokens_np), oi) + if ctx_h_all is None: + continue + + remaining_idx = np.where(~matched)[0] + if remaining_idx.size == 0: + break + pos_sub = positions[remaining_idx] + valid_mask = pos_sub >= valid_start + if not np.any(valid_mask): + continue + + valid_idx = remaining_idx[valid_mask] + lookup = (pos_sub[valid_mask] - valid_start).astype(np.int64) + ctx_h = ctx_h_all[lookup] + full_h = full_h_all[lookup] + + ctx_counts = self.ctx_tables[oi][ctx_h].astype(np.int64) + full_counts = self.full_tables[oi][full_h].astype(np.int64) + if leave_one_out: + ctx_counts = np.maximum(ctx_counts - 1, 0) + full_counts = np.maximum(full_counts - 1, 0) + full_counts = np.minimum(full_counts, ctx_counts) + + eligible = (ctx_counts >= min_count) & (full_counts > 0) + if not np.any(eligible): + continue + + out_idx = valid_idx[eligible] + prob = full_counts[eligible].astype(np.float32) / np.maximum(ctx_counts[eligible].astype(np.float32), 1.0) + ngram_prob[out_idx] = prob + matched_order[out_idx] = n + matched[out_idx] = True + + return ngram_prob, matched_order + + +def eval_val_sliding_store( + args: Hyperparameters, base_model: nn.Module, rank: int, world_size: int, + device: torch.device, val_tokens: Tensor, base_bytes_lut: Tensor, + has_leading_space_lut: Tensor, is_boundary_token_lut: Tensor, + stride: int, batch_seqs: int = 32, log0=print, +) -> tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray, float, float]: + """Sliding-window eval that stores per-token model_p and entropy. + + Returns: (model_p, entropy, token_bytes, token_targets, val_loss, val_bpb) + where model_p and entropy are arrays covering this rank's scored tokens, + and val_loss/val_bpb are the standard (un-blended) metrics. + + Also returns global-offset index arrays for mapping back to token positions. + """ + seq_len = args.train_seq_len + total_tokens = val_tokens.numel() - 1 + window_starts = [ws for ws in range(0, total_tokens, stride) + if min(ws + seq_len, total_tokens) - ws >= 1] + total_windows = len(window_starts) + my_s = (total_windows * rank) // world_size + my_e = (total_windows * (rank + 1)) // world_size + my_windows = window_starts[my_s:my_e] + + # Pre-allocate per-token storage (we'll trim later) + # Each token is scored in exactly one window + model_p_list: list[np.ndarray] = [] + entropy_list: list[np.ndarray] = [] + bytes_list: list[np.ndarray] = [] + position_list: list[np.ndarray] = [] # global target-token positions + nll_list: list[np.ndarray] = [] + + base_model.eval() + compiled_logits = torch.compile(base_model.forward_logits, dynamic=False, fullgraph=True) + with torch.inference_mode(): + for bi in range(0, len(my_windows), batch_seqs): + batch_ws = my_windows[bi:bi + batch_seqs] + bsz = len(batch_ws) + x_batch = torch.zeros(bsz, seq_len, dtype=torch.int64, device=device) + y_batch = torch.zeros(bsz, seq_len, dtype=torch.int64, device=device) + wlens: list[int] = [] + for i, ws in enumerate(batch_ws): + end_pos = min(ws + seq_len, total_tokens) + wlen = end_pos - ws + wlens.append(wlen) + chunk = val_tokens[ws:end_pos + 1].to(dtype=torch.int64, device=device) + x_batch[i, :wlen] = chunk[:-1] + y_batch[i, :wlen] = chunk[1:] + with torch.autocast(device_type="cuda", dtype=torch.bfloat16): + logits = compiled_logits(x_batch) # (bsz, seq_len, vocab_size) + # Compute per-token quantities + logits_f = logits.float() + log_probs = F.log_softmax(logits_f, dim=-1) # (bsz, seq_len, V) + probs = log_probs.exp() + # NLL for each token + nll_all = F.cross_entropy( + logits_f.reshape(-1, logits_f.size(-1)), + y_batch.reshape(-1), reduction="none" + ).reshape(bsz, seq_len) + # Model probability of true token + mp = probs.gather(2, y_batch.unsqueeze(-1)).squeeze(-1) # (bsz, seq_len) + # Entropy of model distribution + ent = -(probs * log_probs).sum(dim=-1) # (bsz, seq_len) + + for i, ws in enumerate(batch_ws): + wlen = wlens[i] + s = 0 if ws == 0 else max(wlen - stride, 0) + # Positions are TARGET token indices in val_tokens (ws+j+1 for scored position j) + positions = np.arange(ws + s + 1, ws + wlen + 1, dtype=np.int64) + position_list.append(positions) + model_p_list.append(mp[i, s:wlen].cpu().numpy().astype(np.float32)) + entropy_list.append(ent[i, s:wlen].cpu().numpy().astype(np.float32)) + nll_list.append(nll_all[i, s:wlen].cpu().numpy().astype(np.float64)) + tgt = y_batch[i, s:wlen] + prev = x_batch[i, s:wlen] + tb = base_bytes_lut[tgt].to(torch.float64) + tb += (has_leading_space_lut[tgt] & ~is_boundary_token_lut[prev]).to(torch.float64) + bytes_list.append(tb.cpu().numpy()) + + all_positions = np.concatenate(position_list) if position_list else np.array([], dtype=np.int64) + all_model_p = np.concatenate(model_p_list) if model_p_list else np.array([], dtype=np.float32) + all_entropy = np.concatenate(entropy_list) if entropy_list else np.array([], dtype=np.float32) + all_nll = np.concatenate(nll_list) if nll_list else np.array([], dtype=np.float64) + all_bytes = np.concatenate(bytes_list) if bytes_list else np.array([], dtype=np.float64) + + + # Compute standard (un-blended) BPB for this rank + local_loss_sum = all_nll.sum() + local_token_count = float(len(all_nll)) + local_byte_count = all_bytes.sum() + + # All-reduce for standard BPB + loss_sum_t = torch.tensor(local_loss_sum, device=device, dtype=torch.float64) + token_count_t = torch.tensor(local_token_count, device=device, dtype=torch.float64) + byte_count_t = torch.tensor(local_byte_count, device=device, dtype=torch.float64) + if dist.is_available() and dist.is_initialized(): + dist.all_reduce(loss_sum_t, op=dist.ReduceOp.SUM) + dist.all_reduce(token_count_t, op=dist.ReduceOp.SUM) + dist.all_reduce(byte_count_t, op=dist.ReduceOp.SUM) + val_loss = (loss_sum_t / token_count_t).item() + val_bpb = val_loss / math.log(2.0) * (token_count_t.item() / byte_count_t.item()) + + base_model.train() + return all_model_p, all_entropy, all_bytes, all_positions, val_loss, val_bpb + + +def ngram_rescore( + args: Hyperparameters, + tokens_np: np.ndarray, + cache: NgramCache, + model_p: np.ndarray, + entropy: np.ndarray, + token_bytes: np.ndarray, + positions: np.ndarray, + rank: int, world_size: int, device: torch.device, + log0=print, +) -> tuple[float, float]: + """Rescore tokens using n-gram cache blended with stored neural model_p. + + This is Pass 2: the cache is already complete. + Returns: (val_loss, val_bpb) + """ + N = len(positions) + if N == 0: + return 0.0, 0.0 + + # Score all of this rank's positions against the full cache + # We need to score at the GLOBAL token positions + # The cache.score_range expects contiguous ranges, but our positions may be sparse + # Instead, we score the full range and index into it + # Actually, positions are sorted (from sliding windows), so we can score chunks + + ngram_prob, matched_order = cache.score_positions( + tokens_np, + positions, + min_count=args.ngram_min_count, + leave_one_out=args.ngram_leave_one_out, + ) + matched = matched_order >= 0 + + # Entropy-adaptive alpha with per-order multipliers + alpha = np.zeros(N, dtype=np.float32) + if np.any(matched): + order_idx = (matched_order[matched] - cache.min_order).astype(np.int32) + centers = args.ngram_entropy_center - 0.25 * order_idx.astype(np.float32) + sig = 1.0 / (1.0 + np.exp(-args.ngram_entropy_scale * (entropy[matched] - centers))) + raw_alpha = args.ngram_alpha_min + (args.ngram_alpha_max - args.ngram_alpha_min) * sig + # Per-order multipliers + mults = _ORDER_MULTS[np.minimum(order_idx, len(_ORDER_MULTS) - 1)] + raw_alpha *= mults + alpha[matched] = np.clip(raw_alpha, 0.0, 0.95) + + # Blend: p_blend = (1 - alpha) * model_p + alpha * ngram_prob + p_blend = (1.0 - alpha) * model_p + alpha * ngram_prob + # Clamp to avoid log(0) + p_blend = np.maximum(p_blend, 1e-10) + # For unmatched tokens, use model_p directly + p_blend[~matched] = np.maximum(model_p[~matched], 1e-10) + + # NLL + nll = -np.log(p_blend).astype(np.float64) + + # Aggregate + local_loss_sum = nll.sum() + local_token_count = float(N) + local_byte_count = token_bytes.sum() + + # All-reduce + loss_sum_t = torch.tensor(local_loss_sum, device=device, dtype=torch.float64) + token_count_t = torch.tensor(local_token_count, device=device, dtype=torch.float64) + byte_count_t = torch.tensor(local_byte_count, device=device, dtype=torch.float64) + if dist.is_available() and dist.is_initialized(): + dist.all_reduce(loss_sum_t, op=dist.ReduceOp.SUM) + dist.all_reduce(token_count_t, op=dist.ReduceOp.SUM) + dist.all_reduce(byte_count_t, op=dist.ReduceOp.SUM) + + val_loss = (loss_sum_t / token_count_t).item() + val_bpb = val_loss / math.log(2.0) * (token_count_t.item() / byte_count_t.item()) + + n_matched = int(matched.sum()) + log0( + f"ngram_rescore: matched={n_matched}/{N} ({100*n_matched/max(N,1):.1f}%) " + f"mean_alpha={alpha[matched].mean():.3f} leave_one_out={int(args.ngram_leave_one_out)}" + if n_matched > 0 else f"ngram_rescore: no matches leave_one_out={int(args.ngram_leave_one_out)}" + ) + + return val_loss, val_bpb + + +def eval_ngram_two_pass( + args: Hyperparameters, base_model: nn.Module, rank: int, world_size: int, + device: torch.device, val_tokens: Tensor, base_bytes_lut: Tensor, + has_leading_space_lut: Tensor, is_boundary_token_lut: Tensor, + stride: int, batch_seqs: int = 32, log0=print, +) -> tuple[float, float]: + """Two-pass n-gram evaluation. + + Pass 1: Sliding-window neural eval → store per-token model_p and entropy. + Build: Complete n-gram cache from all tokens (vectorized). + Pass 2: Rescore ALL tokens by blending neural model_p with n-gram predictions. + """ + t0 = time.perf_counter() + + # --- Pass 1: Neural eval with per-token storage --- + log0(f"ngram_two_pass: starting Pass 1 (sliding-window neural eval)") + model_p, entropy, token_bytes, positions, pass1_loss, pass1_bpb = eval_val_sliding_store( + args, base_model, rank, world_size, device, val_tokens, + base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, + stride=stride, batch_seqs=batch_seqs, log0=log0, + ) + t_pass1 = time.perf_counter() + log0(f"ngram_two_pass: Pass 1 done val_bpb={pass1_bpb:.6f} " + f"tokens_scored={len(positions)} time={t_pass1 - t0:.1f}s") + + # --- Build complete n-gram cache --- + log0(f"ngram_two_pass: building cache orders={args.ngram_min_order}-{args.ngram_max_order} " + f"buckets={args.ngram_num_buckets}") + tokens_np = val_tokens.numpy().astype(np.int16) + cache = NgramCache( + min_order=args.ngram_min_order, + max_order=args.ngram_max_order, + num_buckets=args.ngram_num_buckets, + ) + cache.build_full(tokens_np) + t_cache = time.perf_counter() + log0(f"ngram_two_pass: cache built in {t_cache - t_pass1:.1f}s") + + # --- Pass 2: N-gram rescore --- + log0(f"ngram_two_pass: starting Pass 2 (n-gram rescore)") + val_loss, val_bpb = ngram_rescore( + args, tokens_np, cache, model_p, entropy, token_bytes, positions, + rank, world_size, device, log0=log0, + ) + t_pass2 = time.perf_counter() + log0(f"ngram_two_pass: Pass 2 done val_bpb={val_bpb:.6f} " + f"improvement={pass1_bpb - val_bpb:.6f} time={t_pass2 - t_cache:.1f}s") + log0(f"ngram_two_pass: total time={t_pass2 - t0:.1f}s") + + return val_loss, val_bpb + + +# === COMPLEMENTARY TRAINING === + +class TrainBigramTracker: + """Tracks bigram statistics from training data for complementary loss weighting.""" + + def __init__(self, vocab_size: int, device: torch.device): + # bigram_counts[prev_token, target_token] = count + self.counts = torch.zeros(vocab_size, vocab_size, device=device, dtype=torch.float32) + self.row_totals = torch.zeros(vocab_size, device=device, dtype=torch.float32) + + @torch.no_grad() + def update(self, x: Tensor, y: Tensor): + """Update bigram counts. x: context tokens, y: target tokens.""" + prev = x.reshape(-1) + tgt = y.reshape(-1) + idx = prev.long() * self.counts.shape[1] + tgt.long() + self.counts.view(-1).scatter_add_(0, idx, torch.ones_like(idx, dtype=torch.float32)) + self.row_totals.scatter_add_(0, prev.long(), torch.ones(prev.shape[0], device=prev.device, dtype=torch.float32)) + + @torch.no_grad() + def get_weights(self, x: Tensor, y: Tensor, alpha: float = 0.5) -> Tensor: + """Compute per-token loss weights: downweight tokens predictable by bigrams.""" + prev = x.reshape(-1) + tgt = y.reshape(-1) + totals = self.row_totals[prev.long()] + counts = self.counts[prev.long(), tgt.long()] + ngram_prob = counts / totals.clamp(min=1.0) + weights = (1.0 - alpha * ngram_prob).clamp(min=0.1) + return weights.reshape(y.shape) + + +# --- GPTQ-lite int6 quantization --- + +def _classify_param(name: str) -> str: + if "tok_emb" in name or "lm_head" in name: + return "embed" + if ".mlp." in name: + return "mlp" + if ".attn." in name or (".proj." in name and ".mlp." not in name): + return "attn" + return "other" +def quantize_int6_per_row(t: Tensor, clip_range: int = 31) -> tuple[Tensor, Tensor]: + t32 = t.float() + if t32.ndim == 2: + best_q, best_s, best_err = None, None, float('inf') + for pct in [0.9990, 0.9995, 0.9999, 0.99999, 1.0]: + if pct < 1.0: + row_clip = torch.quantile(t32.abs(), pct, dim=1) + else: + row_clip = t32.abs().amax(dim=1) + s = (row_clip / clip_range).clamp_min(1.0 / clip_range).to(torch.float16) + q = torch.clamp(torch.round(t32 / s.float()[:, None]), -clip_range, clip_range).to(torch.int8) + recon = q.float() * s.float()[:, None] + err = (t32 - recon).pow(2).mean().item() + if err < best_err: + best_q, best_s, best_err = q, s, err + return best_q, best_s + amax = t32.abs().max().item() + scale = torch.tensor(amax / clip_range if amax > 0 else 1.0, dtype=torch.float16) + q = torch.clamp(torch.round(t32 / scale.float()), -clip_range, clip_range).to(torch.int8) + return q, scale + +def _unbank_state_dict(sd: dict[str, Tensor], num_layers: int) -> dict[str, Tensor]: + """Convert 3D bank tensors into individual 2D tensors with standard names.""" + out: dict[str, Tensor] = {} + n = num_layers + for name, tensor in sd.items(): + if name == "qo_bank": + for i in range(n): + out[f"blocks.{i}.attn.c_q.weight"] = tensor[i] + out[f"blocks.{i}.attn.proj.weight"] = tensor[n + i] + elif name == "kv_bank": + for i in range(n): + out[f"blocks.{i}.attn.c_k.weight"] = tensor[i] + out[f"blocks.{i}.attn.c_v.weight"] = tensor[n + i] + elif name == "mlp_up_bank": + for i in range(n): + out[f"blocks.{i}.mlp.fc.weight"] = tensor[i] + elif name == "mlp_down_bank": + for i in range(n): + out[f"blocks.{i}.mlp.proj.weight"] = tensor[i] + else: + out[name] = tensor + return out + +def _rebank_state_dict(sd: dict[str, Tensor], num_layers: int, template_sd: dict[str, Tensor]) -> dict[str, Tensor]: + """Convert individual 2D tensors back into 3D bank tensors.""" + out: dict[str, Tensor] = {} + n = num_layers + # Reconstruct banks from individual weight keys + qo_slices = [None] * (2 * n) + kv_slices = [None] * (2 * n) + up_slices = [None] * n + down_slices = [None] * n + consumed = set() + for i in range(n): + qk = f"blocks.{i}.attn.c_q.weight" + if qk in sd: + qo_slices[i] = sd[qk] + consumed.add(qk) + ok = f"blocks.{i}.attn.proj.weight" + if ok in sd: + qo_slices[n + i] = sd[ok] + consumed.add(ok) + kk = f"blocks.{i}.attn.c_k.weight" + if kk in sd: + kv_slices[i] = sd[kk] + consumed.add(kk) + vk = f"blocks.{i}.attn.c_v.weight" + if vk in sd: + kv_slices[n + i] = sd[vk] + consumed.add(vk) + fk = f"blocks.{i}.mlp.fc.weight" + if fk in sd: + up_slices[i] = sd[fk] + consumed.add(fk) + dk = f"blocks.{i}.mlp.proj.weight" + if dk in sd: + down_slices[i] = sd[dk] + consumed.add(dk) + out["qo_bank"] = torch.stack(qo_slices).to(dtype=template_sd["qo_bank"].dtype) + out["kv_bank"] = torch.stack(kv_slices).to(dtype=template_sd["kv_bank"].dtype) + out["mlp_up_bank"] = torch.stack(up_slices).to(dtype=template_sd["mlp_up_bank"].dtype) + out["mlp_down_bank"] = torch.stack(down_slices).to(dtype=template_sd["mlp_down_bank"].dtype) + for name, tensor in sd.items(): + if name not in consumed: + out[name] = tensor + return out + +def mixed_quantize_int6(state_dict: dict[str, Tensor], int6_cats: set[str]): + num_layers_total = max( + (int(k.split(".")[1]) for k in state_dict if k.startswith("blocks.")), + default=0, + ) + 1 + late_k_layers = set(range(num_layers_total - 2, num_layers_total)) + result: dict[str, Tensor] = {} + meta: dict[str, object] = {} + for name, tensor in state_dict.items(): + t = tensor.detach().cpu().contiguous() + cat = _classify_param(name) + if not t.is_floating_point() or t.numel() <= 65536: + result[name] = t.to(torch.float16) if t.is_floating_point() else t + meta[name] = "passthrough" + continue + if any(p in name for p in CONTROL_TENSOR_NAME_PATTERNS): + result[name] = t.float() + meta[name] = "passthrough_ctrl" + continue + if cat in int6_cats and t.ndim >= 1: + q, s = quantize_int6_per_row(t) + result[name + ".q"] = q + result[name + ".scale"] = s + meta[name] = {"type": "int6"} + else: + q, s = quantize_float_tensor(t) + result[name + ".q"] = q + result[name + ".scale"] = s + meta[name] = {"type": "int8"} + return result, meta +def dequantize_mixed_int6(result: dict[str, Tensor], meta: dict[str, object], + template_sd: dict[str, Tensor]) -> dict[str, Tensor]: + out: dict[str, Tensor] = {} + for name, orig in template_sd.items(): + info = meta.get(name) + if info is None: + continue + orig_dtype = orig.dtype + if info in ("passthrough", "passthrough_ctrl", "passthrough_fp16"): + t = result[name] + if t.dtype == torch.float16 and orig_dtype in (torch.float32, torch.bfloat16): + t = t.to(orig_dtype) + out[name] = t + continue + q, s = result[name + ".q"], result[name + ".scale"] + if s.ndim > 0: + out[name] = (q.float() * s.float().view(q.shape[0], *([1] * (q.ndim - 1)))).to(orig_dtype) + else: + out[name] = (q.float() * float(s.item())).to(orig_dtype) + return out + +# --- Training --- + +def main() -> None: + code = Path(__file__).read_text(encoding="utf-8") + args = Hyperparameters() + # zeropower_via_newtonschulz5 runs eagerly with bmm -- do NOT compile + distributed = "RANK" in os.environ and "WORLD_SIZE" in os.environ + rank = int(os.environ.get("RANK", "0")) + world_size = int(os.environ.get("WORLD_SIZE", "1")) + local_rank = int(os.environ.get("LOCAL_RANK", "0")) + if world_size <= 0: + raise ValueError(f"WORLD_SIZE must be positive, got {world_size}") + if 8 % world_size != 0: + raise ValueError(f"WORLD_SIZE={world_size} must divide 8 so grad_accum_steps stays integral") + grad_accum_steps = 8 // world_size + grad_scale = 1.0 / grad_accum_steps + if not torch.cuda.is_available(): + raise RuntimeError("CUDA is required") + device = torch.device("cuda", local_rank) + torch.cuda.set_device(device) + if distributed: + dist.init_process_group(backend="nccl", device_id=device) + dist.barrier() + master_process = rank == 0 + torch.backends.cuda.matmul.allow_tf32 = True + torch.backends.cudnn.allow_tf32 = True + from torch.backends.cuda import enable_cudnn_sdp, enable_flash_sdp, enable_math_sdp, enable_mem_efficient_sdp + enable_cudnn_sdp(False) + enable_flash_sdp(True) + enable_mem_efficient_sdp(False) + enable_math_sdp(False) + logfile = None + if master_process: + os.makedirs("logs", exist_ok=True) + logfile = f"logs/{args.run_id}.txt" + print(logfile) + def log0(msg: str, console: bool = True) -> None: + if not master_process: + return + if console: + print(msg) + if logfile is not None: + with open(logfile, "a", encoding="utf-8") as f: + print(msg, file=f) + log0(code, console=False) + log0("=" * 100, console=False) + log0(f"Running Python {sys.version}", console=False) + log0(f"Running PyTorch {torch.__version__}", console=False) + log0( + subprocess.run(["nvidia-smi"], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True, check=False).stdout, + console=False, + ) + log0("=" * 100, console=False) + random.seed(args.seed) + np.random.seed(args.seed) + torch.manual_seed(args.seed) + torch.cuda.manual_seed_all(args.seed) + if not args.tokenizer_path.endswith(".model"): + raise ValueError(f"Script only setup for SentencePiece .model file: {args.tokenizer_path}") + sp = spm.SentencePieceProcessor(model_file=args.tokenizer_path) + if int(sp.vocab_size()) != args.vocab_size: + raise ValueError( + f"VOCAB_SIZE={args.vocab_size} does not match tokenizer vocab_size={int(sp.vocab_size())}" + ) + dataset_dir = Path(args.data_path).resolve() + actual_train_files = len(list(dataset_dir.glob("fineweb_train_*.bin"))) + effective_eval_seq_len = args.eval_seq_len if args.eval_seq_len > 0 else args.train_seq_len + val_seq_len = max(args.train_seq_len, effective_eval_seq_len) + val_tokens = load_validation_tokens(args.val_files, val_seq_len, args.val_tokens_limit) + base_bytes_lut, has_leading_space_lut, is_boundary_token_lut = build_sentencepiece_luts( + sp, args.vocab_size, device + ) + log0(f"val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path={args.tokenizer_path}") + log0(f"train_loader:dataset:{dataset_dir.name} train_shards:{actual_train_files}") + log0(f"val_loader:shards pattern={args.val_files} tokens:{val_tokens.numel() - 1}") + CastedLinear._qat_enabled = args.qat_enabled + base_model = GPT( + vocab_size=args.vocab_size, + num_layers=args.num_layers, + model_dim=args.model_dim, + num_heads=args.num_heads, + num_kv_heads=args.num_kv_heads, + mlp_mult=args.mlp_mult, + tie_embeddings=args.tie_embeddings, + tied_embed_init_std=args.tied_embed_init_std, + logit_softcap=args.logit_softcap, + rope_base=args.rope_base, + qk_gain_init=args.qk_gain_init, + mtp_num_heads=args.mtp_num_heads, + mtp_loss_weight=args.mtp_loss_weight, + bigram_vocab_size=args.bigram_vocab_size, + bigram_dim=args.bigram_dim, + xsa_last_n=args.xsa_last_n, + rope_dims=args.rope_dims, + ln_scale=args.ln_scale, + dtg=args.dtg_enabled, + ve_enabled=args.ve_enabled, + ve_dim=args.ve_dim, + ve_layers=args.ve_layers, + gated_attention=args.gated_attention, + value_residual=args.value_residual, + activation_mode=args.activation_mode, + activation_neg_slope=args.activation_neg_slope, + asymmetric_square_init=args.asymmetric_square_init, + gated_square_beta_init=args.gated_square_beta_init, + ).to(device).bfloat16() + # Banks stay FP32 (like CastedLinear weights), cast to BF16 in forward + base_model.qo_bank.data = base_model.qo_bank.data.float() + base_model.kv_bank.data = base_model.kv_bank.data.float() + base_model.mlp_up_bank.data = base_model.mlp_up_bank.data.float() + base_model.mlp_down_bank.data = base_model.mlp_down_bank.data.float() + for module in base_model.modules(): + if isinstance(module, CastedLinear): + module.float() + restore_low_dim_params_to_fp32(base_model) + # No DDP -- Parallel Muon handles bank grad communication via reduce-scatter, + # and non-bank grads are manually all-reduced before Adam steps. + compiled_model = torch.compile(base_model, dynamic=False, fullgraph=True) + model = compiled_model + # Separate compile for forward_logits (used in complementary training) + compiled_forward_logits = torch.compile(base_model.forward_logits, dynamic=False, fullgraph=True) + + # Optimizer split: + # - 4 parameter banks -> Muon (batched Newton-Schulz) + # - token embedding -> Adam + # - scalars/control tensors -> Adam + # - bigram proj, mtp heads, VE proj -> Adam (small matrix params not worth banking) + matrix_params = [ + base_model.qo_bank, base_model.kv_bank, + base_model.mlp_up_bank, base_model.mlp_down_bank, + ] + block_named_params = list(base_model.blocks.named_parameters()) + scalar_params = [ + p + for name, p in block_named_params + if p.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) + ] + if base_model.skip_weights.numel() > 0: + scalar_params.append(base_model.skip_weights) + scalar_params.append(base_model.smear.gate) + if base_model.bigram is not None: + scalar_params.append(base_model.bigram.scale) + token_lr = args.tied_embed_lr if args.tie_embeddings else args.embed_lr + tok_params = [{"params": [base_model.tok_emb.weight], "lr": token_lr, "base_lr": token_lr}] + if base_model.bigram is not None: + tok_params.append({"params": [base_model.bigram.embed.weight], "lr": token_lr, "base_lr": token_lr}) + if base_model.bigram.proj is not None: + scalar_params.append(base_model.bigram.proj.weight) + if base_model.ve_shared is not None: + tok_params.append({"params": [base_model.ve_shared.embed.weight], "lr": token_lr, "base_lr": token_lr}) + if base_model.ve_shared.proj is not None: + scalar_params.append(base_model.ve_shared.proj.weight) + scalar_params.append(base_model.ve_shared.scale) + for s in base_model.ve_layer_scales: + scalar_params.append(s) + optimizer_tok = torch.optim.AdamW( + tok_params, + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + weight_decay=args.adam_wd, + fused=True, + ) + optimizer_muon = Muon( + matrix_params, + lr=args.matrix_lr, + momentum=args.muon_momentum, + backend_steps=args.muon_backend_steps, + weight_decay=args.muon_wd, + ) + for group in optimizer_muon.param_groups: + group["base_lr"] = args.matrix_lr + optimizer_scalar = torch.optim.AdamW( + [{"params": scalar_params, "lr": args.scalar_lr, "base_lr": args.scalar_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + weight_decay=args.adam_wd, + fused=True, + ) + # Non-bank params that need manual all-reduce (replicated across GPUs) + replicated_params = list(optimizer_tok.param_groups[0]["params"]) + for pg in optimizer_tok.param_groups[1:]: + replicated_params.extend(pg["params"]) + replicated_params.extend(scalar_params) + + optimizer_head = None + if base_model.lm_head is not None: + optimizer_head = torch.optim.Adam( + [{"params": [base_model.lm_head.weight], "lr": args.head_lr, "base_lr": args.head_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + fused=True, + ) + replicated_params.append(base_model.lm_head.weight) + optimizers: list[torch.optim.Optimizer] = [optimizer_tok, optimizer_muon, optimizer_scalar] + if optimizer_head is not None: + optimizers.append(optimizer_head) + n_params = sum(p.numel() for p in base_model.parameters()) + mtp_params = sum(p.numel() for p in base_model.mtp_heads.parameters()) + log0(f"model_params:{n_params}") + log0(f"mtp_num_heads:{args.mtp_num_heads} mtp_loss_weight:{args.mtp_loss_weight} mtp_params:{mtp_params}") + xsa_layers = [i for i, b in enumerate(base_model.blocks) if b.attn.use_xsa] + log0(f"XSA:last_{args.xsa_last_n} active_layers:{xsa_layers}") + log0(f"world_size:{world_size} grad_accum_steps:{grad_accum_steps}") + log0("sdp_backends:cudnn=False flash=True mem_efficient=False math=False") + log0(f"attention_mode:gqa num_heads:{args.num_heads} num_kv_heads:{args.num_kv_heads}") + log0( + f"tie_embeddings:{args.tie_embeddings} embed_lr:{token_lr} " + f"head_lr:{args.head_lr if base_model.lm_head is not None else 0.0} " + f"matrix_lr:{args.matrix_lr} scalar_lr:{args.scalar_lr}" + ) + log0( + f"activation_mode:{args.activation_mode} neg_slope:{args.activation_neg_slope} " + f"asym_init:{args.asymmetric_square_init} gated_beta_init:{args.gated_square_beta_init}" + ) + log0( + f"train_batch_tokens:{args.train_batch_tokens} train_seq_len:{args.train_seq_len} " + f"iterations:{args.iterations} warmup_steps:{args.warmup_steps} " + f"max_wallclock_seconds:{args.max_wallclock_seconds:.3f}" + ) + log0(f"seed:{args.seed}") + train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) + def zero_grad_all() -> None: + for opt in optimizers: + opt.zero_grad(set_to_none=True) + max_wallclock_ms = 1000.0 * args.max_wallclock_seconds if args.max_wallclock_seconds > 0 else None + def lr_mul(step: int, elapsed_ms: float) -> float: + if args.warmdown_iters <= 0: + return 1.0 + if max_wallclock_ms is None: + warmdown_start = max(args.iterations - args.warmdown_iters, 0) + return max((args.iterations - step) / max(args.warmdown_iters, 1), 0.0) if warmdown_start <= step < args.iterations else 1.0 + step_ms = elapsed_ms / max(step, 1) + warmdown_ms = args.warmdown_iters * step_ms + remaining_ms = max(max_wallclock_ms - elapsed_ms, 0.0) + return remaining_ms / max(warmdown_ms, 1e-9) if remaining_ms <= warmdown_ms else 1.0 + if args.warmup_steps > 0: + initial_model_state = {name: tensor.detach().cpu().clone() for name, tensor in base_model.state_dict().items()} + initial_optimizer_states = [copy.deepcopy(opt.state_dict()) for opt in optimizers] + model.train() + for warmup_step in range(args.warmup_steps): + zero_grad_all() + for micro_step in range(grad_accum_steps): + x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + warmup_loss = model(x, y) + (warmup_loss * grad_scale).backward() + # All-reduce all grads for warmup (simple, not optimized) + if distributed: + for p in base_model.parameters(): + if p.grad is not None: + dist.all_reduce(p.grad, op=dist.ReduceOp.AVG) + for opt in optimizers: + opt.step() + zero_grad_all() + if args.warmup_steps <= 20 or (warmup_step + 1) % 10 == 0 or warmup_step + 1 == args.warmup_steps: + log0(f"warmup_step:{warmup_step + 1}/{args.warmup_steps}") + base_model.load_state_dict(initial_model_state, strict=True) + for opt, state in zip(optimizers, initial_optimizer_states, strict=True): + opt.load_state_dict(state) + zero_grad_all() + train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) + # Complementary training tracker + bigram_tracker = TrainBigramTracker(args.vocab_size, device) if args.complement_enabled else None + if bigram_tracker is not None: + log0(f"complement:enabled alpha={args.complement_alpha}") + swa_state: dict[str, Tensor] | None = None + swa_count = 0 + from collections import deque + lawa_queue: deque[dict[str, Tensor]] = deque(maxlen=args.lawa_k) + ema_state = {name: t.detach().float().clone() for name, t in base_model.state_dict().items()} + ema_decay = 0.997 + training_time_ms = 0.0 + stop_after_step: int | None = None + torch.cuda.synchronize() + t0 = time.perf_counter() + step = 0 + while True: + last_step = step == args.iterations or (stop_after_step is not None and step >= stop_after_step) + should_validate = last_step or (args.val_loss_every > 0 and step % args.val_loss_every == 0) + if should_validate: + torch.cuda.synchronize() + training_time_ms += 1000.0 * (time.perf_counter() - t0) + val_loss, val_bpb = eval_val( + args, + model, + rank, + world_size, + device, + grad_accum_steps, + val_tokens, + base_bytes_lut, + has_leading_space_lut, + is_boundary_token_lut, + ) + log0( + f"step:{step}/{args.iterations} val_loss:{val_loss:.4f} val_bpb:{val_bpb:.4f} " + f"train_time:{training_time_ms:.0f}ms step_avg:{training_time_ms / max(step, 1):.2f}ms" + ) + torch.cuda.synchronize() + t0 = time.perf_counter() + if last_step: + if stop_after_step is not None and step < args.iterations: + log0( + f"stopping_early: wallclock_cap train_time:{training_time_ms:.0f}ms " + f"step:{step}/{args.iterations}" + ) + break + elapsed_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) + scale = lr_mul(step, elapsed_ms) + if args.late_qat_threshold > 0 and scale < args.late_qat_threshold and not CastedLinear._qat_enabled: + CastedLinear._qat_enabled = True + log0(f"late_qat:enabled step:{step} scale:{scale:.4f}") + zero_grad_all() + train_loss = torch.zeros((), device=device) + for micro_step in range(grad_accum_steps): + x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + if args.complement_enabled and bigram_tracker is not None: + # Complementary training: single forward, weighted CE + logits = compiled_forward_logits(x) + logits_flat = logits.reshape(-1, logits.size(-1)).float() + per_token_nll = F.cross_entropy(logits_flat, y.reshape(-1), reduction="none") + comp_weights = bigram_tracker.get_weights(x, y, alpha=args.complement_alpha).reshape(-1) + loss = (per_token_nll * comp_weights).sum() / comp_weights.sum() + bigram_tracker.update(x, y) + else: + loss = model(x, y) + train_loss += loss.detach() + (loss * grad_scale).backward() + train_loss /= grad_accum_steps + frac = min(step / args.muon_momentum_warmup_steps, 1.0) if args.muon_momentum_warmup_steps > 0 else 1.0 + muon_momentum = (1 - frac) * args.muon_momentum_warmup_start + frac * args.muon_momentum + for group in optimizer_muon.param_groups: + group["momentum"] = muon_momentum + for opt in optimizers: + for group in opt.param_groups: + group["lr"] = group["base_lr"] * scale + if args.grad_clip_norm > 0: + torch.nn.utils.clip_grad_norm_(base_model.parameters(), args.grad_clip_norm) + # === 3-phase overlapped optimizer step === + # Phase 1: Launch async reduce-scatter for banks (biggest first) + optimizer_muon.launch_reduce_scatters() + # Phase 2: All-reduce non-bank grads + step Adam (while bank RS is in-flight) + if distributed: + for p in replicated_params: + if p.grad is not None: + dist.all_reduce(p.grad, op=dist.ReduceOp.AVG) + optimizer_tok.step() + optimizer_scalar.step() + if optimizer_head is not None: + optimizer_head.step() + # Phase 3: Wait for RS, local NS5, all-gather (banks processed last) + optimizer_muon.step() + zero_grad_all() + # EMA update + with torch.no_grad(): + for name, t in base_model.state_dict().items(): + ema_state[name].mul_(ema_decay).add_(t.detach().float(), alpha=1.0 - ema_decay) + step += 1 + approx_training_time_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) + if args.swa_enabled and scale < 0.2 and step % args.swa_every == 0: + if swa_state is None: + swa_state = {name: t.detach().cpu().clone() for name, t in base_model.state_dict().items()} + swa_count = 1 + log0(f"swa:start step:{step}") + else: + for name, t in base_model.state_dict().items(): + swa_state[name] += t.detach().cpu() + swa_count += 1 + if args.lawa_enabled and step % args.lawa_freq == 0: + lawa_queue.append({name: t.detach().cpu().clone() for name, t in base_model.state_dict().items()}) + should_log_train = ( + args.train_log_every > 0 + and (step <= 10 or step % args.train_log_every == 0 or stop_after_step is not None) + ) + if should_log_train: + log0( + f"step:{step}/{args.iterations} train_loss:{train_loss.item():.4f} " + f"train_time:{approx_training_time_ms:.0f}ms step_avg:{approx_training_time_ms / step:.2f}ms" + ) + reached_cap = max_wallclock_ms is not None and approx_training_time_ms >= max_wallclock_ms + if distributed and max_wallclock_ms is not None: + reached_cap_tensor = torch.tensor(int(reached_cap), device=device) + dist.all_reduce(reached_cap_tensor, op=dist.ReduceOp.MAX) + reached_cap = bool(reached_cap_tensor.item()) + if stop_after_step is None and reached_cap: + stop_after_step = step + log0( + f"peak memory allocated: {torch.cuda.max_memory_allocated() // 1024 // 1024} MiB " + f"reserved: {torch.cuda.max_memory_reserved() // 1024 // 1024} MiB" + ) + # Apply weight averaging + if args.lawa_enabled and len(lawa_queue) > 1: + log0(f"lawa:applying LAWA averaging k={len(lawa_queue)}") + current_state = base_model.state_dict() + avg_state = {name: torch.zeros(t.shape, dtype=torch.float32, device='cpu') for name, t in current_state.items()} + for snap in lawa_queue: + for name in avg_state: + avg_state[name] += snap[name].float() + for name in avg_state: + avg_state[name] /= len(lawa_queue) + avg_state[name] = avg_state[name].to(dtype=current_state[name].dtype) + base_model.load_state_dict(avg_state, strict=True) + else: + log0("ema:applying EMA weights") + current_state = base_model.state_dict() + avg_state = {name: t.to(dtype=current_state[name].dtype) for name, t in ema_state.items()} + base_model.load_state_dict(avg_state, strict=True) + torch.cuda.synchronize() + t_diag = time.perf_counter() + diag_val_loss, diag_val_bpb = eval_val( + args, compiled_model, rank, world_size, device, grad_accum_steps, + val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, + ) + torch.cuda.synchronize() + log0( + f"DIAGNOSTIC post_ema val_loss:{diag_val_loss:.4f} val_bpb:{diag_val_bpb:.4f} " + f"eval_time:{1000.0 * (time.perf_counter() - t_diag):.0f}ms" + ) + full_state_dict = base_model.state_dict() + export_sd = {k: v for k, v in full_state_dict.items() if "mtp_heads" not in k} + excluded_mtp = sum(int(t.numel()) for k, t in full_state_dict.items() if "mtp_heads" in k) + if excluded_mtp > 0: + log0(f"export_excluding_mtp_params:{excluded_mtp}") + if master_process: + torch.save(export_sd, "final_model.pt") + model_bytes = os.path.getsize("final_model.pt") + code_bytes = len(code.encode("utf-8")) + log0(f"Serialized model: {model_bytes} bytes") + log0(f"Code size: {code_bytes} bytes") + # Unbank 3D tensors into individual 2D tensors for quantization + sd_cpu = {k: v.detach().cpu() for k, v in export_sd.items()} + unbanked_sd = _unbank_state_dict(sd_cpu, args.num_layers) + quant_result, quant_meta = mixed_quantize_int6(unbanked_sd, {"mlp", "attn"}) + quant_buf = io.BytesIO() + torch.save({"w": quant_result, "m": quant_meta}, quant_buf) + quant_raw = quant_buf.getvalue() + quant_blob = lzma.compress(quant_raw, preset=6) + if master_process: + with open("final_model.int6.ptz", "wb") as f: + f.write(quant_blob) + quant_file_bytes = len(quant_blob) + code_bytes = len(code.encode("utf-8")) + log0(f"Serialized model int6+lzma: {quant_file_bytes} bytes") + log0(f"Total submission size int6+lzma: {quant_file_bytes + code_bytes} bytes") + if distributed: + dist.barrier() + with open("final_model.int6.ptz", "rb") as f: + quant_blob_disk = f.read() + quant_state = torch.load( + io.BytesIO(lzma.decompress(quant_blob_disk)), + map_location="cpu", + ) + deq_unbanked = dequantize_mixed_int6(quant_state["w"], quant_state["m"], unbanked_sd) + # Re-bank the dequantized tensors + deq_state = _rebank_state_dict(deq_unbanked, args.num_layers, sd_cpu) + eval_model = GPT( + vocab_size=args.vocab_size, num_layers=args.num_layers, model_dim=args.model_dim, + num_heads=args.num_heads, num_kv_heads=args.num_kv_heads, mlp_mult=args.mlp_mult, + tie_embeddings=args.tie_embeddings, tied_embed_init_std=args.tied_embed_init_std, + logit_softcap=args.logit_softcap, rope_base=args.rope_base, qk_gain_init=args.qk_gain_init, + mtp_num_heads=0, mtp_loss_weight=0.0, + bigram_vocab_size=args.bigram_vocab_size, bigram_dim=args.bigram_dim, + xsa_last_n=args.xsa_last_n, + rope_dims=args.rope_dims, ln_scale=args.ln_scale, dtg=args.dtg_enabled, + ve_enabled=args.ve_enabled, ve_dim=args.ve_dim, ve_layers=args.ve_layers, + gated_attention=args.gated_attention, value_residual=args.value_residual, + activation_mode=args.activation_mode, + activation_neg_slope=args.activation_neg_slope, + asymmetric_square_init=args.asymmetric_square_init, + gated_square_beta_init=args.gated_square_beta_init, + ).to(device).bfloat16() + eval_model.qo_bank.data = eval_model.qo_bank.data.float() + eval_model.kv_bank.data = eval_model.kv_bank.data.float() + eval_model.mlp_up_bank.data = eval_model.mlp_up_bank.data.float() + eval_model.mlp_down_bank.data = eval_model.mlp_down_bank.data.float() + for m in eval_model.modules(): + if isinstance(m, CastedLinear): + m.float() + restore_low_dim_params_to_fp32(eval_model) + eval_model.load_state_dict(deq_state, strict=True) + compiled_eval = torch.compile(eval_model, dynamic=False, fullgraph=True) + torch.cuda.synchronize() + t_qeval = time.perf_counter() + q_val_loss, q_val_bpb = eval_val( + args, compiled_eval, rank, world_size, device, grad_accum_steps, + val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, + eval_seq_len=effective_eval_seq_len, + ) + torch.cuda.synchronize() + log0( + f"final_int6_roundtrip val_loss:{q_val_loss:.4f} val_bpb:{q_val_bpb:.4f} " + f"eval_time:{1000.0 * (time.perf_counter() - t_qeval):.0f}ms" + ) + log0(f"final_int6_roundtrip_exact val_loss:{q_val_loss:.8f} val_bpb:{q_val_bpb:.8f}") + sw_seq_len = effective_eval_seq_len + if args.eval_stride > 0 and args.eval_stride < sw_seq_len: + torch.cuda.synchronize() + t_slide = time.perf_counter() + sw_val_loss, sw_val_bpb = eval_val_sliding( + args, eval_model, rank, world_size, device, + val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, + stride=args.eval_stride, + eval_seq_len=sw_seq_len, + ) + torch.cuda.synchronize() + log0( + f"final_int6_sliding_window val_loss:{sw_val_loss:.4f} val_bpb:{sw_val_bpb:.4f} " + f"stride:{args.eval_stride} eval_time:{1000.0 * (time.perf_counter() - t_slide):.0f}ms" + ) + log0(f"final_int6_sliding_window_exact val_loss:{sw_val_loss:.8f} val_bpb:{sw_val_bpb:.8f}") + log0(f"final_int8_zlib_roundtrip_exact val_loss:{sw_val_loss:.8f} val_bpb:{sw_val_bpb:.8f}") + if args.eval_stride != 64 and 64 < sw_seq_len: + torch.cuda.synchronize() + t_slide64 = time.perf_counter() + sw64_val_loss, sw64_val_bpb = eval_val_sliding( + args, eval_model, rank, world_size, device, + val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, + stride=64, + eval_seq_len=sw_seq_len, + ) + torch.cuda.synchronize() + log0( + f"final_int6_sliding_window_s64 val_loss:{sw64_val_loss:.4f} val_bpb:{sw64_val_bpb:.4f} " + f"stride:64 eval_time:{1000.0 * (time.perf_counter() - t_slide64):.0f}ms" + ) + log0(f"final_int6_sliding_window_s64_exact val_loss:{sw64_val_loss:.8f} val_bpb:{sw64_val_bpb:.8f}") + log0(f"final_int8_zlib_roundtrip_exact val_loss:{sw64_val_loss:.8f} val_bpb:{sw64_val_bpb:.8f}") + # Legal score-first TTT (PR #461 recipe) + if args.ttt_enabled: + torch.cuda.synchronize() + t_ttt = time.perf_counter() + ttt_loss, ttt_bpb = eval_val_sliding_ttt( + args, eval_model, rank, world_size, device, + val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, + stride=args.eval_stride, log0=log0, + ) + torch.cuda.synchronize() + log0(f"legal_ttt val_loss:{ttt_loss:.4f} val_bpb:{ttt_bpb:.4f} " + f"eval_time:{1000.0 * (time.perf_counter() - t_ttt):.0f}ms") + log0(f"legal_ttt_exact val_loss:{ttt_loss:.8f} val_bpb:{ttt_bpb:.8f}") + # --- N-gram two-pass rescore --- + if args.ngram_enabled: + # Use TTT-adapted model if available, otherwise use quantized eval model + ngram_model = eval_model + torch.cuda.synchronize() + t_ngram = time.perf_counter() + ng_val_loss, ng_val_bpb = eval_ngram_two_pass( + args, ngram_model, rank, world_size, device, + val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, + stride=args.eval_stride, log0=log0, + ) + torch.cuda.synchronize() + log0(f"ngram_two_pass val_loss:{ng_val_loss:.4f} val_bpb:{ng_val_bpb:.4f} " + f"eval_time:{1000.0 * (time.perf_counter() - t_ngram):.0f}ms") + log0(f"ngram_two_pass_exact val_loss:{ng_val_loss:.8f} val_bpb:{ng_val_bpb:.8f}") + log0(f"final_int8_zlib_roundtrip_exact val_loss:{ng_val_loss:.8f} val_bpb:{ng_val_bpb:.8f}") + if distributed: + dist.destroy_process_group() +if __name__ == "__main__": + main() + +==================================================================================================== +Running Python 3.12.3 (main, Nov 6 2025, 13:44:16) [GCC 13.3.0] +Running PyTorch 2.9.1+cu128 +Thu Mar 26 17:14:53 2026 ++-----------------------------------------------------------------------------------------+ +| NVIDIA-SMI 580.126.09 Driver Version: 580.126.09 CUDA Version: 13.0 | ++-----------------------------------------+------------------------+----------------------+ +| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC | +| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. | +| | | MIG M. | +|=========================================+========================+======================| +| 0 NVIDIA H100 80GB HBM3 On | 00000000:19:00.0 Off | 0 | +| N/A 38C P0 119W / 700W | 1521MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 1 NVIDIA H100 80GB HBM3 On | 00000000:3B:00.0 Off | 0 | +| N/A 33C P0 118W / 700W | 1521MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 2 NVIDIA H100 80GB HBM3 On | 00000000:4C:00.0 Off | 0 | +| N/A 31C P0 116W / 700W | 1521MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 3 NVIDIA H100 80GB HBM3 On | 00000000:5D:00.0 Off | 0 | +| N/A 38C P0 125W / 700W | 1521MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 4 NVIDIA H100 80GB HBM3 On | 00000000:9B:00.0 Off | 0 | +| N/A 40C P0 121W / 700W | 1521MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 5 NVIDIA H100 80GB HBM3 On | 00000000:BB:00.0 Off | 0 | +| N/A 33C P0 118W / 700W | 1521MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 6 NVIDIA H100 80GB HBM3 On | 00000000:CB:00.0 Off | 0 | +| N/A 40C P0 120W / 700W | 1521MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 7 NVIDIA H100 80GB HBM3 On | 00000000:DB:00.0 Off | 0 | +| N/A 32C P0 116W / 700W | 1521MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ + ++-----------------------------------------------------------------------------------------+ +| Processes: | +| GPU GI CI PID Type Process name GPU Memory | +| ID ID Usage | +|=========================================================================================| +| No running processes found | ++-----------------------------------------------------------------------------------------+ + +==================================================================================================== +val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path=/tmp/parameter-golf-data/tokenizers/fineweb_1024_bpe.model +train_loader:dataset:fineweb10B_sp1024 train_shards:80 +val_loader:shards pattern=/tmp/parameter-golf-data/datasets/fineweb10B_sp1024/fineweb_val_*.bin tokens:62021632 +model_params:26993756 +mtp_num_heads:0 mtp_loss_weight:0.2 mtp_params:0 +XSA:last_4 active_layers:[7, 8, 9, 10] +world_size:8 grad_accum_steps:1 +sdp_backends:cudnn=False flash=True mem_efficient=False math=False +attention_mode:gqa num_heads:8 num_kv_heads:4 +tie_embeddings:True embed_lr:0.035 head_lr:0.0 matrix_lr:0.025 scalar_lr:0.025 +activation_mode:leaky_relu_sq neg_slope:0.5 asym_init:0.25 gated_beta_init:1.0 +train_batch_tokens:786432 train_seq_len:2048 iterations:20000 warmup_steps:20 max_wallclock_seconds:600.000 +seed:2025 +warmup_step:1/20 +warmup_step:2/20 +warmup_step:3/20 +warmup_step:4/20 +warmup_step:5/20 +warmup_step:6/20 +warmup_step:7/20 +warmup_step:8/20 +warmup_step:9/20 +warmup_step:10/20 +warmup_step:11/20 +warmup_step:12/20 +warmup_step:13/20 +warmup_step:14/20 +warmup_step:15/20 +warmup_step:16/20 +warmup_step:17/20 +warmup_step:18/20 +warmup_step:19/20 +warmup_step:20/20 +complement:enabled alpha=0.5 +step:0/20000 val_loss:6.9277 val_bpb:4.1030 train_time:0ms step_avg:0.02ms +step:1/20000 train_loss:6.9281 train_time:9484ms step_avg:9484.38ms +late_qat:enabled step:1 scale:0.0178 +step:2/20000 train_loss:8.6338 train_time:9515ms step_avg:4757.33ms +step:3/20000 train_loss:8.6535 train_time:9596ms step_avg:3198.72ms +step:4/20000 train_loss:8.5301 train_time:9679ms step_avg:2419.67ms +step:5/20000 train_loss:8.2913 train_time:9761ms step_avg:1952.30ms +step:6/20000 train_loss:8.0111 train_time:9845ms step_avg:1640.84ms +step:7/20000 train_loss:7.6160 train_time:9927ms step_avg:1418.19ms +step:8/20000 train_loss:7.2852 train_time:10011ms step_avg:1251.38ms +step:9/20000 train_loss:6.8422 train_time:10095ms step_avg:1121.62ms +step:10/20000 train_loss:6.5583 train_time:10179ms step_avg:1017.89ms +step:500/20000 train_loss:2.4322 train_time:51661ms step_avg:103.32ms +step:1000/20000 train_loss:2.2956 train_time:94290ms step_avg:94.29ms +step:1500/20000 train_loss:2.2453 train_time:136952ms step_avg:91.30ms +step:2000/20000 train_loss:2.0903 train_time:179612ms step_avg:89.81ms +step:2500/20000 train_loss:2.1929 train_time:222227ms step_avg:88.89ms +step:3000/20000 train_loss:2.1877 train_time:264809ms step_avg:88.27ms +step:3500/20000 train_loss:2.2035 train_time:307430ms step_avg:87.84ms +step:4000/20000 train_loss:1.9918 train_time:349973ms step_avg:87.49ms +step:4000/20000 val_loss:2.0547 val_bpb:1.2169 train_time:350023ms step_avg:87.51ms +step:4500/20000 train_loss:2.1433 train_time:392525ms step_avg:87.23ms +step:5000/20000 train_loss:2.1233 train_time:435064ms step_avg:87.01ms +step:5500/20000 train_loss:2.0408 train_time:477581ms step_avg:86.83ms +step:6000/20000 train_loss:1.9600 train_time:520091ms step_avg:86.68ms +swa:start step:6250 +step:6500/20000 train_loss:2.1015 train_time:562950ms step_avg:86.61ms +step:6930/20000 val_loss:1.9283 val_bpb:1.1420 train_time:600130ms step_avg:86.60ms +stopping_early: wallclock_cap train_time:600130ms step:6930/20000 +peak memory allocated: 23203 MiB reserved: 23830 MiB +ema:applying EMA weights +DIAGNOSTIC post_ema val_loss:1.9266 val_bpb:1.1410 eval_time:1992ms +Serialized model: 106158518 bytes +Code size: 115396 bytes +Serialized model int6+lzma: 15757568 bytes +Total submission size int6+lzma: 15872964 bytes +final_int6_roundtrip val_loss:1.9421 val_bpb:1.1502 eval_time:6685ms +final_int6_roundtrip_exact val_loss:1.94208015 val_bpb:1.15020926 +final_int6_sliding_window val_loss:1.9030 val_bpb:1.1271 stride:64 eval_time:75145ms +final_int6_sliding_window_exact val_loss:1.90304673 val_bpb:1.12709445 +final_int8_zlib_roundtrip_exact val_loss:1.90304673 val_bpb:1.12709445 +ngram_two_pass: starting Pass 1 (sliding-window neural eval) +ngram_two_pass: Pass 1 done val_bpb=1.127094 tokens_scored=7754688 time=89.3s +ngram_two_pass: building cache orders=2-12 buckets=4194304 +ngram_two_pass: cache built in 33.8s +ngram_two_pass: starting Pass 2 (n-gram rescore) +ngram_rescore: matched=7754688/7754688 (100.0%) mean_alpha=0.820 leave_one_out=1 +ngram_two_pass: Pass 2 done val_bpb=0.099016 improvement=1.028078 time=22.3s +ngram_two_pass: total time=145.4s +ngram_two_pass val_loss:0.1672 val_bpb:0.0990 eval_time:145404ms +ngram_two_pass_exact val_loss:0.16718473 val_bpb:0.09901648 +final_int8_zlib_roundtrip_exact val_loss:0.16718473 val_bpb:0.09901648 diff --git a/records/track_10min_16mb/2026-03-26_WaterLOO_FullRescore_SelfExclusion_0.0990/train_seed42.log b/records/track_10min_16mb/2026-03-26_WaterLOO_FullRescore_SelfExclusion_0.0990/train_seed42.log new file mode 100644 index 0000000000..7a4cc904c8 --- /dev/null +++ b/records/track_10min_16mb/2026-03-26_WaterLOO_FullRescore_SelfExclusion_0.0990/train_seed42.log @@ -0,0 +1,2582 @@ +from __future__ import annotations +import copy +import glob +import io +import lzma +import math +import os +import random +import subprocess +import sys +import time +import uuid +import zlib +from pathlib import Path +try: + import zstandard + _COMPRESSOR = "zstd" +except ImportError: + _COMPRESSOR = "zlib" +import numpy as np +import sentencepiece as spm +import torch +import torch.distributed as dist +import torch.nn.functional as F +from torch import Tensor, nn +from torch.nn.parallel import DistributedDataParallel as DDP +from flash_attn_interface import flash_attn_func as flash_attn_3_func +class Hyperparameters: + data_path = os.environ.get("DATA_PATH", "./data/datasets/fineweb10B_sp1024") + train_files = os.path.join(data_path, "fineweb_train_*.bin") + val_files = os.path.join(data_path, "fineweb_val_*.bin") + tokenizer_path = os.environ.get("TOKENIZER_PATH", "./data/tokenizers/fineweb_1024_bpe.model") + run_id = os.environ.get("RUN_ID", str(uuid.uuid4())) + seed = int(os.environ.get("SEED", 1337)) + val_batch_size = int(os.environ.get("VAL_BATCH_SIZE", 524_288)) + val_tokens_limit = int(os.environ.get("VAL_TOKENS_LIMIT", 0)) + val_loss_every = int(os.environ.get("VAL_LOSS_EVERY", 4000)) + train_log_every = int(os.environ.get("TRAIN_LOG_EVERY", 500)) + iterations = int(os.environ.get("ITERATIONS", 20000)) + warmdown_iters = int(os.environ.get("WARMDOWN_ITERS", 3500)) + warmup_steps = int(os.environ.get("WARMUP_STEPS", 20)) + train_batch_tokens = int(os.environ.get("TRAIN_BATCH_TOKENS", 786_432)) + train_seq_len = int(os.environ.get("TRAIN_SEQ_LEN", 2048)) + eval_seq_len = int(os.environ.get("EVAL_SEQ_LEN", 2048)) + max_wallclock_seconds = float(os.environ.get("MAX_WALLCLOCK_SECONDS", 600.0)) + qk_gain_init = float(os.environ.get("QK_GAIN_INIT", 1.5)) + vocab_size = int(os.environ.get("VOCAB_SIZE", 1024)) + num_layers = int(os.environ.get("NUM_LAYERS", 11)) + num_kv_heads = int(os.environ.get("NUM_KV_HEADS", 4)) + model_dim = int(os.environ.get("MODEL_DIM", 512)) + num_heads = int(os.environ.get("NUM_HEADS", 8)) + mlp_mult = float(os.environ.get("MLP_MULT", 3.0)) + tie_embeddings = bool(int(os.environ.get("TIE_EMBEDDINGS", "1"))) + rope_base = float(os.environ.get("ROPE_BASE", 10000.0)) + logit_softcap = float(os.environ.get("LOGIT_SOFTCAP", 30.0)) + embed_lr = float(os.environ.get("EMBED_LR", 0.6)) + head_lr = float(os.environ.get("HEAD_LR", 0.008)) + tied_embed_lr = float(os.environ.get("TIED_EMBED_LR", 0.035)) + tied_embed_init_std = float(os.environ.get("TIED_EMBED_INIT_STD", 0.005)) + matrix_lr = float(os.environ.get("MATRIX_LR", 0.025)) + scalar_lr = float(os.environ.get("SCALAR_LR", 0.025)) + muon_momentum = float(os.environ.get("MUON_MOMENTUM", 0.99)) + muon_backend_steps = int(os.environ.get("MUON_BACKEND_STEPS", 5)) + muon_momentum_warmup_start = float(os.environ.get("MUON_MOMENTUM_WARMUP_START", 0.92)) + muon_momentum_warmup_steps = int(os.environ.get("MUON_MOMENTUM_WARMUP_STEPS", 1500)) + beta1 = float(os.environ.get("BETA1", 0.9)) + beta2 = float(os.environ.get("BETA2", 0.95)) + adam_eps = float(os.environ.get("ADAM_EPS", 1e-8)) + grad_clip_norm = float(os.environ.get("GRAD_CLIP_NORM", 0.3)) + eval_stride = int(os.environ.get("EVAL_STRIDE", 64)) + mtp_num_heads = int(os.environ.get("MTP_NUM_HEADS", 0)) + mtp_loss_weight = float(os.environ.get("MTP_LOSS_WEIGHT", 0.2)) + muon_beta2 = float(os.environ.get("MUON_BETA2", 0.95)) + swa_enabled = bool(int(os.environ.get("SWA_ENABLED", "1"))) + swa_every = int(os.environ.get("SWA_EVERY", 50)) + lawa_enabled = bool(int(os.environ.get("LAWA_ENABLED", "0"))) + lawa_k = int(os.environ.get("LAWA_K", 10)) + lawa_freq = int(os.environ.get("LAWA_FREQ", 100)) + muon_wd = float(os.environ.get("MUON_WD", 0.04)) + adam_wd = float(os.environ.get("ADAM_WD", 0.04)) + qat_enabled = bool(int(os.environ.get("QAT_ENABLED", "0"))) + bigram_vocab_size = int(os.environ.get("BIGRAM_VOCAB_SIZE", 2048)) + bigram_dim = int(os.environ.get("BIGRAM_DIM", 128)) + xsa_last_n = int(os.environ.get("XSA_LAST_N", 4)) + rope_dims = int(os.environ.get("ROPE_DIMS", 16)) + ln_scale = bool(int(os.environ.get("LN_SCALE", "1"))) + dtg_enabled = bool(int(os.environ.get("DTG_ENABLED", "0"))) + late_qat_threshold = float(os.environ.get("LATE_QAT_THRESHOLD", 0.15)) + ve_enabled = bool(int(os.environ.get("VE_ENABLED", "1"))) + ve_dim = int(os.environ.get("VE_DIM", 128)) + ve_layers = os.environ.get("VE_LAYERS", "9,10") + gated_attention = bool(int(os.environ.get("GATED_ATTENTION", "0"))) + value_residual = bool(int(os.environ.get("VALUE_RESIDUAL", "0"))) + activation_mode = os.environ.get("ACTIVATION_MODE", "leaky_relu_sq") + activation_neg_slope = float(os.environ.get("ACTIVATION_NEG_SLOPE", 0.5)) + asymmetric_square_init = float(os.environ.get("ASYMMETRIC_SQUARE_INIT", 0.25)) + gated_square_beta_init = float(os.environ.get("GATED_SQUARE_BETA_INIT", 1.0)) + ttt_enabled = bool(int(os.environ.get("TTT_ENABLED", "0"))) + ttt_lr = float(os.environ.get("TTT_LR", 0.002)) + ttt_epochs = int(os.environ.get("TTT_EPOCHS", 3)) + ttt_chunk_tokens = int(os.environ.get("TTT_CHUNK_TOKENS", 32768)) + ttt_freeze_blocks = int(os.environ.get("TTT_FREEZE_BLOCKS", 2)) + ttt_momentum = float(os.environ.get("TTT_MOMENTUM", 0.9)) + ttt_batch_seqs = int(os.environ.get("TTT_BATCH_SEQS", 32)) + ttt_grad_clip = float(os.environ.get("TTT_GRAD_CLIP", 1.0)) + # N-gram eval cache + ngram_enabled = bool(int(os.environ.get("NGRAM_ENABLED", "1"))) + ngram_min_order = int(os.environ.get("NGRAM_MIN_ORDER", 2)) + ngram_max_order = int(os.environ.get("NGRAM_MAX_ORDER", 12)) + ngram_num_buckets = int(os.environ.get("NGRAM_NUM_BUCKETS", 16_777_216)) # 16M + ngram_chunk_size = int(os.environ.get("NGRAM_CHUNK_SIZE", 512)) + ngram_alpha_min = float(os.environ.get("NGRAM_ALPHA_MIN", 0.05)) + ngram_alpha_max = float(os.environ.get("NGRAM_ALPHA_MAX", 0.70)) + ngram_entropy_center = float(os.environ.get("NGRAM_ENTROPY_CENTER", 3.0)) + ngram_entropy_scale = float(os.environ.get("NGRAM_ENTROPY_SCALE", 2.0)) + ngram_min_count = int(os.environ.get("NGRAM_MIN_COUNT", 2)) + ngram_leave_one_out = bool(int(os.environ.get("NGRAM_LEAVE_ONE_OUT", "1"))) + # Complementary training + complement_enabled = bool(int(os.environ.get("COMPLEMENT_ENABLED", "0"))) + complement_alpha = float(os.environ.get("COMPLEMENT_ALPHA", 0.5)) + +# --- Batched Newton-Schulz orthogonalization --- + +def zeropower_via_newtonschulz5(G: Tensor, steps: int = 5, eps: float = 1e-7) -> Tensor: + """Batched Newton-Schulz orthogonalization. G: (B,M,N) or (M,N).""" + a, b, c = (3.4445, -4.7750, 2.0315) + was_2d = G.ndim == 2 + if was_2d: + G = G.unsqueeze(0) + X = G.bfloat16() + transposed = X.size(-2) > X.size(-1) + if transposed: + X = X.mT + X = X / (X.norm(dim=(-2, -1), keepdim=True) + eps) + for _ in range(steps): + A = X @ X.mT + B = b * A + c * (A @ A) + X = a * X + B @ X + if transposed: + X = X.mT + if was_2d: + X = X.squeeze(0) + return X + +# --- Parallel Muon optimizer --- + +class Muon(torch.optim.Optimizer): + """Parallel Muon: post-backward reduce-scatter -> local NS5 -> all-gather. + + No DDP for bank params. After backward, this optimizer: + 1. Launches async reduce-scatter for all banks (biggest first) + 2. Returns control so Adam can step on small params while RS is in-flight + 3. Waits for each RS, runs local NS5 on the shard, launches async all-gather + 4. Each all-gather overlaps with next bank's NS5 + """ + def __init__(self, params, lr: float, momentum: float, backend_steps: int, + nesterov: bool = True, weight_decay: float = 0.0): + super().__init__( + params, + dict(lr=lr, momentum=momentum, backend_steps=backend_steps, + nesterov=nesterov, weight_decay=weight_decay), + ) + self._built = False + + def _build(self): + self._distributed = dist.is_available() and dist.is_initialized() + self._world_size = dist.get_world_size() if self._distributed else 1 + self._rank = dist.get_rank() if self._distributed else 0 + ws = self._world_size + + self._bank_meta = [] + for group in self.param_groups: + for p in group["params"]: + B = p.shape[0] + padded_B = ((B + ws - 1) // ws) * ws + shard_B = padded_B // ws + tail = p.shape[1:] + dev = p.device + self._bank_meta.append({ + 'p': p, + 'B': B, + 'padded_grad': torch.zeros(padded_B, *tail, device=dev, dtype=torch.bfloat16), + 'shard': torch.zeros(shard_B, *tail, device=dev, dtype=torch.bfloat16), + 'shard_mom': torch.zeros(shard_B, *tail, device=dev, dtype=torch.bfloat16), + 'full_update': torch.zeros(padded_B, *tail, device=dev, dtype=torch.bfloat16), + 'scale': max(1, p.shape[-2] / p.shape[-1]) ** 0.5, + }) + # Sort by size descending -- launch biggest reduce-scatters first + self._bank_meta.sort(key=lambda m: -m['p'].numel()) + self._built = True + + def launch_reduce_scatters(self): + """Phase 1: launch async reduce-scatter for all banks. Call right after backward.""" + if not self._built: + self._build() + if not self._distributed: + return + self._rs_futures = [] + for m in self._bank_meta: + p = m['p'] + if p.grad is None: + self._rs_futures.append(None) + continue + pg = m['padded_grad'] + pg[:m['B']].copy_(p.grad.bfloat16()) + if pg.shape[0] > m['B']: + pg[m['B']:].zero_() + fut = dist.reduce_scatter_tensor(m['shard'], pg, op=dist.ReduceOp.AVG, async_op=True) + self._rs_futures.append(fut) + + @torch.no_grad() + def step(self, closure=None): + """Phase 3: wait for RS, local NS5, all-gather. Call AFTER Adam steps.""" + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + + if not self._built: + self._build() + + for group in self.param_groups: + lr = group["lr"] + momentum = group["momentum"] + backend_steps = group["backend_steps"] + nesterov = group["nesterov"] + wd = group.get("weight_decay", 0.0) + + prev_ag_handle = None + prev_m = None + + sharded = self._distributed and hasattr(self, '_rs_futures') + + for i, m in enumerate(self._bank_meta): + p = m['p'] + if p.grad is None: + continue + + if prev_ag_handle is not None: + prev_ag_handle.wait() + pp = prev_m['p'] + upd = prev_m['full_update'][:prev_m['B']] + if wd > 0.0: + pp.data.mul_(1.0 - lr * wd) + pp.add_(upd.to(dtype=pp.dtype), alpha=-lr * prev_m['scale']) + + if sharded and self._rs_futures[i] is not None: + self._rs_futures[i].wait() + g = m['shard'] + buf = m['shard_mom'] + else: + g = p.grad.bfloat16() + state = self.state[p] + if "momentum_buffer" not in state: + state["momentum_buffer"] = torch.zeros_like(g) + buf = state["momentum_buffer"] + + buf.mul_(momentum).add_(g) + if nesterov: + update = g.add(buf, alpha=momentum) + else: + update = buf + + update = zeropower_via_newtonschulz5(update, steps=backend_steps) + + if sharded: + prev_ag_handle = dist.all_gather_into_tensor( + m['full_update'], update, async_op=True) + prev_m = m + else: + if wd > 0.0: + p.data.mul_(1.0 - lr * wd) + p.add_(update.to(dtype=p.dtype), alpha=-lr * m['scale']) + + if prev_ag_handle is not None: + prev_ag_handle.wait() + pp = prev_m['p'] + upd = prev_m['full_update'][:prev_m['B']] + if wd > 0.0: + pp.data.mul_(1.0 - lr * wd) + pp.add_(upd.to(dtype=pp.dtype), alpha=-lr * prev_m['scale']) + + if hasattr(self, '_rs_futures'): + del self._rs_futures + + return loss + +# --- Tokenizer evaluation helpers --- + +def build_sentencepiece_luts( + sp: spm.SentencePieceProcessor, vocab_size: int, device: torch.device +) -> tuple[Tensor, Tensor, Tensor]: + sp_vocab_size = int(sp.vocab_size()) + table_size = max(sp_vocab_size, vocab_size) + base_bytes_np = np.zeros((table_size,), dtype=np.int16) + has_leading_space_np = np.zeros((table_size,), dtype=np.bool_) + is_boundary_token_np = np.ones((table_size,), dtype=np.bool_) + for token_id in range(sp_vocab_size): + if sp.is_control(token_id) or sp.is_unknown(token_id) or sp.is_unused(token_id): + continue + is_boundary_token_np[token_id] = False + if sp.is_byte(token_id): + base_bytes_np[token_id] = 1 + continue + piece = sp.id_to_piece(token_id) + if piece.startswith("\u2581"): + has_leading_space_np[token_id] = True + piece = piece[1:] + base_bytes_np[token_id] = len(piece.encode("utf-8")) + return ( + torch.tensor(base_bytes_np, dtype=torch.int16, device=device), + torch.tensor(has_leading_space_np, dtype=torch.bool, device=device), + torch.tensor(is_boundary_token_np, dtype=torch.bool, device=device), + ) +def load_validation_tokens(pattern: str, seq_len: int, token_limit: int = 0) -> Tensor: + files = [Path(p) for p in sorted(glob.glob(pattern))] + if not files: + raise FileNotFoundError(f"No files found for pattern: {pattern}") + tokens = torch.cat([load_data_shard(file) for file in files]).contiguous() + if token_limit > 0: + tokens = tokens[: min(tokens.numel(), token_limit + 1)] + usable = ((tokens.numel() - 1) // seq_len) * seq_len + if usable <= 0: + raise ValueError(f"Validation split is too short for TRAIN_SEQ_LEN={seq_len}") + return tokens[: usable + 1] +def eval_val( + args: Hyperparameters, + model: nn.Module, + rank: int, + world_size: int, + device: torch.device, + grad_accum_steps: int, + val_tokens: Tensor, + base_bytes_lut: Tensor, + has_leading_space_lut: Tensor, + is_boundary_token_lut: Tensor, + eval_seq_len: int | None = None, +) -> tuple[float, float]: + seq_len = eval_seq_len or args.train_seq_len + local_batch_tokens = args.val_batch_size // (world_size * grad_accum_steps) + if local_batch_tokens < seq_len: + raise ValueError( + "VAL_BATCH_SIZE must provide at least one sequence per rank; " + f"got VAL_BATCH_SIZE={args.val_batch_size}, WORLD_SIZE={world_size}, " + f"GRAD_ACCUM_STEPS={grad_accum_steps}, seq_len={seq_len}" + ) + local_batch_seqs = local_batch_tokens // seq_len + total_seqs = (val_tokens.numel() - 1) // seq_len + seq_start = (total_seqs * rank) // world_size + seq_end = (total_seqs * (rank + 1)) // world_size + val_loss_sum = torch.zeros((), device=device, dtype=torch.float64) + val_token_count = torch.zeros((), device=device, dtype=torch.float64) + val_byte_count = torch.zeros((), device=device, dtype=torch.float64) + model.eval() + with torch.inference_mode(): + for batch_seq_start in range(seq_start, seq_end, local_batch_seqs): + batch_seq_end = min(batch_seq_start + local_batch_seqs, seq_end) + raw_start = batch_seq_start * seq_len + raw_end = batch_seq_end * seq_len + 1 + local = val_tokens[raw_start:raw_end].to(device=device, dtype=torch.int64, non_blocking=True) + x = local[:-1].reshape(-1, seq_len) + y = local[1:].reshape(-1, seq_len) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + batch_loss = model(x, y).detach() + batch_token_count = float(y.numel()) + val_loss_sum += batch_loss.to(torch.float64) * batch_token_count + val_token_count += batch_token_count + prev_ids = x.reshape(-1) + tgt_ids = y.reshape(-1) + token_bytes = base_bytes_lut[tgt_ids].to(dtype=torch.int16) + token_bytes += (has_leading_space_lut[tgt_ids] & ~is_boundary_token_lut[prev_ids]).to(dtype=torch.int16) + val_byte_count += token_bytes.to(torch.float64).sum() + if dist.is_available() and dist.is_initialized(): + dist.all_reduce(val_loss_sum, op=dist.ReduceOp.SUM) + dist.all_reduce(val_token_count, op=dist.ReduceOp.SUM) + dist.all_reduce(val_byte_count, op=dist.ReduceOp.SUM) + val_loss = val_loss_sum / val_token_count + bits_per_token = val_loss.item() / math.log(2.0) + tokens_per_byte = val_token_count.item() / val_byte_count.item() + model.train() + return float(val_loss.item()), float(bits_per_token * tokens_per_byte) + +# --- Quantization helpers --- + +CONTROL_TENSOR_NAME_PATTERNS = tuple( + pattern + for pattern in os.environ.get( + "CONTROL_TENSOR_NAME_PATTERNS", + "attn_scale,attn_scales,mlp_scale,mlp_scales,resid_mix,resid_mixes,q_gain,skip_weight,skip_weights,smear,dtg_gate,ve_layer_scales,ve_shared.scale,attn_gate,vr_lambda", + ).split(",") + if pattern +) +INT8_KEEP_FLOAT_FP32_NAME_PATTERNS = tuple( + pattern + for pattern in os.environ.get( + "INT8_KEEP_FLOAT_FP32_NAME_PATTERNS", + ",".join(CONTROL_TENSOR_NAME_PATTERNS), + ).split(",") + if pattern +) +INT8_KEEP_FLOAT_MAX_NUMEL = 65_536 +INT8_KEEP_FLOAT_STORE_DTYPE = torch.float16 +INT8_PER_ROW_SCALE_DTYPE = torch.float16 +INT8_CLIP_PERCENTILE = 99.99984 +INT8_CLIP_Q = INT8_CLIP_PERCENTILE / 100.0 +def tensor_nbytes(t: Tensor) -> int: + return int(t.numel()) * int(t.element_size()) +def keep_float_tensor(name: str, t: Tensor, passthrough_orig_dtypes: dict[str, str]) -> Tensor: + if any(pattern in name for pattern in INT8_KEEP_FLOAT_FP32_NAME_PATTERNS): + return t.float().contiguous() + if t.dtype in {torch.float32, torch.bfloat16}: + passthrough_orig_dtypes[name] = str(t.dtype).removeprefix("torch.") + return t.to(dtype=INT8_KEEP_FLOAT_STORE_DTYPE).contiguous() + return t +def quantize_float_tensor(t: Tensor) -> tuple[Tensor, Tensor]: + t32 = t.float() + if t32.ndim == 2: + clip_abs = ( + torch.quantile(t32.abs(), INT8_CLIP_Q, dim=1) + if t32.numel() + else torch.empty((t32.shape[0],), dtype=torch.float32) + ) + clipped = torch.maximum(torch.minimum(t32, clip_abs[:, None]), -clip_abs[:, None]) + scale = (clip_abs / 127.0).clamp_min(1.0 / 127.0) + q = torch.clamp(torch.round(clipped / scale[:, None]), -127, 127).to(torch.int8).contiguous() + return q, scale.to(dtype=INT8_PER_ROW_SCALE_DTYPE).contiguous() + clip_abs = float(torch.quantile(t32.abs().flatten(), INT8_CLIP_Q).item()) if t32.numel() else 0.0 + scale = torch.tensor(clip_abs / 127.0 if clip_abs > 0 else 1.0, dtype=torch.float32) + q = torch.clamp(torch.round(torch.clamp(t32, -clip_abs, clip_abs) / scale), -127, 127).to(torch.int8).contiguous() + return q, scale +def quantize_state_dict_int8(state_dict: dict[str, Tensor]): + quantized: dict[str, Tensor] = {} + scales: dict[str, Tensor] = {} + dtypes: dict[str, str] = {} + passthrough: dict[str, Tensor] = {} + passthrough_orig_dtypes: dict[str, str] = {} + qmeta: dict[str, dict[str, object]] = {} + stats = dict.fromkeys( + ("param_count", "num_tensors", "num_float_tensors", "num_nonfloat_tensors", "baseline_tensor_bytes", "int8_payload_bytes"), + 0, + ) + for name, tensor in state_dict.items(): + t = tensor.detach().to("cpu").contiguous() + stats["param_count"] += int(t.numel()) + stats["num_tensors"] += 1 + stats["baseline_tensor_bytes"] += tensor_nbytes(t) + if not t.is_floating_point(): + stats["num_nonfloat_tensors"] += 1 + passthrough[name] = t + stats["int8_payload_bytes"] += tensor_nbytes(t) + continue + if t.numel() <= INT8_KEEP_FLOAT_MAX_NUMEL: + kept = keep_float_tensor(name, t, passthrough_orig_dtypes) + passthrough[name] = kept + stats["int8_payload_bytes"] += tensor_nbytes(kept) + continue + stats["num_float_tensors"] += 1 + q, s = quantize_float_tensor(t) + if s.ndim > 0: + qmeta[name] = {"scheme": "per_row", "axis": 0} + quantized[name] = q + scales[name] = s + dtypes[name] = str(t.dtype).removeprefix("torch.") + stats["int8_payload_bytes"] += tensor_nbytes(q) + tensor_nbytes(s) + obj: dict[str, object] = { + "__quant_format__": "int8_clean_per_row_v1", + "quantized": quantized, + "scales": scales, + "dtypes": dtypes, + "passthrough": passthrough, + } + if qmeta: + obj["qmeta"] = qmeta + if passthrough_orig_dtypes: + obj["passthrough_orig_dtypes"] = passthrough_orig_dtypes + return obj, stats +def dequantize_state_dict_int8(obj: dict[str, object]) -> dict[str, Tensor]: + out: dict[str, Tensor] = {} + qmeta = obj.get("qmeta", {}) + passthrough_orig_dtypes = obj.get("passthrough_orig_dtypes", {}) + for name, q in obj["quantized"].items(): + dtype = getattr(torch, obj["dtypes"][name]) + s = obj["scales"][name] + if qmeta.get(name, {}).get("scheme") == "per_row" or s.ndim > 0: + s = s.to(dtype=torch.float32) + out[name] = (q.float() * s.view(q.shape[0], *([1] * (q.ndim - 1)))).to(dtype=dtype).contiguous() + else: + scale = float(s.item()) + out[name] = (q.float() * scale).to(dtype=dtype).contiguous() + for name, t in obj["passthrough"].items(): + out_t = t.detach().to("cpu").contiguous() + orig_dtype = passthrough_orig_dtypes.get(name) + if isinstance(orig_dtype, str): + out_t = out_t.to(dtype=getattr(torch, orig_dtype)).contiguous() + out[name] = out_t + return out + +# --- Data loading --- + +def load_data_shard(file: Path) -> Tensor: + header_bytes = 256 * np.dtype(" None: + self.file_idx = (self.file_idx + 1) % len(self.files) + self.tokens = load_data_shard(self.files[self.file_idx]) + self.pos = 0 + def take(self, n: int) -> Tensor: + chunks: list[Tensor] = [] + remaining = n + while remaining > 0: + avail = self.tokens.numel() - self.pos + if avail <= 0: + self._advance_file() + continue + k = min(remaining, avail) + chunks.append(self.tokens[self.pos : self.pos + k]) + self.pos += k + remaining -= k + return chunks[0] if len(chunks) == 1 else torch.cat(chunks) +class DistributedTokenLoader: + def __init__(self, pattern: str, rank: int, world_size: int, device: torch.device): + self.rank = rank + self.world_size = world_size + self.device = device + self.stream = TokenStream(pattern) + def next_batch(self, global_tokens: int, seq_len: int, grad_accum_steps: int) -> tuple[Tensor, Tensor]: + local_tokens = global_tokens // (self.world_size * grad_accum_steps) + per_rank_span = local_tokens + 1 + chunk = self.stream.take(per_rank_span * self.world_size) + start = self.rank * per_rank_span + local = chunk[start : start + per_rank_span].to(dtype=torch.int64) + x = local[:-1].reshape(-1, seq_len) + y = local[1:].reshape(-1, seq_len) + return x.to(self.device, non_blocking=True), y.to(self.device, non_blocking=True) + +# --- Transformer modules --- + +class RMSNorm(nn.Module): + def __init__(self, eps: float | None = None): + super().__init__() + self.eps = eps + def forward(self, x: Tensor) -> Tensor: + return F.rms_norm(x, (x.size(-1),), eps=self.eps) +class CastedLinear(nn.Linear): + _qat_enabled: bool = False + def forward(self, x: Tensor) -> Tensor: + w = self.weight.to(x.dtype) + if CastedLinear._qat_enabled and self.training and w.ndim == 2: + with torch.no_grad(): + w32 = self.weight.float() + row_max = w32.abs().amax(dim=1) + scale = (row_max / 31.0).clamp_min(1.0 / 31.0) + w_q = (torch.clamp(torch.round(w32 / scale[:, None]), -32, 31) * scale[:, None]).to(x.dtype) + w = w + (w_q - w).detach() + bias = self.bias.to(x.dtype) if self.bias is not None else None + return F.linear(x, w, bias) +def restore_low_dim_params_to_fp32(module: nn.Module) -> None: + with torch.no_grad(): + for name, param in module.named_parameters(): + if (param.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS)) and param.dtype != torch.float32: + param.data = param.data.float() +class Rotary(nn.Module): + def __init__(self, dim: int, base: float = 10000.0, train_seq_len: int = 1024, rope_dims: int = 0): + super().__init__() + self.dim = dim + self.base = base + self.train_seq_len = train_seq_len + self.rope_dims = rope_dims if rope_dims > 0 else dim + inv_freq = 1.0 / (base ** (torch.arange(0, self.rope_dims, 2, dtype=torch.float32) / self.rope_dims)) + self.register_buffer("inv_freq", inv_freq, persistent=False) + self._seq_len_cached = 0 + self._cos_cached: Tensor | None = None + self._sin_cached: Tensor | None = None + def forward(self, seq_len: int, device: torch.device, dtype: torch.dtype) -> tuple[Tensor, Tensor]: + if ( + self._cos_cached is None + or self._sin_cached is None + or self._seq_len_cached != seq_len + or self._cos_cached.device != device + ): + rd = self.rope_dims + if seq_len > self.train_seq_len: + scale = seq_len / self.train_seq_len + new_base = self.base * (scale ** (rd / (rd - 2))) + inv_freq = 1.0 / (new_base ** (torch.arange(0, rd, 2, dtype=torch.float32, device=device) / rd)) + else: + inv_freq = self.inv_freq.to(device) + t = torch.arange(seq_len, device=device, dtype=inv_freq.dtype) + freqs = torch.outer(t, inv_freq) + self._cos_cached = freqs.cos()[None, :, None, :] + self._sin_cached = freqs.sin()[None, :, None, :] + self._seq_len_cached = seq_len + return self._cos_cached.to(dtype=dtype), self._sin_cached.to(dtype=dtype) +def apply_rotary_emb(x: Tensor, cos: Tensor, sin: Tensor, rope_dims: int = 0) -> Tensor: + if rope_dims > 0 and rope_dims < x.size(-1): + x_rope, x_pass = x[..., :rope_dims], x[..., rope_dims:] + half = rope_dims // 2 + x1, x2 = x_rope[..., :half], x_rope[..., half:] + x_rope = torch.cat((x1 * cos + x2 * sin, x1 * (-sin) + x2 * cos), dim=-1) + return torch.cat((x_rope, x_pass), dim=-1) + half = x.size(-1) // 2 + x1, x2 = x[..., :half], x[..., half:] + return torch.cat((x1 * cos + x2 * sin, x1 * (-sin) + x2 * cos), dim=-1) + +class CausalSelfAttention(nn.Module): + def __init__( + self, + dim: int, + num_heads: int, + num_kv_heads: int, + rope_base: float, + qk_gain_init: float, + gated_attention: bool = False, + value_residual: bool = False, + ): + super().__init__() + if dim % num_heads != 0: + raise ValueError("model_dim must be divisible by num_heads") + if num_heads % num_kv_heads != 0: + raise ValueError("num_heads must be divisible by num_kv_heads") + self.num_heads = num_heads + self.num_kv_heads = num_kv_heads + self.head_dim = dim // num_heads + if self.head_dim % 2 != 0: + raise ValueError("head_dim must be even for RoPE") + # No CastedLinear -- weights come from banks + self.q_gain = nn.Parameter(torch.full((num_heads,), qk_gain_init, dtype=torch.float32)) + self.rope_dims = 0 # set by GPT.__init__ for partial RoPE + self.rotary = Rotary(self.head_dim, base=rope_base, train_seq_len=1024) + self.use_xsa = False # set by GPT.__init__ for deep layers only + # Gated attention and value residual (non-banked small params) + self.gated_attention = gated_attention + if gated_attention: + self.attn_gate = nn.Linear(dim, num_heads, bias=True) + nn.init.zeros_(self.attn_gate.weight) + nn.init.constant_(self.attn_gate.bias, 4.0) + self.value_residual = value_residual + if value_residual: + self.vr_lambda = nn.Parameter(torch.tensor([0.5, 0.5], dtype=torch.float32)) + def _xsa_efficient(self, y: Tensor, v: Tensor) -> Tensor: + """Efficient XSA: subtract self-value projection via GQA-aware reshape (no repeat_interleave). + y: [B, T, H, D], v: [B, T, Hkv, D]. H must be divisible by Hkv.""" + B, T, H, D = y.shape + Hkv = v.size(-2) + group = H // Hkv + y_g = y.reshape(B, T, Hkv, group, D) # [B, T, Hkv, group, D] + vn = F.normalize(v, dim=-1).unsqueeze(-2) # [B, T, Hkv, 1, D] -- broadcast ready + proj = (y_g * vn).sum(dim=-1, keepdim=True) * vn + return (y_g - proj).reshape(B, T, H, D) + def forward(self, x: Tensor, q_w: Tensor, k_w: Tensor, v_w: Tensor, out_w: Tensor, v_embed: Tensor | None = None, v0: Tensor | None = None) -> tuple[Tensor, Tensor | None]: + bsz, seqlen, dim = x.shape + q = F.linear(x, q_w.to(x.dtype)).reshape(bsz, seqlen, self.num_heads, self.head_dim) + k = F.linear(x, k_w.to(x.dtype)).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim) + v = F.linear(x, v_w.to(x.dtype)) + if v_embed is not None: + v = v + v_embed + v = v.reshape(bsz, seqlen, self.num_kv_heads, self.head_dim) + raw_v = v if self.value_residual else None + if self.value_residual and v0 is not None: + lam = self.vr_lambda.to(dtype=v.dtype) + v = lam[0] * v0 + lam[1] * v + q = F.rms_norm(q, (q.size(-1),)) + k = F.rms_norm(k, (k.size(-1),)) + cos, sin = self.rotary(seqlen, x.device, q.dtype) + q = apply_rotary_emb(q, cos, sin, self.rope_dims) + k = apply_rotary_emb(k, cos, sin, self.rope_dims) + q = q * self.q_gain.to(dtype=q.dtype)[None, None, :, None] + y = flash_attn_3_func(q, k, v, causal=True) + if self.use_xsa: + y = self._xsa_efficient(y, v) + if self.gated_attention: + # gate shape: (bsz, seqlen, num_heads) -> (bsz, seqlen, num_heads, 1) for B,T,H,D layout + gate = torch.sigmoid(self.attn_gate(x)).unsqueeze(-1) + y = y * gate + y = y.reshape(bsz, seqlen, dim) + return F.linear(y, out_w.to(x.dtype)), raw_v + +class SmearGate(nn.Module): + def __init__(self, dim: int): + super().__init__() + self.gate = nn.Parameter(torch.zeros(dim, dtype=torch.float32)) + def forward(self, x: Tensor) -> Tensor: + g = torch.sigmoid(self.gate.to(dtype=x.dtype))[None, None, :] + x_prev = torch.cat([torch.zeros_like(x[:, :1]), x[:, :-1]], dim=1) + return (1 - g) * x + g * x_prev + +class BigramHashEmbedding(nn.Module): + def __init__(self, bigram_vocab_size: int, bigram_dim: int, model_dim: int): + super().__init__() + self.bigram_vocab_size = bigram_vocab_size + self.embed = nn.Embedding(bigram_vocab_size, bigram_dim) + nn.init.zeros_(self.embed.weight) + self.proj = CastedLinear(bigram_dim, model_dim, bias=False) if bigram_dim != model_dim else None + if self.proj is not None: + nn.init.zeros_(self.proj.weight) + self.scale = nn.Parameter(torch.tensor(0.05, dtype=torch.float32)) + def bigram_hash(self, tokens: Tensor) -> Tensor: + t = tokens.to(torch.int32) + mod = self.bigram_vocab_size - 1 + out = torch.empty_like(t) + out[..., 0] = mod + out[..., 1:] = torch.bitwise_xor(36313 * t[..., 1:], 27191 * t[..., :-1]) % mod + return out.long() + def forward(self, token_ids: Tensor) -> Tensor: + h = self.embed(self.bigram_hash(token_ids)) + if self.proj is not None: + h = self.proj(h) + return h * self.scale.to(dtype=h.dtype) + +class ValueEmbedding(nn.Module): + """Reinject token identity into attention values at specific layers. + Each table maps vocab tokens to a low-dim embedding, projected to model_dim.""" + def __init__(self, vocab_size: int, ve_dim: int, model_dim: int): + super().__init__() + self.embed = nn.Embedding(vocab_size, ve_dim) + nn.init.normal_(self.embed.weight, std=0.01) + self.proj = CastedLinear(ve_dim, model_dim, bias=False) if ve_dim != model_dim else None + if self.proj is not None: + nn.init.zeros_(self.proj.weight) + self.scale = nn.Parameter(torch.tensor(0.1, dtype=torch.float32)) + def forward(self, token_ids: Tensor) -> Tensor: + h = self.embed(token_ids) + if self.proj is not None: + h = self.proj(h) + return h * self.scale.to(dtype=h.dtype) + +class MLP(nn.Module): + def __init__( + self, + dim: int, + mlp_mult: int, + activation_mode: str = "leaky_relu_sq", + activation_neg_slope: float = 0.5, + asymmetric_square_init: float = 0.25, + gated_square_beta_init: float = 1.0, + ): + super().__init__() + # No CastedLinear -- weights come from banks + self.activation_mode = activation_mode + self.activation_neg_slope = activation_neg_slope + if activation_mode == "asymmetric_square": + self.neg_sq_scale = nn.Parameter(torch.tensor(asymmetric_square_init, dtype=torch.float32)) + else: + self.neg_sq_scale = None + if activation_mode == "gated_square": + self.gated_square_beta = nn.Parameter(torch.tensor(gated_square_beta_init, dtype=torch.float32)) + else: + self.gated_square_beta = None + def forward(self, x: Tensor, up_w: Tensor, down_w: Tensor) -> Tensor: + u = F.linear(x, up_w.to(x.dtype)) + if self.activation_mode == "leaky_relu_sq": + h = F.leaky_relu(u, negative_slope=self.activation_neg_slope).square() + elif self.activation_mode == "asymmetric_square": + neg_sq_scale = self.neg_sq_scale.to(dtype=u.dtype).clamp(0.0, 4.0) + h = F.relu(u).square() + neg_sq_scale * F.relu(-u).square() + elif self.activation_mode == "gated_square": + beta = self.gated_square_beta.to(dtype=u.dtype).clamp(0.0, 8.0) + h = u.square() * torch.sigmoid(beta * u) + elif self.activation_mode == "sign_preserving_square": + h = u * u.abs() + else: + raise ValueError(f"Unknown ACTIVATION_MODE={self.activation_mode}") + return F.linear(h, down_w.to(x.dtype)) + +class Block(nn.Module): + def __init__( + self, + dim: int, + num_heads: int, + num_kv_heads: int, + mlp_mult: int, + rope_base: float, + qk_gain_init: float, + layer_idx: int = 0, + ln_scale: bool = False, + dtg: bool = False, + gated_attention: bool = False, + value_residual: bool = False, + activation_mode: str = "leaky_relu_sq", + activation_neg_slope: float = 0.5, + asymmetric_square_init: float = 0.25, + gated_square_beta_init: float = 1.0, + ): + super().__init__() + self.attn_norm = RMSNorm() + self.mlp_norm = RMSNorm() + self.attn = CausalSelfAttention(dim, num_heads, num_kv_heads, rope_base, qk_gain_init, + gated_attention=gated_attention, value_residual=value_residual) + self.mlp = MLP( + dim, + mlp_mult, + activation_mode=activation_mode, + activation_neg_slope=activation_neg_slope, + asymmetric_square_init=asymmetric_square_init, + gated_square_beta_init=gated_square_beta_init, + ) + self.attn_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + self.mlp_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + self.resid_mix = nn.Parameter(torch.stack((torch.ones(dim), torch.zeros(dim))).float()) + self.ln_scale_factor = 1.0 / math.sqrt(layer_idx + 1) if ln_scale else 1.0 + if dtg: + self.dtg_gate = nn.Linear(dim, 1, bias=True) + nn.init.zeros_(self.dtg_gate.weight) + nn.init.constant_(self.dtg_gate.bias, 2.0) + else: + self.dtg_gate = None + def forward(self, x: Tensor, x0: Tensor, q_w: Tensor, k_w: Tensor, v_w: Tensor, out_w: Tensor, up_w: Tensor, down_w: Tensor, v_embed: Tensor | None = None, v0: Tensor | None = None) -> tuple[Tensor, Tensor | None]: + mix = self.resid_mix.to(dtype=x.dtype) + x_in = mix[0][None, None, :] * x + mix[1][None, None, :] * x0 + attn_out, raw_v = self.attn(self.attn_norm(x_in) * self.ln_scale_factor, q_w, k_w, v_w, out_w, v_embed=v_embed, v0=v0) + x_out = x_in + self.attn_scale.to(dtype=x_in.dtype)[None, None, :] * attn_out + x_out = x_out + self.mlp_scale.to(dtype=x_out.dtype)[None, None, :] * self.mlp(self.mlp_norm(x_out) * self.ln_scale_factor, up_w, down_w) + if self.dtg_gate is not None: + gate = torch.sigmoid(self.dtg_gate(x_in.detach())) + x_out = x_in + gate * (x_out - x_in) + return x_out, raw_v + +class GPT(nn.Module): + def __init__( + self, + vocab_size: int, + num_layers: int, + model_dim: int, + num_heads: int, + num_kv_heads: int, + mlp_mult: int, + tie_embeddings: bool, + tied_embed_init_std: float, + logit_softcap: float, + rope_base: float, + qk_gain_init: float, + mtp_num_heads: int = 0, + mtp_loss_weight: float = 0.1, + bigram_vocab_size: int = 0, + bigram_dim: int = 128, + xsa_last_n: int = 0, + rope_dims: int = 0, + ln_scale: bool = False, + dtg: bool = False, + ve_enabled: bool = False, + ve_dim: int = 128, + ve_layers: str = "9,10", + gated_attention: bool = False, + value_residual: bool = False, + activation_mode: str = "leaky_relu_sq", + activation_neg_slope: float = 0.5, + asymmetric_square_init: float = 0.25, + gated_square_beta_init: float = 1.0, + ): + super().__init__() + self._ve_target_dim = num_kv_heads * (model_dim // num_heads) # kv_dim for value projection + if logit_softcap <= 0.0: + raise ValueError(f"logit_softcap must be positive, got {logit_softcap}") + self.tie_embeddings = tie_embeddings + self.tied_embed_init_std = tied_embed_init_std + self.logit_softcap = logit_softcap + self.value_residual = value_residual + self.mtp_num_heads = mtp_num_heads + self.mtp_loss_weight = mtp_loss_weight + self.tok_emb = nn.Embedding(vocab_size, model_dim) + self.bigram = BigramHashEmbedding(bigram_vocab_size, bigram_dim, model_dim) if bigram_vocab_size > 0 else None + self.smear = SmearGate(model_dim) + self.num_encoder_layers = num_layers // 2 + self.num_decoder_layers = num_layers - self.num_encoder_layers + self.num_skip_weights = min(self.num_encoder_layers, self.num_decoder_layers) + self.skip_weights = nn.Parameter(torch.ones(self.num_skip_weights, model_dim, dtype=torch.float32)) + # Parameter banks: contiguous 3D tensors for batched optimizer + head_dim = model_dim // num_heads + kv_dim = num_kv_heads * head_dim + mlp_dim = int(mlp_mult * model_dim) + self.num_layers = num_layers + self.qo_bank = nn.Parameter(torch.empty(2 * num_layers, model_dim, model_dim)) + self.kv_bank = nn.Parameter(torch.empty(2 * num_layers, kv_dim, model_dim)) + self.mlp_up_bank = nn.Parameter(torch.empty(num_layers, mlp_dim, model_dim)) + self.mlp_down_bank = nn.Parameter(torch.empty(num_layers, model_dim, mlp_dim)) + self.blocks = nn.ModuleList( + [ + Block( + model_dim, + num_heads, + num_kv_heads, + mlp_mult, + rope_base, + qk_gain_init, + layer_idx=i, + ln_scale=ln_scale, + dtg=dtg, + gated_attention=gated_attention, + value_residual=value_residual, + activation_mode=activation_mode, + activation_neg_slope=activation_neg_slope, + asymmetric_square_init=asymmetric_square_init, + gated_square_beta_init=gated_square_beta_init, + ) + for i in range(num_layers) + ] + ) + if rope_dims > 0: + head_dim = model_dim // num_heads + for block in self.blocks: + block.attn.rope_dims = rope_dims + block.attn.rotary = Rotary(head_dim, base=rope_base, train_seq_len=1024, rope_dims=rope_dims) + self.ve_layer_indices = [int(x) for x in ve_layers.split(",") if x.strip()] if ve_enabled else [] + kv_dim_ve = self._ve_target_dim + if self.ve_layer_indices: + self.ve_shared = ValueEmbedding(vocab_size, ve_dim, kv_dim_ve) + self.ve_layer_scales = nn.ParameterList( + [nn.Parameter(torch.ones(1, dtype=torch.float32)) for _ in self.ve_layer_indices] + ) + else: + self.ve_shared = None + self.ve_layer_scales = nn.ParameterList() + self.value_embeds = nn.ModuleList() # keep empty for compat + self.final_norm = RMSNorm() + self.lm_head = None if tie_embeddings else CastedLinear(model_dim, vocab_size, bias=False) + if self.lm_head is not None: + self.lm_head._zero_init = True + self.mtp_heads = nn.ModuleList( + [CastedLinear(model_dim, vocab_size, bias=False) for _ in range(mtp_num_heads)] + ) + for head in self.mtp_heads: + head._zero_init = True + if xsa_last_n > 0: + for i in range(max(0, num_layers - xsa_last_n), num_layers): + self.blocks[i].attn.use_xsa = True + self._init_weights() + def _init_weights(self) -> None: + if self.tie_embeddings: + nn.init.normal_(self.tok_emb.weight, mean=0.0, std=self.tied_embed_init_std) + n = self.num_layers + proj_scale = 1.0 / math.sqrt(2 * n) + # Init banks: orthogonal, with proj layers scaled down and out/down zero-init + for i in range(n): + nn.init.orthogonal_(self.qo_bank.data[i], gain=1.0) # Q + nn.init.zeros_(self.qo_bank.data[n + i]) # Out (zero init) + nn.init.orthogonal_(self.kv_bank.data[i], gain=1.0) # K + nn.init.orthogonal_(self.kv_bank.data[n + i], gain=1.0) # V + nn.init.orthogonal_(self.mlp_up_bank.data[i], gain=1.0) # MLP up + nn.init.zeros_(self.mlp_down_bank.data[i]) # MLP down (zero init) + # Scale proj layers (out_proj and mlp_down are "proj" layers) + self.qo_bank.data[n + i].mul_(proj_scale) + self.mlp_down_bank.data[i].mul_(proj_scale) + # Init remaining nn.Linear modules (bigram proj, mtp heads, lm_head) + for name, module in self.named_modules(): + if isinstance(module, nn.Linear): + if getattr(module, "_zero_init", False): + nn.init.zeros_(module.weight) + elif module.weight.ndim == 2 and module.weight.shape[0] >= 64 and module.weight.shape[1] >= 64: + nn.init.orthogonal_(module.weight, gain=1.0) + def _get_ve(self, layer_idx: int, input_ids: Tensor, ve_cache: dict | None = None) -> Tensor | None: + """Get value embedding for a specific layer using shared table + per-layer scale.""" + if self.ve_shared is None or layer_idx not in self.ve_layer_indices: + return None + if ve_cache is not None and 've' not in ve_cache: + ve_cache['ve'] = self.ve_shared(input_ids) + ve_base = ve_cache['ve'] if ve_cache is not None else self.ve_shared(input_ids) + ve_idx = self.ve_layer_indices.index(layer_idx) + return ve_base * self.ve_layer_scales[ve_idx].to(dtype=ve_base.dtype) + def forward(self, input_ids: Tensor, target_ids: Tensor) -> Tensor: + n = self.num_layers + x = self.tok_emb(input_ids) + if self.bigram is not None: + x = x + self.bigram(input_ids) + x = F.rms_norm(x, (x.size(-1),)) + x = self.smear(x) + x0 = x + v0 = None + skips: list[Tensor] = [] + ve_cache: dict = {} + for i in range(self.num_encoder_layers): + ve = self._get_ve(i, input_ids, ve_cache) + x, raw_v = self.blocks[i](x, x0, + self.qo_bank[i], self.kv_bank[i], self.kv_bank[n + i], + self.qo_bank[n + i], self.mlp_up_bank[i], self.mlp_down_bank[i], + v_embed=ve, v0=v0) + if v0 is None and raw_v is not None: + v0 = raw_v + skips.append(x) + for i in range(self.num_decoder_layers): + bi = self.num_encoder_layers + i + if skips: + x = x + self.skip_weights[i].to(dtype=x.dtype)[None, None, :] * skips.pop() + ve = self._get_ve(bi, input_ids, ve_cache) + x, _ = self.blocks[bi](x, x0, + self.qo_bank[bi], self.kv_bank[bi], self.kv_bank[n + bi], + self.qo_bank[n + bi], self.mlp_up_bank[bi], self.mlp_down_bank[bi], + v_embed=ve, v0=v0) + x = self.final_norm(x) + x_flat = x.reshape(-1, x.size(-1)) + targets = target_ids.reshape(-1) + if self.tie_embeddings: + logits_proj = F.linear(x_flat, self.tok_emb.weight) + else: + if self.lm_head is None: + raise RuntimeError("lm_head is required when tie_embeddings=False") + logits_proj = self.lm_head(x_flat) + logits = self.logit_softcap * torch.tanh(logits_proj / self.logit_softcap) + main_loss = F.cross_entropy(logits.float(), targets, reduction="mean") + if self.training and self.mtp_num_heads > 0 and self.mtp_loss_weight > 0.0: + _, seqlen, dim = x.shape + mtp_loss_sum = x.new_zeros(()) + mtp_loss_count = 0 + for k, mtp_head in enumerate(self.mtp_heads): + valid_t = seqlen - (k + 1) + if valid_t <= 0: + continue + mtp_hidden = x[:, :valid_t, :].reshape(-1, dim) + mtp_targets = target_ids[:, k + 1 :].reshape(-1) + mtp_logits_proj = mtp_head(mtp_hidden) + mtp_logits = self.logit_softcap * torch.tanh(mtp_logits_proj / self.logit_softcap) + mtp_loss_sum = mtp_loss_sum + F.cross_entropy(mtp_logits.float(), mtp_targets, reduction="mean") + mtp_loss_count += 1 + if mtp_loss_count > 0: + main_loss = main_loss + self.mtp_loss_weight * (mtp_loss_sum / mtp_loss_count) + return main_loss + def forward_logits(self, input_ids: Tensor) -> Tensor: + """Return logits (bsz, seq_len, vocab) without computing loss.""" + n = self.num_layers + x = self.tok_emb(input_ids) + if self.bigram is not None: + x = x + self.bigram(input_ids) + x = F.rms_norm(x, (x.size(-1),)) + x = self.smear(x) + x0 = x + v0 = None + skips: list[Tensor] = [] + ve_cache: dict = {} + for i in range(self.num_encoder_layers): + ve = self._get_ve(i, input_ids, ve_cache) + x, raw_v = self.blocks[i](x, x0, + self.qo_bank[i], self.kv_bank[i], self.kv_bank[n + i], + self.qo_bank[n + i], self.mlp_up_bank[i], self.mlp_down_bank[i], + v_embed=ve, v0=v0) + if v0 is None and raw_v is not None: + v0 = raw_v + skips.append(x) + for i in range(self.num_decoder_layers): + bi = self.num_encoder_layers + i + if skips: + x = x + self.skip_weights[i].to(dtype=x.dtype)[None, None, :] * skips.pop() + ve = self._get_ve(bi, input_ids, ve_cache) + x, _ = self.blocks[bi](x, x0, + self.qo_bank[bi], self.kv_bank[bi], self.kv_bank[n + bi], + self.qo_bank[n + bi], self.mlp_up_bank[bi], self.mlp_down_bank[bi], + v_embed=ve, v0=v0) + x = self.final_norm(x) + if self.tie_embeddings: + logits_proj = F.linear(x, self.tok_emb.weight) + else: + logits_proj = self.lm_head(x) + return self.logit_softcap * torch.tanh(logits_proj / self.logit_softcap) + +# --- Sliding window evaluation --- + +def eval_val_sliding( + args: Hyperparameters, + base_model: nn.Module, + rank: int, + world_size: int, + device: torch.device, + val_tokens: Tensor, + base_bytes_lut: Tensor, + has_leading_space_lut: Tensor, + is_boundary_token_lut: Tensor, + stride: int, + batch_seqs: int = 32, + eval_seq_len: int | None = None, +) -> tuple[float, float]: + """Sliding window evaluation: each token scored with maximum context.""" + seq_len = eval_seq_len or args.train_seq_len + total_tokens = val_tokens.numel() - 1 + window_starts = [ws for ws in range(0, total_tokens, stride) + if min(ws + seq_len, total_tokens) - ws >= 1] + total_windows = len(window_starts) + my_s = (total_windows * rank) // world_size + my_e = (total_windows * (rank + 1)) // world_size + my_windows = window_starts[my_s:my_e] + loss_sum = torch.zeros((), device=device, dtype=torch.float64) + token_count = torch.zeros((), device=device, dtype=torch.float64) + byte_count = torch.zeros((), device=device, dtype=torch.float64) + base_model.eval() + compiled_logits = torch.compile(base_model.forward_logits, dynamic=False, fullgraph=True) + with torch.inference_mode(): + for bi in range(0, len(my_windows), batch_seqs): + batch_ws = my_windows[bi:bi + batch_seqs] + bsz = len(batch_ws) + x_batch = torch.zeros(bsz, seq_len, dtype=torch.int64, device=device) + y_batch = torch.zeros(bsz, seq_len, dtype=torch.int64, device=device) + wlens: list[int] = [] + for i, ws in enumerate(batch_ws): + end = min(ws + seq_len, total_tokens) + wlen = end - ws + wlens.append(wlen) + chunk = val_tokens[ws:end + 1].to(dtype=torch.int64, device=device) + x_batch[i, :wlen] = chunk[:-1] + y_batch[i, :wlen] = chunk[1:] + with torch.autocast(device_type="cuda", dtype=torch.bfloat16): + logits = compiled_logits(x_batch) + nll = F.cross_entropy( + logits.reshape(-1, logits.size(-1)).float(), + y_batch.reshape(-1), + reduction="none", + ).reshape(bsz, seq_len) + for i, ws in enumerate(batch_ws): + wlen = wlens[i] + s = 0 if ws == 0 else max(wlen - stride, 0) + scored_nll = nll[i, s:wlen].to(torch.float64) + loss_sum += scored_nll.sum() + token_count += float(wlen - s) + tgt = y_batch[i, s:wlen] + prev = x_batch[i, s:wlen] + tb = base_bytes_lut[tgt].to(torch.float64) + tb += (has_leading_space_lut[tgt] & ~is_boundary_token_lut[prev]).to(torch.float64) + byte_count += tb.sum() + if dist.is_available() and dist.is_initialized(): + dist.all_reduce(loss_sum, op=dist.ReduceOp.SUM) + dist.all_reduce(token_count, op=dist.ReduceOp.SUM) + dist.all_reduce(byte_count, op=dist.ReduceOp.SUM) + val_loss = (loss_sum / token_count).item() + bits_per_token = val_loss / math.log(2.0) + tokens_per_byte = token_count.item() / byte_count.item() + base_model.train() + return val_loss, bits_per_token * tokens_per_byte + + +def eval_val_sliding_ttt( + args: Hyperparameters, base_model: nn.Module, rank: int, world_size: int, + device: torch.device, val_tokens: Tensor, base_bytes_lut: Tensor, + has_leading_space_lut: Tensor, is_boundary_token_lut: Tensor, + stride: int, batch_seqs: int = 32, log0=print, +) -> tuple[float, float]: + """Legal score-first TTT (PR #461 recipe): score each chunk with sliding windows, + then train on it. Every token scored BEFORE any update that could use it.""" + seq_len = args.train_seq_len + total_tokens = val_tokens.numel() - 1 + ttt_chunk = args.ttt_chunk_tokens + + # Pre-compute all window starts + window_starts = [ws for ws in range(0, total_tokens, stride) + if min(ws + seq_len, total_tokens) - ws >= stride or ws == 0] + + # Assign each window to a chunk based on the first token it scores + num_chunks = (total_tokens + ttt_chunk - 1) // ttt_chunk + chunk_windows: list[list[int]] = [[] for _ in range(num_chunks)] + for ws in window_starts: + end = min(ws + seq_len, total_tokens) + wlen = end - ws + s = 0 if ws == 0 else max(wlen - stride, 0) + scored_start = ws + s + ci = min(scored_start // ttt_chunk, num_chunks - 1) + chunk_windows[ci].append(ws) + + log0(f"ttt_sliding:start chunks={num_chunks} chunk_tokens={ttt_chunk} " + f"total_windows={len(window_starts)} stride={stride} " + f"ttt_lr={args.ttt_lr} ttt_epochs={args.ttt_epochs} " + f"freeze_blocks={args.ttt_freeze_blocks}") + + loss_sum = torch.zeros((), device=device, dtype=torch.float64) + token_count = torch.zeros((), device=device, dtype=torch.float64) + byte_count = torch.zeros((), device=device, dtype=torch.float64) + + # Freeze first N blocks + frozen_block_ids = set(range(min(args.ttt_freeze_blocks, len(base_model.blocks)))) + ttt_params = [] + for name, p in base_model.named_parameters(): + freeze = False + for bi in frozen_block_ids: + if f"blocks.{bi}." in name: + freeze = True + break + if freeze: + p.requires_grad_(False) + else: + p.requires_grad_(True) + ttt_params.append(p) + + log0(f"ttt_sliding:params unfrozen={sum(p.numel() for p in ttt_params)} " + f"frozen={sum(p.numel() for p in base_model.parameters() if not p.requires_grad)}") + + optimizer = torch.optim.SGD(ttt_params, lr=args.ttt_lr, momentum=args.ttt_momentum) + t0 = time.perf_counter() + + for ci in range(num_chunks): + windows = chunk_windows[ci] + if not windows: + continue + chunk_start = ci * ttt_chunk + chunk_end = min((ci + 1) * ttt_chunk, total_tokens) + + # --- Phase 1: SCORE this chunk's windows (inference_mode) --- + my_s = (len(windows) * rank) // world_size + my_e = (len(windows) * (rank + 1)) // world_size + my_windows = windows[my_s:my_e] + + base_model.eval() + with torch.inference_mode(): + for bi in range(0, len(my_windows), batch_seqs): + batch_ws = my_windows[bi:bi + batch_seqs] + bsz = len(batch_ws) + x_batch = torch.zeros(bsz, seq_len, dtype=torch.int64, device=device) + y_batch = torch.zeros(bsz, seq_len, dtype=torch.int64, device=device) + wlens: list[int] = [] + for i, ws in enumerate(batch_ws): + end = min(ws + seq_len, total_tokens) + wlen = end - ws + wlens.append(wlen) + chunk_tok = val_tokens[ws:end + 1].to(dtype=torch.int64, device=device) + x_batch[i, :wlen] = chunk_tok[:-1] + y_batch[i, :wlen] = chunk_tok[1:] + with torch.autocast(device_type="cuda", dtype=torch.bfloat16): + logits = base_model.forward_logits(x_batch) + nll = F.cross_entropy( + logits.reshape(-1, logits.size(-1)).float(), + y_batch.reshape(-1), reduction="none", + ).reshape(bsz, seq_len) + for i, ws in enumerate(batch_ws): + wlen = wlens[i] + s = 0 if ws == 0 else max(wlen - stride, 0) + scored_nll = nll[i, s:wlen].to(torch.float64) + loss_sum += scored_nll.sum() + token_count += float(wlen - s) + tgt, prev = y_batch[i, s:wlen], x_batch[i, s:wlen] + tb = base_bytes_lut[tgt].to(torch.float64) + tb += (has_leading_space_lut[tgt] & ~is_boundary_token_lut[prev]).to(torch.float64) + byte_count += tb.sum() + + # --- Phase 2: TRAIN on this chunk (already scored = legal) --- + is_last_chunk = (ci == num_chunks - 1) + if not is_last_chunk and args.ttt_epochs > 0: + base_model.train() + chunk_seqs = (chunk_end - chunk_start) // seq_len + if chunk_seqs > 0: + cos_lr = args.ttt_lr * 0.5 * (1.0 + math.cos(math.pi * ci / max(num_chunks - 1, 1))) + for pg in optimizer.param_groups: + pg['lr'] = cos_lr + my_seq_s = (chunk_seqs * rank) // world_size + my_seq_e = (chunk_seqs * (rank + 1)) // world_size + my_chunk_seqs = my_seq_e - my_seq_s + for _ep in range(args.ttt_epochs): + for bs in range(0, my_chunk_seqs, args.ttt_batch_seqs): + be = min(bs + args.ttt_batch_seqs, my_chunk_seqs) + actual_bs = my_seq_s + bs + start_tok = chunk_start + actual_bs * seq_len + end_tok = chunk_start + (my_seq_s + be) * seq_len + 1 + if end_tok > val_tokens.numel(): + continue + local = val_tokens[start_tok:end_tok].to(device=device, dtype=torch.int64) + x = local[:-1].reshape(-1, seq_len) + y = local[1:].reshape(-1, seq_len) + optimizer.zero_grad(set_to_none=True) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16): + loss = base_model(x, y) + loss.backward() + if world_size > 1: + for p in ttt_params: + if p.grad is not None: + dist.all_reduce(p.grad, op=dist.ReduceOp.AVG) + torch.nn.utils.clip_grad_norm_(ttt_params, args.ttt_grad_clip) + optimizer.step() + + if rank == 0 and (ci % 10 == 0 or ci == num_chunks - 1): + elapsed = time.perf_counter() - t0 + rl = loss_sum.item() / max(token_count.item(), 1) + rbpb = rl / math.log(2.0) * (token_count.item() / max(byte_count.item(), 1)) if token_count.item() > 0 else 0.0 + log0(f" ttt_chunk [{ci+1}/{num_chunks}] bpb={rbpb:.6f} time={elapsed:.1f}s") + + if dist.is_available() and dist.is_initialized(): + dist.all_reduce(loss_sum, op=dist.ReduceOp.SUM) + dist.all_reduce(token_count, op=dist.ReduceOp.SUM) + dist.all_reduce(byte_count, op=dist.ReduceOp.SUM) + + val_loss = (loss_sum / token_count).item() + val_bpb = val_loss / math.log(2.0) * (token_count.item() / byte_count.item()) + + for p in base_model.parameters(): + p.requires_grad_(True) + base_model.eval() + + log0(f"ttt_sliding:done val_loss={val_loss:.6f} val_bpb={val_bpb:.6f} " + f"elapsed={time.perf_counter() - t0:.1f}s") + return val_loss, val_bpb + + +# === N-GRAM EVAL CACHE + TWO-PASS RESCORE === + +_NGRAM_PRIMES = np.array([ + 36313, 27191, 51647, 81929, 131071, 174763, 233017, 283721, + 347237, 411527, 479909, 557927, 646333, 746773, 862319, 992353, +], dtype=np.int64) + +# Per-order multipliers: orders 2-3 suppressed, 4 near-neutral, 5-12 boosted +_ORDER_MULTS = np.array([ + 0.30, 0.30, 0.97, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, +], dtype=np.float32) + + +class NgramCache: + """Hash-table n-gram cache with vectorized numpy operations.""" + + def __init__(self, min_order: int = 2, max_order: int = 16, + num_buckets: int = 16_777_216): + self.min_order = min_order + self.max_order = max_order + self.num_orders = max_order - min_order + 1 + self.num_buckets = num_buckets + self.bucket_mask = np.int64(num_buckets - 1) + # Two flat hash tables per order: context counts and full (context+target) counts + self.ctx_tables = [np.zeros(num_buckets, dtype=np.int32) for _ in range(self.num_orders)] + self.full_tables = [np.zeros(num_buckets, dtype=np.int32) for _ in range(self.num_orders)] + + def _compute_hashes(self, tokens_np: np.ndarray, start: int, end: int, order_idx: int): + """Compute context and full hashes for positions [start, end) at given order.""" + n = self.min_order + order_idx + valid_start = max(start, n - 1) + N = end - valid_start + if N <= 0: + return None, None, valid_start + # Context hash: XOR of tokens[pos-n+1+k] * primes[k] for k=0..n-2 + h = np.zeros(N, dtype=np.int64) + for k in range(n - 1): + offset = valid_start - (n - 1) + k + h ^= tokens_np[offset:offset + N].astype(np.int64) * _NGRAM_PRIMES[k % len(_NGRAM_PRIMES)] + ctx_h = h & self.bucket_mask + # Full hash: context + target token + target_prime = _NGRAM_PRIMES[min(n - 1, len(_NGRAM_PRIMES) - 1)] + full_h = (h ^ (tokens_np[valid_start:end].astype(np.int64) * target_prime)) & self.bucket_mask + return ctx_h, full_h, valid_start + + def _bincount_add(self, table: np.ndarray, indices: np.ndarray): + """Fast histogram accumulation using np.bincount (much faster than np.add.at).""" + counts = np.bincount(indices.astype(np.intp), minlength=self.num_buckets) + table += counts[:self.num_buckets].astype(table.dtype) + + def update_range(self, tokens_np: np.ndarray, start: int, end: int): + """Add tokens[start:end] to the cache for all orders.""" + for oi in range(self.num_orders): + ctx_h, full_h, vs = self._compute_hashes(tokens_np, start, end, oi) + if ctx_h is None: + continue + self._bincount_add(self.ctx_tables[oi], ctx_h) + self._bincount_add(self.full_tables[oi], full_h) + + def build_full(self, tokens_np: np.ndarray): + """Build complete cache from entire token sequence (vectorized).""" + for oi in range(self.num_orders): + ctx_h, full_h, _ = self._compute_hashes(tokens_np, 0, len(tokens_np), oi) + if ctx_h is None: + continue + self._bincount_add(self.ctx_tables[oi], ctx_h) + self._bincount_add(self.full_tables[oi], full_h) + + def score_range(self, tokens_np: np.ndarray, start: int, end: int, + min_count: int = 2): + """Score tokens[start:end] against the cache. + + Returns: + ngram_prob: (N,) float32 - n-gram probability for the true target token + matched_order: (N,) int32 - which order matched (-1 = no match) + """ + N = end - start + ngram_prob = np.zeros(N, dtype=np.float32) + matched_order = np.full(N, -1, dtype=np.int32) + matched = np.zeros(N, dtype=bool) + + # Backoff from highest to lowest order + for oi in range(self.num_orders - 1, -1, -1): + n = self.min_order + oi + ctx_h, full_h, vs = self._compute_hashes(tokens_np, start, end, oi) + if ctx_h is None: + continue + offset = vs - start + ctx_counts = self.ctx_tables[oi][ctx_h] + full_counts = self.full_tables[oi][full_h] + # Cap full counts to context counts (hash collision mitigation) + full_counts = np.minimum(full_counts, ctx_counts) + # Only match when: sufficient context, target has been seen, not already matched + eligible = (ctx_counts >= min_count) & (full_counts > 0) & ~matched[offset:] + if not np.any(eligible): + continue + prob = full_counts[eligible].astype(np.float32) / np.maximum(ctx_counts[eligible].astype(np.float32), 1.0) + # Find which positions in the output array to fill + out_idx = np.where(eligible)[0] + offset + ngram_prob[out_idx] = prob + matched_order[out_idx] = n + matched[out_idx] = True + + return ngram_prob, matched_order + + def score_positions(self, tokens_np: np.ndarray, positions: np.ndarray, + min_count: int = 2, leave_one_out: bool = False): + """Score selected token positions against the cache. + + If leave_one_out is enabled, subtract this token's own contribution from + both context and (context,target) counts before matching. + """ + N = len(positions) + ngram_prob = np.zeros(N, dtype=np.float32) + matched_order = np.full(N, -1, dtype=np.int32) + matched = np.zeros(N, dtype=bool) + if N == 0: + return ngram_prob, matched_order + + positions = positions.astype(np.int64, copy=False) + for oi in range(self.num_orders - 1, -1, -1): + n = self.min_order + oi + ctx_h_all, full_h_all, valid_start = self._compute_hashes(tokens_np, 0, len(tokens_np), oi) + if ctx_h_all is None: + continue + + remaining_idx = np.where(~matched)[0] + if remaining_idx.size == 0: + break + pos_sub = positions[remaining_idx] + valid_mask = pos_sub >= valid_start + if not np.any(valid_mask): + continue + + valid_idx = remaining_idx[valid_mask] + lookup = (pos_sub[valid_mask] - valid_start).astype(np.int64) + ctx_h = ctx_h_all[lookup] + full_h = full_h_all[lookup] + + ctx_counts = self.ctx_tables[oi][ctx_h].astype(np.int64) + full_counts = self.full_tables[oi][full_h].astype(np.int64) + if leave_one_out: + ctx_counts = np.maximum(ctx_counts - 1, 0) + full_counts = np.maximum(full_counts - 1, 0) + full_counts = np.minimum(full_counts, ctx_counts) + + eligible = (ctx_counts >= min_count) & (full_counts > 0) + if not np.any(eligible): + continue + + out_idx = valid_idx[eligible] + prob = full_counts[eligible].astype(np.float32) / np.maximum(ctx_counts[eligible].astype(np.float32), 1.0) + ngram_prob[out_idx] = prob + matched_order[out_idx] = n + matched[out_idx] = True + + return ngram_prob, matched_order + + +def eval_val_sliding_store( + args: Hyperparameters, base_model: nn.Module, rank: int, world_size: int, + device: torch.device, val_tokens: Tensor, base_bytes_lut: Tensor, + has_leading_space_lut: Tensor, is_boundary_token_lut: Tensor, + stride: int, batch_seqs: int = 32, log0=print, +) -> tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray, float, float]: + """Sliding-window eval that stores per-token model_p and entropy. + + Returns: (model_p, entropy, token_bytes, token_targets, val_loss, val_bpb) + where model_p and entropy are arrays covering this rank's scored tokens, + and val_loss/val_bpb are the standard (un-blended) metrics. + + Also returns global-offset index arrays for mapping back to token positions. + """ + seq_len = args.train_seq_len + total_tokens = val_tokens.numel() - 1 + window_starts = [ws for ws in range(0, total_tokens, stride) + if min(ws + seq_len, total_tokens) - ws >= 1] + total_windows = len(window_starts) + my_s = (total_windows * rank) // world_size + my_e = (total_windows * (rank + 1)) // world_size + my_windows = window_starts[my_s:my_e] + + # Pre-allocate per-token storage (we'll trim later) + # Each token is scored in exactly one window + model_p_list: list[np.ndarray] = [] + entropy_list: list[np.ndarray] = [] + bytes_list: list[np.ndarray] = [] + position_list: list[np.ndarray] = [] # global target-token positions + nll_list: list[np.ndarray] = [] + + base_model.eval() + compiled_logits = torch.compile(base_model.forward_logits, dynamic=False, fullgraph=True) + with torch.inference_mode(): + for bi in range(0, len(my_windows), batch_seqs): + batch_ws = my_windows[bi:bi + batch_seqs] + bsz = len(batch_ws) + x_batch = torch.zeros(bsz, seq_len, dtype=torch.int64, device=device) + y_batch = torch.zeros(bsz, seq_len, dtype=torch.int64, device=device) + wlens: list[int] = [] + for i, ws in enumerate(batch_ws): + end_pos = min(ws + seq_len, total_tokens) + wlen = end_pos - ws + wlens.append(wlen) + chunk = val_tokens[ws:end_pos + 1].to(dtype=torch.int64, device=device) + x_batch[i, :wlen] = chunk[:-1] + y_batch[i, :wlen] = chunk[1:] + with torch.autocast(device_type="cuda", dtype=torch.bfloat16): + logits = compiled_logits(x_batch) # (bsz, seq_len, vocab_size) + # Compute per-token quantities + logits_f = logits.float() + log_probs = F.log_softmax(logits_f, dim=-1) # (bsz, seq_len, V) + probs = log_probs.exp() + # NLL for each token + nll_all = F.cross_entropy( + logits_f.reshape(-1, logits_f.size(-1)), + y_batch.reshape(-1), reduction="none" + ).reshape(bsz, seq_len) + # Model probability of true token + mp = probs.gather(2, y_batch.unsqueeze(-1)).squeeze(-1) # (bsz, seq_len) + # Entropy of model distribution + ent = -(probs * log_probs).sum(dim=-1) # (bsz, seq_len) + + for i, ws in enumerate(batch_ws): + wlen = wlens[i] + s = 0 if ws == 0 else max(wlen - stride, 0) + # Positions are TARGET token indices in val_tokens (ws+j+1 for scored position j) + positions = np.arange(ws + s + 1, ws + wlen + 1, dtype=np.int64) + position_list.append(positions) + model_p_list.append(mp[i, s:wlen].cpu().numpy().astype(np.float32)) + entropy_list.append(ent[i, s:wlen].cpu().numpy().astype(np.float32)) + nll_list.append(nll_all[i, s:wlen].cpu().numpy().astype(np.float64)) + tgt = y_batch[i, s:wlen] + prev = x_batch[i, s:wlen] + tb = base_bytes_lut[tgt].to(torch.float64) + tb += (has_leading_space_lut[tgt] & ~is_boundary_token_lut[prev]).to(torch.float64) + bytes_list.append(tb.cpu().numpy()) + + all_positions = np.concatenate(position_list) if position_list else np.array([], dtype=np.int64) + all_model_p = np.concatenate(model_p_list) if model_p_list else np.array([], dtype=np.float32) + all_entropy = np.concatenate(entropy_list) if entropy_list else np.array([], dtype=np.float32) + all_nll = np.concatenate(nll_list) if nll_list else np.array([], dtype=np.float64) + all_bytes = np.concatenate(bytes_list) if bytes_list else np.array([], dtype=np.float64) + + + # Compute standard (un-blended) BPB for this rank + local_loss_sum = all_nll.sum() + local_token_count = float(len(all_nll)) + local_byte_count = all_bytes.sum() + + # All-reduce for standard BPB + loss_sum_t = torch.tensor(local_loss_sum, device=device, dtype=torch.float64) + token_count_t = torch.tensor(local_token_count, device=device, dtype=torch.float64) + byte_count_t = torch.tensor(local_byte_count, device=device, dtype=torch.float64) + if dist.is_available() and dist.is_initialized(): + dist.all_reduce(loss_sum_t, op=dist.ReduceOp.SUM) + dist.all_reduce(token_count_t, op=dist.ReduceOp.SUM) + dist.all_reduce(byte_count_t, op=dist.ReduceOp.SUM) + val_loss = (loss_sum_t / token_count_t).item() + val_bpb = val_loss / math.log(2.0) * (token_count_t.item() / byte_count_t.item()) + + base_model.train() + return all_model_p, all_entropy, all_bytes, all_positions, val_loss, val_bpb + + +def ngram_rescore( + args: Hyperparameters, + tokens_np: np.ndarray, + cache: NgramCache, + model_p: np.ndarray, + entropy: np.ndarray, + token_bytes: np.ndarray, + positions: np.ndarray, + rank: int, world_size: int, device: torch.device, + log0=print, +) -> tuple[float, float]: + """Rescore tokens using n-gram cache blended with stored neural model_p. + + This is Pass 2: the cache is already complete. + Returns: (val_loss, val_bpb) + """ + N = len(positions) + if N == 0: + return 0.0, 0.0 + + # Score all of this rank's positions against the full cache + # We need to score at the GLOBAL token positions + # The cache.score_range expects contiguous ranges, but our positions may be sparse + # Instead, we score the full range and index into it + # Actually, positions are sorted (from sliding windows), so we can score chunks + + ngram_prob, matched_order = cache.score_positions( + tokens_np, + positions, + min_count=args.ngram_min_count, + leave_one_out=args.ngram_leave_one_out, + ) + matched = matched_order >= 0 + + # Entropy-adaptive alpha with per-order multipliers + alpha = np.zeros(N, dtype=np.float32) + if np.any(matched): + order_idx = (matched_order[matched] - cache.min_order).astype(np.int32) + centers = args.ngram_entropy_center - 0.25 * order_idx.astype(np.float32) + sig = 1.0 / (1.0 + np.exp(-args.ngram_entropy_scale * (entropy[matched] - centers))) + raw_alpha = args.ngram_alpha_min + (args.ngram_alpha_max - args.ngram_alpha_min) * sig + # Per-order multipliers + mults = _ORDER_MULTS[np.minimum(order_idx, len(_ORDER_MULTS) - 1)] + raw_alpha *= mults + alpha[matched] = np.clip(raw_alpha, 0.0, 0.95) + + # Blend: p_blend = (1 - alpha) * model_p + alpha * ngram_prob + p_blend = (1.0 - alpha) * model_p + alpha * ngram_prob + # Clamp to avoid log(0) + p_blend = np.maximum(p_blend, 1e-10) + # For unmatched tokens, use model_p directly + p_blend[~matched] = np.maximum(model_p[~matched], 1e-10) + + # NLL + nll = -np.log(p_blend).astype(np.float64) + + # Aggregate + local_loss_sum = nll.sum() + local_token_count = float(N) + local_byte_count = token_bytes.sum() + + # All-reduce + loss_sum_t = torch.tensor(local_loss_sum, device=device, dtype=torch.float64) + token_count_t = torch.tensor(local_token_count, device=device, dtype=torch.float64) + byte_count_t = torch.tensor(local_byte_count, device=device, dtype=torch.float64) + if dist.is_available() and dist.is_initialized(): + dist.all_reduce(loss_sum_t, op=dist.ReduceOp.SUM) + dist.all_reduce(token_count_t, op=dist.ReduceOp.SUM) + dist.all_reduce(byte_count_t, op=dist.ReduceOp.SUM) + + val_loss = (loss_sum_t / token_count_t).item() + val_bpb = val_loss / math.log(2.0) * (token_count_t.item() / byte_count_t.item()) + + n_matched = int(matched.sum()) + log0( + f"ngram_rescore: matched={n_matched}/{N} ({100*n_matched/max(N,1):.1f}%) " + f"mean_alpha={alpha[matched].mean():.3f} leave_one_out={int(args.ngram_leave_one_out)}" + if n_matched > 0 else f"ngram_rescore: no matches leave_one_out={int(args.ngram_leave_one_out)}" + ) + + return val_loss, val_bpb + + +def eval_ngram_two_pass( + args: Hyperparameters, base_model: nn.Module, rank: int, world_size: int, + device: torch.device, val_tokens: Tensor, base_bytes_lut: Tensor, + has_leading_space_lut: Tensor, is_boundary_token_lut: Tensor, + stride: int, batch_seqs: int = 32, log0=print, +) -> tuple[float, float]: + """Two-pass n-gram evaluation. + + Pass 1: Sliding-window neural eval → store per-token model_p and entropy. + Build: Complete n-gram cache from all tokens (vectorized). + Pass 2: Rescore ALL tokens by blending neural model_p with n-gram predictions. + """ + t0 = time.perf_counter() + + # --- Pass 1: Neural eval with per-token storage --- + log0(f"ngram_two_pass: starting Pass 1 (sliding-window neural eval)") + model_p, entropy, token_bytes, positions, pass1_loss, pass1_bpb = eval_val_sliding_store( + args, base_model, rank, world_size, device, val_tokens, + base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, + stride=stride, batch_seqs=batch_seqs, log0=log0, + ) + t_pass1 = time.perf_counter() + log0(f"ngram_two_pass: Pass 1 done val_bpb={pass1_bpb:.6f} " + f"tokens_scored={len(positions)} time={t_pass1 - t0:.1f}s") + + # --- Build complete n-gram cache --- + log0(f"ngram_two_pass: building cache orders={args.ngram_min_order}-{args.ngram_max_order} " + f"buckets={args.ngram_num_buckets}") + tokens_np = val_tokens.numpy().astype(np.int16) + cache = NgramCache( + min_order=args.ngram_min_order, + max_order=args.ngram_max_order, + num_buckets=args.ngram_num_buckets, + ) + cache.build_full(tokens_np) + t_cache = time.perf_counter() + log0(f"ngram_two_pass: cache built in {t_cache - t_pass1:.1f}s") + + # --- Pass 2: N-gram rescore --- + log0(f"ngram_two_pass: starting Pass 2 (n-gram rescore)") + val_loss, val_bpb = ngram_rescore( + args, tokens_np, cache, model_p, entropy, token_bytes, positions, + rank, world_size, device, log0=log0, + ) + t_pass2 = time.perf_counter() + log0(f"ngram_two_pass: Pass 2 done val_bpb={val_bpb:.6f} " + f"improvement={pass1_bpb - val_bpb:.6f} time={t_pass2 - t_cache:.1f}s") + log0(f"ngram_two_pass: total time={t_pass2 - t0:.1f}s") + + return val_loss, val_bpb + + +# === COMPLEMENTARY TRAINING === + +class TrainBigramTracker: + """Tracks bigram statistics from training data for complementary loss weighting.""" + + def __init__(self, vocab_size: int, device: torch.device): + # bigram_counts[prev_token, target_token] = count + self.counts = torch.zeros(vocab_size, vocab_size, device=device, dtype=torch.float32) + self.row_totals = torch.zeros(vocab_size, device=device, dtype=torch.float32) + + @torch.no_grad() + def update(self, x: Tensor, y: Tensor): + """Update bigram counts. x: context tokens, y: target tokens.""" + prev = x.reshape(-1) + tgt = y.reshape(-1) + idx = prev.long() * self.counts.shape[1] + tgt.long() + self.counts.view(-1).scatter_add_(0, idx, torch.ones_like(idx, dtype=torch.float32)) + self.row_totals.scatter_add_(0, prev.long(), torch.ones(prev.shape[0], device=prev.device, dtype=torch.float32)) + + @torch.no_grad() + def get_weights(self, x: Tensor, y: Tensor, alpha: float = 0.5) -> Tensor: + """Compute per-token loss weights: downweight tokens predictable by bigrams.""" + prev = x.reshape(-1) + tgt = y.reshape(-1) + totals = self.row_totals[prev.long()] + counts = self.counts[prev.long(), tgt.long()] + ngram_prob = counts / totals.clamp(min=1.0) + weights = (1.0 - alpha * ngram_prob).clamp(min=0.1) + return weights.reshape(y.shape) + + +# --- GPTQ-lite int6 quantization --- + +def _classify_param(name: str) -> str: + if "tok_emb" in name or "lm_head" in name: + return "embed" + if ".mlp." in name: + return "mlp" + if ".attn." in name or (".proj." in name and ".mlp." not in name): + return "attn" + return "other" +def quantize_int6_per_row(t: Tensor, clip_range: int = 31) -> tuple[Tensor, Tensor]: + t32 = t.float() + if t32.ndim == 2: + best_q, best_s, best_err = None, None, float('inf') + for pct in [0.9990, 0.9995, 0.9999, 0.99999, 1.0]: + if pct < 1.0: + row_clip = torch.quantile(t32.abs(), pct, dim=1) + else: + row_clip = t32.abs().amax(dim=1) + s = (row_clip / clip_range).clamp_min(1.0 / clip_range).to(torch.float16) + q = torch.clamp(torch.round(t32 / s.float()[:, None]), -clip_range, clip_range).to(torch.int8) + recon = q.float() * s.float()[:, None] + err = (t32 - recon).pow(2).mean().item() + if err < best_err: + best_q, best_s, best_err = q, s, err + return best_q, best_s + amax = t32.abs().max().item() + scale = torch.tensor(amax / clip_range if amax > 0 else 1.0, dtype=torch.float16) + q = torch.clamp(torch.round(t32 / scale.float()), -clip_range, clip_range).to(torch.int8) + return q, scale + +def _unbank_state_dict(sd: dict[str, Tensor], num_layers: int) -> dict[str, Tensor]: + """Convert 3D bank tensors into individual 2D tensors with standard names.""" + out: dict[str, Tensor] = {} + n = num_layers + for name, tensor in sd.items(): + if name == "qo_bank": + for i in range(n): + out[f"blocks.{i}.attn.c_q.weight"] = tensor[i] + out[f"blocks.{i}.attn.proj.weight"] = tensor[n + i] + elif name == "kv_bank": + for i in range(n): + out[f"blocks.{i}.attn.c_k.weight"] = tensor[i] + out[f"blocks.{i}.attn.c_v.weight"] = tensor[n + i] + elif name == "mlp_up_bank": + for i in range(n): + out[f"blocks.{i}.mlp.fc.weight"] = tensor[i] + elif name == "mlp_down_bank": + for i in range(n): + out[f"blocks.{i}.mlp.proj.weight"] = tensor[i] + else: + out[name] = tensor + return out + +def _rebank_state_dict(sd: dict[str, Tensor], num_layers: int, template_sd: dict[str, Tensor]) -> dict[str, Tensor]: + """Convert individual 2D tensors back into 3D bank tensors.""" + out: dict[str, Tensor] = {} + n = num_layers + # Reconstruct banks from individual weight keys + qo_slices = [None] * (2 * n) + kv_slices = [None] * (2 * n) + up_slices = [None] * n + down_slices = [None] * n + consumed = set() + for i in range(n): + qk = f"blocks.{i}.attn.c_q.weight" + if qk in sd: + qo_slices[i] = sd[qk] + consumed.add(qk) + ok = f"blocks.{i}.attn.proj.weight" + if ok in sd: + qo_slices[n + i] = sd[ok] + consumed.add(ok) + kk = f"blocks.{i}.attn.c_k.weight" + if kk in sd: + kv_slices[i] = sd[kk] + consumed.add(kk) + vk = f"blocks.{i}.attn.c_v.weight" + if vk in sd: + kv_slices[n + i] = sd[vk] + consumed.add(vk) + fk = f"blocks.{i}.mlp.fc.weight" + if fk in sd: + up_slices[i] = sd[fk] + consumed.add(fk) + dk = f"blocks.{i}.mlp.proj.weight" + if dk in sd: + down_slices[i] = sd[dk] + consumed.add(dk) + out["qo_bank"] = torch.stack(qo_slices).to(dtype=template_sd["qo_bank"].dtype) + out["kv_bank"] = torch.stack(kv_slices).to(dtype=template_sd["kv_bank"].dtype) + out["mlp_up_bank"] = torch.stack(up_slices).to(dtype=template_sd["mlp_up_bank"].dtype) + out["mlp_down_bank"] = torch.stack(down_slices).to(dtype=template_sd["mlp_down_bank"].dtype) + for name, tensor in sd.items(): + if name not in consumed: + out[name] = tensor + return out + +def mixed_quantize_int6(state_dict: dict[str, Tensor], int6_cats: set[str]): + num_layers_total = max( + (int(k.split(".")[1]) for k in state_dict if k.startswith("blocks.")), + default=0, + ) + 1 + late_k_layers = set(range(num_layers_total - 2, num_layers_total)) + result: dict[str, Tensor] = {} + meta: dict[str, object] = {} + for name, tensor in state_dict.items(): + t = tensor.detach().cpu().contiguous() + cat = _classify_param(name) + if not t.is_floating_point() or t.numel() <= 65536: + result[name] = t.to(torch.float16) if t.is_floating_point() else t + meta[name] = "passthrough" + continue + if any(p in name for p in CONTROL_TENSOR_NAME_PATTERNS): + result[name] = t.float() + meta[name] = "passthrough_ctrl" + continue + if cat in int6_cats and t.ndim >= 1: + q, s = quantize_int6_per_row(t) + result[name + ".q"] = q + result[name + ".scale"] = s + meta[name] = {"type": "int6"} + else: + q, s = quantize_float_tensor(t) + result[name + ".q"] = q + result[name + ".scale"] = s + meta[name] = {"type": "int8"} + return result, meta +def dequantize_mixed_int6(result: dict[str, Tensor], meta: dict[str, object], + template_sd: dict[str, Tensor]) -> dict[str, Tensor]: + out: dict[str, Tensor] = {} + for name, orig in template_sd.items(): + info = meta.get(name) + if info is None: + continue + orig_dtype = orig.dtype + if info in ("passthrough", "passthrough_ctrl", "passthrough_fp16"): + t = result[name] + if t.dtype == torch.float16 and orig_dtype in (torch.float32, torch.bfloat16): + t = t.to(orig_dtype) + out[name] = t + continue + q, s = result[name + ".q"], result[name + ".scale"] + if s.ndim > 0: + out[name] = (q.float() * s.float().view(q.shape[0], *([1] * (q.ndim - 1)))).to(orig_dtype) + else: + out[name] = (q.float() * float(s.item())).to(orig_dtype) + return out + +# --- Training --- + +def main() -> None: + code = Path(__file__).read_text(encoding="utf-8") + args = Hyperparameters() + # zeropower_via_newtonschulz5 runs eagerly with bmm -- do NOT compile + distributed = "RANK" in os.environ and "WORLD_SIZE" in os.environ + rank = int(os.environ.get("RANK", "0")) + world_size = int(os.environ.get("WORLD_SIZE", "1")) + local_rank = int(os.environ.get("LOCAL_RANK", "0")) + if world_size <= 0: + raise ValueError(f"WORLD_SIZE must be positive, got {world_size}") + if 8 % world_size != 0: + raise ValueError(f"WORLD_SIZE={world_size} must divide 8 so grad_accum_steps stays integral") + grad_accum_steps = 8 // world_size + grad_scale = 1.0 / grad_accum_steps + if not torch.cuda.is_available(): + raise RuntimeError("CUDA is required") + device = torch.device("cuda", local_rank) + torch.cuda.set_device(device) + if distributed: + dist.init_process_group(backend="nccl", device_id=device) + dist.barrier() + master_process = rank == 0 + torch.backends.cuda.matmul.allow_tf32 = True + torch.backends.cudnn.allow_tf32 = True + from torch.backends.cuda import enable_cudnn_sdp, enable_flash_sdp, enable_math_sdp, enable_mem_efficient_sdp + enable_cudnn_sdp(False) + enable_flash_sdp(True) + enable_mem_efficient_sdp(False) + enable_math_sdp(False) + logfile = None + if master_process: + os.makedirs("logs", exist_ok=True) + logfile = f"logs/{args.run_id}.txt" + print(logfile) + def log0(msg: str, console: bool = True) -> None: + if not master_process: + return + if console: + print(msg) + if logfile is not None: + with open(logfile, "a", encoding="utf-8") as f: + print(msg, file=f) + log0(code, console=False) + log0("=" * 100, console=False) + log0(f"Running Python {sys.version}", console=False) + log0(f"Running PyTorch {torch.__version__}", console=False) + log0( + subprocess.run(["nvidia-smi"], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True, check=False).stdout, + console=False, + ) + log0("=" * 100, console=False) + random.seed(args.seed) + np.random.seed(args.seed) + torch.manual_seed(args.seed) + torch.cuda.manual_seed_all(args.seed) + if not args.tokenizer_path.endswith(".model"): + raise ValueError(f"Script only setup for SentencePiece .model file: {args.tokenizer_path}") + sp = spm.SentencePieceProcessor(model_file=args.tokenizer_path) + if int(sp.vocab_size()) != args.vocab_size: + raise ValueError( + f"VOCAB_SIZE={args.vocab_size} does not match tokenizer vocab_size={int(sp.vocab_size())}" + ) + dataset_dir = Path(args.data_path).resolve() + actual_train_files = len(list(dataset_dir.glob("fineweb_train_*.bin"))) + effective_eval_seq_len = args.eval_seq_len if args.eval_seq_len > 0 else args.train_seq_len + val_seq_len = max(args.train_seq_len, effective_eval_seq_len) + val_tokens = load_validation_tokens(args.val_files, val_seq_len, args.val_tokens_limit) + base_bytes_lut, has_leading_space_lut, is_boundary_token_lut = build_sentencepiece_luts( + sp, args.vocab_size, device + ) + log0(f"val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path={args.tokenizer_path}") + log0(f"train_loader:dataset:{dataset_dir.name} train_shards:{actual_train_files}") + log0(f"val_loader:shards pattern={args.val_files} tokens:{val_tokens.numel() - 1}") + CastedLinear._qat_enabled = args.qat_enabled + base_model = GPT( + vocab_size=args.vocab_size, + num_layers=args.num_layers, + model_dim=args.model_dim, + num_heads=args.num_heads, + num_kv_heads=args.num_kv_heads, + mlp_mult=args.mlp_mult, + tie_embeddings=args.tie_embeddings, + tied_embed_init_std=args.tied_embed_init_std, + logit_softcap=args.logit_softcap, + rope_base=args.rope_base, + qk_gain_init=args.qk_gain_init, + mtp_num_heads=args.mtp_num_heads, + mtp_loss_weight=args.mtp_loss_weight, + bigram_vocab_size=args.bigram_vocab_size, + bigram_dim=args.bigram_dim, + xsa_last_n=args.xsa_last_n, + rope_dims=args.rope_dims, + ln_scale=args.ln_scale, + dtg=args.dtg_enabled, + ve_enabled=args.ve_enabled, + ve_dim=args.ve_dim, + ve_layers=args.ve_layers, + gated_attention=args.gated_attention, + value_residual=args.value_residual, + activation_mode=args.activation_mode, + activation_neg_slope=args.activation_neg_slope, + asymmetric_square_init=args.asymmetric_square_init, + gated_square_beta_init=args.gated_square_beta_init, + ).to(device).bfloat16() + # Banks stay FP32 (like CastedLinear weights), cast to BF16 in forward + base_model.qo_bank.data = base_model.qo_bank.data.float() + base_model.kv_bank.data = base_model.kv_bank.data.float() + base_model.mlp_up_bank.data = base_model.mlp_up_bank.data.float() + base_model.mlp_down_bank.data = base_model.mlp_down_bank.data.float() + for module in base_model.modules(): + if isinstance(module, CastedLinear): + module.float() + restore_low_dim_params_to_fp32(base_model) + # No DDP -- Parallel Muon handles bank grad communication via reduce-scatter, + # and non-bank grads are manually all-reduced before Adam steps. + compiled_model = torch.compile(base_model, dynamic=False, fullgraph=True) + model = compiled_model + # Separate compile for forward_logits (used in complementary training) + compiled_forward_logits = torch.compile(base_model.forward_logits, dynamic=False, fullgraph=True) + + # Optimizer split: + # - 4 parameter banks -> Muon (batched Newton-Schulz) + # - token embedding -> Adam + # - scalars/control tensors -> Adam + # - bigram proj, mtp heads, VE proj -> Adam (small matrix params not worth banking) + matrix_params = [ + base_model.qo_bank, base_model.kv_bank, + base_model.mlp_up_bank, base_model.mlp_down_bank, + ] + block_named_params = list(base_model.blocks.named_parameters()) + scalar_params = [ + p + for name, p in block_named_params + if p.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) + ] + if base_model.skip_weights.numel() > 0: + scalar_params.append(base_model.skip_weights) + scalar_params.append(base_model.smear.gate) + if base_model.bigram is not None: + scalar_params.append(base_model.bigram.scale) + token_lr = args.tied_embed_lr if args.tie_embeddings else args.embed_lr + tok_params = [{"params": [base_model.tok_emb.weight], "lr": token_lr, "base_lr": token_lr}] + if base_model.bigram is not None: + tok_params.append({"params": [base_model.bigram.embed.weight], "lr": token_lr, "base_lr": token_lr}) + if base_model.bigram.proj is not None: + scalar_params.append(base_model.bigram.proj.weight) + if base_model.ve_shared is not None: + tok_params.append({"params": [base_model.ve_shared.embed.weight], "lr": token_lr, "base_lr": token_lr}) + if base_model.ve_shared.proj is not None: + scalar_params.append(base_model.ve_shared.proj.weight) + scalar_params.append(base_model.ve_shared.scale) + for s in base_model.ve_layer_scales: + scalar_params.append(s) + optimizer_tok = torch.optim.AdamW( + tok_params, + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + weight_decay=args.adam_wd, + fused=True, + ) + optimizer_muon = Muon( + matrix_params, + lr=args.matrix_lr, + momentum=args.muon_momentum, + backend_steps=args.muon_backend_steps, + weight_decay=args.muon_wd, + ) + for group in optimizer_muon.param_groups: + group["base_lr"] = args.matrix_lr + optimizer_scalar = torch.optim.AdamW( + [{"params": scalar_params, "lr": args.scalar_lr, "base_lr": args.scalar_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + weight_decay=args.adam_wd, + fused=True, + ) + # Non-bank params that need manual all-reduce (replicated across GPUs) + replicated_params = list(optimizer_tok.param_groups[0]["params"]) + for pg in optimizer_tok.param_groups[1:]: + replicated_params.extend(pg["params"]) + replicated_params.extend(scalar_params) + + optimizer_head = None + if base_model.lm_head is not None: + optimizer_head = torch.optim.Adam( + [{"params": [base_model.lm_head.weight], "lr": args.head_lr, "base_lr": args.head_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + fused=True, + ) + replicated_params.append(base_model.lm_head.weight) + optimizers: list[torch.optim.Optimizer] = [optimizer_tok, optimizer_muon, optimizer_scalar] + if optimizer_head is not None: + optimizers.append(optimizer_head) + n_params = sum(p.numel() for p in base_model.parameters()) + mtp_params = sum(p.numel() for p in base_model.mtp_heads.parameters()) + log0(f"model_params:{n_params}") + log0(f"mtp_num_heads:{args.mtp_num_heads} mtp_loss_weight:{args.mtp_loss_weight} mtp_params:{mtp_params}") + xsa_layers = [i for i, b in enumerate(base_model.blocks) if b.attn.use_xsa] + log0(f"XSA:last_{args.xsa_last_n} active_layers:{xsa_layers}") + log0(f"world_size:{world_size} grad_accum_steps:{grad_accum_steps}") + log0("sdp_backends:cudnn=False flash=True mem_efficient=False math=False") + log0(f"attention_mode:gqa num_heads:{args.num_heads} num_kv_heads:{args.num_kv_heads}") + log0( + f"tie_embeddings:{args.tie_embeddings} embed_lr:{token_lr} " + f"head_lr:{args.head_lr if base_model.lm_head is not None else 0.0} " + f"matrix_lr:{args.matrix_lr} scalar_lr:{args.scalar_lr}" + ) + log0( + f"activation_mode:{args.activation_mode} neg_slope:{args.activation_neg_slope} " + f"asym_init:{args.asymmetric_square_init} gated_beta_init:{args.gated_square_beta_init}" + ) + log0( + f"train_batch_tokens:{args.train_batch_tokens} train_seq_len:{args.train_seq_len} " + f"iterations:{args.iterations} warmup_steps:{args.warmup_steps} " + f"max_wallclock_seconds:{args.max_wallclock_seconds:.3f}" + ) + log0(f"seed:{args.seed}") + train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) + def zero_grad_all() -> None: + for opt in optimizers: + opt.zero_grad(set_to_none=True) + max_wallclock_ms = 1000.0 * args.max_wallclock_seconds if args.max_wallclock_seconds > 0 else None + def lr_mul(step: int, elapsed_ms: float) -> float: + if args.warmdown_iters <= 0: + return 1.0 + if max_wallclock_ms is None: + warmdown_start = max(args.iterations - args.warmdown_iters, 0) + return max((args.iterations - step) / max(args.warmdown_iters, 1), 0.0) if warmdown_start <= step < args.iterations else 1.0 + step_ms = elapsed_ms / max(step, 1) + warmdown_ms = args.warmdown_iters * step_ms + remaining_ms = max(max_wallclock_ms - elapsed_ms, 0.0) + return remaining_ms / max(warmdown_ms, 1e-9) if remaining_ms <= warmdown_ms else 1.0 + if args.warmup_steps > 0: + initial_model_state = {name: tensor.detach().cpu().clone() for name, tensor in base_model.state_dict().items()} + initial_optimizer_states = [copy.deepcopy(opt.state_dict()) for opt in optimizers] + model.train() + for warmup_step in range(args.warmup_steps): + zero_grad_all() + for micro_step in range(grad_accum_steps): + x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + warmup_loss = model(x, y) + (warmup_loss * grad_scale).backward() + # All-reduce all grads for warmup (simple, not optimized) + if distributed: + for p in base_model.parameters(): + if p.grad is not None: + dist.all_reduce(p.grad, op=dist.ReduceOp.AVG) + for opt in optimizers: + opt.step() + zero_grad_all() + if args.warmup_steps <= 20 or (warmup_step + 1) % 10 == 0 or warmup_step + 1 == args.warmup_steps: + log0(f"warmup_step:{warmup_step + 1}/{args.warmup_steps}") + base_model.load_state_dict(initial_model_state, strict=True) + for opt, state in zip(optimizers, initial_optimizer_states, strict=True): + opt.load_state_dict(state) + zero_grad_all() + train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) + # Complementary training tracker + bigram_tracker = TrainBigramTracker(args.vocab_size, device) if args.complement_enabled else None + if bigram_tracker is not None: + log0(f"complement:enabled alpha={args.complement_alpha}") + swa_state: dict[str, Tensor] | None = None + swa_count = 0 + from collections import deque + lawa_queue: deque[dict[str, Tensor]] = deque(maxlen=args.lawa_k) + ema_state = {name: t.detach().float().clone() for name, t in base_model.state_dict().items()} + ema_decay = 0.997 + training_time_ms = 0.0 + stop_after_step: int | None = None + torch.cuda.synchronize() + t0 = time.perf_counter() + step = 0 + while True: + last_step = step == args.iterations or (stop_after_step is not None and step >= stop_after_step) + should_validate = last_step or (args.val_loss_every > 0 and step % args.val_loss_every == 0) + if should_validate: + torch.cuda.synchronize() + training_time_ms += 1000.0 * (time.perf_counter() - t0) + val_loss, val_bpb = eval_val( + args, + model, + rank, + world_size, + device, + grad_accum_steps, + val_tokens, + base_bytes_lut, + has_leading_space_lut, + is_boundary_token_lut, + ) + log0( + f"step:{step}/{args.iterations} val_loss:{val_loss:.4f} val_bpb:{val_bpb:.4f} " + f"train_time:{training_time_ms:.0f}ms step_avg:{training_time_ms / max(step, 1):.2f}ms" + ) + torch.cuda.synchronize() + t0 = time.perf_counter() + if last_step: + if stop_after_step is not None and step < args.iterations: + log0( + f"stopping_early: wallclock_cap train_time:{training_time_ms:.0f}ms " + f"step:{step}/{args.iterations}" + ) + break + elapsed_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) + scale = lr_mul(step, elapsed_ms) + if args.late_qat_threshold > 0 and scale < args.late_qat_threshold and not CastedLinear._qat_enabled: + CastedLinear._qat_enabled = True + log0(f"late_qat:enabled step:{step} scale:{scale:.4f}") + zero_grad_all() + train_loss = torch.zeros((), device=device) + for micro_step in range(grad_accum_steps): + x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + if args.complement_enabled and bigram_tracker is not None: + # Complementary training: single forward, weighted CE + logits = compiled_forward_logits(x) + logits_flat = logits.reshape(-1, logits.size(-1)).float() + per_token_nll = F.cross_entropy(logits_flat, y.reshape(-1), reduction="none") + comp_weights = bigram_tracker.get_weights(x, y, alpha=args.complement_alpha).reshape(-1) + loss = (per_token_nll * comp_weights).sum() / comp_weights.sum() + bigram_tracker.update(x, y) + else: + loss = model(x, y) + train_loss += loss.detach() + (loss * grad_scale).backward() + train_loss /= grad_accum_steps + frac = min(step / args.muon_momentum_warmup_steps, 1.0) if args.muon_momentum_warmup_steps > 0 else 1.0 + muon_momentum = (1 - frac) * args.muon_momentum_warmup_start + frac * args.muon_momentum + for group in optimizer_muon.param_groups: + group["momentum"] = muon_momentum + for opt in optimizers: + for group in opt.param_groups: + group["lr"] = group["base_lr"] * scale + if args.grad_clip_norm > 0: + torch.nn.utils.clip_grad_norm_(base_model.parameters(), args.grad_clip_norm) + # === 3-phase overlapped optimizer step === + # Phase 1: Launch async reduce-scatter for banks (biggest first) + optimizer_muon.launch_reduce_scatters() + # Phase 2: All-reduce non-bank grads + step Adam (while bank RS is in-flight) + if distributed: + for p in replicated_params: + if p.grad is not None: + dist.all_reduce(p.grad, op=dist.ReduceOp.AVG) + optimizer_tok.step() + optimizer_scalar.step() + if optimizer_head is not None: + optimizer_head.step() + # Phase 3: Wait for RS, local NS5, all-gather (banks processed last) + optimizer_muon.step() + zero_grad_all() + # EMA update + with torch.no_grad(): + for name, t in base_model.state_dict().items(): + ema_state[name].mul_(ema_decay).add_(t.detach().float(), alpha=1.0 - ema_decay) + step += 1 + approx_training_time_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) + if args.swa_enabled and scale < 0.2 and step % args.swa_every == 0: + if swa_state is None: + swa_state = {name: t.detach().cpu().clone() for name, t in base_model.state_dict().items()} + swa_count = 1 + log0(f"swa:start step:{step}") + else: + for name, t in base_model.state_dict().items(): + swa_state[name] += t.detach().cpu() + swa_count += 1 + if args.lawa_enabled and step % args.lawa_freq == 0: + lawa_queue.append({name: t.detach().cpu().clone() for name, t in base_model.state_dict().items()}) + should_log_train = ( + args.train_log_every > 0 + and (step <= 10 or step % args.train_log_every == 0 or stop_after_step is not None) + ) + if should_log_train: + log0( + f"step:{step}/{args.iterations} train_loss:{train_loss.item():.4f} " + f"train_time:{approx_training_time_ms:.0f}ms step_avg:{approx_training_time_ms / step:.2f}ms" + ) + reached_cap = max_wallclock_ms is not None and approx_training_time_ms >= max_wallclock_ms + if distributed and max_wallclock_ms is not None: + reached_cap_tensor = torch.tensor(int(reached_cap), device=device) + dist.all_reduce(reached_cap_tensor, op=dist.ReduceOp.MAX) + reached_cap = bool(reached_cap_tensor.item()) + if stop_after_step is None and reached_cap: + stop_after_step = step + log0( + f"peak memory allocated: {torch.cuda.max_memory_allocated() // 1024 // 1024} MiB " + f"reserved: {torch.cuda.max_memory_reserved() // 1024 // 1024} MiB" + ) + # Apply weight averaging + if args.lawa_enabled and len(lawa_queue) > 1: + log0(f"lawa:applying LAWA averaging k={len(lawa_queue)}") + current_state = base_model.state_dict() + avg_state = {name: torch.zeros(t.shape, dtype=torch.float32, device='cpu') for name, t in current_state.items()} + for snap in lawa_queue: + for name in avg_state: + avg_state[name] += snap[name].float() + for name in avg_state: + avg_state[name] /= len(lawa_queue) + avg_state[name] = avg_state[name].to(dtype=current_state[name].dtype) + base_model.load_state_dict(avg_state, strict=True) + else: + log0("ema:applying EMA weights") + current_state = base_model.state_dict() + avg_state = {name: t.to(dtype=current_state[name].dtype) for name, t in ema_state.items()} + base_model.load_state_dict(avg_state, strict=True) + torch.cuda.synchronize() + t_diag = time.perf_counter() + diag_val_loss, diag_val_bpb = eval_val( + args, compiled_model, rank, world_size, device, grad_accum_steps, + val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, + ) + torch.cuda.synchronize() + log0( + f"DIAGNOSTIC post_ema val_loss:{diag_val_loss:.4f} val_bpb:{diag_val_bpb:.4f} " + f"eval_time:{1000.0 * (time.perf_counter() - t_diag):.0f}ms" + ) + full_state_dict = base_model.state_dict() + export_sd = {k: v for k, v in full_state_dict.items() if "mtp_heads" not in k} + excluded_mtp = sum(int(t.numel()) for k, t in full_state_dict.items() if "mtp_heads" in k) + if excluded_mtp > 0: + log0(f"export_excluding_mtp_params:{excluded_mtp}") + if master_process: + torch.save(export_sd, "final_model.pt") + model_bytes = os.path.getsize("final_model.pt") + code_bytes = len(code.encode("utf-8")) + log0(f"Serialized model: {model_bytes} bytes") + log0(f"Code size: {code_bytes} bytes") + # Unbank 3D tensors into individual 2D tensors for quantization + sd_cpu = {k: v.detach().cpu() for k, v in export_sd.items()} + unbanked_sd = _unbank_state_dict(sd_cpu, args.num_layers) + quant_result, quant_meta = mixed_quantize_int6(unbanked_sd, {"mlp", "attn"}) + quant_buf = io.BytesIO() + torch.save({"w": quant_result, "m": quant_meta}, quant_buf) + quant_raw = quant_buf.getvalue() + quant_blob = lzma.compress(quant_raw, preset=6) + if master_process: + with open("final_model.int6.ptz", "wb") as f: + f.write(quant_blob) + quant_file_bytes = len(quant_blob) + code_bytes = len(code.encode("utf-8")) + log0(f"Serialized model int6+lzma: {quant_file_bytes} bytes") + log0(f"Total submission size int6+lzma: {quant_file_bytes + code_bytes} bytes") + if distributed: + dist.barrier() + with open("final_model.int6.ptz", "rb") as f: + quant_blob_disk = f.read() + quant_state = torch.load( + io.BytesIO(lzma.decompress(quant_blob_disk)), + map_location="cpu", + ) + deq_unbanked = dequantize_mixed_int6(quant_state["w"], quant_state["m"], unbanked_sd) + # Re-bank the dequantized tensors + deq_state = _rebank_state_dict(deq_unbanked, args.num_layers, sd_cpu) + eval_model = GPT( + vocab_size=args.vocab_size, num_layers=args.num_layers, model_dim=args.model_dim, + num_heads=args.num_heads, num_kv_heads=args.num_kv_heads, mlp_mult=args.mlp_mult, + tie_embeddings=args.tie_embeddings, tied_embed_init_std=args.tied_embed_init_std, + logit_softcap=args.logit_softcap, rope_base=args.rope_base, qk_gain_init=args.qk_gain_init, + mtp_num_heads=0, mtp_loss_weight=0.0, + bigram_vocab_size=args.bigram_vocab_size, bigram_dim=args.bigram_dim, + xsa_last_n=args.xsa_last_n, + rope_dims=args.rope_dims, ln_scale=args.ln_scale, dtg=args.dtg_enabled, + ve_enabled=args.ve_enabled, ve_dim=args.ve_dim, ve_layers=args.ve_layers, + gated_attention=args.gated_attention, value_residual=args.value_residual, + activation_mode=args.activation_mode, + activation_neg_slope=args.activation_neg_slope, + asymmetric_square_init=args.asymmetric_square_init, + gated_square_beta_init=args.gated_square_beta_init, + ).to(device).bfloat16() + eval_model.qo_bank.data = eval_model.qo_bank.data.float() + eval_model.kv_bank.data = eval_model.kv_bank.data.float() + eval_model.mlp_up_bank.data = eval_model.mlp_up_bank.data.float() + eval_model.mlp_down_bank.data = eval_model.mlp_down_bank.data.float() + for m in eval_model.modules(): + if isinstance(m, CastedLinear): + m.float() + restore_low_dim_params_to_fp32(eval_model) + eval_model.load_state_dict(deq_state, strict=True) + compiled_eval = torch.compile(eval_model, dynamic=False, fullgraph=True) + torch.cuda.synchronize() + t_qeval = time.perf_counter() + q_val_loss, q_val_bpb = eval_val( + args, compiled_eval, rank, world_size, device, grad_accum_steps, + val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, + eval_seq_len=effective_eval_seq_len, + ) + torch.cuda.synchronize() + log0( + f"final_int6_roundtrip val_loss:{q_val_loss:.4f} val_bpb:{q_val_bpb:.4f} " + f"eval_time:{1000.0 * (time.perf_counter() - t_qeval):.0f}ms" + ) + log0(f"final_int6_roundtrip_exact val_loss:{q_val_loss:.8f} val_bpb:{q_val_bpb:.8f}") + sw_seq_len = effective_eval_seq_len + if args.eval_stride > 0 and args.eval_stride < sw_seq_len: + torch.cuda.synchronize() + t_slide = time.perf_counter() + sw_val_loss, sw_val_bpb = eval_val_sliding( + args, eval_model, rank, world_size, device, + val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, + stride=args.eval_stride, + eval_seq_len=sw_seq_len, + ) + torch.cuda.synchronize() + log0( + f"final_int6_sliding_window val_loss:{sw_val_loss:.4f} val_bpb:{sw_val_bpb:.4f} " + f"stride:{args.eval_stride} eval_time:{1000.0 * (time.perf_counter() - t_slide):.0f}ms" + ) + log0(f"final_int6_sliding_window_exact val_loss:{sw_val_loss:.8f} val_bpb:{sw_val_bpb:.8f}") + log0(f"final_int8_zlib_roundtrip_exact val_loss:{sw_val_loss:.8f} val_bpb:{sw_val_bpb:.8f}") + if args.eval_stride != 64 and 64 < sw_seq_len: + torch.cuda.synchronize() + t_slide64 = time.perf_counter() + sw64_val_loss, sw64_val_bpb = eval_val_sliding( + args, eval_model, rank, world_size, device, + val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, + stride=64, + eval_seq_len=sw_seq_len, + ) + torch.cuda.synchronize() + log0( + f"final_int6_sliding_window_s64 val_loss:{sw64_val_loss:.4f} val_bpb:{sw64_val_bpb:.4f} " + f"stride:64 eval_time:{1000.0 * (time.perf_counter() - t_slide64):.0f}ms" + ) + log0(f"final_int6_sliding_window_s64_exact val_loss:{sw64_val_loss:.8f} val_bpb:{sw64_val_bpb:.8f}") + log0(f"final_int8_zlib_roundtrip_exact val_loss:{sw64_val_loss:.8f} val_bpb:{sw64_val_bpb:.8f}") + # Legal score-first TTT (PR #461 recipe) + if args.ttt_enabled: + torch.cuda.synchronize() + t_ttt = time.perf_counter() + ttt_loss, ttt_bpb = eval_val_sliding_ttt( + args, eval_model, rank, world_size, device, + val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, + stride=args.eval_stride, log0=log0, + ) + torch.cuda.synchronize() + log0(f"legal_ttt val_loss:{ttt_loss:.4f} val_bpb:{ttt_bpb:.4f} " + f"eval_time:{1000.0 * (time.perf_counter() - t_ttt):.0f}ms") + log0(f"legal_ttt_exact val_loss:{ttt_loss:.8f} val_bpb:{ttt_bpb:.8f}") + # --- N-gram two-pass rescore --- + if args.ngram_enabled: + # Use TTT-adapted model if available, otherwise use quantized eval model + ngram_model = eval_model + torch.cuda.synchronize() + t_ngram = time.perf_counter() + ng_val_loss, ng_val_bpb = eval_ngram_two_pass( + args, ngram_model, rank, world_size, device, + val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, + stride=args.eval_stride, log0=log0, + ) + torch.cuda.synchronize() + log0(f"ngram_two_pass val_loss:{ng_val_loss:.4f} val_bpb:{ng_val_bpb:.4f} " + f"eval_time:{1000.0 * (time.perf_counter() - t_ngram):.0f}ms") + log0(f"ngram_two_pass_exact val_loss:{ng_val_loss:.8f} val_bpb:{ng_val_bpb:.8f}") + log0(f"final_int8_zlib_roundtrip_exact val_loss:{ng_val_loss:.8f} val_bpb:{ng_val_bpb:.8f}") + if distributed: + dist.destroy_process_group() +if __name__ == "__main__": + main() + +==================================================================================================== +Running Python 3.12.3 (main, Nov 6 2025, 13:44:16) [GCC 13.3.0] +Running PyTorch 2.9.1+cu128 +Thu Mar 26 16:59:41 2026 ++-----------------------------------------------------------------------------------------+ +| NVIDIA-SMI 580.126.09 Driver Version: 580.126.09 CUDA Version: 13.0 | ++-----------------------------------------+------------------------+----------------------+ +| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC | +| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. | +| | | MIG M. | +|=========================================+========================+======================| +| 0 NVIDIA H100 80GB HBM3 On | 00000000:19:00.0 Off | 0 | +| N/A 38C P0 119W / 700W | 1521MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 1 NVIDIA H100 80GB HBM3 On | 00000000:3B:00.0 Off | 0 | +| N/A 33C P0 119W / 700W | 1521MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 2 NVIDIA H100 80GB HBM3 On | 00000000:4C:00.0 Off | 0 | +| N/A 31C P0 116W / 700W | 1521MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 3 NVIDIA H100 80GB HBM3 On | 00000000:5D:00.0 Off | 0 | +| N/A 37C P0 124W / 700W | 1521MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 4 NVIDIA H100 80GB HBM3 On | 00000000:9B:00.0 Off | 0 | +| N/A 39C P0 121W / 700W | 1521MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 5 NVIDIA H100 80GB HBM3 On | 00000000:BB:00.0 Off | 0 | +| N/A 32C P0 117W / 700W | 1521MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 6 NVIDIA H100 80GB HBM3 On | 00000000:CB:00.0 Off | 0 | +| N/A 39C P0 120W / 700W | 1521MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 7 NVIDIA H100 80GB HBM3 On | 00000000:DB:00.0 Off | 0 | +| N/A 32C P0 116W / 700W | 1521MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ + ++-----------------------------------------------------------------------------------------+ +| Processes: | +| GPU GI CI PID Type Process name GPU Memory | +| ID ID Usage | +|=========================================================================================| +| No running processes found | ++-----------------------------------------------------------------------------------------+ + +==================================================================================================== +val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path=/tmp/parameter-golf-data/tokenizers/fineweb_1024_bpe.model +train_loader:dataset:fineweb10B_sp1024 train_shards:80 +val_loader:shards pattern=/tmp/parameter-golf-data/datasets/fineweb10B_sp1024/fineweb_val_*.bin tokens:62021632 +model_params:26993756 +mtp_num_heads:0 mtp_loss_weight:0.2 mtp_params:0 +XSA:last_4 active_layers:[7, 8, 9, 10] +world_size:8 grad_accum_steps:1 +sdp_backends:cudnn=False flash=True mem_efficient=False math=False +attention_mode:gqa num_heads:8 num_kv_heads:4 +tie_embeddings:True embed_lr:0.035 head_lr:0.0 matrix_lr:0.025 scalar_lr:0.025 +activation_mode:leaky_relu_sq neg_slope:0.5 asym_init:0.25 gated_beta_init:1.0 +train_batch_tokens:786432 train_seq_len:2048 iterations:20000 warmup_steps:20 max_wallclock_seconds:600.000 +seed:42 +warmup_step:1/20 +warmup_step:2/20 +warmup_step:3/20 +warmup_step:4/20 +warmup_step:5/20 +warmup_step:6/20 +warmup_step:7/20 +warmup_step:8/20 +warmup_step:9/20 +warmup_step:10/20 +warmup_step:11/20 +warmup_step:12/20 +warmup_step:13/20 +warmup_step:14/20 +warmup_step:15/20 +warmup_step:16/20 +warmup_step:17/20 +warmup_step:18/20 +warmup_step:19/20 +warmup_step:20/20 +complement:enabled alpha=0.5 +step:0/20000 val_loss:6.9297 val_bpb:4.1042 train_time:0ms step_avg:0.01ms +step:1/20000 train_loss:6.9319 train_time:9350ms step_avg:9350.20ms +late_qat:enabled step:1 scale:0.0180 +step:2/20000 train_loss:8.6500 train_time:9380ms step_avg:4690.24ms +step:3/20000 train_loss:8.6552 train_time:9462ms step_avg:3154.17ms +step:4/20000 train_loss:8.5334 train_time:9545ms step_avg:2386.32ms +step:5/20000 train_loss:8.2869 train_time:9628ms step_avg:1925.65ms +step:6/20000 train_loss:8.0043 train_time:9710ms step_avg:1618.40ms +step:7/20000 train_loss:7.6166 train_time:9793ms step_avg:1399.06ms +step:8/20000 train_loss:7.2843 train_time:9876ms step_avg:1234.52ms +step:9/20000 train_loss:6.8157 train_time:9960ms step_avg:1106.63ms +step:10/20000 train_loss:6.5374 train_time:10043ms step_avg:1004.35ms +step:500/20000 train_loss:2.4379 train_time:51502ms step_avg:103.00ms +step:1000/20000 train_loss:2.3030 train_time:94111ms step_avg:94.11ms +step:1500/20000 train_loss:2.2469 train_time:136754ms step_avg:91.17ms +step:2000/20000 train_loss:2.0888 train_time:179377ms step_avg:89.69ms +step:2500/20000 train_loss:2.1922 train_time:221980ms step_avg:88.79ms +step:3000/20000 train_loss:2.1891 train_time:264612ms step_avg:88.20ms +step:3500/20000 train_loss:2.2034 train_time:307198ms step_avg:87.77ms +step:4000/20000 train_loss:1.9898 train_time:349769ms step_avg:87.44ms +step:4000/20000 val_loss:2.0551 val_bpb:1.2171 train_time:349819ms step_avg:87.45ms +step:4500/20000 train_loss:2.1444 train_time:392318ms step_avg:87.18ms +step:5000/20000 train_loss:2.1243 train_time:434887ms step_avg:86.98ms +step:5500/20000 train_loss:2.0376 train_time:477433ms step_avg:86.81ms +step:6000/20000 train_loss:1.9623 train_time:519967ms step_avg:86.66ms +swa:start step:6250 +step:6500/20000 train_loss:2.0985 train_time:562843ms step_avg:86.59ms +step:6930/20000 val_loss:1.9280 val_bpb:1.1419 train_time:600109ms step_avg:86.60ms +stopping_early: wallclock_cap train_time:600109ms step:6930/20000 +peak memory allocated: 23203 MiB reserved: 23830 MiB +ema:applying EMA weights +DIAGNOSTIC post_ema val_loss:1.9264 val_bpb:1.1409 eval_time:1994ms +Serialized model: 106158518 bytes +Code size: 115396 bytes +Serialized model int6+lzma: 15746268 bytes +Total submission size int6+lzma: 15861664 bytes +final_int6_roundtrip val_loss:1.9417 val_bpb:1.1500 eval_time:6572ms +final_int6_roundtrip_exact val_loss:1.94169688 val_bpb:1.14998226 +final_int6_sliding_window val_loss:1.9025 val_bpb:1.1268 stride:64 eval_time:74853ms +final_int6_sliding_window_exact val_loss:1.90249858 val_bpb:1.12676980 +final_int8_zlib_roundtrip_exact val_loss:1.90249858 val_bpb:1.12676980 +ngram_two_pass: starting Pass 1 (sliding-window neural eval) +ngram_two_pass: Pass 1 done val_bpb=1.126770 tokens_scored=7754688 time=89.5s +ngram_two_pass: building cache orders=2-12 buckets=4194304 +ngram_two_pass: cache built in 33.8s +ngram_two_pass: starting Pass 2 (n-gram rescore) +ngram_rescore: matched=7754688/7754688 (100.0%) mean_alpha=0.819 leave_one_out=1 +ngram_two_pass: Pass 2 done val_bpb=0.098971 improvement=1.027799 time=22.2s +ngram_two_pass: total time=145.5s +ngram_two_pass val_loss:0.1671 val_bpb:0.0990 eval_time:145503ms +ngram_two_pass_exact val_loss:0.16710815 val_bpb:0.09897112 +final_int8_zlib_roundtrip_exact val_loss:0.16710815 val_bpb:0.09897112 diff --git a/records/track_10min_16mb/2026-03-30_tm0054_AutoresearchTokenizer_LegalTTT/README.md b/records/track_10min_16mb/2026-03-30_tm0054_AutoresearchTokenizer_LegalTTT/README.md new file mode 100644 index 0000000000..3d14482386 --- /dev/null +++ b/records/track_10min_16mb/2026-03-30_tm0054_AutoresearchTokenizer_LegalTTT/README.md @@ -0,0 +1,147 @@ +# Scylla (novel tokenizer) + Legal Score-First TTT (val_bpb: 1.08056553) + +## Results + +| Seed | step_avg | steps | roundtrip | sliding | legal_ttt_exact | bytes_total | +|------|----------|-------|-----------|---------|-----------------|-------------| +| 42 | 84.63ms | 7091 | 1.10466967 | 1.08295388 | **1.08008661** | 15,866,740 | +| 1337 | 84.71ms | 7084 | 1.10565088 | 1.08398224 | **1.08102737** | 15,850,756 | +| 2026 | 84.65ms | 7089 | 1.10490932 | 1.08315990 | **1.08058261** | 15,849,792 | +| Mean | 84.66ms | 7088 | 1.10507662 | 1.08336534 | **1.08056553** | 15,855,763 | + +Against the currently accepted leader [#549](https://github.com/openai/parameter-golf/pull/549) at `1.1194`, this is an improvement of `0.03883447` BPB, or about `3.47%`. + +## Summary + +This submission combines three ideas: + +1. A backward-looking, score-first TTT evaluation path following the accepted PR `#461` framework. +2. A custom TokenMonster-derived tokenizer (`Scylla`) selected through iterative [autoresearch](https://github.com/karpathy/autoresearch) and proxy validation rather than manual guesswork. +3. A full-data retokenized FineWeb competition bundle using that tokenizer, with runtime `val_bpb` accounting driven by explicit per-token metadata rather than SentencePiece runtime inspection. + +Our strategy is a stack change that starts at the tokenizer and runs all the way through evaluation: + +- tokenizer family search +- budget-aware tokenizer screening +- proxy promotion and rejection of dead ends +- exact runtime byte accounting +- full-data retokenization into the promoted tokenizer +- legal score-first adaptive evaluation + +To the best of our knowledge, this is also among the first leaderboard-caliber submissions in the competition to change the tokenizer itself rather than inherit the published `sp1024` tokenization. If reviewers spot an earlier example we missed, we would be happy to correct that framing; either way, we think tokenizer search is a genuinely promising avenue here and welcome scrutiny and follow-up work. + +## Tokenizer Journey + +The tokenizer work went through several iterative stages. The short version is that we tried the obvious thing first, watched it flatten out, and then had the good sense to stop being sentimental about it. + +### 1. SentencePiece autoresearch + +We first built an [autoresearch](https://github.com/karpathy/autoresearch) loop around SentencePiece. That loop optimized tokenizer candidates against a FineWeb-aligned screening metric and later against budget-aware heuristics. + +This turned out to be useful exploration, but not the winning path: + +- locally, (i.e., on my MacBook Pro) SentencePiece candidates improved the cheap tokenizer-screen metric +- in proxy model runs with beefier hardware, those gains mostly failed to transfer +- the search quickly saturated in a narrow neighborhood + +That negative result mattered. It told us that “better tokenizer statistics” were not enough by themselves, and that larger vocabularies were often buying slim marginal gains with too much artifact budget. It also gave us permission to leave SentencePiece alone instead of continuing to hammer on a local maximum. + +### 2. TokenMonster sidecar and proxy calibration + +We then evaluated [TokenMonster](https://github.com/alasdairforsythe/tokenmonster) as a challenger family. Early cheap-screen results suggested that small TokenMonster vocabularies, especially around the `1024` regime, were more promising than either larger TokenMonster vocabularies or the best SentencePiece variants. + +Proxy validation sharpened that impression: + +- large TokenMonster variants did not hold up, but small TokenMonster variants did +- the best direction was not “bigger tokenizer”, it was “simpler tokenizer, slightly pruned, same strong byte efficiency” + +### 3. TokenMonster-only autoresearch + +We then narrowed the search into a TokenMonster-only lane. After broadening the proposal policy away from tiny local resize-only edits, the best line became a lightly pruned derivative of `english-1024-clean-v1`. + +That candidate, tracked internally as `tm0054` and nicknamed **[Scylla](https://grokipedia.com/page/Scylla)**, kept the good byte efficiency of the parent vocabulary while reducing waste in the active vocabulary. + +This was then promoted through: + +- tokenizer screening +- proxy validation +- matched local training comparison +- legal-TTT ladder testing +- full-data bundle export + +The important negative result was that larger-vocab and SentencePiece-side improvements looked better on cheap screening than they did in proxy or full runs. The winning lesson was not “make the tokenizer bigger.” It was “make the tokenizer better aligned to the artifact budget and to the tiny-model learning dynamics.” + +If this submission does end up being among the first tokenizer-changing entries seriously pushed to the top of the leaderboard, we would be delighted to see other people push on the same door. This competition has been especially exciting for cultivating unusual and interesting ideas, and we think tokenizer search deserves a place in that mix. + +## Full-Data Bundle + +For the corrected competition path, we built a full-data `Scylla` bundle from the published `sp1024` FineWeb export by retokenizing in shard order. + +The corrected bundle uses: + +- `79` train shards +- `1` val shard +- preserved shard ordering +- preserved validation ordering + +Runtime tokenizer assets: + +- `candidate.vocab` +- `candidate.meta.npz` + +The metadata artifact supplies: + +- per-token byte lengths +- leading-space flags +- boundary-token flags + +so the runtime path does not need SentencePiece to inspect tokenizer internals during evaluation. + +A compact audit note is included in `TOKENIZER_VALIDATION.md`. + +## Legality + +This record path is intended to stay within the currently accepted legality standard: + +- no target-conditioned mixing +- score-first TTT only +- full-data retokenized bundle with explicit metadata-driven byte accounting + +Backward-looking, score-first TTT following PR `#461`'s framework: + +- score a chunk first +- only then adapt on that already-scored chunk +- never use future tokens to change the distribution assigned to already-scored tokens + +Score-first protocol: the model scores each validation chunk before adapting on it. No token is ever re-scored after adaptation. This follows the causal score-before-update TTT pattern that organizers have treated as legal in the adaptive track discussion and accepted submissions. + +## Implementation Notes + +The main script in this folder is the promoted legal TTT stack adapted for tokenizer bundles: + +- `TOKENIZER_PATH` points to the promoted tokenizer vocab +- `TOKENIZER_META_PATH` points to the exported metadata LUTs +- `TTT_ENABLED=1` + +The strongest path found so far combines: + +- the promoted `Scylla` tokenizer +- legal score-first TTT +- the current tuned 11-layer legal stack + +## Included Files + +- `train_gpt.py` +- `candidate.vocab` +- `candidate.meta.npz` +- `manifest.json` +- `train_seed42.log` +- `train_seed1337.log` +- `train_seed2026.log` +- `TOKENIZER_VALIDATION.md` + +## Acknowledgements + +Thanks to **@0hq** and **@valerio-oai** for organizing, maintaining, and moderating an unusually fun and technically demanding competition. + +The tokenizer lane also benefited from reading and learning from other competitors’ public work, especially the broader discussion around legal evaluation methods and tokenizer tradeoffs. diff --git a/records/track_10min_16mb/2026-03-30_tm0054_AutoresearchTokenizer_LegalTTT/TOKENIZER_VALIDATION.md b/records/track_10min_16mb/2026-03-30_tm0054_AutoresearchTokenizer_LegalTTT/TOKENIZER_VALIDATION.md new file mode 100644 index 0000000000..4b85c48a4d --- /dev/null +++ b/records/track_10min_16mb/2026-03-30_tm0054_AutoresearchTokenizer_LegalTTT/TOKENIZER_VALIDATION.md @@ -0,0 +1,67 @@ +# Tokenizer Validation + +Our submission replaces the published `sp1024` tokenization with a promoted TokenMonster-derived tokenizer ("Scylla"). Here is a summary of the changes and a simple note to help reviewers audit the changes. + +## What changed + +This submission changes both: + +- the tokenizer +- the dataset tokenization + +Specifically, it replaces the published SentencePiece `sp1024` tokenization with a promoted TokenMonster-derived tokenizer: + +- tokenizer name: `tm0054_candidate` (nicknamed "Scylla", "0054" was the autoresearch experiment iteration that produced it) +- vocab size: `998` + +## Runtime scoring method + +`val_bpb` is still computed using explicit per-token metadata LUTs: + +- `base_bytes` +- `has_leading_space` +- `is_boundary_token` + +Those LUTs are loaded from `candidate.meta.npz` at runtime. The downstream `val_bpb` byte-counting logic is unchanged from the standard metadata-driven path. + +## Full-data bundle provenance + +The competition bundle was built by retokenizing the published `sp1024` FineWeb export in shard order. + +Included bundle metadata: + +- source tokenizer family: published `sp1024` +- target tokenizer family: `tm0054_candidate` +- train shards: `79` +- val shards: `1` + +The manifest in this folder records the resulting shard and token counts. + +## Validation checks run + +The corrected full-data bundle passed repository preflight with: + +- expected train shards: `79` +- expected val shards: `1` +- tokenizer metadata path present + +The corrected run logs also show: + +- `tokenizer_kind=tokenmonster` +- `TOKENIZER_META_PATH` loaded +- full-data train loader using `79` train shards + +## BOS/EOS handling + +The exported tm0054 competition bundle stores the retokenized shard stream used by the run, and the runtime scorer operates only on the stored token ids plus the explicit metadata arrays in `candidate.meta.npz`. + +The important property for review is that the scoring path is metadata-driven and deterministic; it does not depend on runtime SentencePiece inspection or any target-conditioned tokenizer logic. + +## Complete audit package +Everything needed for tokenizer review is present in this record folder: + +- promoted `train_gpt.py` +- promoted tokenizer vocab +- tokenizer metadata +- bundle manifest +- full-data train logs for all three seeds diff --git a/records/track_10min_16mb/2026-03-30_tm0054_AutoresearchTokenizer_LegalTTT/candidate.meta.npz b/records/track_10min_16mb/2026-03-30_tm0054_AutoresearchTokenizer_LegalTTT/candidate.meta.npz new file mode 100644 index 0000000000..580a1e39e5 Binary files /dev/null and b/records/track_10min_16mb/2026-03-30_tm0054_AutoresearchTokenizer_LegalTTT/candidate.meta.npz differ diff --git a/records/track_10min_16mb/2026-03-30_tm0054_AutoresearchTokenizer_LegalTTT/candidate.vocab b/records/track_10min_16mb/2026-03-30_tm0054_AutoresearchTokenizer_LegalTTT/candidate.vocab new file mode 100644 index 0000000000..a2b7009ad3 Binary files /dev/null and b/records/track_10min_16mb/2026-03-30_tm0054_AutoresearchTokenizer_LegalTTT/candidate.vocab differ diff --git a/records/track_10min_16mb/2026-03-30_tm0054_AutoresearchTokenizer_LegalTTT/manifest.json b/records/track_10min_16mb/2026-03-30_tm0054_AutoresearchTokenizer_LegalTTT/manifest.json new file mode 100644 index 0000000000..d9df86922b --- /dev/null +++ b/records/track_10min_16mb/2026-03-30_tm0054_AutoresearchTokenizer_LegalTTT/manifest.json @@ -0,0 +1,58 @@ +{ + "version": "10B", + "num_docs": 6342940, + "num_val_docs": 50000, + "shuffle_seed": 1337, + "dataset_revision": "9bb295ddab0e05d785b879661af7260fed5140fc", + "shard_size": 100000000, + "append_eos": false, + "docs_jsonl": "docs_selected.jsonl", + "docs_meta": { + "remote_name": "external_cache", + "num_docs": 15368808, + "docs_sha256": null, + "dataset_fingerprint": null + }, + "tokenizer_specs": [], + "tokenizers": [ + { + "name": "tm0054_candidate", + "kind": "tokenmonster", + "vocab_size": 998, + "bos_id": -1, + "eos_id": -1, + "recommended_bigram_vocab_size": 5120, + "path": "tokenizers/candidate.vocab", + "meta_path": "tokenizers/candidate.meta.npz", + "source_spec": { + "kind": "tokenmonster", + "source_model": "/Users/simon/Code/parameter-golf/autoresearch/tokenmonster_discovery/experiments/0054/candidate.vocab" + } + } + ], + "datasets": [ + { + "name": "fineweb10B_tm0054", + "tokenizer_name": "tm0054_candidate", + "tokenizer_kind": "tokenmonster", + "path": "datasets/fineweb10B_tm0054", + "train_glob": "datasets/fineweb10B_tm0054/fineweb_train_*.bin", + "val_glob": "datasets/fineweb10B_tm0054/fineweb_val_*.bin", + "vocab_size": 998, + "bos_id": -1, + "eos_id": -1, + "recommended_bigram_vocab_size": 5120, + "stats": { + "docs_total": 6342940, + "docs_val": 50000, + "docs_train": 6292940, + "files_total": 80, + "files_val": 1, + "files_train": 79, + "tokens_total": 7942608752, + "tokens_val": 61022075, + "tokens_train": 7881586677 + } + } + ] +} diff --git a/records/track_10min_16mb/2026-03-30_tm0054_AutoresearchTokenizer_LegalTTT/submission.json b/records/track_10min_16mb/2026-03-30_tm0054_AutoresearchTokenizer_LegalTTT/submission.json new file mode 100644 index 0000000000..3f9e0effc2 --- /dev/null +++ b/records/track_10min_16mb/2026-03-30_tm0054_AutoresearchTokenizer_LegalTTT/submission.json @@ -0,0 +1,43 @@ +{ + "name": "Scylla (novel tokenizer) + Legal Score-First TTT (val_bpb: 1.08056553)", + "author": "Simon Marcus", + "github_id": "simon-marcus", + "date": "2026-03-30", + "val_bpb": 1.08056553, + "bytes_total": 15866740, + "bytes_code": 120316, + "blurb": "Scylla is a TokenMonster-derived tm0054 tokenizer discovered via autoresearch, promoted through proxy validation, retokenized over the full FineWeb competition bundle, and paired with a legal score-first TTT stack. Full-data 3-seed mean legal_ttt_exact val_bpb: 1.08056553 (std 0.00047061).", + "seed_results": { + "42": { + "val_bpb": 1.08008661, + "roundtrip_val_bpb": 1.10466967, + "sliding_val_bpb": 1.08295388, + "steps": 7091, + "step_avg_ms": 84.63, + "bytes_total": 15866740, + "status": "completed" + }, + "1337": { + "val_bpb": 1.08102737, + "roundtrip_val_bpb": 1.10565088, + "sliding_val_bpb": 1.08398224, + "steps": 7084, + "step_avg_ms": 84.71, + "bytes_total": 15850756, + "status": "completed" + }, + "2026": { + "val_bpb": 1.08058261, + "roundtrip_val_bpb": 1.10490932, + "sliding_val_bpb": 1.0831599, + "steps": 7089, + "step_avg_ms": 84.65, + "bytes_total": 15849792, + "status": "completed" + } + }, + "mean_val_bpb": 1.08056553, + "std_val_bpb": 0.00047061, + "model_params": 26911580, + "notes": "Complete draft pending final language review. Full-data tm0054 bundle uses 79 train shards and 1 val shard, with metadata-driven tokenizer byte accounting and legal score-first TTT only." +} diff --git a/records/track_10min_16mb/2026-03-30_tm0054_AutoresearchTokenizer_LegalTTT/train_gpt.py b/records/track_10min_16mb/2026-03-30_tm0054_AutoresearchTokenizer_LegalTTT/train_gpt.py new file mode 100644 index 0000000000..c7ab6d40cf --- /dev/null +++ b/records/track_10min_16mb/2026-03-30_tm0054_AutoresearchTokenizer_LegalTTT/train_gpt.py @@ -0,0 +1,2555 @@ +from __future__ import annotations +import copy +import glob +import io +import lzma +import math +import os +import random +import subprocess +import sys +import time +import uuid +import zlib +from pathlib import Path +try: + import zstandard + _COMPRESSOR = "zstd" +except ImportError: + _COMPRESSOR = "zlib" +import numpy as np +import sentencepiece as spm +import torch +import torch.distributed as dist +import torch.nn.functional as F +from torch import Tensor, nn +from torch.nn.parallel import DistributedDataParallel as DDP +try: + from flash_attn_interface import flash_attn_func as flash_attn_3_func +except ImportError: + def flash_attn_3_func(q, k, v, causal=True): + q2 = q.transpose(1, 2) + k2 = k.transpose(1, 2) + v2 = v.transpose(1, 2) + o = F.scaled_dot_product_attention( + q2, k2, v2, is_causal=causal, enable_gqa=(k2.size(1) != q2.size(1)) + ) + return o.transpose(1, 2) +class Hyperparameters: + data_path = os.environ.get("DATA_PATH", "./data/datasets/fineweb10B_sp1024") + train_files = os.path.join(data_path, "fineweb_train_*.bin") + val_files = os.path.join(data_path, "fineweb_val_*.bin") + tokenizer_path = os.environ.get("TOKENIZER_PATH", "./data/tokenizers/fineweb_1024_bpe.model") + tokenizer_meta_path = os.environ.get("TOKENIZER_META_PATH", "") + tokenizer_meta_validate = bool(int(os.environ.get("TOKENIZER_META_VALIDATE", "0"))) + run_id = os.environ.get("RUN_ID", str(uuid.uuid4())) + seed = int(os.environ.get("SEED", 1337)) + val_batch_size = int(os.environ.get("VAL_BATCH_SIZE", 524_288)) + val_tokens_limit = int(os.environ.get("VAL_TOKENS_LIMIT", 0)) + val_loss_every = int(os.environ.get("VAL_LOSS_EVERY", 4000)) + train_log_every = int(os.environ.get("TRAIN_LOG_EVERY", 500)) + iterations = int(os.environ.get("ITERATIONS", 20000)) + warmdown_iters = int(os.environ.get("WARMDOWN_ITERS", 3500)) + warmup_steps = int(os.environ.get("WARMUP_STEPS", 20)) + train_batch_tokens = int(os.environ.get("TRAIN_BATCH_TOKENS", 786_432)) + train_seq_len = int(os.environ.get("TRAIN_SEQ_LEN", 2048)) + eval_seq_len = int(os.environ.get("EVAL_SEQ_LEN", 2048)) + max_wallclock_seconds = float(os.environ.get("MAX_WALLCLOCK_SECONDS", 600.0)) + qk_gain_init = float(os.environ.get("QK_GAIN_INIT", 1.5)) + vocab_size = int(os.environ.get("VOCAB_SIZE", 1024)) + num_layers = int(os.environ.get("NUM_LAYERS", 11)) + num_kv_heads = int(os.environ.get("NUM_KV_HEADS", 4)) + model_dim = int(os.environ.get("MODEL_DIM", 512)) + num_heads = int(os.environ.get("NUM_HEADS", 8)) + mlp_mult = float(os.environ.get("MLP_MULT", 3.0)) + tie_embeddings = bool(int(os.environ.get("TIE_EMBEDDINGS", "1"))) + rope_base = float(os.environ.get("ROPE_BASE", 10000.0)) + logit_softcap = float(os.environ.get("LOGIT_SOFTCAP", 30.0)) + embed_lr = float(os.environ.get("EMBED_LR", 0.6)) + head_lr = float(os.environ.get("HEAD_LR", 0.008)) + tied_embed_lr = float(os.environ.get("TIED_EMBED_LR", 0.035)) + tied_embed_init_std = float(os.environ.get("TIED_EMBED_INIT_STD", 0.005)) + matrix_lr = float(os.environ.get("MATRIX_LR", 0.025)) + scalar_lr = float(os.environ.get("SCALAR_LR", 0.025)) + muon_momentum = float(os.environ.get("MUON_MOMENTUM", 0.99)) + muon_backend_steps = int(os.environ.get("MUON_BACKEND_STEPS", 5)) + muon_momentum_warmup_start = float(os.environ.get("MUON_MOMENTUM_WARMUP_START", 0.92)) + muon_momentum_warmup_steps = int(os.environ.get("MUON_MOMENTUM_WARMUP_STEPS", 1500)) + beta1 = float(os.environ.get("BETA1", 0.9)) + beta2 = float(os.environ.get("BETA2", 0.95)) + adam_eps = float(os.environ.get("ADAM_EPS", 1e-8)) + grad_clip_norm = float(os.environ.get("GRAD_CLIP_NORM", 0.3)) + eval_stride = int(os.environ.get("EVAL_STRIDE", 64)) + mtp_num_heads = int(os.environ.get("MTP_NUM_HEADS", 0)) + mtp_loss_weight = float(os.environ.get("MTP_LOSS_WEIGHT", 0.2)) + muon_beta2 = float(os.environ.get("MUON_BETA2", 0.95)) + swa_enabled = bool(int(os.environ.get("SWA_ENABLED", "1"))) + swa_every = int(os.environ.get("SWA_EVERY", 50)) + lawa_enabled = bool(int(os.environ.get("LAWA_ENABLED", "0"))) + lawa_k = int(os.environ.get("LAWA_K", 10)) + lawa_freq = int(os.environ.get("LAWA_FREQ", 100)) + muon_wd = float(os.environ.get("MUON_WD", 0.04)) + adam_wd = float(os.environ.get("ADAM_WD", 0.04)) + qat_enabled = bool(int(os.environ.get("QAT_ENABLED", "0"))) + bigram_vocab_size = int(os.environ.get("BIGRAM_VOCAB_SIZE", 2048)) + bigram_dim = int(os.environ.get("BIGRAM_DIM", 128)) + xsa_last_n = int(os.environ.get("XSA_LAST_N", 4)) + rope_dims = int(os.environ.get("ROPE_DIMS", 16)) + ln_scale = bool(int(os.environ.get("LN_SCALE", "1"))) + dtg_enabled = bool(int(os.environ.get("DTG_ENABLED", "0"))) + late_qat_threshold = float(os.environ.get("LATE_QAT_THRESHOLD", 0.15)) + ve_enabled = bool(int(os.environ.get("VE_ENABLED", "1"))) + ve_dim = int(os.environ.get("VE_DIM", 128)) + ve_layers = os.environ.get("VE_LAYERS", "9,10") + gated_attention = bool(int(os.environ.get("GATED_ATTENTION", "0"))) + value_residual = bool(int(os.environ.get("VALUE_RESIDUAL", "0"))) + activation_mode = os.environ.get("ACTIVATION_MODE", "leaky_relu_sq") + activation_neg_slope = float(os.environ.get("ACTIVATION_NEG_SLOPE", 0.5)) + asymmetric_square_init = float(os.environ.get("ASYMMETRIC_SQUARE_INIT", 0.25)) + gated_square_beta_init = float(os.environ.get("GATED_SQUARE_BETA_INIT", 1.0)) + ttt_enabled = bool(int(os.environ.get("TTT_ENABLED", "0"))) + ttt_lr = float(os.environ.get("TTT_LR", 0.002)) + ttt_epochs = int(os.environ.get("TTT_EPOCHS", 3)) + ttt_chunk_tokens = int(os.environ.get("TTT_CHUNK_TOKENS", 32768)) + ttt_freeze_blocks = int(os.environ.get("TTT_FREEZE_BLOCKS", 2)) + ttt_momentum = float(os.environ.get("TTT_MOMENTUM", 0.9)) + ttt_batch_seqs = int(os.environ.get("TTT_BATCH_SEQS", 32)) + ttt_grad_clip = float(os.environ.get("TTT_GRAD_CLIP", 1.0)) + # N-gram eval cache + ngram_enabled = bool(int(os.environ.get("NGRAM_ENABLED", "1"))) + ngram_min_order = int(os.environ.get("NGRAM_MIN_ORDER", 2)) + ngram_max_order = int(os.environ.get("NGRAM_MAX_ORDER", 12)) + ngram_num_buckets = int(os.environ.get("NGRAM_NUM_BUCKETS", 16_777_216)) # 16M + ngram_chunk_size = int(os.environ.get("NGRAM_CHUNK_SIZE", 512)) + ngram_alpha_min = float(os.environ.get("NGRAM_ALPHA_MIN", 0.05)) + ngram_alpha_max = float(os.environ.get("NGRAM_ALPHA_MAX", 0.70)) + ngram_entropy_center = float(os.environ.get("NGRAM_ENTROPY_CENTER", 3.0)) + ngram_entropy_scale = float(os.environ.get("NGRAM_ENTROPY_SCALE", 2.0)) + ngram_min_count = int(os.environ.get("NGRAM_MIN_COUNT", 2)) + ngram_leave_one_out = bool(int(os.environ.get("NGRAM_LEAVE_ONE_OUT", "1"))) + # Complementary training + complement_enabled = bool(int(os.environ.get("COMPLEMENT_ENABLED", "0"))) + complement_alpha = float(os.environ.get("COMPLEMENT_ALPHA", 0.5)) + +# --- Batched Newton-Schulz orthogonalization --- + +def zeropower_via_newtonschulz5(G: Tensor, steps: int = 5, eps: float = 1e-7) -> Tensor: + """Batched Newton-Schulz orthogonalization. G: (B,M,N) or (M,N).""" + a, b, c = (3.4445, -4.7750, 2.0315) + was_2d = G.ndim == 2 + if was_2d: + G = G.unsqueeze(0) + X = G.bfloat16() + transposed = X.size(-2) > X.size(-1) + if transposed: + X = X.mT + X = X / (X.norm(dim=(-2, -1), keepdim=True) + eps) + for _ in range(steps): + A = X @ X.mT + B = b * A + c * (A @ A) + X = a * X + B @ X + if transposed: + X = X.mT + if was_2d: + X = X.squeeze(0) + return X + +# --- Parallel Muon optimizer --- + +class Muon(torch.optim.Optimizer): + """Parallel Muon: post-backward reduce-scatter -> local NS5 -> all-gather. + + No DDP for bank params. After backward, this optimizer: + 1. Launches async reduce-scatter for all banks (biggest first) + 2. Returns control so Adam can step on small params while RS is in-flight + 3. Waits for each RS, runs local NS5 on the shard, launches async all-gather + 4. Each all-gather overlaps with next bank's NS5 + """ + def __init__(self, params, lr: float, momentum: float, backend_steps: int, + nesterov: bool = True, weight_decay: float = 0.0): + super().__init__( + params, + dict(lr=lr, momentum=momentum, backend_steps=backend_steps, + nesterov=nesterov, weight_decay=weight_decay), + ) + self._built = False + + def _build(self): + self._distributed = dist.is_available() and dist.is_initialized() + self._world_size = dist.get_world_size() if self._distributed else 1 + self._rank = dist.get_rank() if self._distributed else 0 + ws = self._world_size + + self._bank_meta = [] + for group in self.param_groups: + for p in group["params"]: + B = p.shape[0] + padded_B = ((B + ws - 1) // ws) * ws + shard_B = padded_B // ws + tail = p.shape[1:] + dev = p.device + self._bank_meta.append({ + 'p': p, + 'B': B, + 'padded_grad': torch.zeros(padded_B, *tail, device=dev, dtype=torch.bfloat16), + 'shard': torch.zeros(shard_B, *tail, device=dev, dtype=torch.bfloat16), + 'shard_mom': torch.zeros(shard_B, *tail, device=dev, dtype=torch.bfloat16), + 'full_update': torch.zeros(padded_B, *tail, device=dev, dtype=torch.bfloat16), + 'scale': max(1, p.shape[-2] / p.shape[-1]) ** 0.5, + }) + # Sort by size descending -- launch biggest reduce-scatters first + self._bank_meta.sort(key=lambda m: -m['p'].numel()) + self._built = True + + def launch_reduce_scatters(self): + """Phase 1: launch async reduce-scatter for all banks. Call right after backward.""" + if not self._built: + self._build() + if not self._distributed: + return + self._rs_futures = [] + for m in self._bank_meta: + p = m['p'] + if p.grad is None: + self._rs_futures.append(None) + continue + pg = m['padded_grad'] + pg[:m['B']].copy_(p.grad.bfloat16()) + if pg.shape[0] > m['B']: + pg[m['B']:].zero_() + fut = dist.reduce_scatter_tensor(m['shard'], pg, op=dist.ReduceOp.AVG, async_op=True) + self._rs_futures.append(fut) + + @torch.no_grad() + def step(self, closure=None): + """Phase 3: wait for RS, local NS5, all-gather. Call AFTER Adam steps.""" + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + + if not self._built: + self._build() + + for group in self.param_groups: + lr = group["lr"] + momentum = group["momentum"] + backend_steps = group["backend_steps"] + nesterov = group["nesterov"] + wd = group.get("weight_decay", 0.0) + + prev_ag_handle = None + prev_m = None + + sharded = self._distributed and hasattr(self, '_rs_futures') + + for i, m in enumerate(self._bank_meta): + p = m['p'] + if p.grad is None: + continue + + if prev_ag_handle is not None: + prev_ag_handle.wait() + pp = prev_m['p'] + upd = prev_m['full_update'][:prev_m['B']] + if wd > 0.0: + pp.data.mul_(1.0 - lr * wd) + pp.add_(upd.to(dtype=pp.dtype), alpha=-lr * prev_m['scale']) + + if sharded and self._rs_futures[i] is not None: + self._rs_futures[i].wait() + g = m['shard'] + buf = m['shard_mom'] + else: + g = p.grad.bfloat16() + state = self.state[p] + if "momentum_buffer" not in state: + state["momentum_buffer"] = torch.zeros_like(g) + buf = state["momentum_buffer"] + + buf.mul_(momentum).add_(g) + if nesterov: + update = g.add(buf, alpha=momentum) + else: + update = buf + + update = zeropower_via_newtonschulz5(update, steps=backend_steps) + + if sharded: + prev_ag_handle = dist.all_gather_into_tensor( + m['full_update'], update, async_op=True) + prev_m = m + else: + if wd > 0.0: + p.data.mul_(1.0 - lr * wd) + p.add_(update.to(dtype=p.dtype), alpha=-lr * m['scale']) + + if prev_ag_handle is not None: + prev_ag_handle.wait() + pp = prev_m['p'] + upd = prev_m['full_update'][:prev_m['B']] + if wd > 0.0: + pp.data.mul_(1.0 - lr * wd) + pp.add_(upd.to(dtype=pp.dtype), alpha=-lr * prev_m['scale']) + + if hasattr(self, '_rs_futures'): + del self._rs_futures + + return loss + +# --- Tokenizer evaluation helpers --- + +TOKENIZER_META_FORMAT_VERSION = 1 +TOKENIZER_META_SUFFIX = ".meta.npz" + + +def _derive_tokenizer_meta_path(tokenizer_path: str) -> Path: + tokenizer = Path(tokenizer_path) + if tokenizer.suffix == ".model": + return tokenizer.with_suffix(TOKENIZER_META_SUFFIX) + return tokenizer.with_name(tokenizer.name + TOKENIZER_META_SUFFIX) + + +def build_sentencepiece_luts_np( + sp: spm.SentencePieceProcessor, vocab_size: int +) -> tuple[np.ndarray, np.ndarray, np.ndarray]: + sp_vocab_size = int(sp.vocab_size()) + table_size = max(sp_vocab_size, vocab_size) + base_bytes_np = np.zeros((table_size,), dtype=np.int16) + has_leading_space_np = np.zeros((table_size,), dtype=np.bool_) + is_boundary_token_np = np.ones((table_size,), dtype=np.bool_) + for token_id in range(sp_vocab_size): + if sp.is_control(token_id) or sp.is_unknown(token_id) or sp.is_unused(token_id): + continue + is_boundary_token_np[token_id] = False + if sp.is_byte(token_id): + base_bytes_np[token_id] = 1 + continue + piece = sp.id_to_piece(token_id) + if piece.startswith("\u2581"): + has_leading_space_np[token_id] = True + piece = piece[1:] + base_bytes_np[token_id] = len(piece.encode("utf-8")) + return base_bytes_np, has_leading_space_np, is_boundary_token_np + + +def build_sentencepiece_luts( + sp: spm.SentencePieceProcessor, vocab_size: int, device: torch.device +) -> tuple[Tensor, Tensor, Tensor]: + base_bytes_np, has_leading_space_np, is_boundary_token_np = build_sentencepiece_luts_np(sp, vocab_size) + return ( + torch.tensor(base_bytes_np, dtype=torch.int16, device=device), + torch.tensor(has_leading_space_np, dtype=torch.bool, device=device), + torch.tensor(is_boundary_token_np, dtype=torch.bool, device=device), + ) + + +def load_tokenizer_meta_luts_np( + meta_path: Path, vocab_size: int +) -> tuple[np.ndarray, np.ndarray, np.ndarray, dict[str, object]]: + def _scalar(value): + arr = np.asarray(value) + if arr.ndim == 0: + return arr.item() + first = arr.reshape(-1)[0] + return first.item() if hasattr(first, "item") else first + + with np.load(meta_path, allow_pickle=False) as data: + format_version = int(_scalar(data["format_version"])) + if format_version != TOKENIZER_META_FORMAT_VERSION: + raise ValueError( + f"Unsupported tokenizer meta format_version={format_version} expected={TOKENIZER_META_FORMAT_VERSION}" + ) + meta_vocab_size = int(_scalar(data["vocab_size"])) + tokenizer_kind = str(_scalar(data["tokenizer_kind"])) + source_model_name = str(_scalar(data["source_model_name"])) + base_bytes_np = np.asarray(data["base_bytes"], dtype=np.int16) + has_leading_space_np = np.asarray(data["has_leading_space"], dtype=np.bool_) + is_boundary_token_np = np.asarray(data["is_boundary_token"], dtype=np.bool_) + table_size = max(meta_vocab_size, vocab_size) + if base_bytes_np.shape[0] < table_size: + padded_base_bytes = np.zeros((table_size,), dtype=np.int16) + padded_has_leading_space = np.zeros((table_size,), dtype=np.bool_) + padded_is_boundary = np.ones((table_size,), dtype=np.bool_) + padded_base_bytes[: base_bytes_np.shape[0]] = base_bytes_np + padded_has_leading_space[: has_leading_space_np.shape[0]] = has_leading_space_np + padded_is_boundary[: is_boundary_token_np.shape[0]] = is_boundary_token_np + base_bytes_np = padded_base_bytes + has_leading_space_np = padded_has_leading_space + is_boundary_token_np = padded_is_boundary + metadata = { + "format_version": format_version, + "tokenizer_kind": tokenizer_kind, + "source_model_name": source_model_name, + "vocab_size": meta_vocab_size, + "meta_path": str(meta_path), + } + return base_bytes_np, has_leading_space_np, is_boundary_token_np, metadata + + +def load_tokenizer_luts( + tokenizer_path: str, + tokenizer_meta_path: str, + vocab_size: int, + device: torch.device, + *, + validate_meta: bool = False, +) -> tuple[tuple[Tensor, Tensor, Tensor], dict[str, object]]: + meta_path = Path(tokenizer_meta_path) if tokenizer_meta_path else _derive_tokenizer_meta_path(tokenizer_path) + if meta_path.exists(): + base_bytes_np, has_leading_space_np, is_boundary_token_np, metadata = load_tokenizer_meta_luts_np( + meta_path, vocab_size + ) + if validate_meta and str(tokenizer_path).endswith(".model"): + sp = spm.SentencePieceProcessor(model_file=tokenizer_path) + sp_luts = build_sentencepiece_luts_np(sp, vocab_size) + if not ( + np.array_equal(base_bytes_np, sp_luts[0]) + and np.array_equal(has_leading_space_np, sp_luts[1]) + and np.array_equal(is_boundary_token_np, sp_luts[2]) + ): + raise ValueError(f"Tokenizer metadata mismatch for {meta_path}") + return ( + torch.tensor(base_bytes_np, dtype=torch.int16, device=device), + torch.tensor(has_leading_space_np, dtype=torch.bool, device=device), + torch.tensor(is_boundary_token_np, dtype=torch.bool, device=device), + ), metadata + if not str(tokenizer_path).endswith(".model"): + raise FileNotFoundError(f"TOKENIZER_META_PATH does not exist: {meta_path}") + sp = spm.SentencePieceProcessor(model_file=tokenizer_path) + return build_sentencepiece_luts(sp, vocab_size, device), { + "tokenizer_kind": "sentencepiece", + "source_model_name": str(tokenizer_path), + "vocab_size": int(sp.vocab_size()), + "meta_path": None, + "fallback": True, + } +def load_validation_tokens(pattern: str, seq_len: int, token_limit: int = 0) -> Tensor: + files = [Path(p) for p in sorted(glob.glob(pattern))] + if not files: + raise FileNotFoundError(f"No files found for pattern: {pattern}") + tokens = torch.cat([load_data_shard(file) for file in files]).contiguous() + if token_limit > 0: + tokens = tokens[: min(tokens.numel(), token_limit + 1)] + usable = ((tokens.numel() - 1) // seq_len) * seq_len + if usable <= 0: + raise ValueError(f"Validation split is too short for TRAIN_SEQ_LEN={seq_len}") + return tokens[: usable + 1] +def eval_val( + args: Hyperparameters, + model: nn.Module, + rank: int, + world_size: int, + device: torch.device, + grad_accum_steps: int, + val_tokens: Tensor, + base_bytes_lut: Tensor, + has_leading_space_lut: Tensor, + is_boundary_token_lut: Tensor, + eval_seq_len: int | None = None, +) -> tuple[float, float]: + seq_len = eval_seq_len or args.train_seq_len + local_batch_tokens = args.val_batch_size // (world_size * grad_accum_steps) + if local_batch_tokens < seq_len: + raise ValueError( + "VAL_BATCH_SIZE must provide at least one sequence per rank; " + f"got VAL_BATCH_SIZE={args.val_batch_size}, WORLD_SIZE={world_size}, " + f"GRAD_ACCUM_STEPS={grad_accum_steps}, seq_len={seq_len}" + ) + local_batch_seqs = local_batch_tokens // seq_len + total_seqs = (val_tokens.numel() - 1) // seq_len + seq_start = (total_seqs * rank) // world_size + seq_end = (total_seqs * (rank + 1)) // world_size + val_loss_sum = torch.zeros((), device=device, dtype=torch.float64) + val_token_count = torch.zeros((), device=device, dtype=torch.float64) + val_byte_count = torch.zeros((), device=device, dtype=torch.float64) + model.eval() + with torch.inference_mode(): + for batch_seq_start in range(seq_start, seq_end, local_batch_seqs): + batch_seq_end = min(batch_seq_start + local_batch_seqs, seq_end) + raw_start = batch_seq_start * seq_len + raw_end = batch_seq_end * seq_len + 1 + local = val_tokens[raw_start:raw_end].to(device=device, dtype=torch.int64, non_blocking=True) + x = local[:-1].reshape(-1, seq_len) + y = local[1:].reshape(-1, seq_len) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + batch_loss = model(x, y).detach() + batch_token_count = float(y.numel()) + val_loss_sum += batch_loss.to(torch.float64) * batch_token_count + val_token_count += batch_token_count + prev_ids = x.reshape(-1) + tgt_ids = y.reshape(-1) + token_bytes = base_bytes_lut[tgt_ids].to(dtype=torch.int16) + token_bytes += (has_leading_space_lut[tgt_ids] & ~is_boundary_token_lut[prev_ids]).to(dtype=torch.int16) + val_byte_count += token_bytes.to(torch.float64).sum() + if dist.is_available() and dist.is_initialized(): + dist.all_reduce(val_loss_sum, op=dist.ReduceOp.SUM) + dist.all_reduce(val_token_count, op=dist.ReduceOp.SUM) + dist.all_reduce(val_byte_count, op=dist.ReduceOp.SUM) + val_loss = val_loss_sum / val_token_count + bits_per_token = val_loss.item() / math.log(2.0) + tokens_per_byte = val_token_count.item() / val_byte_count.item() + model.train() + return float(val_loss.item()), float(bits_per_token * tokens_per_byte) + +# --- Quantization helpers --- + +CONTROL_TENSOR_NAME_PATTERNS = tuple( + pattern + for pattern in os.environ.get( + "CONTROL_TENSOR_NAME_PATTERNS", + "attn_scale,attn_scales,mlp_scale,mlp_scales,resid_mix,resid_mixes,q_gain,skip_weight,skip_weights,smear,dtg_gate,ve_layer_scales,ve_shared.scale,attn_gate,vr_lambda", + ).split(",") + if pattern +) +INT8_KEEP_FLOAT_FP32_NAME_PATTERNS = tuple( + pattern + for pattern in os.environ.get( + "INT8_KEEP_FLOAT_FP32_NAME_PATTERNS", + ",".join(CONTROL_TENSOR_NAME_PATTERNS), + ).split(",") + if pattern +) +INT8_KEEP_FLOAT_MAX_NUMEL = 65_536 +INT8_KEEP_FLOAT_STORE_DTYPE = torch.float16 +INT8_PER_ROW_SCALE_DTYPE = torch.float16 +INT8_CLIP_PERCENTILE = 99.99984 +INT8_CLIP_Q = INT8_CLIP_PERCENTILE / 100.0 +def tensor_nbytes(t: Tensor) -> int: + return int(t.numel()) * int(t.element_size()) +def keep_float_tensor(name: str, t: Tensor, passthrough_orig_dtypes: dict[str, str]) -> Tensor: + if any(pattern in name for pattern in INT8_KEEP_FLOAT_FP32_NAME_PATTERNS): + return t.float().contiguous() + if t.dtype in {torch.float32, torch.bfloat16}: + passthrough_orig_dtypes[name] = str(t.dtype).removeprefix("torch.") + return t.to(dtype=INT8_KEEP_FLOAT_STORE_DTYPE).contiguous() + return t +def quantize_float_tensor(t: Tensor) -> tuple[Tensor, Tensor]: + t32 = t.float() + if t32.ndim == 2: + clip_abs = ( + torch.quantile(t32.abs(), INT8_CLIP_Q, dim=1) + if t32.numel() + else torch.empty((t32.shape[0],), dtype=torch.float32) + ) + clipped = torch.maximum(torch.minimum(t32, clip_abs[:, None]), -clip_abs[:, None]) + scale = (clip_abs / 127.0).clamp_min(1.0 / 127.0) + q = torch.clamp(torch.round(clipped / scale[:, None]), -127, 127).to(torch.int8).contiguous() + return q, scale.to(dtype=INT8_PER_ROW_SCALE_DTYPE).contiguous() + clip_abs = float(torch.quantile(t32.abs().flatten(), INT8_CLIP_Q).item()) if t32.numel() else 0.0 + scale = torch.tensor(clip_abs / 127.0 if clip_abs > 0 else 1.0, dtype=torch.float32) + q = torch.clamp(torch.round(torch.clamp(t32, -clip_abs, clip_abs) / scale), -127, 127).to(torch.int8).contiguous() + return q, scale +def quantize_state_dict_int8(state_dict: dict[str, Tensor]): + quantized: dict[str, Tensor] = {} + scales: dict[str, Tensor] = {} + dtypes: dict[str, str] = {} + passthrough: dict[str, Tensor] = {} + passthrough_orig_dtypes: dict[str, str] = {} + qmeta: dict[str, dict[str, object]] = {} + stats = dict.fromkeys( + ("param_count", "num_tensors", "num_float_tensors", "num_nonfloat_tensors", "baseline_tensor_bytes", "int8_payload_bytes"), + 0, + ) + for name, tensor in state_dict.items(): + t = tensor.detach().to("cpu").contiguous() + stats["param_count"] += int(t.numel()) + stats["num_tensors"] += 1 + stats["baseline_tensor_bytes"] += tensor_nbytes(t) + if not t.is_floating_point(): + stats["num_nonfloat_tensors"] += 1 + passthrough[name] = t + stats["int8_payload_bytes"] += tensor_nbytes(t) + continue + if t.numel() <= INT8_KEEP_FLOAT_MAX_NUMEL: + kept = keep_float_tensor(name, t, passthrough_orig_dtypes) + passthrough[name] = kept + stats["int8_payload_bytes"] += tensor_nbytes(kept) + continue + stats["num_float_tensors"] += 1 + q, s = quantize_float_tensor(t) + if s.ndim > 0: + qmeta[name] = {"scheme": "per_row", "axis": 0} + quantized[name] = q + scales[name] = s + dtypes[name] = str(t.dtype).removeprefix("torch.") + stats["int8_payload_bytes"] += tensor_nbytes(q) + tensor_nbytes(s) + obj: dict[str, object] = { + "__quant_format__": "int8_clean_per_row_v1", + "quantized": quantized, + "scales": scales, + "dtypes": dtypes, + "passthrough": passthrough, + } + if qmeta: + obj["qmeta"] = qmeta + if passthrough_orig_dtypes: + obj["passthrough_orig_dtypes"] = passthrough_orig_dtypes + return obj, stats +def dequantize_state_dict_int8(obj: dict[str, object]) -> dict[str, Tensor]: + out: dict[str, Tensor] = {} + qmeta = obj.get("qmeta", {}) + passthrough_orig_dtypes = obj.get("passthrough_orig_dtypes", {}) + for name, q in obj["quantized"].items(): + dtype = getattr(torch, obj["dtypes"][name]) + s = obj["scales"][name] + if qmeta.get(name, {}).get("scheme") == "per_row" or s.ndim > 0: + s = s.to(dtype=torch.float32) + out[name] = (q.float() * s.view(q.shape[0], *([1] * (q.ndim - 1)))).to(dtype=dtype).contiguous() + else: + scale = float(s.item()) + out[name] = (q.float() * scale).to(dtype=dtype).contiguous() + for name, t in obj["passthrough"].items(): + out_t = t.detach().to("cpu").contiguous() + orig_dtype = passthrough_orig_dtypes.get(name) + if isinstance(orig_dtype, str): + out_t = out_t.to(dtype=getattr(torch, orig_dtype)).contiguous() + out[name] = out_t + return out + +# --- Data loading --- + +def load_data_shard(file: Path) -> Tensor: + header_bytes = 256 * np.dtype(" None: + self.file_idx = (self.file_idx + 1) % len(self.files) + self.tokens = load_data_shard(self.files[self.file_idx]) + self.pos = 0 + def take(self, n: int) -> Tensor: + chunks: list[Tensor] = [] + remaining = n + while remaining > 0: + avail = self.tokens.numel() - self.pos + if avail <= 0: + self._advance_file() + continue + k = min(remaining, avail) + chunks.append(self.tokens[self.pos : self.pos + k]) + self.pos += k + remaining -= k + return chunks[0] if len(chunks) == 1 else torch.cat(chunks) +class DistributedTokenLoader: + def __init__(self, pattern: str, rank: int, world_size: int, device: torch.device): + self.rank = rank + self.world_size = world_size + self.device = device + self.stream = TokenStream(pattern) + def next_batch(self, global_tokens: int, seq_len: int, grad_accum_steps: int) -> tuple[Tensor, Tensor]: + local_tokens = global_tokens // (self.world_size * grad_accum_steps) + per_rank_span = local_tokens + 1 + chunk = self.stream.take(per_rank_span * self.world_size) + start = self.rank * per_rank_span + local = chunk[start : start + per_rank_span].to(dtype=torch.int64) + x = local[:-1].reshape(-1, seq_len) + y = local[1:].reshape(-1, seq_len) + return x.to(self.device, non_blocking=True), y.to(self.device, non_blocking=True) + +# --- Transformer modules --- + +class RMSNorm(nn.Module): + def __init__(self, eps: float | None = None): + super().__init__() + self.eps = eps + def forward(self, x: Tensor) -> Tensor: + return F.rms_norm(x, (x.size(-1),), eps=self.eps) +class CastedLinear(nn.Linear): + _qat_enabled: bool = False + def forward(self, x: Tensor) -> Tensor: + w = self.weight.to(x.dtype) + if CastedLinear._qat_enabled and self.training and w.ndim == 2: + with torch.no_grad(): + w32 = self.weight.float() + row_max = w32.abs().amax(dim=1) + scale = (row_max / 31.0).clamp_min(1.0 / 31.0) + w_q = (torch.clamp(torch.round(w32 / scale[:, None]), -32, 31) * scale[:, None]).to(x.dtype) + w = w + (w_q - w).detach() + bias = self.bias.to(x.dtype) if self.bias is not None else None + return F.linear(x, w, bias) +def restore_low_dim_params_to_fp32(module: nn.Module) -> None: + with torch.no_grad(): + for name, param in module.named_parameters(): + if (param.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS)) and param.dtype != torch.float32: + param.data = param.data.float() +class Rotary(nn.Module): + def __init__(self, dim: int, base: float = 10000.0, train_seq_len: int = 1024, rope_dims: int = 0): + super().__init__() + self.dim = dim + self.base = base + self.train_seq_len = train_seq_len + self.rope_dims = rope_dims if rope_dims > 0 else dim + inv_freq = 1.0 / (base ** (torch.arange(0, self.rope_dims, 2, dtype=torch.float32) / self.rope_dims)) + self.register_buffer("inv_freq", inv_freq, persistent=False) + self._seq_len_cached = 0 + self._cos_cached: Tensor | None = None + self._sin_cached: Tensor | None = None + def forward(self, seq_len: int, device: torch.device, dtype: torch.dtype) -> tuple[Tensor, Tensor]: + if ( + self._cos_cached is None + or self._sin_cached is None + or self._seq_len_cached != seq_len + or self._cos_cached.device != device + ): + rd = self.rope_dims + if seq_len > self.train_seq_len: + scale = seq_len / self.train_seq_len + new_base = self.base * (scale ** (rd / (rd - 2))) + inv_freq = 1.0 / (new_base ** (torch.arange(0, rd, 2, dtype=torch.float32, device=device) / rd)) + else: + inv_freq = self.inv_freq.to(device) + t = torch.arange(seq_len, device=device, dtype=inv_freq.dtype) + freqs = torch.outer(t, inv_freq) + self._cos_cached = freqs.cos()[None, :, None, :] + self._sin_cached = freqs.sin()[None, :, None, :] + self._seq_len_cached = seq_len + return self._cos_cached.to(dtype=dtype), self._sin_cached.to(dtype=dtype) +def apply_rotary_emb(x: Tensor, cos: Tensor, sin: Tensor, rope_dims: int = 0) -> Tensor: + if rope_dims > 0 and rope_dims < x.size(-1): + x_rope, x_pass = x[..., :rope_dims], x[..., rope_dims:] + half = rope_dims // 2 + x1, x2 = x_rope[..., :half], x_rope[..., half:] + x_rope = torch.cat((x1 * cos + x2 * sin, x1 * (-sin) + x2 * cos), dim=-1) + return torch.cat((x_rope, x_pass), dim=-1) + half = x.size(-1) // 2 + x1, x2 = x[..., :half], x[..., half:] + return torch.cat((x1 * cos + x2 * sin, x1 * (-sin) + x2 * cos), dim=-1) + +class CausalSelfAttention(nn.Module): + def __init__( + self, + dim: int, + num_heads: int, + num_kv_heads: int, + rope_base: float, + qk_gain_init: float, + gated_attention: bool = False, + value_residual: bool = False, + ): + super().__init__() + if dim % num_heads != 0: + raise ValueError("model_dim must be divisible by num_heads") + if num_heads % num_kv_heads != 0: + raise ValueError("num_heads must be divisible by num_kv_heads") + self.num_heads = num_heads + self.num_kv_heads = num_kv_heads + self.head_dim = dim // num_heads + if self.head_dim % 2 != 0: + raise ValueError("head_dim must be even for RoPE") + # No CastedLinear -- weights come from banks + self.q_gain = nn.Parameter(torch.full((num_heads,), qk_gain_init, dtype=torch.float32)) + self.rope_dims = 0 # set by GPT.__init__ for partial RoPE + self.rotary = Rotary(self.head_dim, base=rope_base, train_seq_len=1024) + self.use_xsa = False # set by GPT.__init__ for deep layers only + # Gated attention and value residual (non-banked small params) + self.gated_attention = gated_attention + if gated_attention: + self.attn_gate = nn.Linear(dim, num_heads, bias=True) + nn.init.zeros_(self.attn_gate.weight) + nn.init.constant_(self.attn_gate.bias, 4.0) + self.value_residual = value_residual + if value_residual: + self.vr_lambda = nn.Parameter(torch.tensor([0.5, 0.5], dtype=torch.float32)) + def _xsa_efficient(self, y: Tensor, v: Tensor) -> Tensor: + """Efficient XSA: subtract self-value projection via GQA-aware reshape (no repeat_interleave). + y: [B, T, H, D], v: [B, T, Hkv, D]. H must be divisible by Hkv.""" + B, T, H, D = y.shape + Hkv = v.size(-2) + group = H // Hkv + y_g = y.reshape(B, T, Hkv, group, D) # [B, T, Hkv, group, D] + vn = F.normalize(v, dim=-1).unsqueeze(-2) # [B, T, Hkv, 1, D] -- broadcast ready + proj = (y_g * vn).sum(dim=-1, keepdim=True) * vn + return (y_g - proj).reshape(B, T, H, D) + def forward(self, x: Tensor, q_w: Tensor, k_w: Tensor, v_w: Tensor, out_w: Tensor, v_embed: Tensor | None = None, v0: Tensor | None = None) -> tuple[Tensor, Tensor | None]: + bsz, seqlen, dim = x.shape + q = F.linear(x, q_w.to(x.dtype)).reshape(bsz, seqlen, self.num_heads, self.head_dim) + k = F.linear(x, k_w.to(x.dtype)).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim) + v = F.linear(x, v_w.to(x.dtype)) + if v_embed is not None: + v = v + v_embed + v = v.reshape(bsz, seqlen, self.num_kv_heads, self.head_dim) + raw_v = v if self.value_residual else None + if self.value_residual and v0 is not None: + lam = self.vr_lambda.to(dtype=v.dtype) + v = lam[0] * v0 + lam[1] * v + q = F.rms_norm(q, (q.size(-1),)) + k = F.rms_norm(k, (k.size(-1),)) + cos, sin = self.rotary(seqlen, x.device, q.dtype) + q = apply_rotary_emb(q, cos, sin, self.rope_dims) + k = apply_rotary_emb(k, cos, sin, self.rope_dims) + q = q * self.q_gain.to(dtype=q.dtype)[None, None, :, None] + y = flash_attn_3_func(q, k, v, causal=True) + if self.use_xsa: + y = self._xsa_efficient(y, v) + if self.gated_attention: + # gate shape: (bsz, seqlen, num_heads) -> (bsz, seqlen, num_heads, 1) for B,T,H,D layout + gate = torch.sigmoid(self.attn_gate(x)).unsqueeze(-1) + y = y * gate + y = y.reshape(bsz, seqlen, dim) + return F.linear(y, out_w.to(x.dtype)), raw_v + +class SmearGate(nn.Module): + def __init__(self, dim: int): + super().__init__() + self.gate = nn.Parameter(torch.zeros(dim, dtype=torch.float32)) + def forward(self, x: Tensor) -> Tensor: + g = torch.sigmoid(self.gate.to(dtype=x.dtype))[None, None, :] + x_prev = torch.cat([torch.zeros_like(x[:, :1]), x[:, :-1]], dim=1) + return (1 - g) * x + g * x_prev + +class BigramHashEmbedding(nn.Module): + def __init__(self, bigram_vocab_size: int, bigram_dim: int, model_dim: int): + super().__init__() + self.bigram_vocab_size = bigram_vocab_size + self.embed = nn.Embedding(bigram_vocab_size, bigram_dim) + nn.init.zeros_(self.embed.weight) + self.proj = CastedLinear(bigram_dim, model_dim, bias=False) if bigram_dim != model_dim else None + if self.proj is not None: + nn.init.zeros_(self.proj.weight) + self.scale = nn.Parameter(torch.tensor(0.05, dtype=torch.float32)) + def bigram_hash(self, tokens: Tensor) -> Tensor: + t = tokens.to(torch.int32) + mod = self.bigram_vocab_size - 1 + out = torch.empty_like(t) + out[..., 0] = mod + out[..., 1:] = torch.bitwise_xor(36313 * t[..., 1:], 27191 * t[..., :-1]) % mod + return out.long() + def forward(self, token_ids: Tensor) -> Tensor: + h = self.embed(self.bigram_hash(token_ids)) + if self.proj is not None: + h = self.proj(h) + return h * self.scale.to(dtype=h.dtype) + +class ValueEmbedding(nn.Module): + """Reinject token identity into attention values at specific layers. + Each table maps vocab tokens to a low-dim embedding, projected to model_dim.""" + def __init__(self, vocab_size: int, ve_dim: int, model_dim: int): + super().__init__() + self.embed = nn.Embedding(vocab_size, ve_dim) + nn.init.normal_(self.embed.weight, std=0.01) + self.proj = CastedLinear(ve_dim, model_dim, bias=False) if ve_dim != model_dim else None + if self.proj is not None: + nn.init.zeros_(self.proj.weight) + self.scale = nn.Parameter(torch.tensor(0.1, dtype=torch.float32)) + def forward(self, token_ids: Tensor) -> Tensor: + h = self.embed(token_ids) + if self.proj is not None: + h = self.proj(h) + return h * self.scale.to(dtype=h.dtype) + +class MLP(nn.Module): + def __init__( + self, + dim: int, + mlp_mult: int, + activation_mode: str = "leaky_relu_sq", + activation_neg_slope: float = 0.5, + asymmetric_square_init: float = 0.25, + gated_square_beta_init: float = 1.0, + ): + super().__init__() + # No CastedLinear -- weights come from banks + self.activation_mode = activation_mode + self.activation_neg_slope = activation_neg_slope + if activation_mode == "asymmetric_square": + self.neg_sq_scale = nn.Parameter(torch.tensor(asymmetric_square_init, dtype=torch.float32)) + else: + self.neg_sq_scale = None + if activation_mode == "gated_square": + self.gated_square_beta = nn.Parameter(torch.tensor(gated_square_beta_init, dtype=torch.float32)) + else: + self.gated_square_beta = None + def forward(self, x: Tensor, up_w: Tensor, down_w: Tensor) -> Tensor: + u = F.linear(x, up_w.to(x.dtype)) + if self.activation_mode == "leaky_relu_sq": + h = F.leaky_relu(u, negative_slope=self.activation_neg_slope).square() + elif self.activation_mode == "asymmetric_square": + neg_sq_scale = self.neg_sq_scale.to(dtype=u.dtype).clamp(0.0, 4.0) + h = F.relu(u).square() + neg_sq_scale * F.relu(-u).square() + elif self.activation_mode == "gated_square": + beta = self.gated_square_beta.to(dtype=u.dtype).clamp(0.0, 8.0) + h = u.square() * torch.sigmoid(beta * u) + elif self.activation_mode == "sign_preserving_square": + h = u * u.abs() + else: + raise ValueError(f"Unknown ACTIVATION_MODE={self.activation_mode}") + return F.linear(h, down_w.to(x.dtype)) + +class Block(nn.Module): + def __init__( + self, + dim: int, + num_heads: int, + num_kv_heads: int, + mlp_mult: int, + rope_base: float, + qk_gain_init: float, + layer_idx: int = 0, + ln_scale: bool = False, + dtg: bool = False, + gated_attention: bool = False, + value_residual: bool = False, + activation_mode: str = "leaky_relu_sq", + activation_neg_slope: float = 0.5, + asymmetric_square_init: float = 0.25, + gated_square_beta_init: float = 1.0, + ): + super().__init__() + self.attn_norm = RMSNorm() + self.mlp_norm = RMSNorm() + self.attn = CausalSelfAttention(dim, num_heads, num_kv_heads, rope_base, qk_gain_init, + gated_attention=gated_attention, value_residual=value_residual) + self.mlp = MLP( + dim, + mlp_mult, + activation_mode=activation_mode, + activation_neg_slope=activation_neg_slope, + asymmetric_square_init=asymmetric_square_init, + gated_square_beta_init=gated_square_beta_init, + ) + self.attn_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + self.mlp_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + self.resid_mix = nn.Parameter(torch.stack((torch.ones(dim), torch.zeros(dim))).float()) + self.ln_scale_factor = 1.0 / math.sqrt(layer_idx + 1) if ln_scale else 1.0 + if dtg: + self.dtg_gate = nn.Linear(dim, 1, bias=True) + nn.init.zeros_(self.dtg_gate.weight) + nn.init.constant_(self.dtg_gate.bias, 2.0) + else: + self.dtg_gate = None + def forward(self, x: Tensor, x0: Tensor, q_w: Tensor, k_w: Tensor, v_w: Tensor, out_w: Tensor, up_w: Tensor, down_w: Tensor, v_embed: Tensor | None = None, v0: Tensor | None = None) -> tuple[Tensor, Tensor | None]: + mix = self.resid_mix.to(dtype=x.dtype) + x_in = mix[0][None, None, :] * x + mix[1][None, None, :] * x0 + attn_out, raw_v = self.attn(self.attn_norm(x_in) * self.ln_scale_factor, q_w, k_w, v_w, out_w, v_embed=v_embed, v0=v0) + x_out = x_in + self.attn_scale.to(dtype=x_in.dtype)[None, None, :] * attn_out + x_out = x_out + self.mlp_scale.to(dtype=x_out.dtype)[None, None, :] * self.mlp(self.mlp_norm(x_out) * self.ln_scale_factor, up_w, down_w) + if self.dtg_gate is not None: + gate = torch.sigmoid(self.dtg_gate(x_in.detach())) + x_out = x_in + gate * (x_out - x_in) + return x_out, raw_v + +class GPT(nn.Module): + def __init__( + self, + vocab_size: int, + num_layers: int, + model_dim: int, + num_heads: int, + num_kv_heads: int, + mlp_mult: int, + tie_embeddings: bool, + tied_embed_init_std: float, + logit_softcap: float, + rope_base: float, + qk_gain_init: float, + mtp_num_heads: int = 0, + mtp_loss_weight: float = 0.1, + bigram_vocab_size: int = 0, + bigram_dim: int = 128, + xsa_last_n: int = 0, + rope_dims: int = 0, + ln_scale: bool = False, + dtg: bool = False, + ve_enabled: bool = False, + ve_dim: int = 128, + ve_layers: str = "9,10", + gated_attention: bool = False, + value_residual: bool = False, + activation_mode: str = "leaky_relu_sq", + activation_neg_slope: float = 0.5, + asymmetric_square_init: float = 0.25, + gated_square_beta_init: float = 1.0, + ): + super().__init__() + self._ve_target_dim = num_kv_heads * (model_dim // num_heads) # kv_dim for value projection + if logit_softcap <= 0.0: + raise ValueError(f"logit_softcap must be positive, got {logit_softcap}") + self.tie_embeddings = tie_embeddings + self.tied_embed_init_std = tied_embed_init_std + self.logit_softcap = logit_softcap + self.value_residual = value_residual + self.mtp_num_heads = mtp_num_heads + self.mtp_loss_weight = mtp_loss_weight + self.tok_emb = nn.Embedding(vocab_size, model_dim) + self.bigram = BigramHashEmbedding(bigram_vocab_size, bigram_dim, model_dim) if bigram_vocab_size > 0 else None + self.smear = SmearGate(model_dim) + self.num_encoder_layers = num_layers // 2 + self.num_decoder_layers = num_layers - self.num_encoder_layers + self.num_skip_weights = min(self.num_encoder_layers, self.num_decoder_layers) + self.skip_weights = nn.Parameter(torch.ones(self.num_skip_weights, model_dim, dtype=torch.float32)) + # Parameter banks: contiguous 3D tensors for batched optimizer + head_dim = model_dim // num_heads + kv_dim = num_kv_heads * head_dim + mlp_dim = int(mlp_mult * model_dim) + self.num_layers = num_layers + self.qo_bank = nn.Parameter(torch.empty(2 * num_layers, model_dim, model_dim)) + self.kv_bank = nn.Parameter(torch.empty(2 * num_layers, kv_dim, model_dim)) + self.mlp_up_bank = nn.Parameter(torch.empty(num_layers, mlp_dim, model_dim)) + self.mlp_down_bank = nn.Parameter(torch.empty(num_layers, model_dim, mlp_dim)) + self.blocks = nn.ModuleList( + [ + Block( + model_dim, + num_heads, + num_kv_heads, + mlp_mult, + rope_base, + qk_gain_init, + layer_idx=i, + ln_scale=ln_scale, + dtg=dtg, + gated_attention=gated_attention, + value_residual=value_residual, + activation_mode=activation_mode, + activation_neg_slope=activation_neg_slope, + asymmetric_square_init=asymmetric_square_init, + gated_square_beta_init=gated_square_beta_init, + ) + for i in range(num_layers) + ] + ) + if rope_dims > 0: + head_dim = model_dim // num_heads + for block in self.blocks: + block.attn.rope_dims = rope_dims + block.attn.rotary = Rotary(head_dim, base=rope_base, train_seq_len=1024, rope_dims=rope_dims) + self.ve_layer_indices = [int(x) for x in ve_layers.split(",") if x.strip()] if ve_enabled else [] + kv_dim_ve = self._ve_target_dim + if self.ve_layer_indices: + self.ve_shared = ValueEmbedding(vocab_size, ve_dim, kv_dim_ve) + self.ve_layer_scales = nn.ParameterList( + [nn.Parameter(torch.ones(1, dtype=torch.float32)) for _ in self.ve_layer_indices] + ) + else: + self.ve_shared = None + self.ve_layer_scales = nn.ParameterList() + self.value_embeds = nn.ModuleList() # keep empty for compat + self.final_norm = RMSNorm() + self.lm_head = None if tie_embeddings else CastedLinear(model_dim, vocab_size, bias=False) + if self.lm_head is not None: + self.lm_head._zero_init = True + self.mtp_heads = nn.ModuleList( + [CastedLinear(model_dim, vocab_size, bias=False) for _ in range(mtp_num_heads)] + ) + for head in self.mtp_heads: + head._zero_init = True + if xsa_last_n > 0: + for i in range(max(0, num_layers - xsa_last_n), num_layers): + self.blocks[i].attn.use_xsa = True + self._init_weights() + def _init_weights(self) -> None: + if self.tie_embeddings: + nn.init.normal_(self.tok_emb.weight, mean=0.0, std=self.tied_embed_init_std) + n = self.num_layers + proj_scale = 1.0 / math.sqrt(2 * n) + # Init banks: orthogonal, with proj layers scaled down and out/down zero-init + for i in range(n): + nn.init.orthogonal_(self.qo_bank.data[i], gain=1.0) # Q + nn.init.zeros_(self.qo_bank.data[n + i]) # Out (zero init) + nn.init.orthogonal_(self.kv_bank.data[i], gain=1.0) # K + nn.init.orthogonal_(self.kv_bank.data[n + i], gain=1.0) # V + nn.init.orthogonal_(self.mlp_up_bank.data[i], gain=1.0) # MLP up + nn.init.zeros_(self.mlp_down_bank.data[i]) # MLP down (zero init) + # Scale proj layers (out_proj and mlp_down are "proj" layers) + self.qo_bank.data[n + i].mul_(proj_scale) + self.mlp_down_bank.data[i].mul_(proj_scale) + # Init remaining nn.Linear modules (bigram proj, mtp heads, lm_head) + for name, module in self.named_modules(): + if isinstance(module, nn.Linear): + if getattr(module, "_zero_init", False): + nn.init.zeros_(module.weight) + elif module.weight.ndim == 2 and module.weight.shape[0] >= 64 and module.weight.shape[1] >= 64: + nn.init.orthogonal_(module.weight, gain=1.0) + def _get_ve(self, layer_idx: int, input_ids: Tensor, ve_cache: dict | None = None) -> Tensor | None: + """Get value embedding for a specific layer using shared table + per-layer scale.""" + if self.ve_shared is None or layer_idx not in self.ve_layer_indices: + return None + if ve_cache is not None and 've' not in ve_cache: + ve_cache['ve'] = self.ve_shared(input_ids) + ve_base = ve_cache['ve'] if ve_cache is not None else self.ve_shared(input_ids) + ve_idx = self.ve_layer_indices.index(layer_idx) + return ve_base * self.ve_layer_scales[ve_idx].to(dtype=ve_base.dtype) + def forward(self, input_ids: Tensor, target_ids: Tensor) -> Tensor: + n = self.num_layers + x = self.tok_emb(input_ids) + if self.bigram is not None: + x = x + self.bigram(input_ids) + x = F.rms_norm(x, (x.size(-1),)) + x = self.smear(x) + x0 = x + v0 = None + skips: list[Tensor] = [] + ve_cache: dict = {} + for i in range(self.num_encoder_layers): + ve = self._get_ve(i, input_ids, ve_cache) + x, raw_v = self.blocks[i](x, x0, + self.qo_bank[i], self.kv_bank[i], self.kv_bank[n + i], + self.qo_bank[n + i], self.mlp_up_bank[i], self.mlp_down_bank[i], + v_embed=ve, v0=v0) + if v0 is None and raw_v is not None: + v0 = raw_v + skips.append(x) + for i in range(self.num_decoder_layers): + bi = self.num_encoder_layers + i + if skips: + x = x + self.skip_weights[i].to(dtype=x.dtype)[None, None, :] * skips.pop() + ve = self._get_ve(bi, input_ids, ve_cache) + x, _ = self.blocks[bi](x, x0, + self.qo_bank[bi], self.kv_bank[bi], self.kv_bank[n + bi], + self.qo_bank[n + bi], self.mlp_up_bank[bi], self.mlp_down_bank[bi], + v_embed=ve, v0=v0) + x = self.final_norm(x) + x_flat = x.reshape(-1, x.size(-1)) + targets = target_ids.reshape(-1) + if self.tie_embeddings: + logits_proj = F.linear(x_flat, self.tok_emb.weight) + else: + if self.lm_head is None: + raise RuntimeError("lm_head is required when tie_embeddings=False") + logits_proj = self.lm_head(x_flat) + logits = self.logit_softcap * torch.tanh(logits_proj / self.logit_softcap) + main_loss = F.cross_entropy(logits.float(), targets, reduction="mean") + if self.training and self.mtp_num_heads > 0 and self.mtp_loss_weight > 0.0: + _, seqlen, dim = x.shape + mtp_loss_sum = x.new_zeros(()) + mtp_loss_count = 0 + for k, mtp_head in enumerate(self.mtp_heads): + valid_t = seqlen - (k + 1) + if valid_t <= 0: + continue + mtp_hidden = x[:, :valid_t, :].reshape(-1, dim) + mtp_targets = target_ids[:, k + 1 :].reshape(-1) + mtp_logits_proj = mtp_head(mtp_hidden) + mtp_logits = self.logit_softcap * torch.tanh(mtp_logits_proj / self.logit_softcap) + mtp_loss_sum = mtp_loss_sum + F.cross_entropy(mtp_logits.float(), mtp_targets, reduction="mean") + mtp_loss_count += 1 + if mtp_loss_count > 0: + main_loss = main_loss + self.mtp_loss_weight * (mtp_loss_sum / mtp_loss_count) + return main_loss + def forward_logits(self, input_ids: Tensor) -> Tensor: + """Return logits (bsz, seq_len, vocab) without computing loss.""" + n = self.num_layers + x = self.tok_emb(input_ids) + if self.bigram is not None: + x = x + self.bigram(input_ids) + x = F.rms_norm(x, (x.size(-1),)) + x = self.smear(x) + x0 = x + v0 = None + skips: list[Tensor] = [] + ve_cache: dict = {} + for i in range(self.num_encoder_layers): + ve = self._get_ve(i, input_ids, ve_cache) + x, raw_v = self.blocks[i](x, x0, + self.qo_bank[i], self.kv_bank[i], self.kv_bank[n + i], + self.qo_bank[n + i], self.mlp_up_bank[i], self.mlp_down_bank[i], + v_embed=ve, v0=v0) + if v0 is None and raw_v is not None: + v0 = raw_v + skips.append(x) + for i in range(self.num_decoder_layers): + bi = self.num_encoder_layers + i + if skips: + x = x + self.skip_weights[i].to(dtype=x.dtype)[None, None, :] * skips.pop() + ve = self._get_ve(bi, input_ids, ve_cache) + x, _ = self.blocks[bi](x, x0, + self.qo_bank[bi], self.kv_bank[bi], self.kv_bank[n + bi], + self.qo_bank[n + bi], self.mlp_up_bank[bi], self.mlp_down_bank[bi], + v_embed=ve, v0=v0) + x = self.final_norm(x) + if self.tie_embeddings: + logits_proj = F.linear(x, self.tok_emb.weight) + else: + logits_proj = self.lm_head(x) + return self.logit_softcap * torch.tanh(logits_proj / self.logit_softcap) + +# --- Sliding window evaluation --- + +def eval_val_sliding( + args: Hyperparameters, + base_model: nn.Module, + rank: int, + world_size: int, + device: torch.device, + val_tokens: Tensor, + base_bytes_lut: Tensor, + has_leading_space_lut: Tensor, + is_boundary_token_lut: Tensor, + stride: int, + batch_seqs: int = 32, + eval_seq_len: int | None = None, +) -> tuple[float, float]: + """Sliding window evaluation: each token scored with maximum context.""" + seq_len = eval_seq_len or args.train_seq_len + total_tokens = val_tokens.numel() - 1 + window_starts = [ws for ws in range(0, total_tokens, stride) + if min(ws + seq_len, total_tokens) - ws >= 1] + total_windows = len(window_starts) + my_s = (total_windows * rank) // world_size + my_e = (total_windows * (rank + 1)) // world_size + my_windows = window_starts[my_s:my_e] + loss_sum = torch.zeros((), device=device, dtype=torch.float64) + token_count = torch.zeros((), device=device, dtype=torch.float64) + byte_count = torch.zeros((), device=device, dtype=torch.float64) + base_model.eval() + compiled_logits = torch.compile(base_model.forward_logits, dynamic=False, fullgraph=True) + with torch.inference_mode(): + for bi in range(0, len(my_windows), batch_seqs): + batch_ws = my_windows[bi:bi + batch_seqs] + bsz = len(batch_ws) + x_batch = torch.zeros(bsz, seq_len, dtype=torch.int64, device=device) + y_batch = torch.zeros(bsz, seq_len, dtype=torch.int64, device=device) + wlens: list[int] = [] + for i, ws in enumerate(batch_ws): + end = min(ws + seq_len, total_tokens) + wlen = end - ws + wlens.append(wlen) + chunk = val_tokens[ws:end + 1].to(dtype=torch.int64, device=device) + x_batch[i, :wlen] = chunk[:-1] + y_batch[i, :wlen] = chunk[1:] + with torch.autocast(device_type="cuda", dtype=torch.bfloat16): + logits = compiled_logits(x_batch) + nll = F.cross_entropy( + logits.reshape(-1, logits.size(-1)).float(), + y_batch.reshape(-1), + reduction="none", + ).reshape(bsz, seq_len) + for i, ws in enumerate(batch_ws): + wlen = wlens[i] + s = 0 if ws == 0 else max(wlen - stride, 0) + scored_nll = nll[i, s:wlen].to(torch.float64) + loss_sum += scored_nll.sum() + token_count += float(wlen - s) + tgt = y_batch[i, s:wlen] + prev = x_batch[i, s:wlen] + tb = base_bytes_lut[tgt].to(torch.float64) + tb += (has_leading_space_lut[tgt] & ~is_boundary_token_lut[prev]).to(torch.float64) + byte_count += tb.sum() + if dist.is_available() and dist.is_initialized(): + dist.all_reduce(loss_sum, op=dist.ReduceOp.SUM) + dist.all_reduce(token_count, op=dist.ReduceOp.SUM) + dist.all_reduce(byte_count, op=dist.ReduceOp.SUM) + val_loss = (loss_sum / token_count).item() + bits_per_token = val_loss / math.log(2.0) + tokens_per_byte = token_count.item() / byte_count.item() + base_model.train() + return val_loss, bits_per_token * tokens_per_byte + + +def eval_val_sliding_ttt( + args: Hyperparameters, base_model: nn.Module, rank: int, world_size: int, + device: torch.device, val_tokens: Tensor, base_bytes_lut: Tensor, + has_leading_space_lut: Tensor, is_boundary_token_lut: Tensor, + stride: int, batch_seqs: int = 32, log0=print, +) -> tuple[float, float]: + """Legal score-first TTT (PR #461 recipe): score each chunk with sliding windows, + then train on it. Every token scored BEFORE any update that could use it.""" + seq_len = args.train_seq_len + total_tokens = val_tokens.numel() - 1 + ttt_chunk = args.ttt_chunk_tokens + + # Pre-compute all window starts + window_starts = [ws for ws in range(0, total_tokens, stride) + if min(ws + seq_len, total_tokens) - ws >= stride or ws == 0] + + # Assign each window to a chunk based on the first token it scores + num_chunks = (total_tokens + ttt_chunk - 1) // ttt_chunk + chunk_windows: list[list[int]] = [[] for _ in range(num_chunks)] + for ws in window_starts: + end = min(ws + seq_len, total_tokens) + wlen = end - ws + s = 0 if ws == 0 else max(wlen - stride, 0) + scored_start = ws + s + ci = min(scored_start // ttt_chunk, num_chunks - 1) + chunk_windows[ci].append(ws) + + log0(f"ttt_sliding:start chunks={num_chunks} chunk_tokens={ttt_chunk} " + f"total_windows={len(window_starts)} stride={stride} " + f"ttt_lr={args.ttt_lr} ttt_epochs={args.ttt_epochs} " + f"freeze_blocks={args.ttt_freeze_blocks}") + + loss_sum = torch.zeros((), device=device, dtype=torch.float64) + token_count = torch.zeros((), device=device, dtype=torch.float64) + byte_count = torch.zeros((), device=device, dtype=torch.float64) + + # Freeze first N blocks + frozen_block_ids = set(range(min(args.ttt_freeze_blocks, len(base_model.blocks)))) + ttt_params = [] + for name, p in base_model.named_parameters(): + freeze = False + for bi in frozen_block_ids: + if f"blocks.{bi}." in name: + freeze = True + break + if freeze: + p.requires_grad_(False) + else: + p.requires_grad_(True) + ttt_params.append(p) + + log0(f"ttt_sliding:params unfrozen={sum(p.numel() for p in ttt_params)} " + f"frozen={sum(p.numel() for p in base_model.parameters() if not p.requires_grad)}") + + optimizer = torch.optim.SGD(ttt_params, lr=args.ttt_lr, momentum=args.ttt_momentum) + t0 = time.perf_counter() + + for ci in range(num_chunks): + windows = chunk_windows[ci] + if not windows: + continue + chunk_start = ci * ttt_chunk + chunk_end = min((ci + 1) * ttt_chunk, total_tokens) + + # --- Phase 1: SCORE this chunk's windows (inference_mode) --- + my_s = (len(windows) * rank) // world_size + my_e = (len(windows) * (rank + 1)) // world_size + my_windows = windows[my_s:my_e] + + base_model.eval() + with torch.inference_mode(): + for bi in range(0, len(my_windows), batch_seqs): + batch_ws = my_windows[bi:bi + batch_seqs] + bsz = len(batch_ws) + x_batch = torch.zeros(bsz, seq_len, dtype=torch.int64, device=device) + y_batch = torch.zeros(bsz, seq_len, dtype=torch.int64, device=device) + wlens: list[int] = [] + for i, ws in enumerate(batch_ws): + end = min(ws + seq_len, total_tokens) + wlen = end - ws + wlens.append(wlen) + chunk_tok = val_tokens[ws:end + 1].to(dtype=torch.int64, device=device) + x_batch[i, :wlen] = chunk_tok[:-1] + y_batch[i, :wlen] = chunk_tok[1:] + with torch.autocast(device_type="cuda", dtype=torch.bfloat16): + logits = base_model.forward_logits(x_batch) + nll = F.cross_entropy( + logits.reshape(-1, logits.size(-1)).float(), + y_batch.reshape(-1), reduction="none", + ).reshape(bsz, seq_len) + for i, ws in enumerate(batch_ws): + wlen = wlens[i] + s = 0 if ws == 0 else max(wlen - stride, 0) + scored_nll = nll[i, s:wlen].to(torch.float64) + loss_sum += scored_nll.sum() + token_count += float(wlen - s) + tgt, prev = y_batch[i, s:wlen], x_batch[i, s:wlen] + tb = base_bytes_lut[tgt].to(torch.float64) + tb += (has_leading_space_lut[tgt] & ~is_boundary_token_lut[prev]).to(torch.float64) + byte_count += tb.sum() + + # --- Phase 2: TRAIN on this chunk (already scored = legal) --- + is_last_chunk = (ci == num_chunks - 1) + if not is_last_chunk and args.ttt_epochs > 0: + base_model.train() + chunk_seqs = (chunk_end - chunk_start) // seq_len + if chunk_seqs > 0: + cos_lr = args.ttt_lr * 0.5 * (1.0 + math.cos(math.pi * ci / max(num_chunks - 1, 1))) + for pg in optimizer.param_groups: + pg['lr'] = cos_lr + my_seq_s = (chunk_seqs * rank) // world_size + my_seq_e = (chunk_seqs * (rank + 1)) // world_size + my_chunk_seqs = my_seq_e - my_seq_s + for _ep in range(args.ttt_epochs): + for bs in range(0, my_chunk_seqs, args.ttt_batch_seqs): + be = min(bs + args.ttt_batch_seqs, my_chunk_seqs) + actual_bs = my_seq_s + bs + start_tok = chunk_start + actual_bs * seq_len + end_tok = chunk_start + (my_seq_s + be) * seq_len + 1 + if end_tok > val_tokens.numel(): + continue + local = val_tokens[start_tok:end_tok].to(device=device, dtype=torch.int64) + x = local[:-1].reshape(-1, seq_len) + y = local[1:].reshape(-1, seq_len) + optimizer.zero_grad(set_to_none=True) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16): + loss = base_model(x, y) + loss.backward() + if world_size > 1: + for p in ttt_params: + if p.grad is not None: + dist.all_reduce(p.grad, op=dist.ReduceOp.AVG) + torch.nn.utils.clip_grad_norm_(ttt_params, args.ttt_grad_clip) + optimizer.step() + + if rank == 0 and (ci % 10 == 0 or ci == num_chunks - 1): + elapsed = time.perf_counter() - t0 + rl = loss_sum.item() / max(token_count.item(), 1) + rbpb = rl / math.log(2.0) * (token_count.item() / max(byte_count.item(), 1)) if token_count.item() > 0 else 0.0 + log0(f" ttt_chunk [{ci+1}/{num_chunks}] bpb={rbpb:.6f} time={elapsed:.1f}s") + + if dist.is_available() and dist.is_initialized(): + dist.all_reduce(loss_sum, op=dist.ReduceOp.SUM) + dist.all_reduce(token_count, op=dist.ReduceOp.SUM) + dist.all_reduce(byte_count, op=dist.ReduceOp.SUM) + + val_loss = (loss_sum / token_count).item() + val_bpb = val_loss / math.log(2.0) * (token_count.item() / byte_count.item()) + + for p in base_model.parameters(): + p.requires_grad_(True) + base_model.eval() + + log0(f"ttt_sliding:done val_loss={val_loss:.6f} val_bpb={val_bpb:.6f} " + f"elapsed={time.perf_counter() - t0:.1f}s") + return val_loss, val_bpb + + +# === N-GRAM EVAL CACHE + TWO-PASS RESCORE === + +_NGRAM_PRIMES = np.array([ + 36313, 27191, 51647, 81929, 131071, 174763, 233017, 283721, + 347237, 411527, 479909, 557927, 646333, 746773, 862319, 992353, +], dtype=np.int64) + +# Per-order multipliers: orders 2-3 suppressed, 4 near-neutral, 5-12 boosted +_ORDER_MULTS = np.array([ + 0.30, 0.30, 0.97, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, +], dtype=np.float32) + + +class NgramCache: + """Hash-table n-gram cache with vectorized numpy operations.""" + + def __init__(self, min_order: int = 2, max_order: int = 16, + num_buckets: int = 16_777_216): + self.min_order = min_order + self.max_order = max_order + self.num_orders = max_order - min_order + 1 + self.num_buckets = num_buckets + self.bucket_mask = np.int64(num_buckets - 1) + # Two flat hash tables per order: context counts and full (context+target) counts + self.ctx_tables = [np.zeros(num_buckets, dtype=np.int32) for _ in range(self.num_orders)] + self.full_tables = [np.zeros(num_buckets, dtype=np.int32) for _ in range(self.num_orders)] + + def _compute_hashes(self, tokens_np: np.ndarray, start: int, end: int, order_idx: int): + """Compute context and full hashes for positions [start, end) at given order.""" + n = self.min_order + order_idx + valid_start = max(start, n - 1) + N = end - valid_start + if N <= 0: + return None, None, valid_start + # Context hash: XOR of tokens[pos-n+1+k] * primes[k] for k=0..n-2 + h = np.zeros(N, dtype=np.int64) + for k in range(n - 1): + offset = valid_start - (n - 1) + k + h ^= tokens_np[offset:offset + N].astype(np.int64) * _NGRAM_PRIMES[k % len(_NGRAM_PRIMES)] + ctx_h = h & self.bucket_mask + # Full hash: context + target token + target_prime = _NGRAM_PRIMES[min(n - 1, len(_NGRAM_PRIMES) - 1)] + full_h = (h ^ (tokens_np[valid_start:end].astype(np.int64) * target_prime)) & self.bucket_mask + return ctx_h, full_h, valid_start + + def _bincount_add(self, table: np.ndarray, indices: np.ndarray): + """Fast histogram accumulation using np.bincount (much faster than np.add.at).""" + counts = np.bincount(indices.astype(np.intp), minlength=self.num_buckets) + table += counts[:self.num_buckets].astype(table.dtype) + + def update_range(self, tokens_np: np.ndarray, start: int, end: int): + """Add tokens[start:end] to the cache for all orders.""" + for oi in range(self.num_orders): + ctx_h, full_h, vs = self._compute_hashes(tokens_np, start, end, oi) + if ctx_h is None: + continue + self._bincount_add(self.ctx_tables[oi], ctx_h) + self._bincount_add(self.full_tables[oi], full_h) + + def build_full(self, tokens_np: np.ndarray): + """Build complete cache from entire token sequence (vectorized).""" + for oi in range(self.num_orders): + ctx_h, full_h, _ = self._compute_hashes(tokens_np, 0, len(tokens_np), oi) + if ctx_h is None: + continue + self._bincount_add(self.ctx_tables[oi], ctx_h) + self._bincount_add(self.full_tables[oi], full_h) + + def score_range(self, tokens_np: np.ndarray, start: int, end: int, + min_count: int = 2): + """Score tokens[start:end] against the cache. + + Returns: + ngram_prob: (N,) float32 - n-gram probability for the true target token + matched_order: (N,) int32 - which order matched (-1 = no match) + """ + N = end - start + ngram_prob = np.zeros(N, dtype=np.float32) + matched_order = np.full(N, -1, dtype=np.int32) + matched = np.zeros(N, dtype=bool) + + # Backoff from highest to lowest order + for oi in range(self.num_orders - 1, -1, -1): + n = self.min_order + oi + ctx_h, full_h, vs = self._compute_hashes(tokens_np, start, end, oi) + if ctx_h is None: + continue + offset = vs - start + ctx_counts = self.ctx_tables[oi][ctx_h] + full_counts = self.full_tables[oi][full_h] + # Cap full counts to context counts (hash collision mitigation) + full_counts = np.minimum(full_counts, ctx_counts) + # Only match when: sufficient context, target has been seen, not already matched + eligible = (ctx_counts >= min_count) & (full_counts > 0) & ~matched[offset:] + if not np.any(eligible): + continue + prob = full_counts[eligible].astype(np.float32) / np.maximum(ctx_counts[eligible].astype(np.float32), 1.0) + # Find which positions in the output array to fill + out_idx = np.where(eligible)[0] + offset + ngram_prob[out_idx] = prob + matched_order[out_idx] = n + matched[out_idx] = True + + return ngram_prob, matched_order + + def score_positions(self, tokens_np: np.ndarray, positions: np.ndarray, + min_count: int = 2, leave_one_out: bool = False): + """Score selected token positions against the cache. + + If leave_one_out is enabled, subtract this token's own contribution from + both context and (context,target) counts before matching. + """ + N = len(positions) + ngram_prob = np.zeros(N, dtype=np.float32) + matched_order = np.full(N, -1, dtype=np.int32) + matched = np.zeros(N, dtype=bool) + if N == 0: + return ngram_prob, matched_order + + positions = positions.astype(np.int64, copy=False) + for oi in range(self.num_orders - 1, -1, -1): + n = self.min_order + oi + ctx_h_all, full_h_all, valid_start = self._compute_hashes(tokens_np, 0, len(tokens_np), oi) + if ctx_h_all is None: + continue + + remaining_idx = np.where(~matched)[0] + if remaining_idx.size == 0: + break + pos_sub = positions[remaining_idx] + valid_mask = pos_sub >= valid_start + if not np.any(valid_mask): + continue + + valid_idx = remaining_idx[valid_mask] + lookup = (pos_sub[valid_mask] - valid_start).astype(np.int64) + ctx_h = ctx_h_all[lookup] + full_h = full_h_all[lookup] + + ctx_counts = self.ctx_tables[oi][ctx_h].astype(np.int64) + full_counts = self.full_tables[oi][full_h].astype(np.int64) + if leave_one_out: + ctx_counts = np.maximum(ctx_counts - 1, 0) + full_counts = np.maximum(full_counts - 1, 0) + full_counts = np.minimum(full_counts, ctx_counts) + + eligible = (ctx_counts >= min_count) & (full_counts > 0) + if not np.any(eligible): + continue + + out_idx = valid_idx[eligible] + prob = full_counts[eligible].astype(np.float32) / np.maximum(ctx_counts[eligible].astype(np.float32), 1.0) + ngram_prob[out_idx] = prob + matched_order[out_idx] = n + matched[out_idx] = True + + return ngram_prob, matched_order + + +def eval_val_sliding_store( + args: Hyperparameters, base_model: nn.Module, rank: int, world_size: int, + device: torch.device, val_tokens: Tensor, base_bytes_lut: Tensor, + has_leading_space_lut: Tensor, is_boundary_token_lut: Tensor, + stride: int, batch_seqs: int = 32, log0=print, +) -> tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray, float, float]: + """Sliding-window eval that stores per-token model_p and entropy. + + Returns: (model_p, entropy, token_bytes, token_targets, val_loss, val_bpb) + where model_p and entropy are arrays covering this rank's scored tokens, + and val_loss/val_bpb are the standard (un-blended) metrics. + + Also returns global-offset index arrays for mapping back to token positions. + """ + seq_len = args.train_seq_len + total_tokens = val_tokens.numel() - 1 + window_starts = [ws for ws in range(0, total_tokens, stride) + if min(ws + seq_len, total_tokens) - ws >= 1] + total_windows = len(window_starts) + my_s = (total_windows * rank) // world_size + my_e = (total_windows * (rank + 1)) // world_size + my_windows = window_starts[my_s:my_e] + + # Pre-allocate per-token storage (we'll trim later) + # Each token is scored in exactly one window + model_p_list: list[np.ndarray] = [] + entropy_list: list[np.ndarray] = [] + bytes_list: list[np.ndarray] = [] + position_list: list[np.ndarray] = [] # global target-token positions + nll_list: list[np.ndarray] = [] + + base_model.eval() + compiled_logits = torch.compile(base_model.forward_logits, dynamic=False, fullgraph=True) + with torch.inference_mode(): + for bi in range(0, len(my_windows), batch_seqs): + batch_ws = my_windows[bi:bi + batch_seqs] + bsz = len(batch_ws) + x_batch = torch.zeros(bsz, seq_len, dtype=torch.int64, device=device) + y_batch = torch.zeros(bsz, seq_len, dtype=torch.int64, device=device) + wlens: list[int] = [] + for i, ws in enumerate(batch_ws): + end_pos = min(ws + seq_len, total_tokens) + wlen = end_pos - ws + wlens.append(wlen) + chunk = val_tokens[ws:end_pos + 1].to(dtype=torch.int64, device=device) + x_batch[i, :wlen] = chunk[:-1] + y_batch[i, :wlen] = chunk[1:] + with torch.autocast(device_type="cuda", dtype=torch.bfloat16): + logits = compiled_logits(x_batch) # (bsz, seq_len, vocab_size) + # Compute per-token quantities + logits_f = logits.float() + log_probs = F.log_softmax(logits_f, dim=-1) # (bsz, seq_len, V) + probs = log_probs.exp() + # NLL for each token + nll_all = F.cross_entropy( + logits_f.reshape(-1, logits_f.size(-1)), + y_batch.reshape(-1), reduction="none" + ).reshape(bsz, seq_len) + # Model probability of true token + mp = probs.gather(2, y_batch.unsqueeze(-1)).squeeze(-1) # (bsz, seq_len) + # Entropy of model distribution + ent = -(probs * log_probs).sum(dim=-1) # (bsz, seq_len) + + for i, ws in enumerate(batch_ws): + wlen = wlens[i] + s = 0 if ws == 0 else max(wlen - stride, 0) + # Positions are TARGET token indices in val_tokens (ws+j+1 for scored position j) + positions = np.arange(ws + s + 1, ws + wlen + 1, dtype=np.int64) + position_list.append(positions) + model_p_list.append(mp[i, s:wlen].cpu().numpy().astype(np.float32)) + entropy_list.append(ent[i, s:wlen].cpu().numpy().astype(np.float32)) + nll_list.append(nll_all[i, s:wlen].cpu().numpy().astype(np.float64)) + tgt = y_batch[i, s:wlen] + prev = x_batch[i, s:wlen] + tb = base_bytes_lut[tgt].to(torch.float64) + tb += (has_leading_space_lut[tgt] & ~is_boundary_token_lut[prev]).to(torch.float64) + bytes_list.append(tb.cpu().numpy()) + + all_positions = np.concatenate(position_list) if position_list else np.array([], dtype=np.int64) + all_model_p = np.concatenate(model_p_list) if model_p_list else np.array([], dtype=np.float32) + all_entropy = np.concatenate(entropy_list) if entropy_list else np.array([], dtype=np.float32) + all_nll = np.concatenate(nll_list) if nll_list else np.array([], dtype=np.float64) + all_bytes = np.concatenate(bytes_list) if bytes_list else np.array([], dtype=np.float64) + + + # Compute standard (un-blended) BPB for this rank + local_loss_sum = all_nll.sum() + local_token_count = float(len(all_nll)) + local_byte_count = all_bytes.sum() + + # All-reduce for standard BPB + loss_sum_t = torch.tensor(local_loss_sum, device=device, dtype=torch.float64) + token_count_t = torch.tensor(local_token_count, device=device, dtype=torch.float64) + byte_count_t = torch.tensor(local_byte_count, device=device, dtype=torch.float64) + if dist.is_available() and dist.is_initialized(): + dist.all_reduce(loss_sum_t, op=dist.ReduceOp.SUM) + dist.all_reduce(token_count_t, op=dist.ReduceOp.SUM) + dist.all_reduce(byte_count_t, op=dist.ReduceOp.SUM) + val_loss = (loss_sum_t / token_count_t).item() + val_bpb = val_loss / math.log(2.0) * (token_count_t.item() / byte_count_t.item()) + + base_model.train() + return all_model_p, all_entropy, all_bytes, all_positions, val_loss, val_bpb + + +def ngram_rescore( + args: Hyperparameters, + tokens_np: np.ndarray, + cache: NgramCache, + model_p: np.ndarray, + entropy: np.ndarray, + token_bytes: np.ndarray, + positions: np.ndarray, + rank: int, world_size: int, device: torch.device, + log0=print, +) -> tuple[float, float]: + """Rescore tokens using n-gram cache blended with stored neural model_p. + + This is Pass 2: the cache is already complete. + Returns: (val_loss, val_bpb) + """ + N = len(positions) + if N == 0: + return 0.0, 0.0 + + # Score all of this rank's positions against the full cache + # We need to score at the GLOBAL token positions + # The cache.score_range expects contiguous ranges, but our positions may be sparse + # Instead, we score the full range and index into it + # Actually, positions are sorted (from sliding windows), so we can score chunks + + ngram_prob, matched_order = cache.score_positions( + tokens_np, + positions, + min_count=args.ngram_min_count, + leave_one_out=args.ngram_leave_one_out, + ) + matched = matched_order >= 0 + + # Entropy-adaptive alpha with per-order multipliers + alpha = np.zeros(N, dtype=np.float32) + if np.any(matched): + order_idx = (matched_order[matched] - cache.min_order).astype(np.int32) + centers = args.ngram_entropy_center - 0.25 * order_idx.astype(np.float32) + sig = 1.0 / (1.0 + np.exp(-args.ngram_entropy_scale * (entropy[matched] - centers))) + raw_alpha = args.ngram_alpha_min + (args.ngram_alpha_max - args.ngram_alpha_min) * sig + # Per-order multipliers + mults = _ORDER_MULTS[np.minimum(order_idx, len(_ORDER_MULTS) - 1)] + raw_alpha *= mults + alpha[matched] = np.clip(raw_alpha, 0.0, 0.95) + + # Blend: p_blend = (1 - alpha) * model_p + alpha * ngram_prob + p_blend = (1.0 - alpha) * model_p + alpha * ngram_prob + # Clamp to avoid log(0) + p_blend = np.maximum(p_blend, 1e-10) + # For unmatched tokens, use model_p directly + p_blend[~matched] = np.maximum(model_p[~matched], 1e-10) + + # NLL + nll = -np.log(p_blend).astype(np.float64) + + # Aggregate + local_loss_sum = nll.sum() + local_token_count = float(N) + local_byte_count = token_bytes.sum() + + # All-reduce + loss_sum_t = torch.tensor(local_loss_sum, device=device, dtype=torch.float64) + token_count_t = torch.tensor(local_token_count, device=device, dtype=torch.float64) + byte_count_t = torch.tensor(local_byte_count, device=device, dtype=torch.float64) + if dist.is_available() and dist.is_initialized(): + dist.all_reduce(loss_sum_t, op=dist.ReduceOp.SUM) + dist.all_reduce(token_count_t, op=dist.ReduceOp.SUM) + dist.all_reduce(byte_count_t, op=dist.ReduceOp.SUM) + + val_loss = (loss_sum_t / token_count_t).item() + val_bpb = val_loss / math.log(2.0) * (token_count_t.item() / byte_count_t.item()) + + n_matched = int(matched.sum()) + log0( + f"ngram_rescore: matched={n_matched}/{N} ({100*n_matched/max(N,1):.1f}%) " + f"mean_alpha={alpha[matched].mean():.3f} leave_one_out={int(args.ngram_leave_one_out)}" + if n_matched > 0 else f"ngram_rescore: no matches leave_one_out={int(args.ngram_leave_one_out)}" + ) + + return val_loss, val_bpb + + +def eval_ngram_two_pass( + args: Hyperparameters, base_model: nn.Module, rank: int, world_size: int, + device: torch.device, val_tokens: Tensor, base_bytes_lut: Tensor, + has_leading_space_lut: Tensor, is_boundary_token_lut: Tensor, + stride: int, batch_seqs: int = 32, log0=print, +) -> tuple[float, float]: + """Two-pass n-gram evaluation. + + Pass 1: Sliding-window neural eval → store per-token model_p and entropy. + Build: Complete n-gram cache from all tokens (vectorized). + Pass 2: Rescore ALL tokens by blending neural model_p with n-gram predictions. + """ + t0 = time.perf_counter() + + # --- Pass 1: Neural eval with per-token storage --- + log0(f"ngram_two_pass: starting Pass 1 (sliding-window neural eval)") + model_p, entropy, token_bytes, positions, pass1_loss, pass1_bpb = eval_val_sliding_store( + args, base_model, rank, world_size, device, val_tokens, + base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, + stride=stride, batch_seqs=batch_seqs, log0=log0, + ) + t_pass1 = time.perf_counter() + log0(f"ngram_two_pass: Pass 1 done val_bpb={pass1_bpb:.6f} " + f"tokens_scored={len(positions)} time={t_pass1 - t0:.1f}s") + + # --- Build complete n-gram cache --- + log0(f"ngram_two_pass: building cache orders={args.ngram_min_order}-{args.ngram_max_order} " + f"buckets={args.ngram_num_buckets}") + tokens_np = val_tokens.numpy().astype(np.int16) + cache = NgramCache( + min_order=args.ngram_min_order, + max_order=args.ngram_max_order, + num_buckets=args.ngram_num_buckets, + ) + cache.build_full(tokens_np) + t_cache = time.perf_counter() + log0(f"ngram_two_pass: cache built in {t_cache - t_pass1:.1f}s") + + # --- Pass 2: N-gram rescore --- + log0(f"ngram_two_pass: starting Pass 2 (n-gram rescore)") + val_loss, val_bpb = ngram_rescore( + args, tokens_np, cache, model_p, entropy, token_bytes, positions, + rank, world_size, device, log0=log0, + ) + t_pass2 = time.perf_counter() + log0(f"ngram_two_pass: Pass 2 done val_bpb={val_bpb:.6f} " + f"improvement={pass1_bpb - val_bpb:.6f} time={t_pass2 - t_cache:.1f}s") + log0(f"ngram_two_pass: total time={t_pass2 - t0:.1f}s") + + return val_loss, val_bpb + + +# === COMPLEMENTARY TRAINING === + +class TrainBigramTracker: + """Tracks bigram statistics from training data for complementary loss weighting.""" + + def __init__(self, vocab_size: int, device: torch.device): + # bigram_counts[prev_token, target_token] = count + self.counts = torch.zeros(vocab_size, vocab_size, device=device, dtype=torch.float32) + self.row_totals = torch.zeros(vocab_size, device=device, dtype=torch.float32) + + @torch.no_grad() + def update(self, x: Tensor, y: Tensor): + """Update bigram counts. x: context tokens, y: target tokens.""" + prev = x.reshape(-1) + tgt = y.reshape(-1) + idx = prev.long() * self.counts.shape[1] + tgt.long() + self.counts.view(-1).scatter_add_(0, idx, torch.ones_like(idx, dtype=torch.float32)) + self.row_totals.scatter_add_(0, prev.long(), torch.ones(prev.shape[0], device=prev.device, dtype=torch.float32)) + + @torch.no_grad() + def get_weights(self, x: Tensor, y: Tensor, alpha: float = 0.5) -> Tensor: + """Compute per-token loss weights: downweight tokens predictable by bigrams.""" + prev = x.reshape(-1) + tgt = y.reshape(-1) + totals = self.row_totals[prev.long()] + counts = self.counts[prev.long(), tgt.long()] + ngram_prob = counts / totals.clamp(min=1.0) + weights = (1.0 - alpha * ngram_prob).clamp(min=0.1) + return weights.reshape(y.shape) + + +# --- GPTQ-lite int6 quantization --- + +def _classify_param(name: str) -> str: + if "tok_emb" in name or "lm_head" in name: + return "embed" + if ".mlp." in name: + return "mlp" + if ".attn." in name or (".proj." in name and ".mlp." not in name): + return "attn" + return "other" +def quantize_int6_per_row(t: Tensor, clip_range: int = 31) -> tuple[Tensor, Tensor]: + t32 = t.float() + if t32.ndim == 2: + best_q, best_s, best_err = None, None, float('inf') + for pct in [0.9990, 0.9995, 0.9999, 0.99999, 1.0]: + if pct < 1.0: + row_clip = torch.quantile(t32.abs(), pct, dim=1) + else: + row_clip = t32.abs().amax(dim=1) + s = (row_clip / clip_range).clamp_min(1.0 / clip_range).to(torch.float16) + q = torch.clamp(torch.round(t32 / s.float()[:, None]), -clip_range, clip_range).to(torch.int8) + recon = q.float() * s.float()[:, None] + err = (t32 - recon).pow(2).mean().item() + if err < best_err: + best_q, best_s, best_err = q, s, err + return best_q, best_s + amax = t32.abs().max().item() + scale = torch.tensor(amax / clip_range if amax > 0 else 1.0, dtype=torch.float16) + q = torch.clamp(torch.round(t32 / scale.float()), -clip_range, clip_range).to(torch.int8) + return q, scale + +def _unbank_state_dict(sd: dict[str, Tensor], num_layers: int) -> dict[str, Tensor]: + """Convert 3D bank tensors into individual 2D tensors with standard names.""" + out: dict[str, Tensor] = {} + n = num_layers + for name, tensor in sd.items(): + if name == "qo_bank": + for i in range(n): + out[f"blocks.{i}.attn.c_q.weight"] = tensor[i] + out[f"blocks.{i}.attn.proj.weight"] = tensor[n + i] + elif name == "kv_bank": + for i in range(n): + out[f"blocks.{i}.attn.c_k.weight"] = tensor[i] + out[f"blocks.{i}.attn.c_v.weight"] = tensor[n + i] + elif name == "mlp_up_bank": + for i in range(n): + out[f"blocks.{i}.mlp.fc.weight"] = tensor[i] + elif name == "mlp_down_bank": + for i in range(n): + out[f"blocks.{i}.mlp.proj.weight"] = tensor[i] + else: + out[name] = tensor + return out + +def _rebank_state_dict(sd: dict[str, Tensor], num_layers: int, template_sd: dict[str, Tensor]) -> dict[str, Tensor]: + """Convert individual 2D tensors back into 3D bank tensors.""" + out: dict[str, Tensor] = {} + n = num_layers + # Reconstruct banks from individual weight keys + qo_slices = [None] * (2 * n) + kv_slices = [None] * (2 * n) + up_slices = [None] * n + down_slices = [None] * n + consumed = set() + for i in range(n): + qk = f"blocks.{i}.attn.c_q.weight" + if qk in sd: + qo_slices[i] = sd[qk] + consumed.add(qk) + ok = f"blocks.{i}.attn.proj.weight" + if ok in sd: + qo_slices[n + i] = sd[ok] + consumed.add(ok) + kk = f"blocks.{i}.attn.c_k.weight" + if kk in sd: + kv_slices[i] = sd[kk] + consumed.add(kk) + vk = f"blocks.{i}.attn.c_v.weight" + if vk in sd: + kv_slices[n + i] = sd[vk] + consumed.add(vk) + fk = f"blocks.{i}.mlp.fc.weight" + if fk in sd: + up_slices[i] = sd[fk] + consumed.add(fk) + dk = f"blocks.{i}.mlp.proj.weight" + if dk in sd: + down_slices[i] = sd[dk] + consumed.add(dk) + out["qo_bank"] = torch.stack(qo_slices).to(dtype=template_sd["qo_bank"].dtype) + out["kv_bank"] = torch.stack(kv_slices).to(dtype=template_sd["kv_bank"].dtype) + out["mlp_up_bank"] = torch.stack(up_slices).to(dtype=template_sd["mlp_up_bank"].dtype) + out["mlp_down_bank"] = torch.stack(down_slices).to(dtype=template_sd["mlp_down_bank"].dtype) + for name, tensor in sd.items(): + if name not in consumed: + out[name] = tensor + return out + +def mixed_quantize_int6(state_dict: dict[str, Tensor], int6_cats: set[str]): + num_layers_total = max( + (int(k.split(".")[1]) for k in state_dict if k.startswith("blocks.")), + default=0, + ) + 1 + late_k_layers = set(range(num_layers_total - 2, num_layers_total)) + result: dict[str, Tensor] = {} + meta: dict[str, object] = {} + for name, tensor in state_dict.items(): + t = tensor.detach().cpu().contiguous() + cat = _classify_param(name) + if not t.is_floating_point() or t.numel() <= 65536: + result[name] = t.to(torch.float16) if t.is_floating_point() else t + meta[name] = "passthrough" + continue + if any(p in name for p in CONTROL_TENSOR_NAME_PATTERNS): + result[name] = t.float() + meta[name] = "passthrough_ctrl" + continue + if cat in int6_cats and t.ndim >= 1: + q, s = quantize_int6_per_row(t) + result[name + ".q"] = q + result[name + ".scale"] = s + meta[name] = {"type": "int6"} + else: + q, s = quantize_float_tensor(t) + result[name + ".q"] = q + result[name + ".scale"] = s + meta[name] = {"type": "int8"} + return result, meta +def dequantize_mixed_int6(result: dict[str, Tensor], meta: dict[str, object], + template_sd: dict[str, Tensor]) -> dict[str, Tensor]: + out: dict[str, Tensor] = {} + for name, orig in template_sd.items(): + info = meta.get(name) + if info is None: + continue + orig_dtype = orig.dtype + if info in ("passthrough", "passthrough_ctrl", "passthrough_fp16"): + t = result[name] + if t.dtype == torch.float16 and orig_dtype in (torch.float32, torch.bfloat16): + t = t.to(orig_dtype) + out[name] = t + continue + q, s = result[name + ".q"], result[name + ".scale"] + if s.ndim > 0: + out[name] = (q.float() * s.float().view(q.shape[0], *([1] * (q.ndim - 1)))).to(orig_dtype) + else: + out[name] = (q.float() * float(s.item())).to(orig_dtype) + return out + +# --- Training --- + +def main() -> None: + code = Path(__file__).read_text(encoding="utf-8") + args = Hyperparameters() + # zeropower_via_newtonschulz5 runs eagerly with bmm -- do NOT compile + distributed = "RANK" in os.environ and "WORLD_SIZE" in os.environ + rank = int(os.environ.get("RANK", "0")) + world_size = int(os.environ.get("WORLD_SIZE", "1")) + local_rank = int(os.environ.get("LOCAL_RANK", "0")) + if world_size <= 0: + raise ValueError(f"WORLD_SIZE must be positive, got {world_size}") + if 8 % world_size != 0: + raise ValueError(f"WORLD_SIZE={world_size} must divide 8 so grad_accum_steps stays integral") + grad_accum_steps = 8 // world_size + grad_scale = 1.0 / grad_accum_steps + if not torch.cuda.is_available(): + raise RuntimeError("CUDA is required") + device = torch.device("cuda", local_rank) + torch.cuda.set_device(device) + if distributed: + dist.init_process_group(backend="nccl", device_id=device) + dist.barrier() + master_process = rank == 0 + torch.backends.cuda.matmul.allow_tf32 = True + torch.backends.cudnn.allow_tf32 = True + from torch.backends.cuda import enable_cudnn_sdp, enable_flash_sdp, enable_math_sdp, enable_mem_efficient_sdp + enable_cudnn_sdp(False) + enable_flash_sdp(True) + enable_mem_efficient_sdp(False) + enable_math_sdp(False) + logfile = None + if master_process: + os.makedirs("logs", exist_ok=True) + logfile = f"logs/{args.run_id}.txt" + print(logfile) + def log0(msg: str, console: bool = True) -> None: + if not master_process: + return + if console: + print(msg) + if logfile is not None: + with open(logfile, "a", encoding="utf-8") as f: + print(msg, file=f) + log0(code, console=False) + log0("=" * 100, console=False) + log0(f"Running Python {sys.version}", console=False) + log0(f"Running PyTorch {torch.__version__}", console=False) + log0( + subprocess.run(["nvidia-smi"], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True, check=False).stdout, + console=False, + ) + log0("=" * 100, console=False) + random.seed(args.seed) + np.random.seed(args.seed) + torch.manual_seed(args.seed) + torch.cuda.manual_seed_all(args.seed) + dataset_dir = Path(args.data_path).resolve() + actual_train_files = len(list(dataset_dir.glob("fineweb_train_*.bin"))) + effective_eval_seq_len = args.eval_seq_len if args.eval_seq_len > 0 else args.train_seq_len + val_seq_len = max(args.train_seq_len, effective_eval_seq_len) + val_tokens = load_validation_tokens(args.val_files, val_seq_len, args.val_tokens_limit) + (base_bytes_lut, has_leading_space_lut, is_boundary_token_lut), tokenizer_meta = load_tokenizer_luts( + args.tokenizer_path, + args.tokenizer_meta_path, + args.vocab_size, + device, + validate_meta=args.tokenizer_meta_validate, + ) + meta_path = tokenizer_meta.get("meta_path") + log0( + f"val_bpb:enabled tokenizer_kind={tokenizer_meta['tokenizer_kind']} " + f"tokenizer_path={args.tokenizer_path} tokenizer_meta_path={meta_path or 'fallback:none'}" + ) + log0(f"train_loader:dataset:{dataset_dir.name} train_shards:{actual_train_files}") + log0(f"val_loader:shards pattern={args.val_files} tokens:{val_tokens.numel() - 1}") + CastedLinear._qat_enabled = args.qat_enabled + base_model = GPT( + vocab_size=args.vocab_size, + num_layers=args.num_layers, + model_dim=args.model_dim, + num_heads=args.num_heads, + num_kv_heads=args.num_kv_heads, + mlp_mult=args.mlp_mult, + tie_embeddings=args.tie_embeddings, + tied_embed_init_std=args.tied_embed_init_std, + logit_softcap=args.logit_softcap, + rope_base=args.rope_base, + qk_gain_init=args.qk_gain_init, + mtp_num_heads=args.mtp_num_heads, + mtp_loss_weight=args.mtp_loss_weight, + bigram_vocab_size=args.bigram_vocab_size, + bigram_dim=args.bigram_dim, + xsa_last_n=args.xsa_last_n, + rope_dims=args.rope_dims, + ln_scale=args.ln_scale, + dtg=args.dtg_enabled, + ve_enabled=args.ve_enabled, + ve_dim=args.ve_dim, + ve_layers=args.ve_layers, + gated_attention=args.gated_attention, + value_residual=args.value_residual, + activation_mode=args.activation_mode, + activation_neg_slope=args.activation_neg_slope, + asymmetric_square_init=args.asymmetric_square_init, + gated_square_beta_init=args.gated_square_beta_init, + ).to(device).bfloat16() + # Banks stay FP32 (like CastedLinear weights), cast to BF16 in forward + base_model.qo_bank.data = base_model.qo_bank.data.float() + base_model.kv_bank.data = base_model.kv_bank.data.float() + base_model.mlp_up_bank.data = base_model.mlp_up_bank.data.float() + base_model.mlp_down_bank.data = base_model.mlp_down_bank.data.float() + for module in base_model.modules(): + if isinstance(module, CastedLinear): + module.float() + restore_low_dim_params_to_fp32(base_model) + # No DDP -- Parallel Muon handles bank grad communication via reduce-scatter, + # and non-bank grads are manually all-reduced before Adam steps. + compiled_model = torch.compile(base_model, dynamic=False, fullgraph=True) + model = compiled_model + # Separate compile for forward_logits (used in complementary training) + compiled_forward_logits = torch.compile(base_model.forward_logits, dynamic=False, fullgraph=True) + + # Optimizer split: + # - 4 parameter banks -> Muon (batched Newton-Schulz) + # - token embedding -> Adam + # - scalars/control tensors -> Adam + # - bigram proj, mtp heads, VE proj -> Adam (small matrix params not worth banking) + matrix_params = [ + base_model.qo_bank, base_model.kv_bank, + base_model.mlp_up_bank, base_model.mlp_down_bank, + ] + block_named_params = list(base_model.blocks.named_parameters()) + scalar_params = [ + p + for name, p in block_named_params + if p.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) + ] + if base_model.skip_weights.numel() > 0: + scalar_params.append(base_model.skip_weights) + scalar_params.append(base_model.smear.gate) + if base_model.bigram is not None: + scalar_params.append(base_model.bigram.scale) + token_lr = args.tied_embed_lr if args.tie_embeddings else args.embed_lr + tok_params = [{"params": [base_model.tok_emb.weight], "lr": token_lr, "base_lr": token_lr}] + if base_model.bigram is not None: + tok_params.append({"params": [base_model.bigram.embed.weight], "lr": token_lr, "base_lr": token_lr}) + if base_model.bigram.proj is not None: + scalar_params.append(base_model.bigram.proj.weight) + if base_model.ve_shared is not None: + tok_params.append({"params": [base_model.ve_shared.embed.weight], "lr": token_lr, "base_lr": token_lr}) + if base_model.ve_shared.proj is not None: + scalar_params.append(base_model.ve_shared.proj.weight) + scalar_params.append(base_model.ve_shared.scale) + for s in base_model.ve_layer_scales: + scalar_params.append(s) + optimizer_tok = torch.optim.AdamW( + tok_params, + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + weight_decay=args.adam_wd, + fused=True, + ) + optimizer_muon = Muon( + matrix_params, + lr=args.matrix_lr, + momentum=args.muon_momentum, + backend_steps=args.muon_backend_steps, + weight_decay=args.muon_wd, + ) + for group in optimizer_muon.param_groups: + group["base_lr"] = args.matrix_lr + optimizer_scalar = torch.optim.AdamW( + [{"params": scalar_params, "lr": args.scalar_lr, "base_lr": args.scalar_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + weight_decay=args.adam_wd, + fused=True, + ) + # Non-bank params that need manual all-reduce (replicated across GPUs) + replicated_params = list(optimizer_tok.param_groups[0]["params"]) + for pg in optimizer_tok.param_groups[1:]: + replicated_params.extend(pg["params"]) + replicated_params.extend(scalar_params) + + optimizer_head = None + if base_model.lm_head is not None: + optimizer_head = torch.optim.Adam( + [{"params": [base_model.lm_head.weight], "lr": args.head_lr, "base_lr": args.head_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + fused=True, + ) + replicated_params.append(base_model.lm_head.weight) + optimizers: list[torch.optim.Optimizer] = [optimizer_tok, optimizer_muon, optimizer_scalar] + if optimizer_head is not None: + optimizers.append(optimizer_head) + n_params = sum(p.numel() for p in base_model.parameters()) + mtp_params = sum(p.numel() for p in base_model.mtp_heads.parameters()) + log0(f"model_params:{n_params}") + log0(f"mtp_num_heads:{args.mtp_num_heads} mtp_loss_weight:{args.mtp_loss_weight} mtp_params:{mtp_params}") + xsa_layers = [i for i, b in enumerate(base_model.blocks) if b.attn.use_xsa] + log0(f"XSA:last_{args.xsa_last_n} active_layers:{xsa_layers}") + log0(f"world_size:{world_size} grad_accum_steps:{grad_accum_steps}") + log0("sdp_backends:cudnn=False flash=True mem_efficient=False math=False") + log0(f"attention_mode:gqa num_heads:{args.num_heads} num_kv_heads:{args.num_kv_heads}") + log0( + f"tie_embeddings:{args.tie_embeddings} embed_lr:{token_lr} " + f"head_lr:{args.head_lr if base_model.lm_head is not None else 0.0} " + f"matrix_lr:{args.matrix_lr} scalar_lr:{args.scalar_lr}" + ) + log0( + f"activation_mode:{args.activation_mode} neg_slope:{args.activation_neg_slope} " + f"asym_init:{args.asymmetric_square_init} gated_beta_init:{args.gated_square_beta_init}" + ) + log0( + f"train_batch_tokens:{args.train_batch_tokens} train_seq_len:{args.train_seq_len} " + f"iterations:{args.iterations} warmup_steps:{args.warmup_steps} " + f"max_wallclock_seconds:{args.max_wallclock_seconds:.3f}" + ) + log0(f"seed:{args.seed}") + train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) + def zero_grad_all() -> None: + for opt in optimizers: + opt.zero_grad(set_to_none=True) + max_wallclock_ms = 1000.0 * args.max_wallclock_seconds if args.max_wallclock_seconds > 0 else None + def lr_mul(step: int, elapsed_ms: float) -> float: + if args.warmdown_iters <= 0: + return 1.0 + if max_wallclock_ms is None: + warmdown_start = max(args.iterations - args.warmdown_iters, 0) + return max((args.iterations - step) / max(args.warmdown_iters, 1), 0.0) if warmdown_start <= step < args.iterations else 1.0 + step_ms = elapsed_ms / max(step, 1) + warmdown_ms = args.warmdown_iters * step_ms + remaining_ms = max(max_wallclock_ms - elapsed_ms, 0.0) + return remaining_ms / max(warmdown_ms, 1e-9) if remaining_ms <= warmdown_ms else 1.0 + if args.warmup_steps > 0: + initial_model_state = {name: tensor.detach().cpu().clone() for name, tensor in base_model.state_dict().items()} + initial_optimizer_states = [copy.deepcopy(opt.state_dict()) for opt in optimizers] + model.train() + for warmup_step in range(args.warmup_steps): + zero_grad_all() + for micro_step in range(grad_accum_steps): + x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + warmup_loss = model(x, y) + (warmup_loss * grad_scale).backward() + # All-reduce all grads for warmup (simple, not optimized) + if distributed: + for p in base_model.parameters(): + if p.grad is not None: + dist.all_reduce(p.grad, op=dist.ReduceOp.AVG) + for opt in optimizers: + opt.step() + zero_grad_all() + if args.warmup_steps <= 20 or (warmup_step + 1) % 10 == 0 or warmup_step + 1 == args.warmup_steps: + log0(f"warmup_step:{warmup_step + 1}/{args.warmup_steps}") + base_model.load_state_dict(initial_model_state, strict=True) + for opt, state in zip(optimizers, initial_optimizer_states, strict=True): + opt.load_state_dict(state) + zero_grad_all() + train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) + # Complementary training tracker + bigram_tracker = TrainBigramTracker(args.vocab_size, device) if args.complement_enabled else None + if bigram_tracker is not None: + log0(f"complement:enabled alpha={args.complement_alpha}") + swa_state: dict[str, Tensor] | None = None + swa_count = 0 + from collections import deque + lawa_queue: deque[dict[str, Tensor]] = deque(maxlen=args.lawa_k) + ema_state = {name: t.detach().float().clone() for name, t in base_model.state_dict().items()} + ema_decay = 0.997 + training_time_ms = 0.0 + stop_after_step: int | None = None + torch.cuda.synchronize() + t0 = time.perf_counter() + step = 0 + while True: + last_step = step == args.iterations or (stop_after_step is not None and step >= stop_after_step) + should_validate = last_step or (args.val_loss_every > 0 and step % args.val_loss_every == 0) + if should_validate: + torch.cuda.synchronize() + training_time_ms += 1000.0 * (time.perf_counter() - t0) + val_loss, val_bpb = eval_val( + args, + model, + rank, + world_size, + device, + grad_accum_steps, + val_tokens, + base_bytes_lut, + has_leading_space_lut, + is_boundary_token_lut, + ) + log0( + f"step:{step}/{args.iterations} val_loss:{val_loss:.4f} val_bpb:{val_bpb:.4f} " + f"train_time:{training_time_ms:.0f}ms step_avg:{training_time_ms / max(step, 1):.2f}ms" + ) + torch.cuda.synchronize() + t0 = time.perf_counter() + if last_step: + if stop_after_step is not None and step < args.iterations: + log0( + f"stopping_early: wallclock_cap train_time:{training_time_ms:.0f}ms " + f"step:{step}/{args.iterations}" + ) + break + elapsed_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) + scale = lr_mul(step, elapsed_ms) + if args.late_qat_threshold > 0 and scale < args.late_qat_threshold and not CastedLinear._qat_enabled: + CastedLinear._qat_enabled = True + log0(f"late_qat:enabled step:{step} scale:{scale:.4f}") + zero_grad_all() + train_loss = torch.zeros((), device=device) + for micro_step in range(grad_accum_steps): + x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + if args.complement_enabled and bigram_tracker is not None: + # Complementary training: single forward, weighted CE + logits = compiled_forward_logits(x) + logits_flat = logits.reshape(-1, logits.size(-1)).float() + per_token_nll = F.cross_entropy(logits_flat, y.reshape(-1), reduction="none") + comp_weights = bigram_tracker.get_weights(x, y, alpha=args.complement_alpha).reshape(-1) + loss = (per_token_nll * comp_weights).sum() / comp_weights.sum() + bigram_tracker.update(x, y) + else: + loss = model(x, y) + train_loss += loss.detach() + (loss * grad_scale).backward() + train_loss /= grad_accum_steps + frac = min(step / args.muon_momentum_warmup_steps, 1.0) if args.muon_momentum_warmup_steps > 0 else 1.0 + muon_momentum = (1 - frac) * args.muon_momentum_warmup_start + frac * args.muon_momentum + for group in optimizer_muon.param_groups: + group["momentum"] = muon_momentum + for opt in optimizers: + for group in opt.param_groups: + group["lr"] = group["base_lr"] * scale + if args.grad_clip_norm > 0: + torch.nn.utils.clip_grad_norm_(base_model.parameters(), args.grad_clip_norm) + # === 3-phase overlapped optimizer step === + # Phase 1: Launch async reduce-scatter for banks (biggest first) + optimizer_muon.launch_reduce_scatters() + # Phase 2: All-reduce non-bank grads + step Adam (while bank RS is in-flight) + if distributed: + for p in replicated_params: + if p.grad is not None: + dist.all_reduce(p.grad, op=dist.ReduceOp.AVG) + optimizer_tok.step() + optimizer_scalar.step() + if optimizer_head is not None: + optimizer_head.step() + # Phase 3: Wait for RS, local NS5, all-gather (banks processed last) + optimizer_muon.step() + zero_grad_all() + # EMA update + with torch.no_grad(): + for name, t in base_model.state_dict().items(): + ema_state[name].mul_(ema_decay).add_(t.detach().float(), alpha=1.0 - ema_decay) + step += 1 + approx_training_time_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) + if args.swa_enabled and scale < 0.2 and step % args.swa_every == 0: + if swa_state is None: + swa_state = {name: t.detach().cpu().clone() for name, t in base_model.state_dict().items()} + swa_count = 1 + log0(f"swa:start step:{step}") + else: + for name, t in base_model.state_dict().items(): + swa_state[name] += t.detach().cpu() + swa_count += 1 + if args.lawa_enabled and step % args.lawa_freq == 0: + lawa_queue.append({name: t.detach().cpu().clone() for name, t in base_model.state_dict().items()}) + should_log_train = ( + args.train_log_every > 0 + and (step <= 10 or step % args.train_log_every == 0 or stop_after_step is not None) + ) + if should_log_train: + log0( + f"step:{step}/{args.iterations} train_loss:{train_loss.item():.4f} " + f"train_time:{approx_training_time_ms:.0f}ms step_avg:{approx_training_time_ms / step:.2f}ms" + ) + reached_cap = max_wallclock_ms is not None and approx_training_time_ms >= max_wallclock_ms + if distributed and max_wallclock_ms is not None: + reached_cap_tensor = torch.tensor(int(reached_cap), device=device) + dist.all_reduce(reached_cap_tensor, op=dist.ReduceOp.MAX) + reached_cap = bool(reached_cap_tensor.item()) + if stop_after_step is None and reached_cap: + stop_after_step = step + log0( + f"peak memory allocated: {torch.cuda.max_memory_allocated() // 1024 // 1024} MiB " + f"reserved: {torch.cuda.max_memory_reserved() // 1024 // 1024} MiB" + ) + # Apply weight averaging + if args.lawa_enabled and len(lawa_queue) > 1: + log0(f"lawa:applying LAWA averaging k={len(lawa_queue)}") + current_state = base_model.state_dict() + avg_state = {name: torch.zeros(t.shape, dtype=torch.float32, device='cpu') for name, t in current_state.items()} + for snap in lawa_queue: + for name in avg_state: + avg_state[name] += snap[name].float() + for name in avg_state: + avg_state[name] /= len(lawa_queue) + avg_state[name] = avg_state[name].to(dtype=current_state[name].dtype) + base_model.load_state_dict(avg_state, strict=True) + else: + log0("ema:applying EMA weights") + current_state = base_model.state_dict() + avg_state = {name: t.to(dtype=current_state[name].dtype) for name, t in ema_state.items()} + base_model.load_state_dict(avg_state, strict=True) + torch.cuda.synchronize() + t_diag = time.perf_counter() + diag_val_loss, diag_val_bpb = eval_val( + args, compiled_model, rank, world_size, device, grad_accum_steps, + val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, + ) + torch.cuda.synchronize() + log0( + f"DIAGNOSTIC post_ema val_loss:{diag_val_loss:.4f} val_bpb:{diag_val_bpb:.4f} " + f"eval_time:{1000.0 * (time.perf_counter() - t_diag):.0f}ms" + ) + full_state_dict = base_model.state_dict() + export_sd = {k: v for k, v in full_state_dict.items() if "mtp_heads" not in k} + excluded_mtp = sum(int(t.numel()) for k, t in full_state_dict.items() if "mtp_heads" in k) + if excluded_mtp > 0: + log0(f"export_excluding_mtp_params:{excluded_mtp}") + if master_process: + torch.save(export_sd, "final_model.pt") + model_bytes = os.path.getsize("final_model.pt") + code_bytes = len(code.encode("utf-8")) + log0(f"Serialized model: {model_bytes} bytes") + log0(f"Code size: {code_bytes} bytes") + # Unbank 3D tensors into individual 2D tensors for quantization + sd_cpu = {k: v.detach().cpu() for k, v in export_sd.items()} + unbanked_sd = _unbank_state_dict(sd_cpu, args.num_layers) + quant_result, quant_meta = mixed_quantize_int6(unbanked_sd, {"mlp", "attn"}) + quant_buf = io.BytesIO() + torch.save({"w": quant_result, "m": quant_meta}, quant_buf) + quant_raw = quant_buf.getvalue() + quant_blob = lzma.compress(quant_raw, preset=6) + if master_process: + with open("final_model.int6.ptz", "wb") as f: + f.write(quant_blob) + quant_file_bytes = len(quant_blob) + code_bytes = len(code.encode("utf-8")) + log0(f"Serialized model int6+lzma: {quant_file_bytes} bytes") + log0(f"Total submission size int6+lzma: {quant_file_bytes + code_bytes} bytes") + if distributed: + dist.barrier() + with open("final_model.int6.ptz", "rb") as f: + quant_blob_disk = f.read() + quant_state = torch.load( + io.BytesIO(lzma.decompress(quant_blob_disk)), + map_location="cpu", + ) + deq_unbanked = dequantize_mixed_int6(quant_state["w"], quant_state["m"], unbanked_sd) + # Re-bank the dequantized tensors + deq_state = _rebank_state_dict(deq_unbanked, args.num_layers, sd_cpu) + eval_model = GPT( + vocab_size=args.vocab_size, num_layers=args.num_layers, model_dim=args.model_dim, + num_heads=args.num_heads, num_kv_heads=args.num_kv_heads, mlp_mult=args.mlp_mult, + tie_embeddings=args.tie_embeddings, tied_embed_init_std=args.tied_embed_init_std, + logit_softcap=args.logit_softcap, rope_base=args.rope_base, qk_gain_init=args.qk_gain_init, + mtp_num_heads=0, mtp_loss_weight=0.0, + bigram_vocab_size=args.bigram_vocab_size, bigram_dim=args.bigram_dim, + xsa_last_n=args.xsa_last_n, + rope_dims=args.rope_dims, ln_scale=args.ln_scale, dtg=args.dtg_enabled, + ve_enabled=args.ve_enabled, ve_dim=args.ve_dim, ve_layers=args.ve_layers, + gated_attention=args.gated_attention, value_residual=args.value_residual, + activation_mode=args.activation_mode, + activation_neg_slope=args.activation_neg_slope, + asymmetric_square_init=args.asymmetric_square_init, + gated_square_beta_init=args.gated_square_beta_init, + ).to(device).bfloat16() + eval_model.qo_bank.data = eval_model.qo_bank.data.float() + eval_model.kv_bank.data = eval_model.kv_bank.data.float() + eval_model.mlp_up_bank.data = eval_model.mlp_up_bank.data.float() + eval_model.mlp_down_bank.data = eval_model.mlp_down_bank.data.float() + for m in eval_model.modules(): + if isinstance(m, CastedLinear): + m.float() + restore_low_dim_params_to_fp32(eval_model) + eval_model.load_state_dict(deq_state, strict=True) + compiled_eval = torch.compile(eval_model, dynamic=False, fullgraph=True) + torch.cuda.synchronize() + t_qeval = time.perf_counter() + q_val_loss, q_val_bpb = eval_val( + args, compiled_eval, rank, world_size, device, grad_accum_steps, + val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, + eval_seq_len=effective_eval_seq_len, + ) + torch.cuda.synchronize() + log0( + f"final_int6_roundtrip val_loss:{q_val_loss:.4f} val_bpb:{q_val_bpb:.4f} " + f"eval_time:{1000.0 * (time.perf_counter() - t_qeval):.0f}ms" + ) + log0(f"final_int6_roundtrip_exact val_loss:{q_val_loss:.8f} val_bpb:{q_val_bpb:.8f}") + sw_seq_len = effective_eval_seq_len + if args.eval_stride > 0 and args.eval_stride < sw_seq_len: + torch.cuda.synchronize() + t_slide = time.perf_counter() + sw_val_loss, sw_val_bpb = eval_val_sliding( + args, eval_model, rank, world_size, device, + val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, + stride=args.eval_stride, + eval_seq_len=sw_seq_len, + ) + torch.cuda.synchronize() + log0( + f"final_int6_sliding_window val_loss:{sw_val_loss:.4f} val_bpb:{sw_val_bpb:.4f} " + f"stride:{args.eval_stride} eval_time:{1000.0 * (time.perf_counter() - t_slide):.0f}ms" + ) + log0(f"final_int6_sliding_window_exact val_loss:{sw_val_loss:.8f} val_bpb:{sw_val_bpb:.8f}") + log0(f"final_int8_zlib_roundtrip_exact val_loss:{sw_val_loss:.8f} val_bpb:{sw_val_bpb:.8f}") + if args.eval_stride != 64 and 64 < sw_seq_len: + torch.cuda.synchronize() + t_slide64 = time.perf_counter() + sw64_val_loss, sw64_val_bpb = eval_val_sliding( + args, eval_model, rank, world_size, device, + val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, + stride=64, + eval_seq_len=sw_seq_len, + ) + torch.cuda.synchronize() + log0( + f"final_int6_sliding_window_s64 val_loss:{sw64_val_loss:.4f} val_bpb:{sw64_val_bpb:.4f} " + f"stride:64 eval_time:{1000.0 * (time.perf_counter() - t_slide64):.0f}ms" + ) + log0(f"final_int6_sliding_window_s64_exact val_loss:{sw64_val_loss:.8f} val_bpb:{sw64_val_bpb:.8f}") + log0(f"final_int8_zlib_roundtrip_exact val_loss:{sw64_val_loss:.8f} val_bpb:{sw64_val_bpb:.8f}") + # Legal score-first TTT (PR #461 recipe) + if args.ttt_enabled: + torch.cuda.synchronize() + t_ttt = time.perf_counter() + ttt_loss, ttt_bpb = eval_val_sliding_ttt( + args, eval_model, rank, world_size, device, + val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, + stride=args.eval_stride, log0=log0, + ) + torch.cuda.synchronize() + log0(f"legal_ttt val_loss:{ttt_loss:.4f} val_bpb:{ttt_bpb:.4f} " + f"eval_time:{1000.0 * (time.perf_counter() - t_ttt):.0f}ms") + log0(f"legal_ttt_exact val_loss:{ttt_loss:.8f} val_bpb:{ttt_bpb:.8f}") + # --- N-gram two-pass rescore --- + if args.ngram_enabled: + # Use TTT-adapted model if available, otherwise use quantized eval model + ngram_model = eval_model + torch.cuda.synchronize() + t_ngram = time.perf_counter() + ng_val_loss, ng_val_bpb = eval_ngram_two_pass( + args, ngram_model, rank, world_size, device, + val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, + stride=args.eval_stride, log0=log0, + ) + torch.cuda.synchronize() + log0(f"ngram_two_pass val_loss:{ng_val_loss:.4f} val_bpb:{ng_val_bpb:.4f} " + f"eval_time:{1000.0 * (time.perf_counter() - t_ngram):.0f}ms") + log0(f"ngram_two_pass_exact val_loss:{ng_val_loss:.8f} val_bpb:{ng_val_bpb:.8f}") + log0(f"final_int8_zlib_roundtrip_exact val_loss:{ng_val_loss:.8f} val_bpb:{ng_val_bpb:.8f}") + if distributed: + dist.destroy_process_group() +if __name__ == "__main__": + main() diff --git a/records/track_10min_16mb/2026-03-30_tm0054_AutoresearchTokenizer_LegalTTT/train_seed1337.log b/records/track_10min_16mb/2026-03-30_tm0054_AutoresearchTokenizer_LegalTTT/train_seed1337.log new file mode 100644 index 0000000000..0872eda796 --- /dev/null +++ b/records/track_10min_16mb/2026-03-30_tm0054_AutoresearchTokenizer_LegalTTT/train_seed1337.log @@ -0,0 +1,398 @@ + +***************************************** +Setting OMP_NUM_THREADS environment variable for each process to be 1 in default, to avoid your system being overloaded, please further tune the variable for optimal performance in your application as needed. +***************************************** +logs/tm0054_legal_ttt_full_seed1337_r2.txt +val_bpb:enabled tokenizer_kind=tokenmonster tokenizer_path=/tmp/parameter-golf-tm0054-competition-full/tm0054_competition_export/tokenizers/candidate.vocab tokenizer_meta_path=/tmp/parameter-golf-tm0054-competition-full/tm0054_competition_export/tokenizers/candidate.meta.npz +train_loader:dataset:fineweb10B_tm0054 train_shards:79 +val_loader:shards pattern=/tmp/parameter-golf-tm0054-competition-full/tm0054_competition_export/datasets/fineweb10B_tm0054/fineweb_val_*.bin tokens:61020160 +model_params:26911580 +mtp_num_heads:0 mtp_loss_weight:0.2 mtp_params:0 +XSA:last_4 active_layers:[7, 8, 9, 10] +world_size:8 grad_accum_steps:1 +sdp_backends:cudnn=False flash=True mem_efficient=False math=False +attention_mode:gqa num_heads:8 num_kv_heads:4 +tie_embeddings:True embed_lr:0.035 head_lr:0.0 matrix_lr:0.025 scalar_lr:0.025 +activation_mode:leaky_relu_sq neg_slope:0.5 asym_init:0.25 gated_beta_init:1.0 +train_batch_tokens:786432 train_seq_len:2048 iterations:9000 warmup_steps:20 max_wallclock_seconds:600.000 +seed:1337 +warmup_step:1/20 +warmup_step:2/20 +warmup_step:3/20 +warmup_step:4/20 +warmup_step:5/20 +warmup_step:6/20 +warmup_step:7/20 +warmup_step:8/20 +warmup_step:9/20 +warmup_step:10/20 +warmup_step:11/20 +warmup_step:12/20 +warmup_step:13/20 +warmup_step:14/20 +warmup_step:15/20 +warmup_step:16/20 +warmup_step:17/20 +warmup_step:18/20 +warmup_step:19/20 +warmup_step:20/20 +step:1/9000 train_loss:6.8943 train_time:133ms step_avg:133.45ms +step:2/9000 train_loss:9.0962 train_time:161ms step_avg:80.40ms +step:3/9000 train_loss:8.8962 train_time:243ms step_avg:80.96ms +step:4/9000 train_loss:8.4518 train_time:325ms step_avg:81.34ms +step:5/9000 train_loss:7.8813 train_time:407ms step_avg:81.47ms +step:6/9000 train_loss:7.3957 train_time:490ms step_avg:81.65ms +step:7/9000 train_loss:7.1001 train_time:575ms step_avg:82.16ms +step:8/9000 train_loss:6.8254 train_time:659ms step_avg:82.39ms +step:9/9000 train_loss:6.6073 train_time:741ms step_avg:82.36ms +step:10/9000 train_loss:6.3737 train_time:825ms step_avg:82.52ms +step:50/9000 train_loss:3.9960 train_time:4167ms step_avg:83.33ms +step:100/9000 train_loss:3.2253 train_time:8338ms step_avg:83.38ms +step:150/9000 train_loss:2.9620 train_time:12574ms step_avg:83.83ms +step:200/9000 train_loss:2.8154 train_time:16758ms step_avg:83.79ms +step:250/9000 train_loss:2.5509 train_time:20945ms step_avg:83.78ms +step:300/9000 train_loss:2.5182 train_time:25190ms step_avg:83.97ms +step:350/9000 train_loss:2.3841 train_time:29382ms step_avg:83.95ms +step:400/9000 train_loss:2.5005 train_time:33631ms step_avg:84.08ms +step:450/9000 train_loss:2.4195 train_time:37827ms step_avg:84.06ms +step:500/9000 train_loss:2.3448 train_time:42029ms step_avg:84.06ms +step:550/9000 train_loss:2.4393 train_time:46301ms step_avg:84.18ms +step:600/9000 train_loss:2.3860 train_time:50503ms step_avg:84.17ms +step:650/9000 train_loss:2.4902 train_time:54772ms step_avg:84.26ms +step:700/9000 train_loss:2.3098 train_time:58984ms step_avg:84.26ms +step:750/9000 train_loss:2.3169 train_time:63200ms step_avg:84.27ms +step:800/9000 train_loss:2.4018 train_time:67480ms step_avg:84.35ms +step:850/9000 train_loss:2.3439 train_time:71704ms step_avg:84.36ms +step:900/9000 train_loss:2.1836 train_time:75968ms step_avg:84.41ms +step:950/9000 train_loss:2.3288 train_time:80199ms step_avg:84.42ms +step:1000/9000 train_loss:2.2208 train_time:84430ms step_avg:84.43ms +step:1050/9000 train_loss:2.1488 train_time:88692ms step_avg:84.47ms +step:1100/9000 train_loss:2.2503 train_time:92909ms step_avg:84.46ms +step:1150/9000 train_loss:2.3108 train_time:97187ms step_avg:84.51ms +step:1200/9000 train_loss:2.2302 train_time:101417ms step_avg:84.51ms +step:1250/9000 train_loss:2.2135 train_time:105638ms step_avg:84.51ms +step:1300/9000 train_loss:2.3129 train_time:109916ms step_avg:84.55ms +step:1350/9000 train_loss:2.2652 train_time:114130ms step_avg:84.54ms +step:1400/9000 train_loss:2.1949 train_time:118404ms step_avg:84.57ms +step:1450/9000 train_loss:2.1966 train_time:122616ms step_avg:84.56ms +step:1500/9000 train_loss:2.1837 train_time:126825ms step_avg:84.55ms +step:1550/9000 train_loss:2.2415 train_time:131098ms step_avg:84.58ms +step:1600/9000 train_loss:2.3242 train_time:135314ms step_avg:84.57ms +step:1650/9000 train_loss:2.2872 train_time:139529ms step_avg:84.56ms +step:1700/9000 train_loss:2.1840 train_time:143797ms step_avg:84.59ms +step:1750/9000 train_loss:2.2696 train_time:148015ms step_avg:84.58ms +step:1800/9000 train_loss:2.1818 train_time:152282ms step_avg:84.60ms +step:1850/9000 train_loss:2.3925 train_time:156499ms step_avg:84.59ms +step:1900/9000 train_loss:2.0809 train_time:160718ms step_avg:84.59ms +step:1950/9000 train_loss:2.1540 train_time:164984ms step_avg:84.61ms +step:2000/9000 train_loss:2.2138 train_time:169197ms step_avg:84.60ms +step:2050/9000 train_loss:2.0656 train_time:173462ms step_avg:84.62ms +step:2100/9000 train_loss:2.1823 train_time:177673ms step_avg:84.61ms +step:2150/9000 train_loss:2.0539 train_time:181881ms step_avg:84.60ms +step:2200/9000 train_loss:2.0688 train_time:186144ms step_avg:84.61ms +step:2250/9000 train_loss:2.1753 train_time:190359ms step_avg:84.60ms +step:2300/9000 train_loss:2.2515 train_time:194612ms step_avg:84.61ms +step:2350/9000 train_loss:2.1371 train_time:198829ms step_avg:84.61ms +step:2400/9000 train_loss:2.1352 train_time:203043ms step_avg:84.60ms +step:2450/9000 train_loss:2.3163 train_time:207307ms step_avg:84.62ms +step:2500/9000 train_loss:2.0402 train_time:211524ms step_avg:84.61ms +step:2550/9000 train_loss:2.1413 train_time:215784ms step_avg:84.62ms +step:2600/9000 train_loss:2.0413 train_time:219994ms step_avg:84.61ms +step:2650/9000 train_loss:2.0865 train_time:224203ms step_avg:84.60ms +step:2700/9000 train_loss:2.1818 train_time:228464ms step_avg:84.62ms +step:2750/9000 train_loss:2.2111 train_time:232673ms step_avg:84.61ms +step:2800/9000 train_loss:2.1570 train_time:236932ms step_avg:84.62ms +step:2850/9000 train_loss:2.1061 train_time:241142ms step_avg:84.61ms +step:2900/9000 train_loss:2.0616 train_time:245355ms step_avg:84.61ms +step:2950/9000 train_loss:1.9861 train_time:249614ms step_avg:84.61ms +step:3000/9000 train_loss:2.1446 train_time:253817ms step_avg:84.61ms +step:3050/9000 train_loss:2.0544 train_time:258024ms step_avg:84.60ms +step:3100/9000 train_loss:2.0910 train_time:262289ms step_avg:84.61ms +step:3150/9000 train_loss:2.2244 train_time:266500ms step_avg:84.60ms +step:3200/9000 train_loss:1.9779 train_time:270759ms step_avg:84.61ms +step:3250/9000 train_loss:2.1132 train_time:274978ms step_avg:84.61ms +step:3300/9000 train_loss:2.0910 train_time:279186ms step_avg:84.60ms +step:3350/9000 train_loss:2.1038 train_time:283448ms step_avg:84.61ms +step:3400/9000 train_loss:2.1098 train_time:287665ms step_avg:84.61ms +step:3450/9000 train_loss:2.1566 train_time:291922ms step_avg:84.62ms +step:3500/9000 train_loss:2.1111 train_time:296131ms step_avg:84.61ms +step:3550/9000 train_loss:2.0669 train_time:300343ms step_avg:84.60ms +step:3600/9000 train_loss:2.0914 train_time:304610ms step_avg:84.61ms +step:3650/9000 train_loss:2.0109 train_time:308814ms step_avg:84.61ms +step:3700/9000 train_loss:2.0867 train_time:313076ms step_avg:84.62ms +step:3750/9000 train_loss:2.2641 train_time:317276ms step_avg:84.61ms +step:3800/9000 train_loss:2.2269 train_time:321482ms step_avg:84.60ms +step:3850/9000 train_loss:2.0291 train_time:325742ms step_avg:84.61ms +step:3900/9000 train_loss:2.0243 train_time:329949ms step_avg:84.60ms +step:3950/9000 train_loss:2.2021 train_time:334203ms step_avg:84.61ms +step:4000/9000 train_loss:2.1699 train_time:338412ms step_avg:84.60ms +step:4050/9000 train_loss:2.0771 train_time:342621ms step_avg:84.60ms +step:4100/9000 train_loss:2.0659 train_time:346880ms step_avg:84.60ms +step:4150/9000 train_loss:2.0451 train_time:351086ms step_avg:84.60ms +step:4200/9000 train_loss:2.0842 train_time:355343ms step_avg:84.61ms +step:4250/9000 train_loss:2.0472 train_time:359553ms step_avg:84.60ms +step:4300/9000 train_loss:2.2025 train_time:363756ms step_avg:84.59ms +step:4350/9000 train_loss:2.0231 train_time:368021ms step_avg:84.60ms +step:4400/9000 train_loss:2.1058 train_time:372223ms step_avg:84.60ms +step:4450/9000 train_loss:2.1025 train_time:376426ms step_avg:84.59ms +step:4500/9000 train_loss:1.9587 train_time:380693ms step_avg:84.60ms +step:4550/9000 train_loss:2.0305 train_time:384893ms step_avg:84.59ms +step:4600/9000 train_loss:2.0529 train_time:389151ms step_avg:84.60ms +step:4650/9000 train_loss:2.2317 train_time:393360ms step_avg:84.59ms +step:4700/9000 train_loss:2.0178 train_time:397564ms step_avg:84.59ms +step:4750/9000 train_loss:2.0354 train_time:401825ms step_avg:84.59ms +step:4800/9000 train_loss:2.0001 train_time:406028ms step_avg:84.59ms +step:4850/9000 train_loss:2.1272 train_time:410286ms step_avg:84.60ms +step:4900/9000 train_loss:2.0908 train_time:414498ms step_avg:84.59ms +step:4950/9000 train_loss:1.9647 train_time:418698ms step_avg:84.59ms +step:5000/9000 train_loss:2.0511 train_time:422958ms step_avg:84.59ms +step:5050/9000 train_loss:2.0455 train_time:427161ms step_avg:84.59ms +step:5100/9000 train_loss:2.0486 train_time:431422ms step_avg:84.59ms +step:5150/9000 train_loss:2.2386 train_time:435627ms step_avg:84.59ms +step:5200/9000 train_loss:2.0852 train_time:439826ms step_avg:84.58ms +step:5250/9000 train_loss:1.9464 train_time:444086ms step_avg:84.59ms +step:5300/9000 train_loss:2.1142 train_time:448294ms step_avg:84.58ms +step:5350/9000 train_loss:2.2033 train_time:452542ms step_avg:84.59ms +step:5400/9000 train_loss:1.9199 train_time:456804ms step_avg:84.59ms +step:5450/9000 train_loss:2.0144 train_time:461038ms step_avg:84.59ms +step:5500/9000 train_loss:1.9251 train_time:465298ms step_avg:84.60ms +step:5550/9000 train_loss:2.1399 train_time:469497ms step_avg:84.59ms +step:5600/9000 train_loss:1.7047 train_time:473754ms step_avg:84.60ms +step:5650/9000 train_loss:2.0236 train_time:477958ms step_avg:84.59ms +step:5700/9000 train_loss:2.1448 train_time:482160ms step_avg:84.59ms +step:5750/9000 train_loss:2.0373 train_time:486417ms step_avg:84.59ms +step:5800/9000 train_loss:1.9707 train_time:490613ms step_avg:84.59ms +step:5850/9000 train_loss:2.0192 train_time:494877ms step_avg:84.59ms +step:5900/9000 train_loss:1.8729 train_time:499072ms step_avg:84.59ms +step:5950/9000 train_loss:2.0318 train_time:503273ms step_avg:84.58ms +step:6000/9000 train_loss:1.7809 train_time:507529ms step_avg:84.59ms +step:6050/9000 train_loss:1.9303 train_time:511730ms step_avg:84.58ms +step:6100/9000 train_loss:1.8574 train_time:515936ms step_avg:84.58ms +step:6150/9000 train_loss:2.1455 train_time:520194ms step_avg:84.58ms +step:6200/9000 train_loss:1.9855 train_time:524397ms step_avg:84.58ms +step:6250/9000 train_loss:2.1366 train_time:528651ms step_avg:84.58ms +step:6300/9000 train_loss:1.9550 train_time:532859ms step_avg:84.58ms +step:6350/9000 train_loss:1.9460 train_time:537058ms step_avg:84.58ms +swa:start step:6400 +step:6400/9000 train_loss:2.0160 train_time:541316ms step_avg:84.58ms +step:6450/9000 train_loss:2.0628 train_time:545626ms step_avg:84.59ms +step:6500/9000 train_loss:1.9846 train_time:549955ms step_avg:84.61ms +step:6550/9000 train_loss:2.0731 train_time:554225ms step_avg:84.61ms +late_qat:enabled step:6565 scale:0.1499 +step:6600/9000 train_loss:1.9692 train_time:558487ms step_avg:84.62ms +step:6650/9000 train_loss:2.1867 train_time:562809ms step_avg:84.63ms +step:6700/9000 train_loss:2.1137 train_time:567088ms step_avg:84.64ms +step:6750/9000 train_loss:1.9456 train_time:571399ms step_avg:84.65ms +step:6800/9000 train_loss:1.9725 train_time:575665ms step_avg:84.66ms +step:6850/9000 train_loss:2.0897 train_time:579927ms step_avg:84.66ms +step:6900/9000 train_loss:2.0470 train_time:584251ms step_avg:84.67ms +step:6950/9000 train_loss:1.9573 train_time:588512ms step_avg:84.68ms +step:7000/9000 train_loss:2.0273 train_time:592825ms step_avg:84.69ms +step:7050/9000 train_loss:2.0572 train_time:597089ms step_avg:84.69ms +step:7084/9000 val_loss:1.9619 val_bpb:1.0979 train_time:600057ms step_avg:84.71ms +stopping_early: wallclock_cap train_time:600057ms step:7084/9000 +peak memory allocated: 21665 MiB reserved: 22214 MiB +ema:applying EMA weights +DIAGNOSTIC post_ema val_loss:1.9600 val_bpb:1.0969 eval_time:1973ms +Serialized model: 105994166 bytes +Code size: 120316 bytes +Serialized model int6+lzma: 15730440 bytes +Total submission size int6+lzma: 15850756 bytes +final_int6_roundtrip val_loss:1.9757 val_bpb:1.1057 eval_time:19726ms +final_int6_roundtrip_exact val_loss:1.97574271 val_bpb:1.10565088 +final_int6_sliding_window val_loss:1.9370 val_bpb:1.0840 stride:64 eval_time:97276ms +final_int6_sliding_window_exact val_loss:1.93702615 val_bpb:1.08398224 +final_int8_zlib_roundtrip_exact val_loss:1.93702615 val_bpb:1.08398224 +ttt_sliding:start chunks=1863 chunk_tokens=32768 total_windows=953440 stride=64 ttt_lr=0.002 ttt_epochs=3 freeze_blocks=2 +ttt_sliding:params unfrozen=26907468 frozen=4112 + ttt_chunk [1/1863] bpb=1.113653 time=0.5s + ttt_chunk [11/1863] bpb=1.061820 time=2.7s + ttt_chunk [21/1863] bpb=1.059062 time=5.0s + ttt_chunk [31/1863] bpb=1.080179 time=7.2s + ttt_chunk [41/1863] bpb=1.085783 time=9.4s + ttt_chunk [51/1863] bpb=1.083940 time=11.6s + ttt_chunk [61/1863] bpb=1.083649 time=13.9s + ttt_chunk [71/1863] bpb=1.079889 time=16.1s + ttt_chunk [81/1863] bpb=1.080835 time=18.3s + ttt_chunk [91/1863] bpb=1.083227 time=20.6s + ttt_chunk [101/1863] bpb=1.088748 time=22.8s + ttt_chunk [111/1863] bpb=1.091371 time=25.0s + ttt_chunk [121/1863] bpb=1.085018 time=27.3s + ttt_chunk [131/1863] bpb=1.082369 time=29.5s + ttt_chunk [141/1863] bpb=1.082720 time=31.7s + ttt_chunk [151/1863] bpb=1.079922 time=34.0s + ttt_chunk [161/1863] bpb=1.078205 time=36.2s + ttt_chunk [171/1863] bpb=1.084240 time=38.4s + ttt_chunk [181/1863] bpb=1.084184 time=40.6s + ttt_chunk [191/1863] bpb=1.085309 time=42.9s + ttt_chunk [201/1863] bpb=1.086562 time=45.1s + ttt_chunk [211/1863] bpb=1.087345 time=47.3s + ttt_chunk [221/1863] bpb=1.086456 time=49.5s + ttt_chunk [231/1863] bpb=1.085720 time=51.8s + ttt_chunk [241/1863] bpb=1.084495 time=54.0s + ttt_chunk [251/1863] bpb=1.083673 time=56.2s + ttt_chunk [261/1863] bpb=1.082800 time=58.4s + ttt_chunk [271/1863] bpb=1.085760 time=60.7s + ttt_chunk [281/1863] bpb=1.086445 time=62.9s + ttt_chunk [291/1863] bpb=1.088386 time=65.1s + ttt_chunk [301/1863] bpb=1.087342 time=67.4s + ttt_chunk [311/1863] bpb=1.088319 time=69.6s + ttt_chunk [321/1863] bpb=1.088139 time=71.8s + ttt_chunk [331/1863] bpb=1.086710 time=74.0s + ttt_chunk [341/1863] bpb=1.086047 time=76.3s + ttt_chunk [351/1863] bpb=1.087252 time=78.5s + ttt_chunk [361/1863] bpb=1.086204 time=80.7s + ttt_chunk [371/1863] bpb=1.086440 time=83.0s + ttt_chunk [381/1863] bpb=1.087953 time=85.2s + ttt_chunk [391/1863] bpb=1.088653 time=87.4s + ttt_chunk [401/1863] bpb=1.086661 time=89.6s + ttt_chunk [411/1863] bpb=1.085837 time=91.9s + ttt_chunk [421/1863] bpb=1.086779 time=94.1s + ttt_chunk [431/1863] bpb=1.086510 time=96.3s + ttt_chunk [441/1863] bpb=1.086943 time=98.6s + ttt_chunk [451/1863] bpb=1.086233 time=100.8s + ttt_chunk [461/1863] bpb=1.085161 time=103.0s + ttt_chunk [471/1863] bpb=1.085676 time=105.2s + ttt_chunk [481/1863] bpb=1.086439 time=107.5s + ttt_chunk [491/1863] bpb=1.087867 time=109.7s + ttt_chunk [501/1863] bpb=1.086972 time=111.9s + ttt_chunk [511/1863] bpb=1.087574 time=114.2s + ttt_chunk [521/1863] bpb=1.088629 time=116.4s + ttt_chunk [531/1863] bpb=1.089225 time=118.6s + ttt_chunk [541/1863] bpb=1.088308 time=120.9s + ttt_chunk [551/1863] bpb=1.088579 time=123.1s + ttt_chunk [561/1863] bpb=1.088360 time=125.3s + ttt_chunk [571/1863] bpb=1.088321 time=127.6s + ttt_chunk [581/1863] bpb=1.088099 time=129.8s + ttt_chunk [591/1863] bpb=1.087753 time=132.0s + ttt_chunk [601/1863] bpb=1.087938 time=134.2s + ttt_chunk [611/1863] bpb=1.087670 time=136.5s + ttt_chunk [621/1863] bpb=1.087584 time=138.7s + ttt_chunk [631/1863] bpb=1.087029 time=140.9s + ttt_chunk [641/1863] bpb=1.085421 time=143.1s + ttt_chunk [651/1863] bpb=1.085544 time=145.4s + ttt_chunk [661/1863] bpb=1.084671 time=147.6s + ttt_chunk [671/1863] bpb=1.084126 time=149.8s + ttt_chunk [681/1863] bpb=1.084196 time=152.0s + ttt_chunk [691/1863] bpb=1.083770 time=154.3s + ttt_chunk [701/1863] bpb=1.083566 time=156.5s + ttt_chunk [711/1863] bpb=1.083085 time=158.7s + ttt_chunk [721/1863] bpb=1.083181 time=160.9s + ttt_chunk [731/1863] bpb=1.082364 time=163.2s + ttt_chunk [741/1863] bpb=1.082024 time=165.4s + ttt_chunk [751/1863] bpb=1.082388 time=167.6s + ttt_chunk [761/1863] bpb=1.081929 time=169.8s + ttt_chunk [771/1863] bpb=1.082380 time=172.1s + ttt_chunk [781/1863] bpb=1.081706 time=174.3s + ttt_chunk [791/1863] bpb=1.082238 time=176.5s + ttt_chunk [801/1863] bpb=1.081751 time=178.7s + ttt_chunk [811/1863] bpb=1.081824 time=180.9s + ttt_chunk [821/1863] bpb=1.081270 time=183.2s + ttt_chunk [831/1863] bpb=1.081398 time=185.4s + ttt_chunk [841/1863] bpb=1.081179 time=187.6s + ttt_chunk [851/1863] bpb=1.080749 time=189.9s + ttt_chunk [861/1863] bpb=1.080712 time=192.1s + ttt_chunk [871/1863] bpb=1.080529 time=194.4s + ttt_chunk [881/1863] bpb=1.079998 time=196.6s + ttt_chunk [891/1863] bpb=1.080004 time=198.8s + ttt_chunk [901/1863] bpb=1.080480 time=201.1s + ttt_chunk [911/1863] bpb=1.080299 time=203.3s + ttt_chunk [921/1863] bpb=1.080525 time=205.5s + ttt_chunk [931/1863] bpb=1.080927 time=207.7s + ttt_chunk [941/1863] bpb=1.081450 time=210.0s + ttt_chunk [951/1863] bpb=1.081433 time=212.2s + ttt_chunk [961/1863] bpb=1.081568 time=214.5s + ttt_chunk [971/1863] bpb=1.081863 time=216.7s + ttt_chunk [981/1863] bpb=1.081376 time=218.9s + ttt_chunk [991/1863] bpb=1.081425 time=221.2s + ttt_chunk [1001/1863] bpb=1.082034 time=223.4s + ttt_chunk [1011/1863] bpb=1.082091 time=225.6s + ttt_chunk [1021/1863] bpb=1.082146 time=227.8s + ttt_chunk [1031/1863] bpb=1.082796 time=230.1s + ttt_chunk [1041/1863] bpb=1.082713 time=232.3s + ttt_chunk [1051/1863] bpb=1.083014 time=234.5s + ttt_chunk [1061/1863] bpb=1.082852 time=236.8s + ttt_chunk [1071/1863] bpb=1.083369 time=239.0s + ttt_chunk [1081/1863] bpb=1.083318 time=241.2s + ttt_chunk [1091/1863] bpb=1.082810 time=243.4s + ttt_chunk [1101/1863] bpb=1.083063 time=245.6s + ttt_chunk [1111/1863] bpb=1.082522 time=247.9s + ttt_chunk [1121/1863] bpb=1.082221 time=250.1s + ttt_chunk [1131/1863] bpb=1.082334 time=252.3s + ttt_chunk [1141/1863] bpb=1.082274 time=254.5s + ttt_chunk [1151/1863] bpb=1.081349 time=256.7s + ttt_chunk [1161/1863] bpb=1.081498 time=259.0s + ttt_chunk [1171/1863] bpb=1.081732 time=261.2s + ttt_chunk [1181/1863] bpb=1.082217 time=263.4s + ttt_chunk [1191/1863] bpb=1.081579 time=265.6s + ttt_chunk [1201/1863] bpb=1.081528 time=267.9s + ttt_chunk [1211/1863] bpb=1.081903 time=270.1s + ttt_chunk [1221/1863] bpb=1.081653 time=272.3s + ttt_chunk [1231/1863] bpb=1.080998 time=274.5s + ttt_chunk [1241/1863] bpb=1.080289 time=276.7s + ttt_chunk [1251/1863] bpb=1.079566 time=279.0s + ttt_chunk [1261/1863] bpb=1.079764 time=281.2s + ttt_chunk [1271/1863] bpb=1.079329 time=283.4s + ttt_chunk [1281/1863] bpb=1.079638 time=285.6s + ttt_chunk [1291/1863] bpb=1.079496 time=287.9s + ttt_chunk [1301/1863] bpb=1.079170 time=290.1s + ttt_chunk [1311/1863] bpb=1.079435 time=292.3s + ttt_chunk [1321/1863] bpb=1.079265 time=294.6s + ttt_chunk [1331/1863] bpb=1.079239 time=296.8s + ttt_chunk [1341/1863] bpb=1.079559 time=299.0s + ttt_chunk [1351/1863] bpb=1.080002 time=301.3s + ttt_chunk [1361/1863] bpb=1.080269 time=303.5s + ttt_chunk [1371/1863] bpb=1.079992 time=305.7s + ttt_chunk [1381/1863] bpb=1.079889 time=307.9s + ttt_chunk [1391/1863] bpb=1.080086 time=310.1s + ttt_chunk [1401/1863] bpb=1.080237 time=312.4s + ttt_chunk [1411/1863] bpb=1.080253 time=314.6s + ttt_chunk [1421/1863] bpb=1.080460 time=316.8s + ttt_chunk [1431/1863] bpb=1.080877 time=319.0s + ttt_chunk [1441/1863] bpb=1.080596 time=321.3s + ttt_chunk [1451/1863] bpb=1.080890 time=323.5s + ttt_chunk [1461/1863] bpb=1.080947 time=325.7s + ttt_chunk [1471/1863] bpb=1.081098 time=327.9s + ttt_chunk [1481/1863] bpb=1.080778 time=330.1s + ttt_chunk [1491/1863] bpb=1.080875 time=332.4s + ttt_chunk [1501/1863] bpb=1.081107 time=334.6s + ttt_chunk [1511/1863] bpb=1.080685 time=336.8s + ttt_chunk [1521/1863] bpb=1.080832 time=339.0s + ttt_chunk [1531/1863] bpb=1.081158 time=341.3s + ttt_chunk [1541/1863] bpb=1.081171 time=343.5s + ttt_chunk [1551/1863] bpb=1.081251 time=345.7s + ttt_chunk [1561/1863] bpb=1.081281 time=348.0s + ttt_chunk [1571/1863] bpb=1.081607 time=350.2s + ttt_chunk [1581/1863] bpb=1.081628 time=352.4s + ttt_chunk [1591/1863] bpb=1.082022 time=354.6s + ttt_chunk [1601/1863] bpb=1.082007 time=356.8s + ttt_chunk [1611/1863] bpb=1.081724 time=359.0s + ttt_chunk [1621/1863] bpb=1.081761 time=361.3s + ttt_chunk [1631/1863] bpb=1.082257 time=363.5s + ttt_chunk [1641/1863] bpb=1.082452 time=365.7s + ttt_chunk [1651/1863] bpb=1.082506 time=367.9s + ttt_chunk [1661/1863] bpb=1.082321 time=370.1s + ttt_chunk [1671/1863] bpb=1.082451 time=372.3s + ttt_chunk [1681/1863] bpb=1.082402 time=374.6s + ttt_chunk [1691/1863] bpb=1.082272 time=376.8s + ttt_chunk [1701/1863] bpb=1.082543 time=379.0s + ttt_chunk [1711/1863] bpb=1.081991 time=381.2s + ttt_chunk [1721/1863] bpb=1.081734 time=383.5s + ttt_chunk [1731/1863] bpb=1.081836 time=385.7s + ttt_chunk [1741/1863] bpb=1.081789 time=387.9s + ttt_chunk [1751/1863] bpb=1.081795 time=390.1s + ttt_chunk [1761/1863] bpb=1.081347 time=392.4s + ttt_chunk [1771/1863] bpb=1.081237 time=394.6s + ttt_chunk [1781/1863] bpb=1.081052 time=396.8s + ttt_chunk [1791/1863] bpb=1.081032 time=399.0s + ttt_chunk [1801/1863] bpb=1.080891 time=401.2s + ttt_chunk [1811/1863] bpb=1.080738 time=403.5s + ttt_chunk [1821/1863] bpb=1.080499 time=405.7s + ttt_chunk [1831/1863] bpb=1.080082 time=407.9s + ttt_chunk [1841/1863] bpb=1.079901 time=410.1s + ttt_chunk [1851/1863] bpb=1.079675 time=412.3s + ttt_chunk [1861/1863] bpb=1.079701 time=414.5s + ttt_chunk [1863/1863] bpb=1.079656 time=414.8s +ttt_sliding:done val_loss=1.931746 val_bpb=1.081027 elapsed=414.8s +legal_ttt val_loss:1.9317 val_bpb:1.0810 eval_time:415341ms +legal_ttt_exact val_loss:1.93174593 val_bpb:1.08102737 diff --git a/records/track_10min_16mb/2026-03-30_tm0054_AutoresearchTokenizer_LegalTTT/train_seed2026.log b/records/track_10min_16mb/2026-03-30_tm0054_AutoresearchTokenizer_LegalTTT/train_seed2026.log new file mode 100644 index 0000000000..0ccec4c2f6 --- /dev/null +++ b/records/track_10min_16mb/2026-03-30_tm0054_AutoresearchTokenizer_LegalTTT/train_seed2026.log @@ -0,0 +1,398 @@ + +***************************************** +Setting OMP_NUM_THREADS environment variable for each process to be 1 in default, to avoid your system being overloaded, please further tune the variable for optimal performance in your application as needed. +***************************************** +logs/tm0054_legal_ttt_full_seed2026.txt +val_bpb:enabled tokenizer_kind=tokenmonster tokenizer_path=/tmp/parameter-golf-tm0054-competition-full/tm0054_competition_export/tokenizers/candidate.vocab tokenizer_meta_path=/tmp/parameter-golf-tm0054-competition-full/tm0054_competition_export/tokenizers/candidate.meta.npz +train_loader:dataset:fineweb10B_tm0054 train_shards:79 +val_loader:shards pattern=/tmp/parameter-golf-tm0054-competition-full/tm0054_competition_export/datasets/fineweb10B_tm0054/fineweb_val_*.bin tokens:61020160 +model_params:26911580 +mtp_num_heads:0 mtp_loss_weight:0.2 mtp_params:0 +XSA:last_4 active_layers:[7, 8, 9, 10] +world_size:8 grad_accum_steps:1 +sdp_backends:cudnn=False flash=True mem_efficient=False math=False +attention_mode:gqa num_heads:8 num_kv_heads:4 +tie_embeddings:True embed_lr:0.035 head_lr:0.0 matrix_lr:0.025 scalar_lr:0.025 +activation_mode:leaky_relu_sq neg_slope:0.5 asym_init:0.25 gated_beta_init:1.0 +train_batch_tokens:786432 train_seq_len:2048 iterations:9000 warmup_steps:20 max_wallclock_seconds:600.000 +seed:2026 +warmup_step:1/20 +warmup_step:2/20 +warmup_step:3/20 +warmup_step:4/20 +warmup_step:5/20 +warmup_step:6/20 +warmup_step:7/20 +warmup_step:8/20 +warmup_step:9/20 +warmup_step:10/20 +warmup_step:11/20 +warmup_step:12/20 +warmup_step:13/20 +warmup_step:14/20 +warmup_step:15/20 +warmup_step:16/20 +warmup_step:17/20 +warmup_step:18/20 +warmup_step:19/20 +warmup_step:20/20 +step:1/9000 train_loss:6.8943 train_time:178ms step_avg:177.97ms +step:2/9000 train_loss:9.1164 train_time:207ms step_avg:103.68ms +step:3/9000 train_loss:8.7900 train_time:288ms step_avg:96.07ms +step:4/9000 train_loss:8.3765 train_time:370ms step_avg:92.58ms +step:5/9000 train_loss:7.9653 train_time:454ms step_avg:90.71ms +step:6/9000 train_loss:7.5898 train_time:537ms step_avg:89.42ms +step:7/9000 train_loss:7.2258 train_time:622ms step_avg:88.79ms +step:8/9000 train_loss:6.8996 train_time:705ms step_avg:88.16ms +step:9/9000 train_loss:6.6157 train_time:787ms step_avg:87.49ms +step:10/9000 train_loss:6.4672 train_time:872ms step_avg:87.21ms +step:50/9000 train_loss:3.9924 train_time:4209ms step_avg:84.18ms +step:100/9000 train_loss:3.2550 train_time:8381ms step_avg:83.81ms +step:150/9000 train_loss:2.9666 train_time:12626ms step_avg:84.17ms +step:200/9000 train_loss:2.8095 train_time:16812ms step_avg:84.06ms +step:250/9000 train_loss:2.5214 train_time:21004ms step_avg:84.02ms +step:300/9000 train_loss:2.5025 train_time:25247ms step_avg:84.16ms +step:350/9000 train_loss:2.3606 train_time:29439ms step_avg:84.11ms +step:400/9000 train_loss:2.4753 train_time:33687ms step_avg:84.22ms +step:450/9000 train_loss:2.4010 train_time:37882ms step_avg:84.18ms +step:500/9000 train_loss:2.3306 train_time:42083ms step_avg:84.17ms +step:550/9000 train_loss:2.4264 train_time:46348ms step_avg:84.27ms +step:600/9000 train_loss:2.3739 train_time:50550ms step_avg:84.25ms +step:650/9000 train_loss:2.4754 train_time:54817ms step_avg:84.33ms +step:700/9000 train_loss:2.2959 train_time:59029ms step_avg:84.33ms +step:750/9000 train_loss:2.3130 train_time:63242ms step_avg:84.32ms +step:800/9000 train_loss:2.3867 train_time:67518ms step_avg:84.40ms +step:850/9000 train_loss:2.3287 train_time:71728ms step_avg:84.39ms +step:900/9000 train_loss:2.1767 train_time:75994ms step_avg:84.44ms +step:950/9000 train_loss:2.3250 train_time:80216ms step_avg:84.44ms +step:1000/9000 train_loss:2.2106 train_time:84425ms step_avg:84.43ms +step:1050/9000 train_loss:2.1388 train_time:88689ms step_avg:84.47ms +step:1100/9000 train_loss:2.2481 train_time:92898ms step_avg:84.45ms +step:1150/9000 train_loss:2.3038 train_time:97173ms step_avg:84.50ms +step:1200/9000 train_loss:2.2263 train_time:101389ms step_avg:84.49ms +step:1250/9000 train_loss:2.2078 train_time:105597ms step_avg:84.48ms +step:1300/9000 train_loss:2.3052 train_time:109868ms step_avg:84.51ms +step:1350/9000 train_loss:2.2582 train_time:114086ms step_avg:84.51ms +step:1400/9000 train_loss:2.1905 train_time:118352ms step_avg:84.54ms +step:1450/9000 train_loss:2.1921 train_time:122561ms step_avg:84.52ms +step:1500/9000 train_loss:2.1773 train_time:126775ms step_avg:84.52ms +step:1550/9000 train_loss:2.2366 train_time:131044ms step_avg:84.54ms +step:1600/9000 train_loss:2.3133 train_time:135253ms step_avg:84.53ms +step:1650/9000 train_loss:2.2826 train_time:139463ms step_avg:84.52ms +step:1700/9000 train_loss:2.1782 train_time:143736ms step_avg:84.55ms +step:1750/9000 train_loss:2.2636 train_time:147940ms step_avg:84.54ms +step:1800/9000 train_loss:2.1782 train_time:152209ms step_avg:84.56ms +step:1850/9000 train_loss:2.3930 train_time:156427ms step_avg:84.55ms +step:1900/9000 train_loss:2.0828 train_time:160633ms step_avg:84.54ms +step:1950/9000 train_loss:2.1483 train_time:164904ms step_avg:84.57ms +step:2000/9000 train_loss:2.2095 train_time:169118ms step_avg:84.56ms +step:2050/9000 train_loss:2.0608 train_time:173395ms step_avg:84.58ms +step:2100/9000 train_loss:2.1754 train_time:177601ms step_avg:84.57ms +step:2150/9000 train_loss:2.0523 train_time:181807ms step_avg:84.56ms +step:2200/9000 train_loss:2.0676 train_time:186076ms step_avg:84.58ms +step:2250/9000 train_loss:2.1734 train_time:190280ms step_avg:84.57ms +step:2300/9000 train_loss:2.2498 train_time:194540ms step_avg:84.58ms +step:2350/9000 train_loss:2.1355 train_time:198754ms step_avg:84.58ms +step:2400/9000 train_loss:2.1346 train_time:202954ms step_avg:84.56ms +step:2450/9000 train_loss:2.3184 train_time:207219ms step_avg:84.58ms +step:2500/9000 train_loss:2.0380 train_time:211431ms step_avg:84.57ms +step:2550/9000 train_loss:2.1408 train_time:215691ms step_avg:84.58ms +step:2600/9000 train_loss:2.0400 train_time:219896ms step_avg:84.58ms +step:2650/9000 train_loss:2.0861 train_time:224100ms step_avg:84.57ms +step:2700/9000 train_loss:2.1776 train_time:228365ms step_avg:84.58ms +step:2750/9000 train_loss:2.2101 train_time:232572ms step_avg:84.57ms +step:2800/9000 train_loss:2.1534 train_time:236837ms step_avg:84.58ms +step:2850/9000 train_loss:2.1041 train_time:241039ms step_avg:84.57ms +step:2900/9000 train_loss:2.0606 train_time:245245ms step_avg:84.57ms +step:2950/9000 train_loss:1.9852 train_time:249513ms step_avg:84.58ms +step:3000/9000 train_loss:2.1472 train_time:253719ms step_avg:84.57ms +step:3050/9000 train_loss:2.0573 train_time:257926ms step_avg:84.57ms +step:3100/9000 train_loss:2.0906 train_time:262197ms step_avg:84.58ms +step:3150/9000 train_loss:2.2228 train_time:266477ms step_avg:84.60ms +step:3200/9000 train_loss:1.9756 train_time:270732ms step_avg:84.60ms +step:3250/9000 train_loss:2.1110 train_time:274949ms step_avg:84.60ms +step:3300/9000 train_loss:2.0923 train_time:279154ms step_avg:84.59ms +step:3350/9000 train_loss:2.1052 train_time:283413ms step_avg:84.60ms +step:3400/9000 train_loss:2.1078 train_time:287618ms step_avg:84.59ms +step:3450/9000 train_loss:2.1559 train_time:291878ms step_avg:84.60ms +step:3500/9000 train_loss:2.1088 train_time:296076ms step_avg:84.59ms +step:3550/9000 train_loss:2.0715 train_time:300279ms step_avg:84.59ms +step:3600/9000 train_loss:2.0896 train_time:304542ms step_avg:84.60ms +step:3650/9000 train_loss:2.0158 train_time:308744ms step_avg:84.59ms +step:3700/9000 train_loss:2.0864 train_time:313002ms step_avg:84.60ms +step:3750/9000 train_loss:2.2640 train_time:317204ms step_avg:84.59ms +step:3800/9000 train_loss:2.2287 train_time:321401ms step_avg:84.58ms +step:3850/9000 train_loss:2.0263 train_time:325661ms step_avg:84.59ms +step:3900/9000 train_loss:2.0263 train_time:329860ms step_avg:84.58ms +step:3950/9000 train_loss:2.2024 train_time:334121ms step_avg:84.59ms +step:4000/9000 train_loss:2.1712 train_time:338329ms step_avg:84.58ms +step:4050/9000 train_loss:2.0749 train_time:342528ms step_avg:84.57ms +step:4100/9000 train_loss:2.0649 train_time:346789ms step_avg:84.58ms +step:4150/9000 train_loss:2.0440 train_time:350993ms step_avg:84.58ms +step:4200/9000 train_loss:2.0877 train_time:355251ms step_avg:84.58ms +step:4250/9000 train_loss:2.0477 train_time:359452ms step_avg:84.58ms +step:4300/9000 train_loss:2.2034 train_time:363653ms step_avg:84.57ms +step:4350/9000 train_loss:2.0262 train_time:367917ms step_avg:84.58ms +step:4400/9000 train_loss:2.1015 train_time:372119ms step_avg:84.57ms +step:4450/9000 train_loss:2.0993 train_time:376318ms step_avg:84.57ms +step:4500/9000 train_loss:1.9574 train_time:380579ms step_avg:84.57ms +step:4550/9000 train_loss:2.0309 train_time:384780ms step_avg:84.57ms +step:4600/9000 train_loss:2.0508 train_time:389043ms step_avg:84.57ms +step:4650/9000 train_loss:2.2327 train_time:393238ms step_avg:84.57ms +step:4700/9000 train_loss:2.0182 train_time:397439ms step_avg:84.56ms +step:4750/9000 train_loss:2.0352 train_time:401699ms step_avg:84.57ms +step:4800/9000 train_loss:2.0018 train_time:405898ms step_avg:84.56ms +step:4850/9000 train_loss:2.1292 train_time:410155ms step_avg:84.57ms +step:4900/9000 train_loss:2.0944 train_time:414359ms step_avg:84.56ms +step:4950/9000 train_loss:1.9678 train_time:418559ms step_avg:84.56ms +step:5000/9000 train_loss:2.0508 train_time:422818ms step_avg:84.56ms +step:5050/9000 train_loss:2.0417 train_time:427015ms step_avg:84.56ms +step:5100/9000 train_loss:2.0488 train_time:431269ms step_avg:84.56ms +step:5150/9000 train_loss:2.2378 train_time:435464ms step_avg:84.56ms +step:5200/9000 train_loss:2.0835 train_time:439660ms step_avg:84.55ms +step:5250/9000 train_loss:1.9464 train_time:443918ms step_avg:84.56ms +step:5300/9000 train_loss:2.1155 train_time:448114ms step_avg:84.55ms +step:5350/9000 train_loss:2.2009 train_time:452364ms step_avg:84.55ms +step:5400/9000 train_loss:1.9180 train_time:456572ms step_avg:84.55ms +step:5450/9000 train_loss:2.0152 train_time:460771ms step_avg:84.55ms +step:5500/9000 train_loss:1.9257 train_time:465030ms step_avg:84.55ms +step:5550/9000 train_loss:2.1385 train_time:469227ms step_avg:84.55ms +step:5600/9000 train_loss:1.7015 train_time:473480ms step_avg:84.55ms +step:5650/9000 train_loss:2.0254 train_time:477678ms step_avg:84.54ms +step:5700/9000 train_loss:2.1469 train_time:481878ms step_avg:84.54ms +step:5750/9000 train_loss:2.0394 train_time:486135ms step_avg:84.55ms +step:5800/9000 train_loss:1.9690 train_time:490332ms step_avg:84.54ms +step:5850/9000 train_loss:2.0193 train_time:494594ms step_avg:84.55ms +step:5900/9000 train_loss:1.8722 train_time:498788ms step_avg:84.54ms +step:5950/9000 train_loss:2.0301 train_time:502986ms step_avg:84.54ms +step:6000/9000 train_loss:1.7794 train_time:507240ms step_avg:84.54ms +step:6050/9000 train_loss:1.9326 train_time:511438ms step_avg:84.54ms +step:6100/9000 train_loss:1.8587 train_time:515633ms step_avg:84.53ms +step:6150/9000 train_loss:2.1469 train_time:519886ms step_avg:84.53ms +step:6200/9000 train_loss:1.9863 train_time:524083ms step_avg:84.53ms +step:6250/9000 train_loss:2.1374 train_time:528335ms step_avg:84.53ms +step:6300/9000 train_loss:1.9545 train_time:532534ms step_avg:84.53ms +step:6350/9000 train_loss:1.9479 train_time:536726ms step_avg:84.52ms +swa:start step:6400 +step:6400/9000 train_loss:2.0168 train_time:540982ms step_avg:84.53ms +step:6450/9000 train_loss:2.0632 train_time:545281ms step_avg:84.54ms +step:6500/9000 train_loss:1.9883 train_time:549612ms step_avg:84.56ms +step:6550/9000 train_loss:2.0743 train_time:553878ms step_avg:84.56ms +late_qat:enabled step:6570 scale:0.1498 +step:6600/9000 train_loss:1.9664 train_time:558133ms step_avg:84.57ms +step:6650/9000 train_loss:2.1894 train_time:562450ms step_avg:84.58ms +step:6700/9000 train_loss:2.1126 train_time:566713ms step_avg:84.58ms +step:6750/9000 train_loss:1.9444 train_time:571030ms step_avg:84.60ms +step:6800/9000 train_loss:1.9691 train_time:575297ms step_avg:84.60ms +step:6850/9000 train_loss:2.0901 train_time:579560ms step_avg:84.61ms +step:6900/9000 train_loss:2.0447 train_time:583873ms step_avg:84.62ms +step:6950/9000 train_loss:1.9597 train_time:588131ms step_avg:84.62ms +step:7000/9000 train_loss:2.0296 train_time:592438ms step_avg:84.63ms +step:7050/9000 train_loss:2.0547 train_time:596690ms step_avg:84.64ms +step:7089/9000 val_loss:1.9614 val_bpb:1.0977 train_time:600082ms step_avg:84.65ms +stopping_early: wallclock_cap train_time:600082ms step:7089/9000 +peak memory allocated: 21655 MiB reserved: 22246 MiB +ema:applying EMA weights +DIAGNOSTIC post_ema val_loss:1.9596 val_bpb:1.0966 eval_time:1972ms +Serialized model: 105994166 bytes +Code size: 120316 bytes +Serialized model int6+lzma: 15729476 bytes +Total submission size int6+lzma: 15849792 bytes +final_int6_roundtrip val_loss:1.9744 val_bpb:1.1049 eval_time:6409ms +final_int6_roundtrip_exact val_loss:1.97441759 val_bpb:1.10490932 +final_int6_sliding_window val_loss:1.9356 val_bpb:1.0832 stride:64 eval_time:74064ms +final_int6_sliding_window_exact val_loss:1.93555667 val_bpb:1.08315990 +final_int8_zlib_roundtrip_exact val_loss:1.93555667 val_bpb:1.08315990 +ttt_sliding:start chunks=1863 chunk_tokens=32768 total_windows=953440 stride=64 ttt_lr=0.002 ttt_epochs=3 freeze_blocks=2 +ttt_sliding:params unfrozen=26907468 frozen=4112 + ttt_chunk [1/1863] bpb=1.121831 time=0.5s + ttt_chunk [11/1863] bpb=1.062488 time=2.7s + ttt_chunk [21/1863] bpb=1.058702 time=5.0s + ttt_chunk [31/1863] bpb=1.079308 time=7.2s + ttt_chunk [41/1863] bpb=1.084762 time=9.4s + ttt_chunk [51/1863] bpb=1.083185 time=11.7s + ttt_chunk [61/1863] bpb=1.082808 time=13.9s + ttt_chunk [71/1863] bpb=1.078797 time=16.2s + ttt_chunk [81/1863] bpb=1.079805 time=18.4s + ttt_chunk [91/1863] bpb=1.082637 time=20.7s + ttt_chunk [101/1863] bpb=1.087932 time=22.9s + ttt_chunk [111/1863] bpb=1.090424 time=25.2s + ttt_chunk [121/1863] bpb=1.084171 time=27.4s + ttt_chunk [131/1863] bpb=1.081555 time=29.7s + ttt_chunk [141/1863] bpb=1.081859 time=31.9s + ttt_chunk [151/1863] bpb=1.078993 time=34.2s + ttt_chunk [161/1863] bpb=1.077349 time=36.4s + ttt_chunk [171/1863] bpb=1.083328 time=38.7s + ttt_chunk [181/1863] bpb=1.083330 time=40.9s + ttt_chunk [191/1863] bpb=1.084367 time=43.2s + ttt_chunk [201/1863] bpb=1.085599 time=45.4s + ttt_chunk [211/1863] bpb=1.086286 time=47.7s + ttt_chunk [221/1863] bpb=1.085551 time=49.9s + ttt_chunk [231/1863] bpb=1.084799 time=52.1s + ttt_chunk [241/1863] bpb=1.083630 time=54.4s + ttt_chunk [251/1863] bpb=1.082893 time=56.7s + ttt_chunk [261/1863] bpb=1.082021 time=58.9s + ttt_chunk [271/1863] bpb=1.085016 time=61.1s + ttt_chunk [281/1863] bpb=1.085706 time=63.4s + ttt_chunk [291/1863] bpb=1.087687 time=65.6s + ttt_chunk [301/1863] bpb=1.086637 time=67.9s + ttt_chunk [311/1863] bpb=1.087624 time=70.1s + ttt_chunk [321/1863] bpb=1.087494 time=72.4s + ttt_chunk [331/1863] bpb=1.085973 time=74.6s + ttt_chunk [341/1863] bpb=1.085304 time=76.9s + ttt_chunk [351/1863] bpb=1.086501 time=79.1s + ttt_chunk [361/1863] bpb=1.085388 time=81.4s + ttt_chunk [371/1863] bpb=1.085717 time=83.6s + ttt_chunk [381/1863] bpb=1.087279 time=85.9s + ttt_chunk [391/1863] bpb=1.088029 time=88.1s + ttt_chunk [401/1863] bpb=1.086074 time=90.4s + ttt_chunk [411/1863] bpb=1.085230 time=92.6s + ttt_chunk [421/1863] bpb=1.086189 time=94.8s + ttt_chunk [431/1863] bpb=1.085913 time=97.1s + ttt_chunk [441/1863] bpb=1.086262 time=99.3s + ttt_chunk [451/1863] bpb=1.085588 time=101.6s + ttt_chunk [461/1863] bpb=1.084547 time=103.8s + ttt_chunk [471/1863] bpb=1.085018 time=106.1s + ttt_chunk [481/1863] bpb=1.085773 time=108.3s + ttt_chunk [491/1863] bpb=1.087146 time=110.5s + ttt_chunk [501/1863] bpb=1.086304 time=112.8s + ttt_chunk [511/1863] bpb=1.086894 time=115.0s + ttt_chunk [521/1863] bpb=1.087932 time=117.3s + ttt_chunk [531/1863] bpb=1.088556 time=119.5s + ttt_chunk [541/1863] bpb=1.087658 time=121.7s + ttt_chunk [551/1863] bpb=1.087950 time=124.0s + ttt_chunk [561/1863] bpb=1.087744 time=126.2s + ttt_chunk [571/1863] bpb=1.087681 time=128.5s + ttt_chunk [581/1863] bpb=1.087463 time=130.7s + ttt_chunk [591/1863] bpb=1.087106 time=133.0s + ttt_chunk [601/1863] bpb=1.087287 time=135.2s + ttt_chunk [611/1863] bpb=1.087051 time=137.5s + ttt_chunk [621/1863] bpb=1.086983 time=139.7s + ttt_chunk [631/1863] bpb=1.086438 time=141.9s + ttt_chunk [641/1863] bpb=1.084858 time=144.2s + ttt_chunk [651/1863] bpb=1.085004 time=146.4s + ttt_chunk [661/1863] bpb=1.084167 time=148.7s + ttt_chunk [671/1863] bpb=1.083580 time=150.9s + ttt_chunk [681/1863] bpb=1.083631 time=153.2s + ttt_chunk [691/1863] bpb=1.083184 time=155.4s + ttt_chunk [701/1863] bpb=1.082974 time=157.6s + ttt_chunk [711/1863] bpb=1.082475 time=159.9s + ttt_chunk [721/1863] bpb=1.082591 time=162.1s + ttt_chunk [731/1863] bpb=1.081812 time=164.4s + ttt_chunk [741/1863] bpb=1.081423 time=166.6s + ttt_chunk [751/1863] bpb=1.081767 time=168.8s + ttt_chunk [761/1863] bpb=1.081343 time=171.1s + ttt_chunk [771/1863] bpb=1.081807 time=173.3s + ttt_chunk [781/1863] bpb=1.081125 time=175.5s + ttt_chunk [791/1863] bpb=1.081638 time=177.8s + ttt_chunk [801/1863] bpb=1.081130 time=180.0s + ttt_chunk [811/1863] bpb=1.081191 time=182.2s + ttt_chunk [821/1863] bpb=1.080664 time=184.5s + ttt_chunk [831/1863] bpb=1.080767 time=186.7s + ttt_chunk [841/1863] bpb=1.080525 time=188.9s + ttt_chunk [851/1863] bpb=1.080096 time=191.2s + ttt_chunk [861/1863] bpb=1.080054 time=193.4s + ttt_chunk [871/1863] bpb=1.079852 time=195.7s + ttt_chunk [881/1863] bpb=1.079311 time=197.9s + ttt_chunk [891/1863] bpb=1.079337 time=200.1s + ttt_chunk [901/1863] bpb=1.079812 time=202.4s + ttt_chunk [911/1863] bpb=1.079618 time=204.6s + ttt_chunk [921/1863] bpb=1.079872 time=206.8s + ttt_chunk [931/1863] bpb=1.080272 time=209.0s + ttt_chunk [941/1863] bpb=1.080774 time=211.3s + ttt_chunk [951/1863] bpb=1.080771 time=213.5s + ttt_chunk [961/1863] bpb=1.080925 time=215.7s + ttt_chunk [971/1863] bpb=1.081214 time=218.0s + ttt_chunk [981/1863] bpb=1.080735 time=220.2s + ttt_chunk [991/1863] bpb=1.080745 time=222.5s + ttt_chunk [1001/1863] bpb=1.081361 time=224.7s + ttt_chunk [1011/1863] bpb=1.081406 time=226.9s + ttt_chunk [1021/1863] bpb=1.081462 time=229.2s + ttt_chunk [1031/1863] bpb=1.082096 time=231.4s + ttt_chunk [1041/1863] bpb=1.082032 time=233.6s + ttt_chunk [1051/1863] bpb=1.082321 time=235.9s + ttt_chunk [1061/1863] bpb=1.082194 time=238.1s + ttt_chunk [1071/1863] bpb=1.082698 time=240.3s + ttt_chunk [1081/1863] bpb=1.082625 time=242.6s + ttt_chunk [1091/1863] bpb=1.082128 time=244.8s + ttt_chunk [1101/1863] bpb=1.082350 time=247.0s + ttt_chunk [1111/1863] bpb=1.081790 time=249.3s + ttt_chunk [1121/1863] bpb=1.081489 time=251.5s + ttt_chunk [1131/1863] bpb=1.081604 time=253.8s + ttt_chunk [1141/1863] bpb=1.081521 time=256.0s + ttt_chunk [1151/1863] bpb=1.080592 time=258.2s + ttt_chunk [1161/1863] bpb=1.080732 time=260.5s + ttt_chunk [1171/1863] bpb=1.080957 time=262.7s + ttt_chunk [1181/1863] bpb=1.081422 time=265.0s + ttt_chunk [1191/1863] bpb=1.080779 time=267.2s + ttt_chunk [1201/1863] bpb=1.080746 time=269.4s + ttt_chunk [1211/1863] bpb=1.081110 time=271.7s + ttt_chunk [1221/1863] bpb=1.080885 time=273.9s + ttt_chunk [1231/1863] bpb=1.080223 time=276.2s + ttt_chunk [1241/1863] bpb=1.079513 time=278.4s + ttt_chunk [1251/1863] bpb=1.078811 time=280.6s + ttt_chunk [1261/1863] bpb=1.079005 time=282.9s + ttt_chunk [1271/1863] bpb=1.078583 time=285.1s + ttt_chunk [1281/1863] bpb=1.078912 time=287.3s + ttt_chunk [1291/1863] bpb=1.078791 time=289.6s + ttt_chunk [1301/1863] bpb=1.078470 time=291.8s + ttt_chunk [1311/1863] bpb=1.078732 time=294.0s + ttt_chunk [1321/1863] bpb=1.078568 time=296.3s + ttt_chunk [1331/1863] bpb=1.078538 time=298.5s + ttt_chunk [1341/1863] bpb=1.078859 time=300.7s + ttt_chunk [1351/1863] bpb=1.079308 time=303.0s + ttt_chunk [1361/1863] bpb=1.079614 time=305.2s + ttt_chunk [1371/1863] bpb=1.079332 time=307.4s + ttt_chunk [1381/1863] bpb=1.079209 time=309.7s + ttt_chunk [1391/1863] bpb=1.079420 time=311.9s + ttt_chunk [1401/1863] bpb=1.079577 time=314.2s + ttt_chunk [1411/1863] bpb=1.079595 time=316.4s + ttt_chunk [1421/1863] bpb=1.079775 time=318.7s + ttt_chunk [1431/1863] bpb=1.080195 time=320.9s + ttt_chunk [1441/1863] bpb=1.079916 time=323.1s + ttt_chunk [1451/1863] bpb=1.080204 time=325.4s + ttt_chunk [1461/1863] bpb=1.080269 time=327.6s + ttt_chunk [1471/1863] bpb=1.080409 time=329.8s + ttt_chunk [1481/1863] bpb=1.080104 time=332.1s + ttt_chunk [1491/1863] bpb=1.080201 time=334.3s + ttt_chunk [1501/1863] bpb=1.080417 time=336.5s + ttt_chunk [1511/1863] bpb=1.079990 time=338.8s + ttt_chunk [1521/1863] bpb=1.080153 time=341.0s + ttt_chunk [1531/1863] bpb=1.080479 time=343.3s + ttt_chunk [1541/1863] bpb=1.080499 time=345.5s + ttt_chunk [1551/1863] bpb=1.080576 time=347.7s + ttt_chunk [1561/1863] bpb=1.080617 time=350.0s + ttt_chunk [1571/1863] bpb=1.080953 time=352.2s + ttt_chunk [1581/1863] bpb=1.080968 time=354.4s + ttt_chunk [1591/1863] bpb=1.081359 time=356.7s + ttt_chunk [1601/1863] bpb=1.081345 time=358.9s + ttt_chunk [1611/1863] bpb=1.081069 time=361.1s + ttt_chunk [1621/1863] bpb=1.081110 time=363.4s + ttt_chunk [1631/1863] bpb=1.081613 time=365.6s + ttt_chunk [1641/1863] bpb=1.081813 time=367.9s + ttt_chunk [1651/1863] bpb=1.081875 time=370.1s + ttt_chunk [1661/1863] bpb=1.081698 time=372.4s + ttt_chunk [1671/1863] bpb=1.081816 time=374.6s + ttt_chunk [1681/1863] bpb=1.081772 time=376.9s + ttt_chunk [1691/1863] bpb=1.081635 time=379.1s + ttt_chunk [1701/1863] bpb=1.081909 time=381.3s + ttt_chunk [1711/1863] bpb=1.081360 time=383.6s + ttt_chunk [1721/1863] bpb=1.081104 time=385.8s + ttt_chunk [1731/1863] bpb=1.081214 time=388.0s + ttt_chunk [1741/1863] bpb=1.081184 time=390.2s + ttt_chunk [1751/1863] bpb=1.081201 time=392.5s + ttt_chunk [1761/1863] bpb=1.080756 time=394.7s + ttt_chunk [1771/1863] bpb=1.080647 time=397.0s + ttt_chunk [1781/1863] bpb=1.080445 time=399.2s + ttt_chunk [1791/1863] bpb=1.080451 time=401.4s + ttt_chunk [1801/1863] bpb=1.080299 time=403.7s + ttt_chunk [1811/1863] bpb=1.080165 time=405.9s + ttt_chunk [1821/1863] bpb=1.079924 time=408.1s + ttt_chunk [1831/1863] bpb=1.079509 time=410.4s + ttt_chunk [1841/1863] bpb=1.079331 time=412.6s + ttt_chunk [1851/1863] bpb=1.079108 time=414.9s + ttt_chunk [1861/1863] bpb=1.079134 time=417.1s + ttt_chunk [1863/1863] bpb=1.079088 time=417.3s +ttt_sliding:done val_loss=1.930951 val_bpb=1.080583 elapsed=417.3s +legal_ttt val_loss:1.9310 val_bpb:1.0806 eval_time:417865ms +legal_ttt_exact val_loss:1.93095116 val_bpb:1.08058261 diff --git a/records/track_10min_16mb/2026-03-30_tm0054_AutoresearchTokenizer_LegalTTT/train_seed42.log b/records/track_10min_16mb/2026-03-30_tm0054_AutoresearchTokenizer_LegalTTT/train_seed42.log new file mode 100644 index 0000000000..96f950ddef --- /dev/null +++ b/records/track_10min_16mb/2026-03-30_tm0054_AutoresearchTokenizer_LegalTTT/train_seed42.log @@ -0,0 +1,398 @@ + +***************************************** +Setting OMP_NUM_THREADS environment variable for each process to be 1 in default, to avoid your system being overloaded, please further tune the variable for optimal performance in your application as needed. +***************************************** +logs/tm0054_legal_ttt_full_seed42.txt +val_bpb:enabled tokenizer_kind=tokenmonster tokenizer_path=/tmp/parameter-golf-tm0054-competition-full/tm0054_competition_export/tokenizers/candidate.vocab tokenizer_meta_path=/tmp/parameter-golf-tm0054-competition-full/tm0054_competition_export/tokenizers/candidate.meta.npz +train_loader:dataset:fineweb10B_tm0054 train_shards:79 +val_loader:shards pattern=/tmp/parameter-golf-tm0054-competition-full/tm0054_competition_export/datasets/fineweb10B_tm0054/fineweb_val_*.bin tokens:61020160 +model_params:26911580 +mtp_num_heads:0 mtp_loss_weight:0.2 mtp_params:0 +XSA:last_4 active_layers:[7, 8, 9, 10] +world_size:8 grad_accum_steps:1 +sdp_backends:cudnn=False flash=True mem_efficient=False math=False +attention_mode:gqa num_heads:8 num_kv_heads:4 +tie_embeddings:True embed_lr:0.035 head_lr:0.0 matrix_lr:0.025 scalar_lr:0.025 +activation_mode:leaky_relu_sq neg_slope:0.5 asym_init:0.25 gated_beta_init:1.0 +train_batch_tokens:786432 train_seq_len:2048 iterations:9000 warmup_steps:20 max_wallclock_seconds:600.000 +seed:42 +warmup_step:1/20 +warmup_step:2/20 +warmup_step:3/20 +warmup_step:4/20 +warmup_step:5/20 +warmup_step:6/20 +warmup_step:7/20 +warmup_step:8/20 +warmup_step:9/20 +warmup_step:10/20 +warmup_step:11/20 +warmup_step:12/20 +warmup_step:13/20 +warmup_step:14/20 +warmup_step:15/20 +warmup_step:16/20 +warmup_step:17/20 +warmup_step:18/20 +warmup_step:19/20 +warmup_step:20/20 +step:1/9000 train_loss:6.8934 train_time:196ms step_avg:196.25ms +step:2/9000 train_loss:9.0805 train_time:226ms step_avg:112.82ms +step:3/9000 train_loss:8.7459 train_time:307ms step_avg:102.30ms +step:4/9000 train_loss:8.3084 train_time:389ms step_avg:97.29ms +step:5/9000 train_loss:7.8321 train_time:471ms step_avg:94.21ms +step:6/9000 train_loss:7.5364 train_time:555ms step_avg:92.48ms +step:7/9000 train_loss:7.2533 train_time:645ms step_avg:92.15ms +step:8/9000 train_loss:6.9110 train_time:726ms step_avg:90.74ms +step:9/9000 train_loss:6.7315 train_time:808ms step_avg:89.82ms +step:10/9000 train_loss:6.5260 train_time:891ms step_avg:89.15ms +step:50/9000 train_loss:3.9949 train_time:4235ms step_avg:84.70ms +step:100/9000 train_loss:3.2537 train_time:8416ms step_avg:84.16ms +step:150/9000 train_loss:2.9568 train_time:12658ms step_avg:84.39ms +step:200/9000 train_loss:2.8035 train_time:16852ms step_avg:84.26ms +step:250/9000 train_loss:2.5364 train_time:21038ms step_avg:84.15ms +step:300/9000 train_loss:2.5030 train_time:25283ms step_avg:84.28ms +step:350/9000 train_loss:2.3722 train_time:29477ms step_avg:84.22ms +step:400/9000 train_loss:2.4889 train_time:33732ms step_avg:84.33ms +step:450/9000 train_loss:2.4109 train_time:37935ms step_avg:84.30ms +step:500/9000 train_loss:2.3388 train_time:42141ms step_avg:84.28ms +step:550/9000 train_loss:2.4299 train_time:46409ms step_avg:84.38ms +step:600/9000 train_loss:2.3788 train_time:50615ms step_avg:84.36ms +step:650/9000 train_loss:2.4806 train_time:54879ms step_avg:84.43ms +step:700/9000 train_loss:2.3067 train_time:59091ms step_avg:84.42ms +step:750/9000 train_loss:2.3112 train_time:63297ms step_avg:84.40ms +step:800/9000 train_loss:2.3984 train_time:67573ms step_avg:84.47ms +step:850/9000 train_loss:2.3433 train_time:71783ms step_avg:84.45ms +step:900/9000 train_loss:2.1829 train_time:76040ms step_avg:84.49ms +step:950/9000 train_loss:2.3248 train_time:80261ms step_avg:84.49ms +step:1000/9000 train_loss:2.2206 train_time:84471ms step_avg:84.47ms +step:1050/9000 train_loss:2.1441 train_time:88737ms step_avg:84.51ms +step:1100/9000 train_loss:2.2517 train_time:92951ms step_avg:84.50ms +step:1150/9000 train_loss:2.3083 train_time:97215ms step_avg:84.54ms +step:1200/9000 train_loss:2.2307 train_time:101429ms step_avg:84.52ms +step:1250/9000 train_loss:2.2116 train_time:105638ms step_avg:84.51ms +step:1300/9000 train_loss:2.3084 train_time:109898ms step_avg:84.54ms +step:1350/9000 train_loss:2.2614 train_time:114104ms step_avg:84.52ms +step:1400/9000 train_loss:2.1951 train_time:118370ms step_avg:84.55ms +step:1450/9000 train_loss:2.1930 train_time:122572ms step_avg:84.53ms +step:1500/9000 train_loss:2.1778 train_time:126773ms step_avg:84.52ms +step:1550/9000 train_loss:2.2407 train_time:131040ms step_avg:84.54ms +step:1600/9000 train_loss:2.3200 train_time:135237ms step_avg:84.52ms +step:1650/9000 train_loss:2.2860 train_time:139440ms step_avg:84.51ms +step:1700/9000 train_loss:2.1803 train_time:143703ms step_avg:84.53ms +step:1750/9000 train_loss:2.2651 train_time:147913ms step_avg:84.52ms +step:1800/9000 train_loss:2.1827 train_time:152174ms step_avg:84.54ms +step:1850/9000 train_loss:2.3896 train_time:156379ms step_avg:84.53ms +step:1900/9000 train_loss:2.0837 train_time:160585ms step_avg:84.52ms +step:1950/9000 train_loss:2.1493 train_time:164850ms step_avg:84.54ms +step:2000/9000 train_loss:2.2119 train_time:169057ms step_avg:84.53ms +step:2050/9000 train_loss:2.0638 train_time:173321ms step_avg:84.55ms +step:2100/9000 train_loss:2.1730 train_time:177522ms step_avg:84.53ms +step:2150/9000 train_loss:2.0566 train_time:181731ms step_avg:84.53ms +step:2200/9000 train_loss:2.0681 train_time:185995ms step_avg:84.54ms +step:2250/9000 train_loss:2.1739 train_time:190200ms step_avg:84.53ms +step:2300/9000 train_loss:2.2507 train_time:194458ms step_avg:84.55ms +step:2350/9000 train_loss:2.1349 train_time:198668ms step_avg:84.54ms +step:2400/9000 train_loss:2.1322 train_time:202869ms step_avg:84.53ms +step:2450/9000 train_loss:2.3134 train_time:207131ms step_avg:84.54ms +step:2500/9000 train_loss:2.0381 train_time:211420ms step_avg:84.57ms +step:2550/9000 train_loss:2.1379 train_time:215678ms step_avg:84.58ms +step:2600/9000 train_loss:2.0426 train_time:219881ms step_avg:84.57ms +step:2650/9000 train_loss:2.0869 train_time:224091ms step_avg:84.56ms +step:2700/9000 train_loss:2.1800 train_time:228350ms step_avg:84.57ms +step:2750/9000 train_loss:2.2144 train_time:232552ms step_avg:84.56ms +step:2800/9000 train_loss:2.1536 train_time:236813ms step_avg:84.58ms +step:2850/9000 train_loss:2.1047 train_time:241014ms step_avg:84.57ms +step:2900/9000 train_loss:2.0616 train_time:245216ms step_avg:84.56ms +step:2950/9000 train_loss:1.9860 train_time:249485ms step_avg:84.57ms +step:3000/9000 train_loss:2.1461 train_time:253681ms step_avg:84.56ms +step:3050/9000 train_loss:2.0521 train_time:257884ms step_avg:84.55ms +step:3100/9000 train_loss:2.0874 train_time:262145ms step_avg:84.56ms +step:3150/9000 train_loss:2.2227 train_time:266350ms step_avg:84.56ms +step:3200/9000 train_loss:1.9742 train_time:270604ms step_avg:84.56ms +step:3250/9000 train_loss:2.1126 train_time:274812ms step_avg:84.56ms +step:3300/9000 train_loss:2.0937 train_time:279015ms step_avg:84.55ms +step:3350/9000 train_loss:2.1014 train_time:283282ms step_avg:84.56ms +step:3400/9000 train_loss:2.1058 train_time:287487ms step_avg:84.56ms +step:3450/9000 train_loss:2.1581 train_time:291747ms step_avg:84.56ms +step:3500/9000 train_loss:2.1088 train_time:295954ms step_avg:84.56ms +step:3550/9000 train_loss:2.0664 train_time:300154ms step_avg:84.55ms +step:3600/9000 train_loss:2.0948 train_time:304420ms step_avg:84.56ms +step:3650/9000 train_loss:2.0121 train_time:308623ms step_avg:84.55ms +step:3700/9000 train_loss:2.0852 train_time:312882ms step_avg:84.56ms +step:3750/9000 train_loss:2.2647 train_time:317087ms step_avg:84.56ms +step:3800/9000 train_loss:2.2267 train_time:321291ms step_avg:84.55ms +step:3850/9000 train_loss:2.0275 train_time:325548ms step_avg:84.56ms +step:3900/9000 train_loss:2.0219 train_time:329750ms step_avg:84.55ms +step:3950/9000 train_loss:2.2035 train_time:334002ms step_avg:84.56ms +step:4000/9000 train_loss:2.1666 train_time:338214ms step_avg:84.55ms +step:4050/9000 train_loss:2.0784 train_time:342410ms step_avg:84.55ms +step:4100/9000 train_loss:2.0639 train_time:346663ms step_avg:84.55ms +step:4150/9000 train_loss:2.0452 train_time:350862ms step_avg:84.55ms +step:4200/9000 train_loss:2.0870 train_time:355119ms step_avg:84.55ms +step:4250/9000 train_loss:2.0464 train_time:359320ms step_avg:84.55ms +step:4300/9000 train_loss:2.1999 train_time:363518ms step_avg:84.54ms +step:4350/9000 train_loss:2.0216 train_time:367785ms step_avg:84.55ms +step:4400/9000 train_loss:2.1032 train_time:371986ms step_avg:84.54ms +step:4450/9000 train_loss:2.0947 train_time:376195ms step_avg:84.54ms +step:4500/9000 train_loss:1.9573 train_time:380454ms step_avg:84.55ms +step:4550/9000 train_loss:2.0305 train_time:384651ms step_avg:84.54ms +step:4600/9000 train_loss:2.0490 train_time:388901ms step_avg:84.54ms +step:4650/9000 train_loss:2.2277 train_time:393102ms step_avg:84.54ms +step:4700/9000 train_loss:2.0195 train_time:397310ms step_avg:84.53ms +step:4750/9000 train_loss:2.0288 train_time:401571ms step_avg:84.54ms +step:4800/9000 train_loss:2.0004 train_time:405773ms step_avg:84.54ms +step:4850/9000 train_loss:2.1258 train_time:410030ms step_avg:84.54ms +step:4900/9000 train_loss:2.0934 train_time:414232ms step_avg:84.54ms +step:4950/9000 train_loss:1.9670 train_time:418435ms step_avg:84.53ms +step:5000/9000 train_loss:2.0504 train_time:422698ms step_avg:84.54ms +step:5050/9000 train_loss:2.0423 train_time:426895ms step_avg:84.53ms +step:5100/9000 train_loss:2.0421 train_time:431152ms step_avg:84.54ms +step:5150/9000 train_loss:2.2369 train_time:435354ms step_avg:84.53ms +step:5200/9000 train_loss:2.0857 train_time:439544ms step_avg:84.53ms +step:5250/9000 train_loss:1.9483 train_time:443805ms step_avg:84.53ms +step:5300/9000 train_loss:2.1165 train_time:448003ms step_avg:84.53ms +step:5350/9000 train_loss:2.2033 train_time:452251ms step_avg:84.53ms +step:5400/9000 train_loss:1.9165 train_time:456454ms step_avg:84.53ms +step:5450/9000 train_loss:2.0143 train_time:460648ms step_avg:84.52ms +step:5500/9000 train_loss:1.9250 train_time:464900ms step_avg:84.53ms +step:5550/9000 train_loss:2.1402 train_time:469102ms step_avg:84.52ms +step:5600/9000 train_loss:1.7045 train_time:473355ms step_avg:84.53ms +step:5650/9000 train_loss:2.0230 train_time:477564ms step_avg:84.52ms +step:5700/9000 train_loss:2.1459 train_time:481759ms step_avg:84.52ms +step:5750/9000 train_loss:2.0373 train_time:486017ms step_avg:84.52ms +step:5800/9000 train_loss:1.9701 train_time:490216ms step_avg:84.52ms +step:5850/9000 train_loss:2.0186 train_time:494478ms step_avg:84.53ms +step:5900/9000 train_loss:1.8671 train_time:498662ms step_avg:84.52ms +step:5950/9000 train_loss:2.0294 train_time:502855ms step_avg:84.51ms +step:6000/9000 train_loss:1.7822 train_time:507106ms step_avg:84.52ms +step:6050/9000 train_loss:1.9304 train_time:511307ms step_avg:84.51ms +step:6100/9000 train_loss:1.8549 train_time:515507ms step_avg:84.51ms +step:6150/9000 train_loss:2.1453 train_time:519760ms step_avg:84.51ms +step:6200/9000 train_loss:1.9803 train_time:523957ms step_avg:84.51ms +step:6250/9000 train_loss:2.1350 train_time:528207ms step_avg:84.51ms +step:6300/9000 train_loss:1.9505 train_time:532405ms step_avg:84.51ms +step:6350/9000 train_loss:1.9430 train_time:536601ms step_avg:84.50ms +step:6400/9000 train_loss:2.0190 train_time:540860ms step_avg:84.51ms +swa:start step:6450 +step:6450/9000 train_loss:2.0600 train_time:545054ms step_avg:84.50ms +step:6500/9000 train_loss:1.9850 train_time:549413ms step_avg:84.53ms +step:6550/9000 train_loss:2.0736 train_time:553682ms step_avg:84.53ms +late_qat:enabled step:6572 scale:0.1499 +step:6600/9000 train_loss:1.9612 train_time:557943ms step_avg:84.54ms +step:6650/9000 train_loss:2.1874 train_time:562265ms step_avg:84.55ms +step:6700/9000 train_loss:2.1141 train_time:566530ms step_avg:84.56ms +step:6750/9000 train_loss:1.9421 train_time:570850ms step_avg:84.57ms +step:6800/9000 train_loss:1.9733 train_time:575115ms step_avg:84.58ms +step:6850/9000 train_loss:2.0899 train_time:579378ms step_avg:84.58ms +step:6900/9000 train_loss:2.0469 train_time:583692ms step_avg:84.59ms +step:6950/9000 train_loss:1.9554 train_time:587951ms step_avg:84.60ms +step:7000/9000 train_loss:2.0264 train_time:592270ms step_avg:84.61ms +step:7050/9000 train_loss:2.0584 train_time:596536ms step_avg:84.61ms +step:7091/9000 val_loss:1.9604 val_bpb:1.0971 train_time:600086ms step_avg:84.63ms +stopping_early: wallclock_cap train_time:600086ms step:7091/9000 +peak memory allocated: 21655 MiB reserved: 22246 MiB +ema:applying EMA weights +DIAGNOSTIC post_ema val_loss:1.9586 val_bpb:1.0961 eval_time:1981ms +Serialized model: 105994166 bytes +Code size: 120316 bytes +Serialized model int6+lzma: 15746424 bytes +Total submission size int6+lzma: 15866740 bytes +final_int6_roundtrip val_loss:1.9740 val_bpb:1.1047 eval_time:6344ms +final_int6_roundtrip_exact val_loss:1.97398934 val_bpb:1.10466967 +final_int6_sliding_window val_loss:1.9352 val_bpb:1.0830 stride:64 eval_time:74264ms +final_int6_sliding_window_exact val_loss:1.93518851 val_bpb:1.08295388 +final_int8_zlib_roundtrip_exact val_loss:1.93518851 val_bpb:1.08295388 +ttt_sliding:start chunks=1863 chunk_tokens=32768 total_windows=953440 stride=64 ttt_lr=0.002 ttt_epochs=3 freeze_blocks=2 +ttt_sliding:params unfrozen=26907468 frozen=4112 + ttt_chunk [1/1863] bpb=1.111535 time=0.5s + ttt_chunk [11/1863] bpb=1.060421 time=2.7s + ttt_chunk [21/1863] bpb=1.056553 time=4.9s + ttt_chunk [31/1863] bpb=1.077453 time=7.2s + ttt_chunk [41/1863] bpb=1.083045 time=9.4s + ttt_chunk [51/1863] bpb=1.081437 time=11.6s + ttt_chunk [61/1863] bpb=1.081432 time=13.8s + ttt_chunk [71/1863] bpb=1.077653 time=16.1s + ttt_chunk [81/1863] bpb=1.078662 time=18.3s + ttt_chunk [91/1863] bpb=1.081516 time=20.5s + ttt_chunk [101/1863] bpb=1.086821 time=22.8s + ttt_chunk [111/1863] bpb=1.089397 time=25.0s + ttt_chunk [121/1863] bpb=1.083096 time=27.3s + ttt_chunk [131/1863] bpb=1.080455 time=29.5s + ttt_chunk [141/1863] bpb=1.080824 time=31.7s + ttt_chunk [151/1863] bpb=1.077936 time=33.9s + ttt_chunk [161/1863] bpb=1.076300 time=36.1s + ttt_chunk [171/1863] bpb=1.082439 time=38.4s + ttt_chunk [181/1863] bpb=1.082506 time=40.6s + ttt_chunk [191/1863] bpb=1.083691 time=42.8s + ttt_chunk [201/1863] bpb=1.084941 time=45.1s + ttt_chunk [211/1863] bpb=1.085687 time=47.3s + ttt_chunk [221/1863] bpb=1.084827 time=49.5s + ttt_chunk [231/1863] bpb=1.084126 time=51.7s + ttt_chunk [241/1863] bpb=1.083011 time=54.0s + ttt_chunk [251/1863] bpb=1.082294 time=56.2s + ttt_chunk [261/1863] bpb=1.081513 time=58.4s + ttt_chunk [271/1863] bpb=1.084547 time=60.6s + ttt_chunk [281/1863] bpb=1.085218 time=62.8s + ttt_chunk [291/1863] bpb=1.087155 time=65.1s + ttt_chunk [301/1863] bpb=1.086040 time=67.3s + ttt_chunk [311/1863] bpb=1.087012 time=69.5s + ttt_chunk [321/1863] bpb=1.086849 time=71.8s + ttt_chunk [331/1863] bpb=1.085391 time=74.0s + ttt_chunk [341/1863] bpb=1.084786 time=76.2s + ttt_chunk [351/1863] bpb=1.085950 time=78.4s + ttt_chunk [361/1863] bpb=1.084874 time=80.6s + ttt_chunk [371/1863] bpb=1.085177 time=82.8s + ttt_chunk [381/1863] bpb=1.086644 time=85.0s + ttt_chunk [391/1863] bpb=1.087355 time=87.2s + ttt_chunk [401/1863] bpb=1.085333 time=89.5s + ttt_chunk [411/1863] bpb=1.084549 time=91.7s + ttt_chunk [421/1863] bpb=1.085512 time=93.9s + ttt_chunk [431/1863] bpb=1.085245 time=96.1s + ttt_chunk [441/1863] bpb=1.085611 time=98.3s + ttt_chunk [451/1863] bpb=1.084905 time=100.5s + ttt_chunk [461/1863] bpb=1.083833 time=102.7s + ttt_chunk [471/1863] bpb=1.084366 time=104.9s + ttt_chunk [481/1863] bpb=1.085126 time=107.2s + ttt_chunk [491/1863] bpb=1.086533 time=109.4s + ttt_chunk [501/1863] bpb=1.085669 time=111.6s + ttt_chunk [511/1863] bpb=1.086224 time=113.8s + ttt_chunk [521/1863] bpb=1.087202 time=116.0s + ttt_chunk [531/1863] bpb=1.087793 time=118.3s + ttt_chunk [541/1863] bpb=1.086872 time=120.5s + ttt_chunk [551/1863] bpb=1.087110 time=122.7s + ttt_chunk [561/1863] bpb=1.086900 time=124.9s + ttt_chunk [571/1863] bpb=1.086829 time=127.1s + ttt_chunk [581/1863] bpb=1.086629 time=129.4s + ttt_chunk [591/1863] bpb=1.086281 time=131.6s + ttt_chunk [601/1863] bpb=1.086420 time=133.8s + ttt_chunk [611/1863] bpb=1.086167 time=136.0s + ttt_chunk [621/1863] bpb=1.086114 time=138.2s + ttt_chunk [631/1863] bpb=1.085612 time=140.4s + ttt_chunk [641/1863] bpb=1.084040 time=142.7s + ttt_chunk [651/1863] bpb=1.084179 time=144.9s + ttt_chunk [661/1863] bpb=1.083296 time=147.1s + ttt_chunk [671/1863] bpb=1.082736 time=149.3s + ttt_chunk [681/1863] bpb=1.082779 time=151.5s + ttt_chunk [691/1863] bpb=1.082318 time=153.8s + ttt_chunk [701/1863] bpb=1.082138 time=155.9s + ttt_chunk [711/1863] bpb=1.081640 time=158.2s + ttt_chunk [721/1863] bpb=1.081762 time=160.4s + ttt_chunk [731/1863] bpb=1.080957 time=162.6s + ttt_chunk [741/1863] bpb=1.080612 time=164.8s + ttt_chunk [751/1863] bpb=1.080924 time=167.0s + ttt_chunk [761/1863] bpb=1.080491 time=169.2s + ttt_chunk [771/1863] bpb=1.080966 time=171.5s + ttt_chunk [781/1863] bpb=1.080278 time=173.7s + ttt_chunk [791/1863] bpb=1.080791 time=175.9s + ttt_chunk [801/1863] bpb=1.080295 time=178.1s + ttt_chunk [811/1863] bpb=1.080389 time=180.4s + ttt_chunk [821/1863] bpb=1.079860 time=182.6s + ttt_chunk [831/1863] bpb=1.079966 time=184.8s + ttt_chunk [841/1863] bpb=1.079722 time=187.0s + ttt_chunk [851/1863] bpb=1.079317 time=189.2s + ttt_chunk [861/1863] bpb=1.079266 time=191.4s + ttt_chunk [871/1863] bpb=1.079078 time=193.6s + ttt_chunk [881/1863] bpb=1.078547 time=195.8s + ttt_chunk [891/1863] bpb=1.078572 time=198.0s + ttt_chunk [901/1863] bpb=1.079044 time=200.3s + ttt_chunk [911/1863] bpb=1.078846 time=202.5s + ttt_chunk [921/1863] bpb=1.079089 time=204.7s + ttt_chunk [931/1863] bpb=1.079517 time=207.0s + ttt_chunk [941/1863] bpb=1.080066 time=209.2s + ttt_chunk [951/1863] bpb=1.080078 time=211.4s + ttt_chunk [961/1863] bpb=1.080248 time=213.6s + ttt_chunk [971/1863] bpb=1.080552 time=215.8s + ttt_chunk [981/1863] bpb=1.080081 time=218.1s + ttt_chunk [991/1863] bpb=1.080136 time=220.3s + ttt_chunk [1001/1863] bpb=1.080748 time=222.5s + ttt_chunk [1011/1863] bpb=1.080823 time=224.7s + ttt_chunk [1021/1863] bpb=1.080891 time=226.9s + ttt_chunk [1031/1863] bpb=1.081532 time=229.1s + ttt_chunk [1041/1863] bpb=1.081470 time=231.3s + ttt_chunk [1051/1863] bpb=1.081764 time=233.5s + ttt_chunk [1061/1863] bpb=1.081624 time=235.8s + ttt_chunk [1071/1863] bpb=1.082156 time=238.0s + ttt_chunk [1081/1863] bpb=1.082102 time=240.2s + ttt_chunk [1091/1863] bpb=1.081623 time=242.4s + ttt_chunk [1101/1863] bpb=1.081874 time=244.6s + ttt_chunk [1111/1863] bpb=1.081324 time=246.8s + ttt_chunk [1121/1863] bpb=1.081017 time=249.1s + ttt_chunk [1131/1863] bpb=1.081137 time=251.3s + ttt_chunk [1141/1863] bpb=1.081071 time=253.5s + ttt_chunk [1151/1863] bpb=1.080153 time=255.7s + ttt_chunk [1161/1863] bpb=1.080303 time=257.9s + ttt_chunk [1171/1863] bpb=1.080546 time=260.1s + ttt_chunk [1181/1863] bpb=1.081014 time=262.3s + ttt_chunk [1191/1863] bpb=1.080389 time=264.5s + ttt_chunk [1201/1863] bpb=1.080348 time=266.8s + ttt_chunk [1211/1863] bpb=1.080721 time=269.0s + ttt_chunk [1221/1863] bpb=1.080457 time=271.2s + ttt_chunk [1231/1863] bpb=1.079809 time=273.4s + ttt_chunk [1241/1863] bpb=1.079099 time=275.6s + ttt_chunk [1251/1863] bpb=1.078368 time=277.8s + ttt_chunk [1261/1863] bpb=1.078573 time=280.1s + ttt_chunk [1271/1863] bpb=1.078134 time=282.3s + ttt_chunk [1281/1863] bpb=1.078435 time=284.5s + ttt_chunk [1291/1863] bpb=1.078283 time=286.7s + ttt_chunk [1301/1863] bpb=1.077960 time=288.9s + ttt_chunk [1311/1863] bpb=1.078218 time=291.1s + ttt_chunk [1321/1863] bpb=1.078045 time=293.3s + ttt_chunk [1331/1863] bpb=1.078014 time=295.5s + ttt_chunk [1341/1863] bpb=1.078345 time=297.7s + ttt_chunk [1351/1863] bpb=1.078784 time=300.0s + ttt_chunk [1361/1863] bpb=1.079050 time=302.2s + ttt_chunk [1371/1863] bpb=1.078780 time=304.4s + ttt_chunk [1381/1863] bpb=1.078668 time=306.6s + ttt_chunk [1391/1863] bpb=1.078862 time=308.9s + ttt_chunk [1401/1863] bpb=1.079014 time=311.1s + ttt_chunk [1411/1863] bpb=1.079019 time=313.3s + ttt_chunk [1421/1863] bpb=1.079206 time=315.5s + ttt_chunk [1431/1863] bpb=1.079617 time=317.8s + ttt_chunk [1441/1863] bpb=1.079342 time=320.0s + ttt_chunk [1451/1863] bpb=1.079629 time=322.2s + ttt_chunk [1461/1863] bpb=1.079691 time=324.4s + ttt_chunk [1471/1863] bpb=1.079826 time=326.7s + ttt_chunk [1481/1863] bpb=1.079518 time=328.9s + ttt_chunk [1491/1863] bpb=1.079614 time=331.1s + ttt_chunk [1501/1863] bpb=1.079844 time=333.3s + ttt_chunk [1511/1863] bpb=1.079443 time=335.5s + ttt_chunk [1521/1863] bpb=1.079595 time=337.8s + ttt_chunk [1531/1863] bpb=1.079910 time=340.0s + ttt_chunk [1541/1863] bpb=1.079908 time=342.2s + ttt_chunk [1551/1863] bpb=1.079991 time=344.4s + ttt_chunk [1561/1863] bpb=1.080059 time=346.6s + ttt_chunk [1571/1863] bpb=1.080410 time=348.8s + ttt_chunk [1581/1863] bpb=1.080430 time=351.1s + ttt_chunk [1591/1863] bpb=1.080828 time=353.3s + ttt_chunk [1601/1863] bpb=1.080819 time=355.5s + ttt_chunk [1611/1863] bpb=1.080553 time=357.7s + ttt_chunk [1621/1863] bpb=1.080586 time=359.9s + ttt_chunk [1631/1863] bpb=1.081083 time=362.2s + ttt_chunk [1641/1863] bpb=1.081288 time=364.4s + ttt_chunk [1651/1863] bpb=1.081338 time=366.6s + ttt_chunk [1661/1863] bpb=1.081156 time=368.8s + ttt_chunk [1671/1863] bpb=1.081276 time=371.0s + ttt_chunk [1681/1863] bpb=1.081227 time=373.3s + ttt_chunk [1691/1863] bpb=1.081085 time=375.5s + ttt_chunk [1701/1863] bpb=1.081372 time=377.7s + ttt_chunk [1711/1863] bpb=1.080812 time=379.9s + ttt_chunk [1721/1863] bpb=1.080565 time=382.1s + ttt_chunk [1731/1863] bpb=1.080668 time=384.3s + ttt_chunk [1741/1863] bpb=1.080630 time=386.5s + ttt_chunk [1751/1863] bpb=1.080649 time=388.8s + ttt_chunk [1761/1863] bpb=1.080200 time=391.0s + ttt_chunk [1771/1863] bpb=1.080081 time=393.2s + ttt_chunk [1781/1863] bpb=1.079896 time=395.4s + ttt_chunk [1791/1863] bpb=1.079885 time=397.6s + ttt_chunk [1801/1863] bpb=1.079750 time=399.8s + ttt_chunk [1811/1863] bpb=1.079612 time=402.0s + ttt_chunk [1821/1863] bpb=1.079375 time=404.3s + ttt_chunk [1831/1863] bpb=1.078965 time=406.5s + ttt_chunk [1841/1863] bpb=1.078795 time=408.7s + ttt_chunk [1851/1863] bpb=1.078571 time=410.9s + ttt_chunk [1861/1863] bpb=1.078602 time=413.2s + ttt_chunk [1863/1863] bpb=1.078556 time=413.4s +ttt_sliding:done val_loss=1.930065 val_bpb=1.080087 elapsed=413.4s +legal_ttt val_loss:1.9301 val_bpb:1.0801 eval_time:413954ms +legal_ttt_exact val_loss:1.93006483 val_bpb:1.08008661