diff --git a/.gitignore b/.gitignore index 3423c416a7..5180fff54b 100644 --- a/.gitignore +++ b/.gitignore @@ -1,11 +1,32 @@ -data/tokenizers +# Auto-generated outputs +runs/ +logs/ +sweep_logs/ +checkpoints/ +submission/ + +# Python __pycache__/ +.pytest_cache/ +*.pyc +*.pyo +*.egg-info/ + +# Data (large binary files) +data/datasets/ +data/tokenizers/ + +# Model artifacts +*.pt +*.ptz +*.ckpt + +# OS .DS_Store -modded-nanogpt/ -modded-nanogpt -data/datasets -data/manifest.json -data/docs_selected.jsonl -.mypy_cache/ -.venv -logs/ \ No newline at end of file +Thumbs.db + +# IDE +.vscode/ +.idea/ +*.swp +*.swo diff --git a/h100-runs.sh b/h100-runs.sh new file mode 100644 index 0000000000..1c700ecf43 --- /dev/null +++ b/h100-runs.sh @@ -0,0 +1,133 @@ +#!/bin/bash +# H100 Runs — Based on 32 local experiments (2026-03-27) +# +# Key findings from local queue: +# 1. Seq_len 2048 is a MASSIVE win (-0.4 BPB vs 1024) +# 2. Muon LR 0.025 gives slight edge over default 0.04 +# 3. TTT gives -0.04 to -0.11 BPB on H100 (eval is free/unlimited time) +# 4. GQA (4KV) keeps artifact under 16MB with ~0.01 BPB cost +# 5. No RoPE slightly better than partial RoPE +# 6. 13L/576d >> 11L/512d but artifact borderline +# +# Previous best: 1.2642 BPB (13L/576d, seq1024, TTT=7, 15.7MB artifact) +# +# Budget: ~$36 for 3 runs (~$12 each) + +set -e +cd /workspace/parameter-golf + +export HF_HOME=/workspace/.hf_cache +export DATA_PATH=./data/datasets/fineweb10B_sp1024 +export TOKENIZER_PATH=./data/tokenizers/fineweb_1024_bpe.model + +# ============================================================ +# RUN A: New Best Candidate — 13L/576d + Seq2048 + TTT=7 +# ============================================================ +# Combines our best architecture (13L/576d) with the biggest +# local discovery (seq_len 2048). Previous 13L/576d got 1.264 +# with seq1024 — seq2048 should push significantly lower. +# Risk: artifact might exceed 16MB (previous was 15.7MB) +echo "=== RUN A: 13L/576d + Seq2048 + TTT=7 ===" +RUN_ID="runa_13L576d_seq2048" \ +NUM_LAYERS=13 MODEL_DIM=576 NUM_HEADS=9 NUM_KV_HEADS=3 \ +TRAIN_SEQ_LEN=2048 \ +ITERATIONS=12000 WARMDOWN_ITERS=3500 WARMUP_STEPS=20 \ +MATRIX_LR=0.025 \ +BIGRAMHASH_BUCKETS=4096 \ +SMEARGATE=1 UNET_SKIPS=1 INT6_QAT=1 TIE_EMBEDDINGS=1 \ +ROPE_PARTIAL_DIMS=0 LN_SCALE=1 XSA_LAYERS=4 \ +EMA_DECAY=0.997 \ +TTT_ENABLED=1 TTT_EPOCHS=7 TTT_LR=0.002 TTT_CHUNK_TOKENS=32768 TTT_BATCH_SEQS=32 \ +EVAL_STRIDE=64 \ +MAX_WALLCLOCK_SECONDS=600 \ +python train_gpt.py 2>&1 | tee /workspace/runa_13L576d_seq2048.log + +# Save artifact with descriptive name +cp final_model.int6.ptz /workspace/artifact_runA_13L576d_seq2048.ptz 2>/dev/null || true +echo "Run A complete. Check log for BPB." +echo "" + +# ============================================================ +# RUN B: Safe Candidate — 11L/512d + Seq2048 + TTT=7 +# ============================================================ +# Same seq2048 win but with guaranteed-legal 11L/512d config. +# Previous 11L/512d got 1.352 (seq1024) — seq2048 should be ~1.25-1.30? +# Artifact: ~14-15MB (safely under 16MB limit) +echo "=== RUN B: 11L/512d + Seq2048 + TTT=7 ===" +RUN_ID="runb_11L512d_seq2048" \ +NUM_LAYERS=11 MODEL_DIM=512 NUM_HEADS=8 NUM_KV_HEADS=4 \ +TRAIN_SEQ_LEN=2048 \ +ITERATIONS=12000 WARMDOWN_ITERS=3500 WARMUP_STEPS=20 \ +MATRIX_LR=0.025 \ +BIGRAMHASH_BUCKETS=4096 \ +SMEARGATE=1 UNET_SKIPS=1 INT6_QAT=1 TIE_EMBEDDINGS=1 \ +ROPE_PARTIAL_DIMS=0 LN_SCALE=1 XSA_LAYERS=4 \ +EMA_DECAY=0.997 \ +TTT_ENABLED=1 TTT_EPOCHS=7 TTT_LR=0.002 TTT_CHUNK_TOKENS=32768 TTT_BATCH_SEQS=32 \ +EVAL_STRIDE=64 \ +MAX_WALLCLOCK_SECONDS=600 \ +python train_gpt.py 2>&1 | tee /workspace/runb_11L512d_seq2048.log + +cp final_model.int6.ptz /workspace/artifact_runB_11L512d_seq2048.ptz 2>/dev/null || true +echo "Run B complete. Check log for BPB." +echo "" + +# ============================================================ +# RUN C: Previous Best Rerun (baseline comparison) +# ============================================================ +# Exact config from Run 4 that got 1.2642 BPB (seq1024). +# Use as baseline to measure how much seq2048 actually helps on H100. +echo "=== RUN C: 13L/576d + Seq1024 + TTT=7 (baseline) ===" +RUN_ID="runc_13L576d_seq1024_baseline" \ +NUM_LAYERS=13 MODEL_DIM=576 NUM_HEADS=9 NUM_KV_HEADS=3 \ +TRAIN_SEQ_LEN=1024 \ +ITERATIONS=12000 WARMDOWN_ITERS=3500 WARMUP_STEPS=20 \ +MATRIX_LR=0.025 \ +BIGRAMHASH_BUCKETS=4096 \ +SMEARGATE=1 UNET_SKIPS=1 INT6_QAT=1 TIE_EMBEDDINGS=1 \ +ROPE_PARTIAL_DIMS=0 LN_SCALE=1 XSA_LAYERS=4 \ +EMA_DECAY=0.997 \ +TTT_ENABLED=1 TTT_EPOCHS=7 TTT_LR=0.002 TTT_CHUNK_TOKENS=32768 TTT_BATCH_SEQS=32 \ +EVAL_STRIDE=64 \ +MAX_WALLCLOCK_SECONDS=600 \ +python train_gpt.py 2>&1 | tee /workspace/runc_13L576d_seq1024.log + +cp final_model.int6.ptz /workspace/artifact_runC_13L576d_seq1024.ptz 2>/dev/null || true +echo "Run C complete. Check log for BPB." +echo "" + +# ============================================================ +# RUN D: Best Local Config (MHA) + Seq2048 + TTT=7 +# ============================================================ +# The best local config (BPB 1.628) used full MHA (8 KV heads) +# but was NEVER tested on H100. All H100 runs used GQA. +# Combining MHA with seq2048 could be the best overall. +# Risk: artifact may exceed 16MB (was 19.4MB locally at 5000 steps) +# but H100 trains longer + we can try compression locally. +echo "=== RUN D: 11L/512d MHA + Seq2048 + TTT=7 (best local config) ===" +RUN_ID="rund_11L512d_mha_seq2048" \ +NUM_LAYERS=11 MODEL_DIM=512 NUM_HEADS=8 NUM_KV_HEADS=8 \ +TRAIN_SEQ_LEN=2048 \ +ITERATIONS=12000 WARMDOWN_ITERS=3500 WARMUP_STEPS=20 \ +MATRIX_LR=0.025 \ +BIGRAMHASH_BUCKETS=4096 \ +SMEARGATE=1 UNET_SKIPS=1 INT6_QAT=1 TIE_EMBEDDINGS=1 \ +ROPE_PARTIAL_DIMS=0 LN_SCALE=1 XSA_LAYERS=4 \ +EMA_DECAY=0.997 \ +TTT_ENABLED=1 TTT_EPOCHS=7 TTT_LR=0.002 TTT_CHUNK_TOKENS=32768 TTT_BATCH_SEQS=32 \ +EVAL_STRIDE=64 \ +MAX_WALLCLOCK_SECONDS=600 \ +python train_gpt.py 2>&1 | tee /workspace/rund_11L512d_mha_seq2048.log + +cp final_model.int6.ptz /workspace/artifact_runD_11L512d_mha_seq2048.ptz 2>/dev/null || true +echo "Run D complete. Check log for BPB." +echo "" + +echo "=== ALL RUNS COMPLETE ===" +echo "Check artifacts in /workspace/artifact_run*.ptz" +echo "Logs in /workspace/run*.log" +echo "" +echo "Expected results:" +echo " Run A (13L+seq2048): Target <1.22 BPB, artifact ~15-17MB" +echo " Run B (11L+seq2048): Target <1.30 BPB, artifact ~14-15MB (safe)" +echo " Run C (13L+seq1024): Expect ~1.264 BPB (baseline for comparison)" diff --git a/ngram_eval.py b/ngram_eval.py new file mode 100644 index 0000000000..82a67849a5 --- /dev/null +++ b/ngram_eval.py @@ -0,0 +1,249 @@ +""" +N-gram Backoff Eval Cache for Parameter Golf. +Implements the breakthrough eval-time technique from PR #809 (0.295 BPB). + +During evaluation, builds a growing N-gram cache from already-scored tokens. +Uses highest-order match with entropy-adaptive alpha blending to combine +N-gram predictions with model predictions. + +Usage: + Set NGRAM_EVAL=1 to enable during final evaluation. + Set NGRAM_MAX_ORDER=9 for max N-gram order (default 9). +""" +from __future__ import annotations + +import math +import os +import time +from collections import defaultdict + +import numpy as np +import torch +import torch.nn.functional as F +from torch import Tensor + + +class BackoffNgramMixer: + """Backoff N-gram cache with entropy-adaptive blending. + + Maintains counts of N-gram patterns from scored tokens. + For predictions, finds the highest-order matching context + and builds a probability distribution from observed next-token counts. + Blends with model probabilities using entropy-adaptive alpha. + """ + + def __init__( + self, + vocab_size: int, + max_order: int = 9, + alpha_min: float = 0.05, + alpha_max: float = 0.70, + entropy_center: float = 3.0, + entropy_scale: float = 1.5, + ): + self.vocab_size = vocab_size + self.max_order = max_order + self.alpha_min = alpha_min + self.alpha_max = alpha_max + self.entropy_center = entropy_center + self.entropy_scale = entropy_scale + # N-gram counts: order -> {context_tuple -> {next_token_id -> count}} + self.tables: dict[int, dict[tuple[int, ...], dict[int, int]]] = { + order: defaultdict(lambda: defaultdict(int)) + for order in range(2, max_order + 1) + } + self.history: list[int] = [] + + def update(self, token_id: int) -> None: + """Add a scored token to the history and update all N-gram tables.""" + self.history.append(token_id) + n = len(self.history) + for order in range(2, min(self.max_order + 1, n + 1)): + context = tuple(self.history[n - order : n - 1]) + self.tables[order][context][token_id] += 1 + + def update_batch(self, token_ids: list[int]) -> None: + """Batch update the cache with multiple tokens.""" + for t in token_ids: + self.update(t) + + def get_ngram_dist(self, context: list[int], device: torch.device) -> Tensor | None: + """Get N-gram probability distribution using backoff strategy. + + Tries highest order first, backs off to lower orders. + Returns tensor of shape [vocab_size] on device, or None if no match. + """ + for order in range(self.max_order, 1, -1): + if len(context) < order - 1: + continue + ctx_key = tuple(context[-(order - 1):]) + counts = self.tables[order].get(ctx_key) + if counts is not None and len(counts) > 0: + total = sum(counts.values()) + if total < 1: + continue + probs = torch.zeros(self.vocab_size, device=device, dtype=torch.float32) + for tok, cnt in counts.items(): + probs[tok] = cnt / total + return probs + return None + + def compute_alpha(self, model_logprobs: Tensor) -> float: + """Compute entropy-adaptive blending alpha from model log-probabilities. + + High model entropy -> trust N-gram more (higher alpha). + Low model entropy -> trust model more (lower alpha). + """ + probs = model_logprobs.exp() + entropy = -(probs * model_logprobs).sum().item() + # Clamp for numerical safety + entropy = max(0.0, entropy) + x = (entropy - self.entropy_center) * self.entropy_scale + # Sigmoid + if x > 20: + sigmoid = 1.0 + elif x < -20: + sigmoid = 0.0 + else: + sigmoid = 1.0 / (1.0 + math.exp(-x)) + return self.alpha_min + (self.alpha_max - self.alpha_min) * sigmoid + + +def eval_val_ngram( + args, + base_model, + model, + rank: int, + world_size: int, + device: torch.device, + grad_accum_steps: int, + val_tokens: Tensor, + base_bytes_lut: Tensor, + has_leading_space_lut: Tensor, + is_boundary_token_lut: Tensor, + max_order: int = 9, + log_fn=None, +) -> tuple[float, float]: + """N-gram enhanced validation evaluation. + + Processes validation tokens in chunks. For each chunk: + 1. Get model logits via forward pass + 2. For each token, look up N-gram prediction from cache + 3. Blend model + N-gram probabilities with entropy-adaptive alpha + 4. Compute cross-entropy from blended distribution + 5. Add scored tokens to cache for future predictions + + Returns (val_loss, val_bpb) like the standard eval_val. + """ + if log_fn is None: + log_fn = print + + mixer = BackoffNgramMixer( + vocab_size=args.vocab_size, + max_order=max_order, + ) + + seq_len = args.train_seq_len + total_tokens = val_tokens.numel() - 1 + + # We need model to return logits instead of loss + base_model._return_logits = True + + loss_sum = 0.0 + token_count = 0.0 + byte_count = 0.0 + ngram_hits = 0 + ngram_misses = 0 + + model.eval() + t_start = time.perf_counter() + chunks_done = 0 + + with torch.inference_mode(): + for start in range(0, total_tokens - seq_len + 1, seq_len): + end = start + seq_len + 1 + if end > val_tokens.numel(): + break + + chunk = val_tokens[start:end].to(device=device, dtype=torch.int64) + x = chunk[:-1].unsqueeze(0) # [1, seq_len] + y = chunk[1:] # [seq_len] + + # Get model logits + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + logits = model(x, y.unsqueeze(0)) # [1, seq_len, vocab_size] + + logits = logits.squeeze(0).float() # [seq_len, vocab_size] + model_logprobs = F.log_softmax(logits, dim=-1) + + # Context tokens for N-gram lookup (from history + current chunk) + chunk_tokens = chunk.cpu().tolist() + + # Process each token in the chunk + chunk_loss = 0.0 + for t in range(seq_len): + target = y[t].item() + token_logprobs = model_logprobs[t] # [vocab_size] + + # Build context from history + current chunk tokens up to position t + context = list(mixer.history) + chunk_tokens[:t + 1] + + # Try N-gram prediction + ngram_probs = mixer.get_ngram_dist(context, device) + + if ngram_probs is not None: + ngram_hits += 1 + alpha = mixer.compute_alpha(token_logprobs) + # Blend: p_final = (1-alpha) * p_model + alpha * p_ngram + model_probs = token_logprobs.exp() + blended = (1.0 - alpha) * model_probs + alpha * ngram_probs + blended = blended.clamp(min=1e-10) + token_loss = -torch.log(blended[target]).item() + else: + ngram_misses += 1 + token_loss = -token_logprobs[target].item() + + chunk_loss += token_loss + + loss_sum += chunk_loss + token_count += seq_len + + # Byte counting + prev_ids = x.reshape(-1) + tgt_ids = y.reshape(-1) + t_bytes = base_bytes_lut[tgt_ids].to(dtype=torch.int16) + t_bytes += (has_leading_space_lut[tgt_ids] & ~is_boundary_token_lut[prev_ids]).to(torch.int16) + byte_count += t_bytes.to(torch.float64).sum().item() + + # Add all scored tokens to the N-gram cache + mixer.update_batch(y.cpu().tolist()) + + chunks_done += 1 + if chunks_done % 50 == 0: + elapsed = time.perf_counter() - t_start + current_bpb = (loss_sum / token_count) / math.log(2.0) * (token_count / max(byte_count, 1)) + hit_rate = ngram_hits / max(ngram_hits + ngram_misses, 1) * 100 + log_fn( + f"ngram_eval: chunk {chunks_done}, " + f"tokens {int(token_count)}/{total_tokens}, " + f"bpb_so_far {current_bpb:.4f}, " + f"hit_rate {hit_rate:.1f}%, " + f"elapsed {elapsed:.0f}s" + ) + + # Restore normal forward mode + base_model._return_logits = False + + val_loss = loss_sum / token_count + bits_per_token = val_loss / math.log(2.0) + tokens_per_byte = token_count / byte_count + + elapsed = time.perf_counter() - t_start + hit_rate = ngram_hits / max(ngram_hits + ngram_misses, 1) * 100 + log_fn( + f"ngram_eval: DONE, {chunks_done} chunks, {elapsed:.0f}s, " + f"hit_rate {hit_rate:.1f}%, cache_size {len(mixer.history)}" + ) + + model.train() + return float(val_loss), float(bits_per_token * tokens_per_byte) diff --git a/scripts/h100-8x-final-submission.sh b/scripts/h100-8x-final-submission.sh new file mode 100644 index 0000000000..c41e53e203 --- /dev/null +++ b/scripts/h100-8x-final-submission.sh @@ -0,0 +1,142 @@ +#!/bin/bash +# ============================================================================= +# 8×H100 FINAL SUBMISSION SCRIPT +# 3-seed validation run for leaderboard submission +# Updated: 2026-04-05 (uses train_gpt_full_stack.py) +# ============================================================================= +# WHEN TO RUN: Only when 1×H100 ablations show our best BPB and fits 16MB. +# COST: ~$21/hr × ~1.5hr total = ~$31-40 +# ============================================================================= +# ⚠️ NO TTT EVAL unless this is explicitly a submission run. +# ⚠️ NEVER run parallel — all 8 GPUs needed per run. +# ============================================================================= +# Usage: bash h100-8x-final-submission.sh [OPTIONAL: config name label] +# Seeds: 4, 30, 2026 +# ============================================================================= + +set -e +cd /workspace/parameter-golf + +LABEL="${1:-sp4096_full_stack}" +LOG_DIR="logs" +mkdir -p "$LOG_DIR" + +TRAIN="train_gpt_full_stack.py" +[ -f "$TRAIN" ] || { echo "ERROR: $TRAIN not found! Run git pull first."; exit 1; } + +echo "========================================" +echo "8×H100 Final Submission Run" +echo "Config: $LABEL" +echo "Script: $TRAIN" +echo "Started: $(date)" +echo "========================================" +echo "" + +# ============================================================================= +# ⚙️ CONFIGURE BEST CONFIG HERE (update after 1×H100 ablations confirm winner) +# ============================================================================= +# SP4096 + MLP4x + MuonEq-R + depth recur + parallel resid + disc TTT + GPTQ +# Adjust based on actual ablation results. +# ============================================================================= +BEST_CONFIG=" + VOCAB_SIZE=4096 + DATA_PATH=./data/datasets/fineweb10B_sp4096 + TOKENIZER_PATH=./data/tokenizers/fineweb_4096_bpe.model + TRAIN_SEQ_LEN=4096 + MLP_MULT=4 + MUON_MOMENTUM=0.95 + MUONEQ_R=1 + WARMDOWN_SCHEDULE=sqrt + QK_GAIN_INIT=5.0 + BIGRAMHASH_DIM=3072 + RECUR_LAYERS=3,4,5 + PARALLEL_START_LAYER=4 + TTT_ENABLED=1 + TTT_PREQUANT=1 + TTT_DISCRIMINATIVE=1 + TTT_EPOCHS=10 + TTT_LR=0.005 + GPTQ_FULL_HESSIAN=1 + GPTQ_DAMP=0.005 + POLAR_EXPRESS=1 + MAX_WALLCLOCK_SECONDS=590 + GRAD_ACCUM_STEPS=1 +" +# Note: GRAD_ACCUM_STEPS=1 on 8×H100 (vs 8 on 1×H100 — torchrun handles total batch) + +# ============================================================================= +# SEED 4 +# ============================================================================= +echo "[1/3] Running seed 4..." +env $BEST_CONFIG SEED=4 \ + torchrun --nproc_per_node=8 --standalone "$TRAIN" \ + 2>&1 | tee "$LOG_DIR/8x_${LABEL}_seed4.log" +echo "" +echo "Seed 4 complete. BPB:" +grep -E "ttt_bpb|slot_bpb|roundtrip.*bpb|val_bpb" "$LOG_DIR/8x_${LABEL}_seed4.log" | tail -5 +echo "" + +# ============================================================================= +# SEED 30 +# ============================================================================= +echo "[2/3] Running seed 30..." +env $BEST_CONFIG SEED=30 \ + torchrun --nproc_per_node=8 --standalone "$TRAIN" \ + 2>&1 | tee "$LOG_DIR/8x_${LABEL}_seed30.log" +echo "" +echo "Seed 30 complete. BPB:" +grep -E "ttt_bpb|slot_bpb|roundtrip.*bpb|val_bpb" "$LOG_DIR/8x_${LABEL}_seed30.log" | tail -5 +echo "" + +# ============================================================================= +# SEED 2026 +# ============================================================================= +echo "[3/3] Running seed 2026..." +env $BEST_CONFIG SEED=2026 \ + torchrun --nproc_per_node=8 --standalone "$TRAIN" \ + 2>&1 | tee "$LOG_DIR/8x_${LABEL}_seed2026.log" +echo "" +echo "Seed 2026 complete. BPB:" +grep -E "ttt_bpb|slot_bpb|roundtrip.*bpb|val_bpb" "$LOG_DIR/8x_${LABEL}_seed2026.log" | tail -5 +echo "" + +# ============================================================================= +# RESULTS SUMMARY +# ============================================================================= +echo "========================================" +echo "3-SEED VALIDATION COMPLETE" +echo "Config: $LABEL" +echo "========================================" +echo "" + +# Extract BPB from each seed +extract_bpb() { + local logfile="$1" + grep -E "ttt_bpb|discriminative.*bpb|roundtrip.*bpb" "$logfile" 2>/dev/null | \ + grep -oP '[0-9]\.[0-9]{4,}' | tail -1 +} + +S1=$(extract_bpb "$LOG_DIR/8x_${LABEL}_seed4.log") +S2=$(extract_bpb "$LOG_DIR/8x_${LABEL}_seed30.log") +S3=$(extract_bpb "$LOG_DIR/8x_${LABEL}_seed2026.log") + +echo "Seed 4: ${S1:-???}" +echo "Seed 30: ${S2:-???}" +echo "Seed 2026: ${S3:-???}" + +if command -v python3 &>/dev/null && [ -n "$S1" ] && [ -n "$S2" ] && [ -n "$S3" ]; then + python3 -c " +s = [$S1, $S2, $S3] +mean = sum(s)/3 +std = (sum((x-mean)**2 for x in s)/3)**0.5 +print(f'3-seed mean: {mean:.4f} BPB') +print(f'3-seed std: {std:.4f}') +print(f'Individual: {s}') +" +fi + +echo "" +echo "Artifact sizes:" +ls -lh artifacts/*.lzma artifacts/*.br 2>/dev/null | tail -5 || echo "(no artifacts found in artifacts/)" +echo "" +echo "Logs: $LOG_DIR/8x_${LABEL}_seed*.log" diff --git a/scripts/h100-next-pod-setup.sh b/scripts/h100-next-pod-setup.sh new file mode 100644 index 0000000000..c8d44265db --- /dev/null +++ b/scripts/h100-next-pod-setup.sh @@ -0,0 +1,58 @@ +#!/bin/bash +# ============================================================================= +# NEXT POD SETUP SCRIPT +# Phase A: SP4096 + MLP4x + full architecture stack +# Updated: 2026-04-05 +# ============================================================================= +# Usage: bash h100-next-pod-setup.sh +# This sets up the environment, downloads SP4096 data, then launches runs. +# ============================================================================= + +set -e +cd /workspace + +# ---- 0. Benchmark this pod first ---- +echo "=== POD BENCHMARK ===" +curl -s https://raw.githubusercontent.com/NathanMaine/runpod-gpu-benchmark/main/pod-test.sh | bash || true +echo "===========================" +echo "If GEMM > 0.70ms or MemBW < 2000 GB/s → SWITCH PODS before continuing!" +read -p "Pod OK? (y/n): " pod_ok +if [[ "$pod_ok" != "y" ]]; then + echo "Aborting — switch pods first." + exit 1 +fi + +# ---- 1. Clone repo ---- +echo "=== Cloning repo ===" +if [ ! -d "parameter-golf" ]; then + git clone https://github.com/Programmerryoki/parameter-golf.git +fi +cd parameter-golf + +# ---- 2. Install deps ---- +echo "=== Installing dependencies ===" +pip install -q brotli zstandard tiktoken 2>&1 | tail -5 + +# ---- 3. Download SP4096 data ---- +echo "=== Downloading SP4096 dataset ===" +mkdir -p data/datasets +if [ ! -d "data/datasets/fineweb10B_sp4096" ] || [ "$(ls data/datasets/fineweb10B_sp4096/*.bin 2>/dev/null | wc -l)" -lt 80 ]; then + MATCHED_FINEWEB_REPO_ID=kevclark/parameter-golf \ + python3 data/cached_challenge_fineweb.py --variant sp4096 --skip-manifest + echo "SP4096 data downloaded." +else + echo "SP4096 data already present ($(ls data/datasets/fineweb10B_sp4096/*.bin | wc -l) shards)." +fi + +# ---- 4. Download SP1024 data (backup/baseline) ---- +echo "=== Checking SP1024 dataset ===" +if [ ! -d "data/datasets/fineweb10B_sp1024" ] || [ "$(ls data/datasets/fineweb10B_sp1024/*.bin 2>/dev/null | wc -l)" -lt 80 ]; then + echo "Downloading SP1024..." + MATCHED_FINEWEB_REPO_ID=willdepueoai/parameter-golf \ + python3 data/cached_challenge_fineweb.py --variant sp1024 --skip-manifest || true +fi + +echo "" +echo "=== SETUP COMPLETE ===" +echo "Now run: bash scripts/h100-sp4096-ablations.sh" +echo "" diff --git a/scripts/h100-sp4096-ablations.sh b/scripts/h100-sp4096-ablations.sh new file mode 100644 index 0000000000..6032a25853 --- /dev/null +++ b/scripts/h100-sp4096-ablations.sh @@ -0,0 +1,120 @@ +#!/bin/bash +# ============================================================================= +# SP4096 ABLATION SUITE (1×H100) — STRICTLY SEQUENTIAL +# Runs one at a time. No parallel runs. No tmux windows per run. +# Updated: 2026-04-05 (fixed: sequential only) +# ============================================================================= + +set -e +cd /workspace/parameter-golf +mkdir -p logs + +TRAIN="train_gpt_full_stack.py" +[ -f "$TRAIN" ] || { echo "ERROR: $TRAIN not found!"; exit 1; } +echo "Using: $TRAIN" +echo "Start: $(date)" +echo "" + +# Common SP4096 base env vars +export VOCAB_SIZE=4096 +export DATA_PATH=./data/datasets/fineweb10B_sp4096 +export TOKENIZER_PATH=./data/tokenizers/fineweb_4096_bpe.model +export TRAIN_SEQ_LEN=4096 +export MLP_MULT=4 +export MUON_MOMENTUM=0.95 +export WARMDOWN_SCHEDULE=sqrt +export QK_GAIN_INIT=5.0 +export BIGRAMHASH_DIM=3072 +export MAX_WALLCLOCK_SECONDS=300 +export TTT_ENABLED=0 + +run_sequential() { + local name="$1" + local logfile="logs/sp4096_${name}.log" + + if [ -f "$logfile" ] && grep -qE "roundtrip|val_bpb" "$logfile" 2>/dev/null; then + bpb=$(grep -oP '[0-9]\.[0-9]{4,}' "$logfile" | tail -1) + echo "SKIP (done): $name BPB=$bpb" + return + fi + + echo "" + echo "============================================" + echo "STARTING: $name ($(date))" + echo "============================================" + + torchrun --nproc_per_node=1 --standalone "$TRAIN" 2>&1 | tee "$logfile" + + bpb=$(grep -oP '[0-9]\.[0-9]{4,}' "$logfile" | tail -1) + echo "" + echo "DONE: $name BPB=${bpb:-???} ($(date))" + echo "============================================" +} + +# ============================================================================= +# PHASE A: SP4096 base configs +# ============================================================================= +echo "=== PHASE A: SP4096 BASE CONFIGS ===" + +# A0: Pure baseline (no MuonEq-R) +MUONEQ_R=0 PARALLEL_START_LAYER=-1 RECUR_LAYERS="" \ + run_sequential "A0_sp4096_base" + +# A1: + MuonEq-R (biggest winner: -0.0627 on sp1024) +MUONEQ_R=1 PARALLEL_START_LAYER=-1 RECUR_LAYERS="" \ + run_sequential "A1_muoneqr" + +# A2: + depth recurrence 3-layer (layers 3,4,5) +MUONEQ_R=1 RECUR_LAYERS="3,4,5" PARALLEL_START_LAYER=-1 \ + run_sequential "A2_depthrecur3" + +# A3: + two-lane parallel residuals (start at layer 4) +MUONEQ_R=1 RECUR_LAYERS="" PARALLEL_START_LAYER=4 \ + run_sequential "A3_parallel" + +# A4: Full arch combo +MUONEQ_R=1 RECUR_LAYERS="3,4,5" PARALLEL_START_LAYER=4 \ + run_sequential "A4_full_arch" + +# ============================================================================= +# PHASE B: Pre-quant TTT +# ============================================================================= +echo "" +echo "=== PHASE B: PRE-QUANT TTT ===" + +MUONEQ_R=1 RECUR_LAYERS="3,4,5" PARALLEL_START_LAYER=4 \ + TTT_ENABLED=1 TTT_PREQUANT=1 TTT_OPTIMIZER=adamw TTT_LR=0.005 TTT_EPOCHS=6 \ + run_sequential "B0_prequant_ttt" + +MUONEQ_R=1 RECUR_LAYERS="3,4,5" PARALLEL_START_LAYER=4 \ + TTT_ENABLED=1 TTT_PREQUANT=1 TTT_DISCRIMINATIVE=1 TTT_EPOCHS=10 \ + run_sequential "B1_disc_ttt" + +# ============================================================================= +# PHASE C: Full GPTQ + Causal SLOT +# ============================================================================= +echo "" +echo "=== PHASE C: FULL GPTQ + CAUSAL SLOT ===" + +MUONEQ_R=1 RECUR_LAYERS="3,4,5" PARALLEL_START_LAYER=4 \ + TTT_ENABLED=1 TTT_PREQUANT=1 TTT_DISCRIMINATIVE=1 TTT_EPOCHS=10 \ + GPTQ_FULL_HESSIAN=1 GPTQ_DAMP=0.005 \ + run_sequential "C0_full_gptq" + +MUONEQ_R=1 RECUR_LAYERS="3,4,5" PARALLEL_START_LAYER=4 \ + TTT_ENABLED=1 TTT_PREQUANT=1 TTT_DISCRIMINATIVE=1 TTT_EPOCHS=10 \ + GPTQ_FULL_HESSIAN=1 GPTQ_DAMP=0.005 CAUSAL_SLOT_ENABLED=1 \ + run_sequential "C1_causal_slot" + +# ============================================================================= +# SUMMARY +# ============================================================================= +echo "" +echo "=== ALL DONE: $(date) ===" +echo "" +echo "Results (roundtrip BPB):" +for log in logs/sp4096_*.log; do + name=$(basename "$log" .log) + bpb=$(grep -oP '[0-9]\.[0-9]{4,}' "$log" 2>/dev/null | tail -1) + echo " $name: ${bpb:-???}" +done diff --git a/train_gpt.py b/train_gpt.py index 651beb2b89..585a16715f 100644 --- a/train_gpt.py +++ b/train_gpt.py @@ -1,11 +1,4 @@ -""" -The `train_gpt.py` and `train_gpt_mlx.py` scripts are intended as good launching-off points for new participants, not SOTA configs. We'll accept PRs that tune, improve, or simplify these scripts without significantly increasing complexity, but competitive submissions should stay in the `/records` folder. - -Hard stop: To keep readable for newcomers, let's make sure `train_gpt.py` and `train_gpt_mlx.py` never are longer than 1500 lines. -""" - from __future__ import annotations - import copy import glob import io @@ -18,7 +11,11 @@ import uuid import zlib from pathlib import Path - +try: + import zstandard as zstd + _HAS_ZSTD = True +except ImportError: + _HAS_ZSTD = False import numpy as np import sentencepiece as spm import torch @@ -26,51 +23,34 @@ import torch.nn.functional as F from torch import Tensor, nn from torch.nn.parallel import DistributedDataParallel as DDP - -# ----------------------------- -# HYPERPARAMETERS -# ----------------------------- -# Default Simple Baseline run: -# - 9 transformer blocks at width 512 -# - 8 attention heads with 4 KV heads (GQA) and 2x MLP expansion -# - vocab size 1024, sequence length 1024, tied embeddings -# - 524,288 train tokens per step for 20,000 iterations with a ~10 minute cap - class Hyperparameters: - # Data paths are shard globs produced by the existing preprocessing pipeline. data_path = os.environ.get("DATA_PATH", "./data/datasets/fineweb10B_sp1024") train_files = os.path.join(data_path, "fineweb_train_*.bin") val_files = os.path.join(data_path, "fineweb_val_*.bin") tokenizer_path = os.environ.get("TOKENIZER_PATH", "./data/tokenizers/fineweb_1024_bpe.model") run_id = os.environ.get("RUN_ID", str(uuid.uuid4())) seed = int(os.environ.get("SEED", 1337)) - - # Validation cadence and batch size. Validation always uses the full fineweb_val split. val_batch_size = int(os.environ.get("VAL_BATCH_SIZE", 524_288)) val_loss_every = int(os.environ.get("VAL_LOSS_EVERY", 1000)) train_log_every = int(os.environ.get("TRAIN_LOG_EVERY", 200)) - - # Training length. + val_max_tokens = int(os.environ.get("VAL_MAX_TOKENS", 0)) + checkpoint_every = int(os.environ.get("CHECKPOINT_EVERY", 0)) iterations = int(os.environ.get("ITERATIONS", 20000)) - warmdown_iters = int(os.environ.get("WARMDOWN_ITERS", 1200)) + warmdown_iters = int(os.environ.get("WARMDOWN_ITERS", 3500)) warmup_steps = int(os.environ.get("WARMUP_STEPS", 20)) train_batch_tokens = int(os.environ.get("TRAIN_BATCH_TOKENS", 524_288)) train_seq_len = int(os.environ.get("TRAIN_SEQ_LEN", 1024)) max_wallclock_seconds = float(os.environ.get("MAX_WALLCLOCK_SECONDS", 600.0)) qk_gain_init = float(os.environ.get("QK_GAIN_INIT", 1.5)) - - # Model shape. vocab_size = int(os.environ.get("VOCAB_SIZE", 1024)) - num_layers = int(os.environ.get("NUM_LAYERS", 9)) + num_layers = int(os.environ.get("NUM_LAYERS", 11)) num_kv_heads = int(os.environ.get("NUM_KV_HEADS", 4)) model_dim = int(os.environ.get("MODEL_DIM", 512)) num_heads = int(os.environ.get("NUM_HEADS", 8)) - mlp_mult = int(os.environ.get("MLP_MULT", 2)) + mlp_mult = int(os.environ.get("MLP_MULT", 3)) tie_embeddings = bool(int(os.environ.get("TIE_EMBEDDINGS", "1"))) rope_base = float(os.environ.get("ROPE_BASE", 10000.0)) logit_softcap = float(os.environ.get("LOGIT_SOFTCAP", 30.0)) - - # Optimizer hyperparameters. embed_lr = float(os.environ.get("EMBED_LR", 0.6)) head_lr = float(os.environ.get("HEAD_LR", 0.008)) tied_embed_lr = float(os.environ.get("TIED_EMBED_LR", 0.05)) @@ -85,17 +65,28 @@ class Hyperparameters: beta2 = float(os.environ.get("BETA2", 0.95)) adam_eps = float(os.environ.get("ADAM_EPS", 1e-8)) grad_clip_norm = float(os.environ.get("GRAD_CLIP_NORM", 0.0)) - -# ----------------------------- -# MUON OPTIMIZER -# ----------------------------- -# -# As borrowed from modded-nanogpt -# Background on Muon: https://kellerjordan.github.io/posts/muon/ - + bigramhash_buckets = int(os.environ.get("BIGRAMHASH_BUCKETS", 4096)) + ema_decay = float(os.environ.get("EMA_DECAY", 0.997)) + smeargate = bool(int(os.environ.get("SMEARGATE", "1"))) + unet_skips = bool(int(os.environ.get("UNET_SKIPS", "1"))) + int6_qat = bool(int(os.environ.get("INT6_QAT", "1"))) + rope_partial_dims = int(os.environ.get("ROPE_PARTIAL_DIMS", 16)) + ln_scale = bool(int(os.environ.get("LN_SCALE", "1"))) + xsa_layers = int(os.environ.get("XSA_LAYERS", 4)) + eval_stride = int(os.environ.get("EVAL_STRIDE", 0)) + ttt_enabled = bool(int(os.environ.get("TTT_ENABLED", "0"))) + ttt_lr = float(os.environ.get("TTT_LR", 0.002)) + ttt_epochs = int(os.environ.get("TTT_EPOCHS", 3)) + ttt_chunk_tokens = int(os.environ.get("TTT_CHUNK_TOKENS", 32768)) + ttt_freeze_blocks = int(os.environ.get("TTT_FREEZE_BLOCKS", 0)) + ttt_momentum = float(os.environ.get("TTT_MOMENTUM", 0.9)) + ttt_batch_seqs = int(os.environ.get("TTT_BATCH_SEQS", 32)) + ttt_grad_clip = float(os.environ.get("TTT_GRAD_CLIP", 1.0)) + recurrence_repeats = int(os.environ.get("RECURRENCE_REPEATS", 0)) + recurrence_unique_head = int(os.environ.get("RECURRENCE_UNIQUE_HEAD", 2)) + recurrence_unique_tail = int(os.environ.get("RECURRENCE_UNIQUE_TAIL", 2)) + moe_num_experts = int(os.environ.get("MOE_NUM_EXPERTS", 0)) def zeropower_via_newtonschulz5(G: Tensor, steps: int = 10, eps: float = 1e-7) -> Tensor: - # Orthogonalize a 2D update matrix with a fast Newton-Schulz iteration. - # Muon uses this to normalize matrix-shaped gradients before applying them. a, b, c = (3.4445, -4.7750, 2.0315) X = G.bfloat16() X /= X.norm() + eps @@ -107,26 +98,21 @@ def zeropower_via_newtonschulz5(G: Tensor, steps: int = 10, eps: float = 1e-7) - B = b * A + c * A @ A X = a * X + B @ X return X.T if transposed else X - - class Muon(torch.optim.Optimizer): def __init__(self, params, lr: float, momentum: float, backend_steps: int, nesterov: bool = True): super().__init__( params, dict(lr=lr, momentum=momentum, backend_steps=backend_steps, nesterov=nesterov), ) - @torch.no_grad() def step(self, closure=None): loss = None if closure is not None: with torch.enable_grad(): loss = closure() - distributed = dist.is_available() and dist.is_initialized() world_size = dist.get_world_size() if distributed else 1 rank = dist.get_rank() if distributed else 0 - for group in self.param_groups: params = group["params"] if not params: @@ -135,10 +121,8 @@ def step(self, closure=None): momentum = group["momentum"] backend_steps = group["backend_steps"] nesterov = group["nesterov"] - total_params = sum(int(p.numel()) for p in params) updates_flat = torch.zeros(total_params, device=params[0].device, dtype=torch.bfloat16) - curr = 0 for i, p in enumerate(params): if i % world_size == rank and p.grad is not None: @@ -151,32 +135,26 @@ def step(self, closure=None): if nesterov: g = g.add(buf, alpha=momentum) g = zeropower_via_newtonschulz5(g, steps=backend_steps) - # Scale correction from Muon reference implementations. g *= max(1, g.size(0) / g.size(1)) ** 0.5 updates_flat[curr : curr + p.numel()] = g.reshape(-1) curr += p.numel() - if distributed: dist.all_reduce(updates_flat, op=dist.ReduceOp.SUM) - curr = 0 for p in params: g = updates_flat[curr : curr + p.numel()].view_as(p).to(dtype=p.dtype) p.add_(g, alpha=-lr) curr += p.numel() - return loss - - -# ----------------------------- -# TOKENIZER-AGNOSTIC EVALUATION SETUP -# ----------------------------- -# -# It's common for small models have a large fraction of their parameters be embeddings, since the 2 * d_model * d_vocab vectors can be gigantic. -# Instead of locking the tokenizer, we let you bring your own and calculate our validation metrics on the average compression of the validation set. -# We calculate BPB (bits-per-byte) instead of validation loss, so we need methods to count the number of bits per token in the tokenizer. -# Note: Submissions that edit the tokenizer will be examined more carefully, since screwing this up might unjustly improve your score. - +def int6_ste(w: Tensor) -> Tensor: + qmax = _QUANT_MAX_VAL # Uses QUANT_BITS override if set (e.g. 15 for int5, 7 for int4) + if w.ndim == 2: + abs_max = w.abs().amax(dim=1, keepdim=True).clamp_min(1e-8) + else: + abs_max = w.abs().amax().clamp_min(1e-8) + scale = abs_max / float(qmax) + w_q = (w / scale).round().clamp(-qmax, qmax) * scale + return w + (w_q - w).detach() def build_sentencepiece_luts( sp: spm.SentencePieceProcessor, vocab_size: int, device: torch.device ) -> tuple[Tensor, Tensor, Tensor]: @@ -202,20 +180,18 @@ def build_sentencepiece_luts( torch.tensor(has_leading_space_np, dtype=torch.bool, device=device), torch.tensor(is_boundary_token_np, dtype=torch.bool, device=device), ) - - def load_validation_tokens(pattern: str, seq_len: int) -> Tensor: files = [Path(p) for p in sorted(glob.glob(pattern))] if not files: raise FileNotFoundError(f"No files found for pattern: {pattern}") - # The export pipeline writes the fixed first-50k-doc validation set to fineweb_val_*. tokens = torch.cat([load_data_shard(file) for file in files]).contiguous() + max_tokens = int(os.environ.get("VAL_MAX_TOKENS", 0)) + if max_tokens > 0 and tokens.numel() > max_tokens: + tokens = tokens[:max_tokens].contiguous() usable = ((tokens.numel() - 1) // seq_len) * seq_len if usable <= 0: raise ValueError(f"Validation split is too short for TRAIN_SEQ_LEN={seq_len}") return tokens[: usable + 1] - - def eval_val( args: Hyperparameters, model: nn.Module, @@ -228,9 +204,6 @@ def eval_val( has_leading_space_lut: Tensor, is_boundary_token_lut: Tensor, ) -> tuple[float, float]: - # Validation computes two metrics: - # - val_loss: token cross-entropy (natural log) - # - val_bpb: tokenizer-agnostic compression metric used by the challenge local_batch_tokens = args.val_batch_size // (world_size * grad_accum_steps) if local_batch_tokens < args.train_seq_len: raise ValueError( @@ -245,7 +218,6 @@ def eval_val( val_loss_sum = torch.zeros((), device=device, dtype=torch.float64) val_token_count = torch.zeros((), device=device, dtype=torch.float64) val_byte_count = torch.zeros((), device=device, dtype=torch.float64) - model.eval() with torch.inference_mode(): for batch_seq_start in range(seq_start, seq_end, local_batch_seqs): @@ -265,31 +237,243 @@ def eval_val( token_bytes = base_bytes_lut[tgt_ids].to(dtype=torch.int16) token_bytes += (has_leading_space_lut[tgt_ids] & ~is_boundary_token_lut[prev_ids]).to(dtype=torch.int16) val_byte_count += token_bytes.to(torch.float64).sum() - if dist.is_available() and dist.is_initialized(): dist.all_reduce(val_loss_sum, op=dist.ReduceOp.SUM) dist.all_reduce(val_token_count, op=dist.ReduceOp.SUM) dist.all_reduce(val_byte_count, op=dist.ReduceOp.SUM) - val_loss = val_loss_sum / val_token_count bits_per_token = val_loss.item() / math.log(2.0) tokens_per_byte = val_token_count.item() / val_byte_count.item() model.train() return float(val_loss.item()), float(bits_per_token * tokens_per_byte) - -# ----------------------------- -# POST-TRAINING QUANTIZATION -# ----------------------------- -# -# It's silly to export our model, which is trained in bf16 and fp32, at that same precision. -# Instead, we get approximately the same model (with a small hit) by quantizing the model to int8 & zlib compressing. -# We can then decompress the model and run in higher precision for evaluation, after closing in under the size limit. - +def eval_val_sliding( + args: Hyperparameters, + base_model: nn.Module, + rank: int, + world_size: int, + device: torch.device, + val_tokens: Tensor, + base_bytes_lut: Tensor, + has_leading_space_lut: Tensor, + is_boundary_token_lut: Tensor, + stride: int, + batch_seqs: int = 32, +) -> tuple[float, float]: + seq_len = args.train_seq_len + total_tokens = val_tokens.numel() - 1 + window_starts = [ws for ws in range(0, total_tokens, stride) + if min(ws + seq_len, total_tokens) - ws >= 1] + total_windows = len(window_starts) + my_s = (total_windows * rank) // world_size + my_e = (total_windows * (rank + 1)) // world_size + my_windows = window_starts[my_s:my_e] + loss_sum = torch.zeros((), device=device, dtype=torch.float64) + token_count = torch.zeros((), device=device, dtype=torch.float64) + byte_count = torch.zeros((), device=device, dtype=torch.float64) + base_model.eval() + compiled_logits = torch.compile(base_model.forward_logits, dynamic=False, fullgraph=True) + with torch.inference_mode(): + for bi in range(0, len(my_windows), batch_seqs): + batch_ws = my_windows[bi:bi + batch_seqs] + bsz = len(batch_ws) + x_batch = torch.zeros(bsz, seq_len, dtype=torch.int64, device=device) + y_batch = torch.zeros(bsz, seq_len, dtype=torch.int64, device=device) + wlens: list[int] = [] + for i, ws in enumerate(batch_ws): + end = min(ws + seq_len, total_tokens) + wlen = end - ws + wlens.append(wlen) + chunk = val_tokens[ws:end + 1].to(dtype=torch.int64, device=device) + x_batch[i, :wlen] = chunk[:-1] + y_batch[i, :wlen] = chunk[1:] + with torch.autocast(device_type="cuda", dtype=torch.bfloat16): + logits = compiled_logits(x_batch) + nll = F.cross_entropy( + logits.reshape(-1, logits.size(-1)).float(), + y_batch.reshape(-1), + reduction="none", + ).reshape(bsz, seq_len) + for i, ws in enumerate(batch_ws): + wlen = wlens[i] + s = 0 if ws == 0 else max(wlen - stride, 0) + scored_nll = nll[i, s:wlen].to(torch.float64) + loss_sum += scored_nll.sum() + token_count += float(wlen - s) + tgt = y_batch[i, s:wlen] + prev = x_batch[i, s:wlen] + tb = base_bytes_lut[tgt].to(torch.float64) + tb += (has_leading_space_lut[tgt] & ~is_boundary_token_lut[prev]).to(torch.float64) + byte_count += tb.sum() + if dist.is_available() and dist.is_initialized(): + dist.all_reduce(loss_sum, op=dist.ReduceOp.SUM) + dist.all_reduce(token_count, op=dist.ReduceOp.SUM) + dist.all_reduce(byte_count, op=dist.ReduceOp.SUM) + val_loss = (loss_sum / token_count).item() + bits_per_token = val_loss / math.log(2.0) + tokens_per_byte = token_count.item() / byte_count.item() + base_model.train() + return val_loss, bits_per_token * tokens_per_byte +def eval_val_sliding_ttt( + args: Hyperparameters, + base_model: nn.Module, + rank: int, + world_size: int, + device: torch.device, + val_tokens: Tensor, + base_bytes_lut: Tensor, + has_leading_space_lut: Tensor, + is_boundary_token_lut: Tensor, + stride: int, + batch_seqs: int = 32, + log0=print, +) -> tuple[float, float]: + seq_len = args.train_seq_len + total_tokens = val_tokens.numel() - 1 + ttt_chunk = args.ttt_chunk_tokens + window_starts = [ws for ws in range(0, total_tokens, stride) + if min(ws + seq_len, total_tokens) - ws >= stride or ws == 0] + num_chunks = (total_tokens + ttt_chunk - 1) // ttt_chunk + chunk_windows: list[list[int]] = [[] for _ in range(num_chunks)] + for ws in window_starts: + end = min(ws + seq_len, total_tokens) + wlen = end - ws + s = 0 if ws == 0 else max(wlen - stride, 0) + scored_start = ws + s + ci = min(scored_start // ttt_chunk, num_chunks - 1) + chunk_windows[ci].append(ws) + log0(f"ttt_sliding:start chunks={num_chunks} chunk_tokens={ttt_chunk} " + f"total_windows={len(window_starts)} stride={stride} " + f"ttt_lr={args.ttt_lr} ttt_epochs={args.ttt_epochs} " + f"freeze_blocks={args.ttt_freeze_blocks}") + loss_sum = torch.zeros((), device=device, dtype=torch.float64) + token_count = torch.zeros((), device=device, dtype=torch.float64) + byte_count = torch.zeros((), device=device, dtype=torch.float64) + frozen_block_ids = set(range(min(args.ttt_freeze_blocks, len(base_model.blocks)))) + ttt_params = [] + for name, p in base_model.named_parameters(): + freeze = False + for bi in frozen_block_ids: + if f"blocks.{bi}." in name: + freeze = True + break + if freeze: + p.requires_grad_(False) + else: + p.requires_grad_(True) + ttt_params.append(p) + log0(f"ttt_sliding:params unfrozen={sum(p.numel() for p in ttt_params)} " + f"frozen={sum(p.numel() for p in base_model.parameters() if not p.requires_grad)}") + optimizer = torch.optim.SGD(ttt_params, lr=args.ttt_lr, momentum=args.ttt_momentum) + t0 = time.perf_counter() + for ci in range(num_chunks): + windows = chunk_windows[ci] + if not windows: + continue + chunk_start = ci * ttt_chunk + chunk_end = min((ci + 1) * ttt_chunk, total_tokens) + my_s = (len(windows) * rank) // world_size + my_e = (len(windows) * (rank + 1)) // world_size + my_windows = windows[my_s:my_e] + base_model.eval() + with torch.inference_mode(): + for bi in range(0, len(my_windows), batch_seqs): + batch_ws = my_windows[bi:bi + batch_seqs] + bsz = len(batch_ws) + x_batch = torch.zeros(bsz, seq_len, dtype=torch.int64, device=device) + y_batch = torch.zeros(bsz, seq_len, dtype=torch.int64, device=device) + wlens: list[int] = [] + for i, ws in enumerate(batch_ws): + end = min(ws + seq_len, total_tokens) + wlen = end - ws + wlens.append(wlen) + chunk_tok = val_tokens[ws:end + 1].to(dtype=torch.int64, device=device) + x_batch[i, :wlen] = chunk_tok[:-1] + y_batch[i, :wlen] = chunk_tok[1:] + with torch.autocast(device_type="cuda", dtype=torch.bfloat16): + logits = base_model.forward_logits(x_batch) + nll = F.cross_entropy( + logits.reshape(-1, logits.size(-1)).float(), + y_batch.reshape(-1), reduction="none", + ).reshape(bsz, seq_len) + for i, ws in enumerate(batch_ws): + wlen = wlens[i] + s = 0 if ws == 0 else max(wlen - stride, 0) + scored_nll = nll[i, s:wlen].to(torch.float64) + loss_sum += scored_nll.sum() + token_count += float(wlen - s) + tgt, prev = y_batch[i, s:wlen], x_batch[i, s:wlen] + tb = base_bytes_lut[tgt].to(torch.float64) + tb += (has_leading_space_lut[tgt] & ~is_boundary_token_lut[prev]).to(torch.float64) + byte_count += tb.sum() + is_last_chunk = (ci == num_chunks - 1) + if not is_last_chunk and args.ttt_epochs > 0: + base_model.train() + chunk_seqs = (chunk_end - chunk_start) // seq_len + if chunk_seqs > 0: + cos_lr = args.ttt_lr * 0.5 * (1.0 + math.cos(math.pi * ci / max(num_chunks - 1, 1))) + for pg in optimizer.param_groups: + pg['lr'] = cos_lr + my_seq_s = (chunk_seqs * rank) // world_size + my_seq_e = (chunk_seqs * (rank + 1)) // world_size + my_chunk_seqs = my_seq_e - my_seq_s + for _ep in range(args.ttt_epochs): + for bs in range(0, my_chunk_seqs, args.ttt_batch_seqs): + be = min(bs + args.ttt_batch_seqs, my_chunk_seqs) + actual_bs = my_seq_s + bs + start_tok = chunk_start + actual_bs * seq_len + end_tok = chunk_start + (my_seq_s + be) * seq_len + 1 + if end_tok > val_tokens.numel(): + continue + local = val_tokens[start_tok:end_tok].to(device=device, dtype=torch.int64) + x = local[:-1].reshape(-1, seq_len) + y = local[1:].reshape(-1, seq_len) + optimizer.zero_grad(set_to_none=True) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16): + loss = base_model(x, y) + loss.backward() + if world_size > 1: + for p in ttt_params: + if p.grad is not None: + dist.all_reduce(p.grad, op=dist.ReduceOp.AVG) + torch.nn.utils.clip_grad_norm_(ttt_params, args.ttt_grad_clip) + optimizer.step() + if rank == 0 and (ci % 10 == 0 or ci == num_chunks - 1): + elapsed = time.perf_counter() - t0 + rl = loss_sum.item() / max(token_count.item(), 1) + rbpb = rl / math.log(2.0) * (token_count.item() / max(byte_count.item(), 1)) if token_count.item() > 0 else 0.0 + log0(f" ttt_chunk [{ci+1}/{num_chunks}] bpb={rbpb:.6f} time={elapsed:.1f}s") + if dist.is_available() and dist.is_initialized(): + dist.all_reduce(loss_sum, op=dist.ReduceOp.SUM) + dist.all_reduce(token_count, op=dist.ReduceOp.SUM) + dist.all_reduce(byte_count, op=dist.ReduceOp.SUM) + val_loss = (loss_sum / token_count).item() + val_bpb = val_loss / math.log(2.0) * (token_count.item() / byte_count.item()) + for p in base_model.parameters(): + p.requires_grad_(True) + base_model.eval() + log0(f"ttt_sliding:done val_loss={val_loss:.6f} val_bpb={val_bpb:.6f} " + f"elapsed={time.perf_counter() - t0:.1f}s") + return val_loss, val_bpb +_INT6_MODE = bool(int(os.environ.get("INT6_QAT", "1"))) +_INT3_MODE = bool(int(os.environ.get("INT3_QUANT", "0"))) # Ternary quantization +_QUANT_OVERRIDE = int(os.environ.get("QUANT_BITS", "0")) # Override: 3,4,5,6,8 +if _QUANT_OVERRIDE > 0: + _QUANT_BITS = _QUANT_OVERRIDE + _QUANT_MAX_VAL = (1 << (_QUANT_BITS - 1)) - 1 # e.g. int5 -> 15, int4 -> 7 +elif _INT3_MODE: + _QUANT_MAX_VAL = 3 # Ternary: -3, -2, -1, 0, 1, 2, 3 + _QUANT_BITS = 3 +elif _INT6_MODE: + _QUANT_MAX_VAL = 31 + _QUANT_BITS = 6 +else: + _QUANT_MAX_VAL = 127 + _QUANT_BITS = 8 CONTROL_TENSOR_NAME_PATTERNS = tuple( pattern for pattern in os.environ.get( "CONTROL_TENSOR_NAME_PATTERNS", - "attn_scale,attn_scales,mlp_scale,mlp_scales,resid_mix,resid_mixes,q_gain,skip_weight,skip_weights", + "attn_scale,attn_scales,mlp_scale,mlp_scales,resid_mix,resid_mixes,q_gain,skip_weight,skip_weights,smear", ).split(",") if pattern ) @@ -306,10 +490,8 @@ def eval_val( INT8_PER_ROW_SCALE_DTYPE = torch.float16 INT8_CLIP_PERCENTILE = 99.99984 INT8_CLIP_Q = INT8_CLIP_PERCENTILE / 100.0 - def tensor_nbytes(t: Tensor) -> int: return int(t.numel()) * int(t.element_size()) - def keep_float_tensor(name: str, t: Tensor, passthrough_orig_dtypes: dict[str, str]) -> Tensor: if any(pattern in name for pattern in INT8_KEEP_FLOAT_FP32_NAME_PATTERNS): return t.float().contiguous() @@ -317,34 +499,44 @@ def keep_float_tensor(name: str, t: Tensor, passthrough_orig_dtypes: dict[str, s passthrough_orig_dtypes[name] = str(t.dtype).removeprefix("torch.") return t.to(dtype=INT8_KEEP_FLOAT_STORE_DTYPE).contiguous() return t - +_GPTQ_LITE_PERCENTILES = [0.9999, 0.99995, 0.99999, 0.999995, 1.0] def quantize_float_tensor(t: Tensor) -> tuple[Tensor, Tensor]: t32 = t.float() + qmax = _QUANT_MAX_VAL + if t32.ndim == 2 and t32.numel() > 0: + abs_vals = t32.abs() + best_q = None + best_scale = None + best_mse = None + for pct in _GPTQ_LITE_PERCENTILES: + if pct >= 1.0: + clip_abs = abs_vals.amax(dim=1) + else: + clip_abs = torch.quantile(abs_vals, pct, dim=1) + clip_abs = clip_abs.clamp_min(1e-8) + scale = (clip_abs / qmax).clamp_min(1.0 / qmax) + clipped = torch.maximum(torch.minimum(t32, clip_abs[:, None]), -clip_abs[:, None]) + q = torch.clamp(torch.round(clipped / scale[:, None]), -qmax, qmax) + recon = q * scale[:, None] + mse = (t32 - recon).square().mean(dim=1) + if best_mse is None: + best_mse = mse + best_q = q + best_scale = scale + else: + improved = mse < best_mse + if improved.any(): + best_mse = torch.where(improved, mse, best_mse) + best_q = torch.where(improved[:, None], q, best_q) + best_scale = torch.where(improved, scale, best_scale) + return best_q.to(torch.int8).contiguous(), best_scale.to(dtype=INT8_PER_ROW_SCALE_DTYPE).contiguous() if t32.ndim == 2: - # Matrices get one scale per row, which usually tracks output-channel - # ranges much better than a single tensor-wide scale. - clip_abs = ( - torch.quantile(t32.abs(), INT8_CLIP_Q, dim=1) - if t32.numel() - else torch.empty((t32.shape[0],), dtype=torch.float32) - ) - clipped = torch.maximum(torch.minimum(t32, clip_abs[:, None]), -clip_abs[:, None]) - scale = (clip_abs / 127.0).clamp_min(1.0 / 127.0) - q = torch.clamp(torch.round(clipped / scale[:, None]), -127, 127).to(torch.int8).contiguous() - return q, scale.to(dtype=INT8_PER_ROW_SCALE_DTYPE).contiguous() - - # Vectors / scalars use a simpler per-tensor scale. + return torch.empty_like(t32, dtype=torch.int8), torch.empty((t32.shape[0],), dtype=INT8_PER_ROW_SCALE_DTYPE) clip_abs = float(torch.quantile(t32.abs().flatten(), INT8_CLIP_Q).item()) if t32.numel() else 0.0 - scale = torch.tensor(clip_abs / 127.0 if clip_abs > 0 else 1.0, dtype=torch.float32) - q = torch.clamp(torch.round(torch.clamp(t32, -clip_abs, clip_abs) / scale), -127, 127).to(torch.int8).contiguous() + scale = torch.tensor(clip_abs / qmax if clip_abs > 0 else 1.0, dtype=torch.float32) + q = torch.clamp(torch.round(torch.clamp(t32, -clip_abs, clip_abs) / scale), -qmax, qmax).to(torch.int8).contiguous() return q, scale - def quantize_state_dict_int8(state_dict: dict[str, Tensor]): - # Single supported clean-script export format: - # - per-row int8 for 2D float tensors - # - per-tensor int8 for other float tensors - # - exact passthrough for non-floats - # - passthrough for small float tensors, stored as fp16 to save bytes quantized: dict[str, Tensor] = {} scales: dict[str, Tensor] = {} dtypes: dict[str, str] = {} @@ -355,27 +547,21 @@ def quantize_state_dict_int8(state_dict: dict[str, Tensor]): ("param_count", "num_tensors", "num_float_tensors", "num_nonfloat_tensors", "baseline_tensor_bytes", "int8_payload_bytes"), 0, ) - for name, tensor in state_dict.items(): t = tensor.detach().to("cpu").contiguous() stats["param_count"] += int(t.numel()) stats["num_tensors"] += 1 stats["baseline_tensor_bytes"] += tensor_nbytes(t) - if not t.is_floating_point(): stats["num_nonfloat_tensors"] += 1 passthrough[name] = t stats["int8_payload_bytes"] += tensor_nbytes(t) continue - - # Small float tensors are cheap enough to keep directly. We still downcast - # fp32/bf16 passthrough tensors to fp16 so metadata does not dominate size. if t.numel() <= INT8_KEEP_FLOAT_MAX_NUMEL: kept = keep_float_tensor(name, t, passthrough_orig_dtypes) passthrough[name] = kept stats["int8_payload_bytes"] += tensor_nbytes(kept) continue - stats["num_float_tensors"] += 1 q, s = quantize_float_tensor(t) if s.ndim > 0: @@ -384,9 +570,8 @@ def quantize_state_dict_int8(state_dict: dict[str, Tensor]): scales[name] = s dtypes[name] = str(t.dtype).removeprefix("torch.") stats["int8_payload_bytes"] += tensor_nbytes(q) + tensor_nbytes(s) - obj: dict[str, object] = { - "__quant_format__": "int8_clean_per_row_v1", + "__quant_format__": f"int{_QUANT_BITS}_clean_per_row_v1", "quantized": quantized, "scales": scales, "dtypes": dtypes, @@ -397,7 +582,6 @@ def quantize_state_dict_int8(state_dict: dict[str, Tensor]): if passthrough_orig_dtypes: obj["passthrough_orig_dtypes"] = passthrough_orig_dtypes return obj, stats - def dequantize_state_dict_int8(obj: dict[str, object]) -> dict[str, Tensor]: out: dict[str, Tensor] = {} qmeta = obj.get("qmeta", {}) @@ -407,30 +591,21 @@ def dequantize_state_dict_int8(obj: dict[str, object]) -> dict[str, Tensor]: s = obj["scales"][name] if qmeta.get(name, {}).get("scheme") == "per_row" or s.ndim > 0: s = s.to(dtype=torch.float32) - # Broadcast the saved row scale back across trailing dimensions. out[name] = (q.float() * s.view(q.shape[0], *([1] * (q.ndim - 1)))).to(dtype=dtype).contiguous() else: scale = float(s.item()) out[name] = (q.float() * scale).to(dtype=dtype).contiguous() for name, t in obj["passthrough"].items(): - # Restore small tensors, undoing the temporary fp16 storage cast if needed. out_t = t.detach().to("cpu").contiguous() orig_dtype = passthrough_orig_dtypes.get(name) if isinstance(orig_dtype, str): out_t = out_t.to(dtype=getattr(torch, orig_dtype)).contiguous() out[name] = out_t return out - - -# ----------------------------- -# DATA LOADING -# ----------------------------- - def load_data_shard(file: Path) -> Tensor: header_bytes = 256 * np.dtype(" Tensor: if tokens_np.size != num_tokens: raise ValueError(f"Short read for {file}") return torch.from_numpy(tokens_np.astype(np.uint16, copy=False)) - - class TokenStream: - # Reads shards sequentially and wraps around forever. The training loop therefore - # has deterministic, simple streaming behavior with no sampling or workers. def __init__(self, pattern: str): self.files = [Path(p) for p in sorted(glob.glob(pattern))] if not self.files: @@ -453,12 +624,10 @@ def __init__(self, pattern: str): self.file_idx = 0 self.tokens = load_data_shard(self.files[0]) self.pos = 0 - def _advance_file(self) -> None: self.file_idx = (self.file_idx + 1) % len(self.files) self.tokens = load_data_shard(self.files[self.file_idx]) self.pos = 0 - def take(self, n: int) -> Tensor: chunks: list[Tensor] = [] remaining = n @@ -472,17 +641,12 @@ def take(self, n: int) -> Tensor: self.pos += k remaining -= k return chunks[0] if len(chunks) == 1 else torch.cat(chunks) - - class DistributedTokenLoader: - # Each call consumes a contiguous chunk from the shared token stream, then slices out - # one disjoint span per rank. The extra "+1" token lets us build (x, y) by shifting. def __init__(self, pattern: str, rank: int, world_size: int, device: torch.device): self.rank = rank self.world_size = world_size self.device = device self.stream = TokenStream(pattern) - def next_batch(self, global_tokens: int, seq_len: int, grad_accum_steps: int) -> tuple[Tensor, Tensor]: local_tokens = global_tokens // (self.world_size * grad_accum_steps) per_rank_span = local_tokens + 1 @@ -492,37 +656,45 @@ def next_batch(self, global_tokens: int, seq_len: int, grad_accum_steps: int) -> x = local[:-1].reshape(-1, seq_len) y = local[1:].reshape(-1, seq_len) return x.to(self.device, non_blocking=True), y.to(self.device, non_blocking=True) - -# ----------------------------- -# TRANSFORMER MODULES -# ----------------------------- - class RMSNorm(nn.Module): def __init__(self, eps: float | None = None): super().__init__() self.eps = eps - def forward(self, x: Tensor) -> Tensor: return F.rms_norm(x, (x.size(-1),), eps=self.eps) - - class CastedLinear(nn.Linear): - # Keep weights in fp32 for optimizer/state quality, cast at matmul time for bf16 compute. + _int6_qat: bool = False def forward(self, x: Tensor) -> Tensor: + w = self.weight.to(x.dtype) + if CastedLinear._int6_qat and self.training: + w = int6_ste(w) bias = self.bias.to(x.dtype) if self.bias is not None else None - return F.linear(x, self.weight.to(x.dtype), bias) - - + return F.linear(x, w, bias) def restore_low_dim_params_to_fp32(module: nn.Module) -> None: - # Keep small/control parameters in fp32 even when the model body runs in bf16. with torch.no_grad(): for name, param in module.named_parameters(): if (param.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS)) and param.dtype != torch.float32: param.data = param.data.float() - - +class BigramHash(nn.Module): + def __init__(self, num_buckets: int, dim: int): + super().__init__() + self.num_buckets = num_buckets + self.emb = nn.Embedding(num_buckets, dim) + nn.init.zeros_(self.emb.weight) + def forward(self, input_ids: Tensor) -> Tensor: + prev = F.pad(input_ids[:, :-1], (1, 0), value=0) + bigram_hash = (prev * 31 + input_ids) % self.num_buckets + return self.emb(bigram_hash) +class SmearGate(nn.Module): + def __init__(self, dim: int): + super().__init__() + self.gate = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + def forward(self, x: Tensor) -> Tensor: + g = torch.sigmoid(self.gate.to(dtype=x.dtype))[None, None, :] + positions = torch.arange(1, x.shape[1] + 1, device=x.device, dtype=x.dtype).view(1, -1, 1) + smooth = torch.cumsum(x, dim=1) / positions + return g * x + (1 - g) * smooth class Rotary(nn.Module): - # Caches cos/sin tables per sequence length on the current device. def __init__(self, dim: int, base: float = 10000.0): super().__init__() inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.float32) / dim)) @@ -530,7 +702,6 @@ def __init__(self, dim: int, base: float = 10000.0): self._seq_len_cached = 0 self._cos_cached: Tensor | None = None self._sin_cached: Tensor | None = None - def forward(self, seq_len: int, device: torch.device, dtype: torch.dtype) -> tuple[Tensor, Tensor]: if ( self._cos_cached is None @@ -544,14 +715,10 @@ def forward(self, seq_len: int, device: torch.device, dtype: torch.dtype) -> tup self._sin_cached = freqs.sin()[None, None, :, :] self._seq_len_cached = seq_len return self._cos_cached.to(dtype=dtype), self._sin_cached.to(dtype=dtype) - - def apply_rotary_emb(x: Tensor, cos: Tensor, sin: Tensor) -> Tensor: half = x.size(-1) // 2 x1, x2 = x[..., :half], x[..., half:] return torch.cat((x1 * cos + x2 * sin, x1 * (-sin) + x2 * cos), dim=-1) - - class CausalSelfAttention(nn.Module): def __init__( self, @@ -560,6 +727,8 @@ def __init__( num_kv_heads: int, rope_base: float, qk_gain_init: float, + rope_partial_dims: int = 0, + use_xsa: bool = False, ): super().__init__() if dim % num_heads != 0: @@ -578,8 +747,10 @@ def __init__( self.proj = CastedLinear(dim, dim, bias=False) self.proj._zero_init = True self.q_gain = nn.Parameter(torch.full((num_heads,), qk_gain_init, dtype=torch.float32)) - self.rotary = Rotary(self.head_dim, base=rope_base) - + rope_dim = rope_partial_dims if 0 < rope_partial_dims < self.head_dim else self.head_dim + self.rope_dim = rope_dim + self.rotary = Rotary(rope_dim, base=rope_base) + self.use_xsa = use_xsa def forward(self, x: Tensor) -> Tensor: bsz, seqlen, dim = x.shape q = self.c_q(x).reshape(bsz, seqlen, self.num_heads, self.head_dim).transpose(1, 2) @@ -588,8 +759,16 @@ def forward(self, x: Tensor) -> Tensor: q = F.rms_norm(q, (q.size(-1),)) k = F.rms_norm(k, (k.size(-1),)) cos, sin = self.rotary(seqlen, x.device, q.dtype) - q = apply_rotary_emb(q, cos, sin) - k = apply_rotary_emb(k, cos, sin) + if self.rope_dim < self.head_dim: + q_rope, q_pass = q[..., :self.rope_dim], q[..., self.rope_dim:] + k_rope, k_pass = k[..., :self.rope_dim], k[..., self.rope_dim:] + q_rope = apply_rotary_emb(q_rope, cos, sin) + k_rope = apply_rotary_emb(k_rope, cos, sin) + q = torch.cat([q_rope, q_pass], dim=-1) + k = torch.cat([k_rope, k_pass], dim=-1) + else: + q = apply_rotary_emb(q, cos, sin) + k = apply_rotary_emb(k, cos, sin) q = q * self.q_gain.to(dtype=q.dtype)[None, :, None, None] y = F.scaled_dot_product_attention( q, @@ -599,24 +778,55 @@ def forward(self, x: Tensor) -> Tensor: is_causal=True, enable_gqa=(self.num_kv_heads != self.num_heads), ) + if self.use_xsa: + if self.num_kv_heads != self.num_heads: + repeats = self.num_heads // self.num_kv_heads + v_expanded = v.repeat_interleave(repeats, dim=1) + else: + v_expanded = v + y = y - v_expanded y = y.transpose(1, 2).contiguous().reshape(bsz, seqlen, dim) return self.proj(y) - - class MLP(nn.Module): - # relu^2 MLP from the original modded-nanogpt setup def __init__(self, dim: int, mlp_mult: int): super().__init__() hidden = mlp_mult * dim self.fc = CastedLinear(dim, hidden, bias=False) self.proj = CastedLinear(hidden, dim, bias=False) self.proj._zero_init = True - def forward(self, x: Tensor) -> Tensor: - x = torch.relu(self.fc(x)) + x = F.leaky_relu(self.fc(x), negative_slope=0.5) return self.proj(x.square()) - - +class MoEMLP(nn.Module): + def __init__(self, dim: int, mlp_mult: int, num_experts: int): + super().__init__() + self.num_experts = num_experts + expert_hidden = (mlp_mult * dim) // num_experts + self.experts = nn.ModuleList([ + nn.ModuleDict({ + "fc": CastedLinear(dim, expert_hidden, bias=False), + "proj": CastedLinear(expert_hidden, dim, bias=False), + }) + for _ in range(num_experts) + ]) + for e in self.experts: + e["proj"]._zero_init = True + self.router = nn.Linear(dim, num_experts, bias=False) + def forward(self, x: Tensor) -> Tensor: + bsz, seqlen, dim = x.shape + logits = self.router(x.detach()) + weights = torch.softmax(logits, dim=-1) + indices = weights.argmax(dim=-1) + out = torch.zeros_like(x) + for i, expert in enumerate(self.experts): + mask = (indices == i) + if not mask.any(): + continue + tokens = x[mask] + h = F.leaky_relu(expert["fc"](tokens), negative_slope=0.5) + h = expert["proj"](h.square()) + out[mask] = h + return out class Block(nn.Module): def __init__( self, @@ -626,25 +836,41 @@ def __init__( mlp_mult: int, rope_base: float, qk_gain_init: float, + use_smeargate: bool = False, + rope_partial_dims: int = 0, + use_xsa: bool = False, + ln_scale_factor: float = 1.0, + moe_num_experts: int = 0, ): super().__init__() self.attn_norm = RMSNorm() self.mlp_norm = RMSNorm() - self.attn = CausalSelfAttention(dim, num_heads, num_kv_heads, rope_base, qk_gain_init) - self.mlp = MLP(dim, mlp_mult) + self.attn = CausalSelfAttention( + dim, num_heads, num_kv_heads, rope_base, qk_gain_init, + rope_partial_dims=rope_partial_dims, + use_xsa=use_xsa, + ) + self.mlp = MoEMLP(dim, mlp_mult, moe_num_experts) if moe_num_experts > 1 else MLP(dim, mlp_mult) self.attn_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) self.mlp_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) self.resid_mix = nn.Parameter(torch.stack((torch.ones(dim), torch.zeros(dim))).float()) - + self.smear = SmearGate(dim) if use_smeargate else None + self.ln_scale_factor = ln_scale_factor def forward(self, x: Tensor, x0: Tensor) -> Tensor: mix = self.resid_mix.to(dtype=x.dtype) x = mix[0][None, None, :] * x + mix[1][None, None, :] * x0 - attn_out = self.attn(self.attn_norm(x)) + normed = self.attn_norm(x) + if self.ln_scale_factor != 1.0: + normed = normed * self.ln_scale_factor + attn_out = self.attn(normed) x = x + self.attn_scale.to(dtype=x.dtype)[None, None, :] * attn_out - x = x + self.mlp_scale.to(dtype=x.dtype)[None, None, :] * self.mlp(self.mlp_norm(x)) + mlp_normed = self.mlp_norm(x) + if self.ln_scale_factor != 1.0: + mlp_normed = mlp_normed * self.ln_scale_factor + x = x + self.mlp_scale.to(dtype=x.dtype)[None, None, :] * self.mlp(mlp_normed) + if self.smear is not None: + x = self.smear(x) return x - - class GPT(nn.Module): def __init__( self, @@ -659,6 +885,16 @@ def __init__( logit_softcap: float, rope_base: float, qk_gain_init: float, + bigramhash_buckets: int = 0, + use_smeargate: bool = False, + unet_skips: bool = True, + rope_partial_dims: int = 0, + ln_scale: bool = False, + xsa_layers: int = 0, + recurrence_repeats: int = 0, + recurrence_unique_head: int = 2, + recurrence_unique_tail: int = 2, + moe_num_experts: int = 0, ): super().__init__() if logit_softcap <= 0.0: @@ -667,51 +903,97 @@ def __init__( self.tied_embed_init_std = tied_embed_init_std self.logit_softcap = logit_softcap self.tok_emb = nn.Embedding(vocab_size, model_dim) - self.num_encoder_layers = num_layers // 2 - self.num_decoder_layers = num_layers - self.num_encoder_layers - self.num_skip_weights = min(self.num_encoder_layers, self.num_decoder_layers) - self.skip_weights = nn.Parameter(torch.ones(self.num_skip_weights, model_dim, dtype=torch.float32)) - self.blocks = nn.ModuleList( - [ - Block( - model_dim, - num_heads, - num_kv_heads, - mlp_mult, - rope_base, - qk_gain_init, + self.bigram_hash = BigramHash(bigramhash_buckets, model_dim) if bigramhash_buckets > 0 else None + self.recurrence_repeats = recurrence_repeats + if recurrence_repeats > 0: + self.num_encoder_layers = 0 + self.num_decoder_layers = 0 + self.num_skip_weights = 0 + self.skip_weights = nn.Parameter(torch.zeros(0, dtype=torch.float32)) + n_head = recurrence_unique_head + n_tail = recurrence_unique_tail + effective_layers = n_head + recurrence_repeats + n_tail + def make_block(layer_idx: int, eff_total: int) -> Block: + xsa_start_idx = eff_total - xsa_layers if xsa_layers > 0 else eff_total + return Block( + model_dim, num_heads, num_kv_heads, mlp_mult, rope_base, qk_gain_init, + use_smeargate=use_smeargate, + rope_partial_dims=rope_partial_dims, + use_xsa=(layer_idx >= xsa_start_idx), + ln_scale_factor=(1.0 / math.sqrt(layer_idx + 1)) if ln_scale else 1.0, + moe_num_experts=moe_num_experts, ) - for i in range(num_layers) - ] - ) + self.head_blocks = nn.ModuleList([make_block(i, effective_layers) for i in range(n_head)]) + self.shared_block = make_block(n_head, effective_layers) + self.recurrence_gate = nn.Parameter(torch.full((recurrence_repeats,), 0.5, dtype=torch.float32)) + self.tail_blocks = nn.ModuleList([ + make_block(n_head + recurrence_repeats + i, effective_layers) for i in range(n_tail) + ]) + self.blocks = nn.ModuleList() + else: + if unet_skips: + self.num_encoder_layers = num_layers // 2 + self.num_decoder_layers = num_layers - self.num_encoder_layers + self.num_skip_weights = min(self.num_encoder_layers, self.num_decoder_layers) + else: + self.num_encoder_layers = num_layers + self.num_decoder_layers = 0 + self.num_skip_weights = 0 + self.skip_weights = nn.Parameter(torch.ones(self.num_skip_weights, model_dim, dtype=torch.float32)) + xsa_start = num_layers - xsa_layers if xsa_layers > 0 else num_layers + self.blocks = nn.ModuleList( + [ + Block( + model_dim, num_heads, num_kv_heads, mlp_mult, rope_base, qk_gain_init, + use_smeargate=use_smeargate, + rope_partial_dims=rope_partial_dims, + use_xsa=(i >= xsa_start), + ln_scale_factor=(1.0 / math.sqrt(i + 1)) if ln_scale else 1.0, + moe_num_experts=moe_num_experts, + ) + for i in range(num_layers) + ] + ) + self.head_blocks = nn.ModuleList() + self.shared_block = None + self.recurrence_gate = None + self.tail_blocks = nn.ModuleList() self.final_norm = RMSNorm() self.lm_head = None if tie_embeddings else CastedLinear(model_dim, vocab_size, bias=False) if self.lm_head is not None: self.lm_head._zero_init = True self._init_weights() - def _init_weights(self) -> None: if self.tie_embeddings: nn.init.normal_(self.tok_emb.weight, mean=0.0, std=self.tied_embed_init_std) for module in self.modules(): if isinstance(module, nn.Linear) and getattr(module, "_zero_init", False): nn.init.zeros_(module.weight) - def forward(self, input_ids: Tensor, target_ids: Tensor) -> Tensor: x = self.tok_emb(input_ids) + if self.bigram_hash is not None: + x = x + self.bigram_hash(input_ids) x = F.rms_norm(x, (x.size(-1),)) x0 = x - skips: list[Tensor] = [] - - # First half stores skips; second half reuses them in reverse order. - for i in range(self.num_encoder_layers): - x = self.blocks[i](x, x0) - skips.append(x) - for i in range(self.num_decoder_layers): - if skips: - x = x + self.skip_weights[i].to(dtype=x.dtype)[None, None, :] * skips.pop() - x = self.blocks[self.num_encoder_layers + i](x, x0) - + if self.recurrence_repeats > 0: + for block in self.head_blocks: + x = block(x, x0) + for r in range(self.recurrence_repeats): + gate = torch.sigmoid(self.recurrence_gate[r]).to(dtype=x.dtype) + residual = x + x = self.shared_block(x, x0) + x = gate * x + (1.0 - gate) * residual + for block in self.tail_blocks: + x = block(x, x0) + else: + skips: list[Tensor] = [] + for i in range(self.num_encoder_layers): + x = self.blocks[i](x, x0) + skips.append(x) + for i in range(self.num_decoder_layers): + if skips: + x = x + self.skip_weights[i].to(dtype=x.dtype)[None, None, :] * skips.pop() + x = self.blocks[self.num_encoder_layers + i](x, x0) x = self.final_norm(x).reshape(-1, x.size(-1)) targets = target_ids.reshape(-1) if self.tie_embeddings: @@ -721,24 +1003,48 @@ def forward(self, input_ids: Tensor, target_ids: Tensor) -> Tensor: raise RuntimeError("lm_head is required when tie_embeddings=False") logits_proj = self.lm_head(x) logits = self.logit_softcap * torch.tanh(logits_proj / self.logit_softcap) + if getattr(self, '_return_logits', False): + return logits.float().reshape(input_ids.shape[0], input_ids.shape[1], -1) return F.cross_entropy(logits.float(), targets, reduction="mean") - - -# ----------------------------- -# TRAINING -# ----------------------------- - + def forward_logits(self, input_ids: Tensor) -> Tensor: + x = self.tok_emb(input_ids) + if self.bigram_hash is not None: + x = x + self.bigram_hash(input_ids) + x = F.rms_norm(x, (x.size(-1),)) + x0 = x + if self.recurrence_repeats > 0: + for block in self.head_blocks: + x = block(x, x0) + for r in range(self.recurrence_repeats): + gate = torch.sigmoid(self.recurrence_gate[r]).to(dtype=x.dtype) + residual = x + x = self.shared_block(x, x0) + x = gate * x + (1.0 - gate) * residual + for block in self.tail_blocks: + x = block(x, x0) + else: + skips: list[Tensor] = [] + for i in range(self.num_encoder_layers): + x = self.blocks[i](x, x0) + skips.append(x) + for i in range(self.num_decoder_layers): + if skips: + x = x + self.skip_weights[i].to(dtype=x.dtype)[None, None, :] * skips.pop() + x = self.blocks[self.num_encoder_layers + i](x, x0) + x = self.final_norm(x) + if self.tie_embeddings: + logits_proj = F.linear(x, self.tok_emb.weight) + else: + if self.lm_head is None: + raise RuntimeError("lm_head is required when tie_embeddings=False") + logits_proj = self.lm_head(x) + return self.logit_softcap * torch.tanh(logits_proj / self.logit_softcap) def main() -> None: global zeropower_via_newtonschulz5 - code = Path(__file__).read_text(encoding="utf-8") args = Hyperparameters() - zeropower_via_newtonschulz5 = torch.compile(zeropower_via_newtonschulz5) - - # ----------------------------- - # DISTRIBUTED + CUDA SETUP - # ----------------------------- - + if not int(os.environ.get("TORCHDYNAMO_DISABLE", "0")): + zeropower_via_newtonschulz5 = torch.compile(zeropower_via_newtonschulz5) distributed = "RANK" in os.environ and "WORLD_SIZE" in os.environ rank = int(os.environ.get("RANK", "0")) world_size = int(os.environ.get("WORLD_SIZE", "1")) @@ -757,23 +1063,19 @@ def main() -> None: dist.init_process_group(backend="nccl", device_id=device) dist.barrier() master_process = rank == 0 - - # Fast math knobs torch.backends.cuda.matmul.allow_tf32 = True torch.backends.cudnn.allow_tf32 = True from torch.backends.cuda import enable_cudnn_sdp, enable_flash_sdp, enable_math_sdp, enable_mem_efficient_sdp - + _use_flash = sys.platform != "win32" enable_cudnn_sdp(False) - enable_flash_sdp(True) + enable_flash_sdp(_use_flash) enable_mem_efficient_sdp(False) - enable_math_sdp(False) - + enable_math_sdp(not _use_flash) logfile = None if master_process: os.makedirs("logs", exist_ok=True) logfile = f"logs/{args.run_id}.txt" print(logfile) - def log0(msg: str, console: bool = True) -> None: if not master_process: return @@ -782,7 +1084,6 @@ def log0(msg: str, console: bool = True) -> None: if logfile is not None: with open(logfile, "a", encoding="utf-8") as f: print(msg, file=f) - log0(code, console=False) log0("=" * 100, console=False) log0(f"Running Python {sys.version}", console=False) @@ -792,16 +1093,10 @@ def log0(msg: str, console: bool = True) -> None: console=False, ) log0("=" * 100, console=False) - - # ----------------------------- - # TOKENIZER + VALIDATION METRIC SETUP - # ----------------------------- - random.seed(args.seed) np.random.seed(args.seed) torch.manual_seed(args.seed) torch.cuda.manual_seed_all(args.seed) - if not args.tokenizer_path.endswith(".model"): raise ValueError(f"Script only setup for SentencePiece .model file: {args.tokenizer_path}") sp = spm.SentencePieceProcessor(model_file=args.tokenizer_path) @@ -818,11 +1113,7 @@ def log0(msg: str, console: bool = True) -> None: log0(f"val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path={args.tokenizer_path}") log0(f"train_loader:dataset:{dataset_dir.name} train_shards:{actual_train_files}") log0(f"val_loader:shards pattern={args.val_files} tokens:{val_tokens.numel() - 1}") - - # ----------------------------- - # MODEL + OPTIMIZER SETUP - # ----------------------------- - + CastedLinear._int6_qat = args.int6_qat base_model = GPT( vocab_size=args.vocab_size, num_layers=args.num_layers, @@ -835,20 +1126,36 @@ def log0(msg: str, console: bool = True) -> None: logit_softcap=args.logit_softcap, rope_base=args.rope_base, qk_gain_init=args.qk_gain_init, + bigramhash_buckets=args.bigramhash_buckets, + use_smeargate=args.smeargate, + unet_skips=args.unet_skips, + rope_partial_dims=args.rope_partial_dims, + ln_scale=args.ln_scale, + xsa_layers=args.xsa_layers, + recurrence_repeats=args.recurrence_repeats, + recurrence_unique_head=args.recurrence_unique_head, + recurrence_unique_tail=args.recurrence_unique_tail, + moe_num_experts=args.moe_num_experts, ).to(device).bfloat16() for module in base_model.modules(): if isinstance(module, CastedLinear): module.float() restore_low_dim_params_to_fp32(base_model) - compiled_model = torch.compile(base_model, dynamic=False, fullgraph=True) + if int(os.environ.get("TORCHDYNAMO_DISABLE", "0")): + compiled_model = base_model + else: + compiled_model = torch.compile(base_model, dynamic=False, fullgraph=True) model: nn.Module = DDP(compiled_model, device_ids=[local_rank], broadcast_buffers=False) if distributed else compiled_model - - # Optimizer split: - # - token embedding (Adam) uses EMBED_LR - # - untied lm_head (Adam) uses HEAD_LR - # - matrix params in transformer blocks use MATRIX_LR via Muon - # - vectors/scalars use SCALAR_LR via Adam - block_named_params = list(base_model.blocks.named_parameters()) + if args.recurrence_repeats > 0: + block_named_params = ( + list(base_model.head_blocks.named_parameters()) + + list(base_model.shared_block.named_parameters()) + + list(base_model.tail_blocks.named_parameters()) + ) + if base_model.recurrence_gate is not None: + block_named_params.append(("recurrence_gate", base_model.recurrence_gate)) + else: + block_named_params = list(base_model.blocks.named_parameters()) matrix_params = [ p for name, p in block_named_params @@ -862,8 +1169,11 @@ def log0(msg: str, console: bool = True) -> None: if base_model.skip_weights.numel() > 0: scalar_params.append(base_model.skip_weights) token_lr = args.tied_embed_lr if args.tie_embeddings else args.embed_lr + embed_params: list[nn.Parameter] = [base_model.tok_emb.weight] + if base_model.bigram_hash is not None: + embed_params.append(base_model.bigram_hash.emb.weight) optimizer_tok = torch.optim.Adam( - [{"params": [base_model.tok_emb.weight], "lr": token_lr, "base_lr": token_lr}], + [{"params": embed_params, "lr": token_lr, "base_lr": token_lr}], betas=(args.beta1, args.beta2), eps=args.adam_eps, fused=True, @@ -891,11 +1201,10 @@ def log0(msg: str, console: bool = True) -> None: fused=True, ) optimizers.insert(1, optimizer_head) - n_params = sum(p.numel() for p in base_model.parameters()) log0(f"model_params:{n_params}") log0(f"world_size:{world_size} grad_accum_steps:{grad_accum_steps}") - log0("sdp_backends:cudnn=False flash=True mem_efficient=False math=False") + log0(f"sdp_backends:cudnn=False flash={_use_flash} mem_efficient=False math={not _use_flash}") log0(f"attention_mode:gqa num_heads:{args.num_heads} num_kv_heads:{args.num_kv_heads}") log0( f"tie_embeddings:{args.tie_embeddings} embed_lr:{token_lr} " @@ -908,19 +1217,26 @@ def log0(msg: str, console: bool = True) -> None: f"max_wallclock_seconds:{args.max_wallclock_seconds:.3f}" ) log0(f"seed:{args.seed}") - - # ----------------------------- - # DATA LOADER & MODEL WARMUP - # ----------------------------- - + log0( + f"meta_baseline: int6_qat:{args.int6_qat} bigramhash:{args.bigramhash_buckets} " + f"smeargate:{args.smeargate} ema_decay:{args.ema_decay} unet_skips:{args.unet_skips} " + f"warmdown_iters:{args.warmdown_iters} compression:{'zstd-22' if _HAS_ZSTD else 'zlib-9'}" + ) + log0( + f"strong_gains: leaky_relu_sq:True rope_partial_dims:{args.rope_partial_dims} " + f"ln_scale:{args.ln_scale} xsa_layers:{args.xsa_layers}" + ) + if args.recurrence_repeats > 0: + eff = args.recurrence_unique_head + args.recurrence_repeats + args.recurrence_unique_tail + log0( + f"depth_recurrence: head={args.recurrence_unique_head} repeats={args.recurrence_repeats} " + f"tail={args.recurrence_unique_tail} effective_layers={eff}" + ) train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) - def zero_grad_all() -> None: for opt in optimizers: opt.zero_grad(set_to_none=True) - max_wallclock_ms = 1000.0 * args.max_wallclock_seconds if args.max_wallclock_seconds > 0 else None - def lr_mul(step: int, elapsed_ms: float) -> float: if args.warmdown_iters <= 0: return 1.0 @@ -931,9 +1247,6 @@ def lr_mul(step: int, elapsed_ms: float) -> float: warmdown_ms = args.warmdown_iters * step_ms remaining_ms = max(max_wallclock_ms - elapsed_ms, 0.0) return remaining_ms / max(warmdown_ms, 1e-9) if remaining_ms <= warmdown_ms else 1.0 - - # Warmup primes the compiled forward/backward/optimizer paths, then we restore the - # initial weights/optimizer state so measured training starts from the true init. if args.warmup_steps > 0: initial_model_state = {name: tensor.detach().cpu().clone() for name, tensor in base_model.state_dict().items()} initial_optimizer_states = [copy.deepcopy(opt.state_dict()) for opt in optimizers] @@ -959,24 +1272,28 @@ def lr_mul(step: int, elapsed_ms: float) -> float: if distributed: model.require_backward_grad_sync = True train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) - - # ----------------------------- - # MAIN TRAINING LOOP - # ----------------------------- - + ema_state: dict[str, Tensor] | None = None + if args.ema_decay > 0: + ema_state = {name: param.data.clone() for name, param in base_model.named_parameters()} + log0(f"ema:enabled decay={args.ema_decay}") training_time_ms = 0.0 stop_after_step: int | None = None torch.cuda.synchronize() t0 = time.perf_counter() - step = 0 while True: last_step = step == args.iterations or (stop_after_step is not None and step >= stop_after_step) - should_validate = last_step or (args.val_loss_every > 0 and step % args.val_loss_every == 0) if should_validate: torch.cuda.synchronize() training_time_ms += 1000.0 * (time.perf_counter() - t0) + ema_backup: dict[str, Tensor] | None = None + if ema_state is not None: + ema_backup = {} + with torch.no_grad(): + for name, p in base_model.named_parameters(): + ema_backup[name] = p.data.clone() + p.data.copy_(ema_state[name]) val_loss, val_bpb = eval_val( args, model, @@ -989,13 +1306,17 @@ def lr_mul(step: int, elapsed_ms: float) -> float: has_leading_space_lut, is_boundary_token_lut, ) + if ema_backup is not None: + with torch.no_grad(): + for name, p in base_model.named_parameters(): + p.data.copy_(ema_backup[name]) + del ema_backup log0( f"step:{step}/{args.iterations} val_loss:{val_loss:.4f} val_bpb:{val_bpb:.4f} " f"train_time:{training_time_ms:.0f}ms step_avg:{training_time_ms / max(step, 1):.2f}ms" ) torch.cuda.synchronize() t0 = time.perf_counter() - if last_step: if stop_after_step is not None and step < args.iterations: log0( @@ -1003,7 +1324,6 @@ def lr_mul(step: int, elapsed_ms: float) -> float: f"step:{step}/{args.iterations}" ) break - elapsed_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) scale = lr_mul(step, elapsed_ms) zero_grad_all() @@ -1017,23 +1337,36 @@ def lr_mul(step: int, elapsed_ms: float) -> float: train_loss += loss.detach() (loss * grad_scale).backward() train_loss /= grad_accum_steps - frac = min(step / args.muon_momentum_warmup_steps, 1.0) if args.muon_momentum_warmup_steps > 0 else 1.0 muon_momentum = (1 - frac) * args.muon_momentum_warmup_start + frac * args.muon_momentum for group in optimizer_muon.param_groups: group["momentum"] = muon_momentum - for opt in optimizers: for group in opt.param_groups: group["lr"] = group["base_lr"] * scale - if args.grad_clip_norm > 0: torch.nn.utils.clip_grad_norm_(base_model.parameters(), args.grad_clip_norm) for opt in optimizers: opt.step() zero_grad_all() - + if ema_state is not None: + with torch.no_grad(): + decay = args.ema_decay + for name, p in base_model.named_parameters(): + ema_state[name].mul_(decay).add_(p.data, alpha=1.0 - decay) step += 1 + if args.checkpoint_every > 0 and step % args.checkpoint_every == 0 and master_process: + ckpt_dir = Path("checkpoints") + ckpt_dir.mkdir(exist_ok=True) + ckpt_path = ckpt_dir / f"ckpt_step{step}.pt" + ckpt_data = {"step": step, "model": base_model.state_dict()} + if ema_state is not None: + ckpt_data["ema"] = {k: v.cpu() for k, v in ema_state.items()} + torch.save(ckpt_data, ckpt_path) + ckpts = sorted(ckpt_dir.glob("ckpt_step*.pt"), key=lambda p: p.stat().st_mtime) + for old in ckpts[:-2]: + old.unlink(missing_ok=True) + log0(f"checkpoint:saved {ckpt_path} (kept {min(len(ckpts), 2)} latest)") approx_training_time_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) should_log_train = ( args.train_log_every > 0 @@ -1044,8 +1377,6 @@ def lr_mul(step: int, elapsed_ms: float) -> float: f"step:{step}/{args.iterations} train_loss:{train_loss.item():.4f} " f"train_time:{approx_training_time_ms:.0f}ms step_avg:{approx_training_time_ms / step:.2f}ms" ) - - # Needed to sync whether we've reached the wallclock cap. reached_cap = max_wallclock_ms is not None and approx_training_time_ms >= max_wallclock_ms if distributed and max_wallclock_ms is not None: reached_cap_tensor = torch.tensor(int(reached_cap), device=device) @@ -1053,18 +1384,15 @@ def lr_mul(step: int, elapsed_ms: float) -> float: reached_cap = bool(reached_cap_tensor.item()) if stop_after_step is None and reached_cap: stop_after_step = step - log0( f"peak memory allocated: {torch.cuda.max_memory_allocated() // 1024 // 1024} MiB " f"reserved: {torch.cuda.max_memory_reserved() // 1024 // 1024} MiB" ) - - # ----------------------------- - # SERIALIZATION + ROUNDTRIP VALIDATION - # ----------------------------- - # Save the raw state (useful for debugging/loading in PyTorch directly), then always produce - # the compressed int8+zlib artifact and validate the round-tripped weights. - + if ema_state is not None: + with torch.no_grad(): + for name, p in base_model.named_parameters(): + p.data.copy_(ema_state[name]) + log0("ema:applied EMA weights for export") if master_process: torch.save(base_model.state_dict(), "final_model.pt") model_bytes = os.path.getsize("final_model.pt") @@ -1072,30 +1400,39 @@ def lr_mul(step: int, elapsed_ms: float) -> float: log0(f"Serialized model: {model_bytes} bytes") log0(f"Code size: {code_bytes} bytes") log0(f"Total submission size: {model_bytes + code_bytes} bytes") - + quant_label = f"int{_QUANT_BITS}" + compress_label = "zstd" if _HAS_ZSTD else "zlib" quant_obj, quant_stats = quantize_state_dict_int8(base_model.state_dict()) quant_buf = io.BytesIO() torch.save(quant_obj, quant_buf) quant_raw = quant_buf.getvalue() - quant_blob = zlib.compress(quant_raw, level=9) + if _HAS_ZSTD: + compressor = zstd.ZstdCompressor(level=22) + quant_blob = compressor.compress(quant_raw) + else: + quant_blob = zlib.compress(quant_raw, level=9) quant_raw_bytes = len(quant_raw) + artifact_name = f"final_model.{quant_label}.ptz" if master_process: - with open("final_model.int8.ptz", "wb") as f: + with open(artifact_name, "wb") as f: f.write(quant_blob) - quant_file_bytes = os.path.getsize("final_model.int8.ptz") + quant_file_bytes = os.path.getsize(artifact_name) code_bytes = len(code.encode("utf-8")) ratio = quant_stats["baseline_tensor_bytes"] / max(quant_stats["int8_payload_bytes"], 1) log0( - f"Serialized model int8+zlib: {quant_file_bytes} bytes " + f"Serialized model {quant_label}+{compress_label}: {quant_file_bytes} bytes " f"(payload:{quant_stats['int8_payload_bytes']} raw_torch:{quant_raw_bytes} payload_ratio:{ratio:.2f}x)" ) - log0(f"Total submission size int8+zlib: {quant_file_bytes + code_bytes} bytes") - + log0(f"Total submission size {quant_label}+{compress_label}: {quant_file_bytes + code_bytes} bytes") if distributed: dist.barrier() - with open("final_model.int8.ptz", "rb") as f: + with open(artifact_name, "rb") as f: quant_blob_disk = f.read() - quant_state = torch.load(io.BytesIO(zlib.decompress(quant_blob_disk)), map_location="cpu") + if _HAS_ZSTD: + decompressor = zstd.ZstdDecompressor() + quant_state = torch.load(io.BytesIO(decompressor.decompress(quant_blob_disk)), map_location="cpu") + else: + quant_state = torch.load(io.BytesIO(zlib.decompress(quant_blob_disk)), map_location="cpu") base_model.load_state_dict(dequantize_state_dict_int8(quant_state), strict=True) torch.cuda.synchronize() t_qeval = time.perf_counter() @@ -1113,14 +1450,60 @@ def lr_mul(step: int, elapsed_ms: float) -> float: ) torch.cuda.synchronize() log0( - f"final_int8_zlib_roundtrip val_loss:{q_val_loss:.4f} val_bpb:{q_val_bpb:.4f} " + f"final_{quant_label}_{compress_label}_roundtrip val_loss:{q_val_loss:.4f} val_bpb:{q_val_bpb:.4f} " f"eval_time:{1000.0 * (time.perf_counter() - t_qeval):.0f}ms" ) - log0(f"final_int8_zlib_roundtrip_exact val_loss:{q_val_loss:.8f} val_bpb:{q_val_bpb:.8f}") - + log0(f"final_{quant_label}_{compress_label}_roundtrip_exact val_loss:{q_val_loss:.8f} val_bpb:{q_val_bpb:.8f}") + if args.eval_stride > 0 and args.eval_stride < args.train_seq_len: + torch.cuda.synchronize() + t_slide = time.perf_counter() + sw_val_loss, sw_val_bpb = eval_val_sliding( + args, base_model, rank, world_size, device, + val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, + stride=args.eval_stride, + ) + torch.cuda.synchronize() + log0( + f"final_sliding_window val_loss:{sw_val_loss:.4f} val_bpb:{sw_val_bpb:.4f} " + f"stride:{args.eval_stride} eval_time:{1000.0 * (time.perf_counter() - t_slide):.0f}ms" + ) + log0(f"final_sliding_window_exact val_loss:{sw_val_loss:.8f} val_bpb:{sw_val_bpb:.8f}") + if args.ttt_enabled: + torch.cuda.synchronize() + t_ttt = time.perf_counter() + ttt_loss, ttt_bpb = eval_val_sliding_ttt( + args, base_model, rank, world_size, device, + val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, + stride=args.eval_stride if args.eval_stride > 0 else 64, + log0=log0, + ) + torch.cuda.synchronize() + log0( + f"legal_ttt val_loss:{ttt_loss:.4f} val_bpb:{ttt_bpb:.4f} " + f"eval_time:{1000.0 * (time.perf_counter() - t_ttt):.0f}ms" + ) + log0(f"legal_ttt_exact val_loss:{ttt_loss:.8f} val_bpb:{ttt_bpb:.8f}") + if int(os.environ.get("NGRAM_EVAL", "0")): + ngram_max_order = int(os.environ.get("NGRAM_MAX_ORDER", "9")) + log0(f"ngram_eval:starting max_order={ngram_max_order}") + torch.cuda.synchronize() + t_ngram = time.perf_counter() + try: + from ngram_eval import eval_val_ngram + ng_val_loss, ng_val_bpb = eval_val_ngram( + args, base_model, model, rank, world_size, device, grad_accum_steps, + val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, + max_order=ngram_max_order, log_fn=log0, + ) + torch.cuda.synchronize() + log0( + f"ngram_eval_result val_loss:{ng_val_loss:.4f} val_bpb:{ng_val_bpb:.4f} " + f"eval_time:{1000.0 * (time.perf_counter() - t_ngram):.0f}ms" + ) + log0(f"ngram_eval_result_exact val_loss:{ng_val_loss:.8f} val_bpb:{ng_val_bpb:.8f}") + except Exception as e: + log0(f"ngram_eval:FAILED {e}") if distributed: dist.destroy_process_group() - - if __name__ == "__main__": main() diff --git a/train_gpt_full_stack.py b/train_gpt_full_stack.py new file mode 100644 index 0000000000..7f9cccaa2e --- /dev/null +++ b/train_gpt_full_stack.py @@ -0,0 +1,2695 @@ +from __future__ import annotations +import copy +import glob +import io +import math +import os +import random +import subprocess +import sys +import time +import uuid +import zlib +from pathlib import Path +try: + import zstandard as zstd + _HAS_ZSTD = True +except ImportError: + _HAS_ZSTD = False +try: + import brotli as _brotli + _HAS_BROTLI = True +except ImportError: + _HAS_BROTLI = False +import lzma +import numpy as np +import sentencepiece as spm +import torch +import torch.distributed as dist +import torch.nn.functional as F +from torch import Tensor, nn +from torch.nn.parallel import DistributedDataParallel as DDP + +def _byte_shuffle(data: bytes, stride: int = 2) -> bytes: + arr = bytearray(data) + n = len(arr) + out = bytearray(n) + for s in range(stride): + src_start = s + dst_start = s * (n // stride) + min(s, n % stride) + count = (n - s + stride - 1) // stride + for i in range(count): + out[dst_start + i] = arr[src_start + i * stride] + return bytes(out) + +def _byte_unshuffle(data: bytes, stride: int = 2) -> bytes: + arr = bytearray(data) + n = len(arr) + out = bytearray(n) + for s in range(stride): + src_start = s * (n // stride) + min(s, n % stride) + dst_start = s + count = (n - s + stride - 1) // stride + for i in range(count): + out[dst_start + i * stride] = arr[src_start + i] + return bytes(out) + +class Hyperparameters: + data_path = os.environ.get("DATA_PATH", f"./data/datasets/fineweb10B_sp{os.environ.get('VOCAB_SIZE', '4096')}") + train_files = os.path.join(data_path, "fineweb_train_*.bin") + val_files = os.path.join(data_path, "fineweb_val_*.bin") + tokenizer_path = os.environ.get("TOKENIZER_PATH", f"./data/tokenizers/fineweb_{os.environ.get('VOCAB_SIZE', '4096')}_bpe.model") + run_id = os.environ.get("RUN_ID", str(uuid.uuid4())) + seed = int(os.environ.get("SEED", 1337)) + val_batch_size = int(os.environ.get("VAL_BATCH_SIZE", 524_288)) + val_loss_every = int(os.environ.get("VAL_LOSS_EVERY", 1000)) + train_log_every = int(os.environ.get("TRAIN_LOG_EVERY", 200)) + val_max_tokens = int(os.environ.get("VAL_MAX_TOKENS", 0)) + checkpoint_every = int(os.environ.get("CHECKPOINT_EVERY", 0)) + iterations = int(os.environ.get("ITERATIONS", 20000)) + warmdown_iters = int(os.environ.get("WARMDOWN_ITERS", 3500)) + warmup_steps = int(os.environ.get("WARMUP_STEPS", 20)) + train_batch_tokens = int(os.environ.get("TRAIN_BATCH_TOKENS", 524_288)) + train_seq_len = int(os.environ.get("TRAIN_SEQ_LEN", 1024)) + max_wallclock_seconds = float(os.environ.get("MAX_WALLCLOCK_SECONDS", 600.0)) + qk_gain_init = float(os.environ.get("QK_GAIN_INIT", 1.5)) + vocab_size = int(os.environ.get("VOCAB_SIZE", 4096)) + num_layers = int(os.environ.get("NUM_LAYERS", 11)) + num_kv_heads = int(os.environ.get("NUM_KV_HEADS", 4)) + model_dim = int(os.environ.get("MODEL_DIM", 512)) + num_heads = int(os.environ.get("NUM_HEADS", 8)) + mlp_mult = float(os.environ.get("MLP_MULT", 4)) + tie_embeddings = bool(int(os.environ.get("TIE_EMBEDDINGS", "1"))) + rope_base = float(os.environ.get("ROPE_BASE", 10000.0)) + logit_softcap = float(os.environ.get("LOGIT_SOFTCAP", 30.0)) + embed_lr = float(os.environ.get("EMBED_LR", 0.6)) + head_lr = float(os.environ.get("HEAD_LR", 0.008)) + tied_embed_lr = float(os.environ.get("TIED_EMBED_LR", 0.05)) + tied_embed_init_std = float(os.environ.get("TIED_EMBED_INIT_STD", 0.005)) + matrix_lr = float(os.environ.get("MATRIX_LR", 0.04)) + scalar_lr = float(os.environ.get("SCALAR_LR", 0.04)) + muon_momentum = float(os.environ.get("MUON_MOMENTUM", 0.95)) + muon_backend_steps = int(os.environ.get("MUON_BACKEND_STEPS", 4)) + muon_momentum_warmup_start = float(os.environ.get("MUON_MOMENTUM_WARMUP_START", 0.85)) + muon_momentum_warmup_steps = int(os.environ.get("MUON_MOMENTUM_WARMUP_STEPS", 500)) + beta1 = float(os.environ.get("BETA1", 0.9)) + beta2 = float(os.environ.get("BETA2", 0.95)) + adam_eps = float(os.environ.get("ADAM_EPS", 1e-8)) + grad_clip_norm = float(os.environ.get("GRAD_CLIP_NORM", 0.0)) + bigramhash_buckets = int(os.environ.get("BIGRAMHASH_BUCKETS", 3072)) + ema_decay = float(os.environ.get("EMA_DECAY", 0.997)) + smeargate = bool(int(os.environ.get("SMEARGATE", "1"))) + unet_skips = bool(int(os.environ.get("UNET_SKIPS", "1"))) + int6_qat = bool(int(os.environ.get("INT6_QAT", "1"))) + rope_partial_dims = int(os.environ.get("ROPE_PARTIAL_DIMS", 16)) + ln_scale = bool(int(os.environ.get("LN_SCALE", "1"))) + xsa_layers = int(os.environ.get("XSA_LAYERS", 4)) + eval_stride = int(os.environ.get("EVAL_STRIDE", 0)) + ttt_enabled = bool(int(os.environ.get("TTT_ENABLED", "0"))) + ttt_lr = float(os.environ.get("TTT_LR", 0.005)) + ttt_epochs = int(os.environ.get("TTT_EPOCHS", 6)) + ttt_chunk_tokens = int(os.environ.get("TTT_CHUNK_TOKENS", 32768)) + ttt_freeze_blocks = int(os.environ.get("TTT_FREEZE_BLOCKS", 0)) + ttt_momentum = float(os.environ.get("TTT_MOMENTUM", 0.9)) + ttt_batch_seqs = int(os.environ.get("TTT_BATCH_SEQS", 32)) + ttt_grad_clip = float(os.environ.get("TTT_GRAD_CLIP", 1.0)) + ttt_lora_enabled = bool(int(os.environ.get("TTT_LORA_ENABLED", "0"))) + ttt_lora_rank = int(os.environ.get("TTT_LORA_RANK", 8)) + ttt_lora_lr = float(os.environ.get("TTT_LORA_LR", 0.001)) + recurrence_repeats = int(os.environ.get("RECURRENCE_REPEATS", 0)) + recurrence_unique_head = int(os.environ.get("RECURRENCE_UNIQUE_HEAD", 2)) + recurrence_unique_tail = int(os.environ.get("RECURRENCE_UNIQUE_TAIL", 2)) + moe_num_experts = int(os.environ.get("MOE_NUM_EXPERTS", 0)) + warmdown_schedule = os.environ.get("WARMDOWN_SCHEDULE", "sqrt") + weight_decay = float(os.environ.get("WEIGHT_DECAY", "0.04")) + bigramhash_dim = int(os.environ.get("BIGRAMHASH_DIM", 112)) + encoder_layers = int(os.environ.get("ENCODER_LAYERS", 1)) + recur_layers = os.environ.get("RECUR_LAYERS", "") + recur_start_step = int(os.environ.get("RECUR_START_STEP", 3000)) + parallel_start_layer = int(os.environ.get("PARALLEL_START_LAYER", -1)) + muoneq_r = bool(int(os.environ.get("MUONEQ_R", "1"))) + ttt_prequant = bool(int(os.environ.get("TTT_PREQUANT", "1"))) + ttt_optimizer = os.environ.get("TTT_OPTIMIZER", "adamw") + ttt_discriminative = bool(int(os.environ.get("TTT_DISCRIMINATIVE", "0"))) + ve_enabled = bool(int(os.environ.get("VE_ENABLED", "0"))) + ve_dim = int(os.environ.get("VE_DIM", 128)) + ve_layers = os.environ.get("VE_LAYERS", "9,10") + swa_enabled = bool(int(os.environ.get("SWA_ENABLED", "0"))) + swa_every = int(os.environ.get("SWA_EVERY", 50)) + swa_start_frac = float(os.environ.get("SWA_START_FRAC", 0.2)) + polar_express = bool(int(os.environ.get("POLAR_EXPRESS", "1"))) + gptq_full_hessian = bool(int(os.environ.get("GPTQ_FULL_HESSIAN", "1"))) + gptq_damp = float(os.environ.get("GPTQ_DAMP", 0.005)) + gptq_blocksize = int(os.environ.get("GPTQ_BLOCKSIZE", 128)) + gptq_actorder = bool(int(os.environ.get("GPTQ_ACTORDER", "1"))) + gptq_ar_selfgen = bool(int(os.environ.get("GPTQ_AR_SELFGEN", "1"))) + gptq_calib_samples = int(os.environ.get("GPTQ_CALIB_SAMPLES", 64)) + gptq_calib_seqlen = int(os.environ.get("GPTQ_CALIB_SEQLEN", 2048)) + gptq_calib_temp = float(os.environ.get("GPTQ_CALIB_TEMP", 0.8)) + causal_slot_enabled = bool(int(os.environ.get("CAUSAL_SLOT_ENABLED", "0"))) + causal_slot_steps = int(os.environ.get("CAUSAL_SLOT_STEPS", 8)) + causal_slot_lr = float(os.environ.get("CAUSAL_SLOT_LR", 0.005)) + causal_slot_dim = int(os.environ.get("CAUSAL_SLOT_DIM", 0)) + +_PE_COEFFS = [ + (8.156554524902461, -22.48329292557795, 15.878769915207462), + (4.042929935166739, -2.808917465908714, 0.5000178451051316), + (3.8916678022926607, -2.772484153217685, 0.5060648178503393), + (3.285753657755655, -2.3681294933425376, 0.46449024233003106), + (2.3465413258596377, -1.7097828382687081, 0.42323551169305323), +] + +def zeropower_via_newtonschulz5(G: Tensor, steps: int = 10, eps: float = 1e-7, polar_express: bool = False) -> Tensor: + if polar_express: + X = G.bfloat16() + X /= X.norm() + eps + transposed = G.size(0) > G.size(1) + if transposed: + X = X.T + for i in range(min(steps, len(_PE_COEFFS))): + a, b, c = _PE_COEFFS[i] + A = X @ X.T + X = a * X + b * (A @ X) + c * (A @ (A @ X)) + if transposed: + X = X.T + return X + else: + a, b, c = (3.4445, -4.7750, 2.0315) + X = G.bfloat16() + X /= X.norm() + eps + transposed = G.size(0) > G.size(1) + if transposed: + X = X.T + for _ in range(steps): + A = X @ X.T + B = b * A + c * A @ A + X = a * X + B @ X + if transposed: + X = X.T + return X + +class Muon(torch.optim.Optimizer): + def __init__(self, params, lr: float, momentum: float, backend_steps: int, nesterov: bool = True): + super().__init__( + params, + dict(lr=lr, momentum=momentum, backend_steps=backend_steps, nesterov=nesterov), + ) + @torch.no_grad() + def step(self, closure=None): + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + distributed = dist.is_available() and dist.is_initialized() + world_size = dist.get_world_size() if distributed else 1 + rank = dist.get_rank() if distributed else 0 + for group in self.param_groups: + params = group["params"] + if not params: + continue + lr = group["lr"] + momentum = group["momentum"] + backend_steps = group["backend_steps"] + nesterov = group["nesterov"] + total_params = sum(int(p.numel()) for p in params) + updates_flat = torch.zeros(total_params, device=params[0].device, dtype=torch.bfloat16) + curr = 0 + for i, p in enumerate(params): + if i % world_size == rank and p.grad is not None: + g = p.grad + state = self.state[p] + if "momentum_buffer" not in state: + state["momentum_buffer"] = torch.zeros_like(g) + buf = state["momentum_buffer"] + buf.mul_(momentum).add_(g) + if nesterov: + g = g.add(buf, alpha=momentum) + if group.get("muoneq_r", False): + g = g / g.norm(dim=-1, keepdim=True).clamp_min(1e-7) + g = zeropower_via_newtonschulz5(g, steps=backend_steps, polar_express=group.get("polar_express", False)) + g *= max(1, g.size(0) / g.size(1)) ** 0.5 + updates_flat[curr : curr + p.numel()] = g.reshape(-1) + curr += p.numel() + if distributed: + dist.all_reduce(updates_flat, op=dist.ReduceOp.SUM) + curr = 0 + for p in params: + g = updates_flat[curr : curr + p.numel()].view_as(p).to(dtype=p.dtype) + p.add_(g, alpha=-lr) + wd = group.get("weight_decay", 0.0) + if wd > 0: + p.add_(p, alpha=-wd * lr) + curr += p.numel() + return loss + +def int6_ste(w: Tensor) -> Tensor: + qmax = _QUANT_MAX_VAL # Uses QUANT_BITS override if set (e.g. 15 for int5, 7 for int4) + if w.ndim == 2: + abs_max = w.abs().amax(dim=1, keepdim=True).clamp_min(1e-8) + else: + abs_max = w.abs().amax().clamp_min(1e-8) + scale = abs_max / float(qmax) + w_q = (w / scale).round().clamp(-qmax, qmax) * scale + return w + (w_q - w).detach() + +def build_sentencepiece_luts( + sp: spm.SentencePieceProcessor, vocab_size: int, device: torch.device +) -> tuple[Tensor, Tensor, Tensor]: + sp_vocab_size = int(sp.vocab_size()) + table_size = max(sp_vocab_size, vocab_size) + base_bytes_np = np.zeros((table_size,), dtype=np.int16) + has_leading_space_np = np.zeros((table_size,), dtype=np.bool_) + is_boundary_token_np = np.ones((table_size,), dtype=np.bool_) + for token_id in range(sp_vocab_size): + if sp.is_control(token_id) or sp.is_unknown(token_id) or sp.is_unused(token_id): + continue + is_boundary_token_np[token_id] = False + if sp.is_byte(token_id): + base_bytes_np[token_id] = 1 + continue + piece = sp.id_to_piece(token_id) + if piece.startswith("\u2581"): + has_leading_space_np[token_id] = True + piece = piece[1:] + base_bytes_np[token_id] = len(piece.encode("utf-8")) + return ( + torch.tensor(base_bytes_np, dtype=torch.int16, device=device), + torch.tensor(has_leading_space_np, dtype=torch.bool, device=device), + torch.tensor(is_boundary_token_np, dtype=torch.bool, device=device), + ) + +def load_validation_tokens(pattern: str, seq_len: int) -> Tensor: + files = [Path(p) for p in sorted(glob.glob(pattern))] + if not files: + raise FileNotFoundError(f"No files found for pattern: {pattern}") + tokens = torch.cat([load_data_shard(file) for file in files]).contiguous() + max_tokens = int(os.environ.get("VAL_MAX_TOKENS", 0)) + if max_tokens > 0 and tokens.numel() > max_tokens: + tokens = tokens[:max_tokens].contiguous() + usable = ((tokens.numel() - 1) // seq_len) * seq_len + if usable <= 0: + raise ValueError(f"Validation split is too short for TRAIN_SEQ_LEN={seq_len}") + return tokens[: usable + 1] + +class CausalSLOT: + """Causal SLOT: optimize a delta vector using only already-scored context positions.""" + def __init__(self, model_dim: int, device, lr: float = 0.005, steps: int = 8): + self.delta = torch.zeros(1, 1, model_dim, device=device, requires_grad=True) + self.optimizer = torch.optim.AdamW([self.delta], lr=lr) + self.steps = steps + + def reset(self): + self.delta.data.zero_() + self.optimizer = torch.optim.AdamW([self.delta], lr=self.optimizer.param_groups[0]['lr']) + + def optimize(self, model, input_ids: Tensor, context_mask: Tensor, target_ids: Tensor): + if not context_mask.any(): + return self.delta.detach() + for _ in range(self.steps): + self.optimizer.zero_grad() + with torch.no_grad(): + x = model.tok_emb(input_ids) + if hasattr(model, 'bigram_hash') and model.bigram_hash is not None: + x = x + model.bigram_hash(input_ids) + x = F.rms_norm(x, (x.size(-1),)) + x0 = x + skips = [] + for i in range(model.num_encoder_layers): + x = model.blocks[i](x, x0) + if model.num_decoder_layers > 0: + skips.append(x) + for i in range(model.num_decoder_layers): + if skips: + skip_idx = model.num_encoder_layers - 1 - i + if 0 <= skip_idx < len(skips): + x = x + model.skip_weights[min(i, len(model.skip_weights)-1)] * skips[skip_idx] + x = model.blocks[model.num_encoder_layers + i](x, x0) + x_with_delta = x + self.delta + x_normed = F.rms_norm(x_with_delta, (x_with_delta.size(-1),)) + if model.tie_embeddings: + logits = F.linear(x_normed, model.tok_emb.weight) + else: + logits = model.lm_head(x_normed) + if hasattr(model, 'logit_softcap') and model.logit_softcap > 0: + logits = model.logit_softcap * torch.tanh(logits / model.logit_softcap) + logits = logits.float() + context_logits = logits[0, context_mask] + context_targets = target_ids[0, context_mask] + loss = F.cross_entropy(context_logits, context_targets) + loss.backward() + self.optimizer.step() + return self.delta.detach() + +def eval_val( + args: Hyperparameters, + model: nn.Module, + rank: int, + world_size: int, + device: torch.device, + grad_accum_steps: int, + val_tokens: Tensor, + base_bytes_lut: Tensor, + has_leading_space_lut: Tensor, + is_boundary_token_lut: Tensor, +) -> tuple[float, float]: + local_batch_tokens = args.val_batch_size // (world_size * grad_accum_steps) + if local_batch_tokens < args.train_seq_len: + raise ValueError( + "VAL_BATCH_SIZE must provide at least one sequence per rank; " + f"got VAL_BATCH_SIZE={args.val_batch_size}, WORLD_SIZE={world_size}, " + f"GRAD_ACCUM_STEPS={grad_accum_steps}, TRAIN_SEQ_LEN={args.train_seq_len}" + ) + local_batch_seqs = local_batch_tokens // args.train_seq_len + total_seqs = (val_tokens.numel() - 1) // args.train_seq_len + seq_start = (total_seqs * rank) // world_size + seq_end = (total_seqs * (rank + 1)) // world_size + val_loss_sum = torch.zeros((), device=device, dtype=torch.float64) + val_token_count = torch.zeros((), device=device, dtype=torch.float64) + val_byte_count = torch.zeros((), device=device, dtype=torch.float64) + model.eval() + with torch.inference_mode(): + for batch_seq_start in range(seq_start, seq_end, local_batch_seqs): + batch_seq_end = min(batch_seq_start + local_batch_seqs, seq_end) + raw_start = batch_seq_start * args.train_seq_len + raw_end = batch_seq_end * args.train_seq_len + 1 + local = val_tokens[raw_start:raw_end].to(device=device, dtype=torch.int64, non_blocking=True) + x = local[:-1].reshape(-1, args.train_seq_len) + y = local[1:].reshape(-1, args.train_seq_len) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + batch_loss = model(x, y).detach() + batch_token_count = float(y.numel()) + val_loss_sum += batch_loss.to(torch.float64) * batch_token_count + val_token_count += batch_token_count + prev_ids = x.reshape(-1) + tgt_ids = y.reshape(-1) + token_bytes = base_bytes_lut[tgt_ids].to(dtype=torch.int16) + token_bytes += (has_leading_space_lut[tgt_ids] & ~is_boundary_token_lut[prev_ids]).to(dtype=torch.int16) + val_byte_count += token_bytes.to(torch.float64).sum() + if dist.is_available() and dist.is_initialized(): + dist.all_reduce(val_loss_sum, op=dist.ReduceOp.SUM) + dist.all_reduce(val_token_count, op=dist.ReduceOp.SUM) + dist.all_reduce(val_byte_count, op=dist.ReduceOp.SUM) + val_loss = val_loss_sum / val_token_count + bits_per_token = val_loss.item() / math.log(2.0) + tokens_per_byte = val_token_count.item() / val_byte_count.item() + model.train() + return float(val_loss.item()), float(bits_per_token * tokens_per_byte) + +def eval_val_sliding( + args: Hyperparameters, + base_model: nn.Module, + rank: int, + world_size: int, + device: torch.device, + val_tokens: Tensor, + base_bytes_lut: Tensor, + has_leading_space_lut: Tensor, + is_boundary_token_lut: Tensor, + stride: int, + batch_seqs: int = 32, +) -> tuple[float, float]: + seq_len = args.train_seq_len + total_tokens = val_tokens.numel() - 1 + window_starts = [ws for ws in range(0, total_tokens, stride) + if min(ws + seq_len, total_tokens) - ws >= 1] + total_windows = len(window_starts) + my_s = (total_windows * rank) // world_size + my_e = (total_windows * (rank + 1)) // world_size + my_windows = window_starts[my_s:my_e] + loss_sum = torch.zeros((), device=device, dtype=torch.float64) + token_count = torch.zeros((), device=device, dtype=torch.float64) + byte_count = torch.zeros((), device=device, dtype=torch.float64) + base_model.eval() + # Initialize Causal SLOT if enabled + causal_slot = None + if args.causal_slot_enabled: + slot_dim = args.causal_slot_dim if args.causal_slot_dim > 0 else args.model_dim + causal_slot = CausalSLOT(slot_dim, device, lr=args.causal_slot_lr, steps=args.causal_slot_steps) + if causal_slot is not None: + # Causal SLOT path: process windows one at a time (delta per window) + t_slot_start = time.perf_counter() + # Freeze model weights for the duration + for p in base_model.parameters(): + p.requires_grad_(False) + for window_idx, ws in enumerate(my_windows): + end = min(ws + seq_len, total_tokens) + wlen = end - ws + chunk = val_tokens[ws:end + 1].to(dtype=torch.int64, device=device) + input_chunk = chunk[:-1].unsqueeze(0) # [1, wlen] + target_chunk = chunk[1:].unsqueeze(0) # [1, wlen] + # Pad to seq_len for uniform tensor shapes + if wlen < seq_len: + pad = torch.zeros(1, seq_len - wlen, dtype=torch.int64, device=device) + input_chunk = torch.cat([input_chunk, pad], dim=1) + target_chunk = torch.cat([target_chunk, pad], dim=1) + s = 0 if ws == 0 else max(wlen - stride, 0) + if causal_slot is not None and window_idx > 0 and s > 0: + # Context mask: positions 0..(s-1) are already-scored context + context_mask = torch.zeros(seq_len, dtype=torch.bool, device=device) + context_mask[:s] = True + # Optimize delta on context positions only + causal_slot.reset() + optimized_delta = causal_slot.optimize( + base_model, input_chunk, context_mask, target_chunk + ) + # Score the new stride positions with delta applied + with torch.no_grad(): + x = base_model.tok_emb(input_chunk) + if hasattr(base_model, 'bigram_hash') and base_model.bigram_hash is not None: + x = x + base_model.bigram_hash(input_chunk) + x = F.rms_norm(x, (x.size(-1),)) + x0 = x + skips_list: list[Tensor] = [] + for i in range(base_model.num_encoder_layers): + x = base_model.blocks[i](x, x0) + if base_model.num_decoder_layers > 0: + skips_list.append(x) + for i in range(base_model.num_decoder_layers): + if skips_list: + skip_idx = base_model.num_encoder_layers - 1 - i + if 0 <= skip_idx < len(skips_list): + x = x + base_model.skip_weights[min(i, len(base_model.skip_weights)-1)] * skips_list[skip_idx] + x = base_model.blocks[base_model.num_encoder_layers + i](x, x0) + x = x + optimized_delta + x_normed = base_model.final_norm(x) + if base_model.tie_embeddings: + logits_proj = F.linear(x_normed, base_model.tok_emb.weight) + else: + logits_proj = base_model.lm_head(x_normed) + logits_with_slot = base_model.logit_softcap * torch.tanh( + logits_proj / base_model.logit_softcap + ) + # Score only stride positions + score_logits = logits_with_slot[0, s:wlen].float() + score_targets = target_chunk[0, s:wlen] + else: + # Standard scoring without SLOT (first window or no context yet) + with torch.no_grad(): + with torch.autocast(device_type="cuda", dtype=torch.bfloat16): + logits_plain = base_model.forward_logits(input_chunk) + score_logits = logits_plain[0, s:wlen].float() + score_targets = target_chunk[0, s:wlen] + scored_nll = F.cross_entropy(score_logits, score_targets, reduction="none").to(torch.float64) + loss_sum += scored_nll.sum() + token_count += float(wlen - s) + tgt = target_chunk[0, s:wlen] + prev = input_chunk[0, s:wlen] + tb = base_bytes_lut[tgt].to(torch.float64) + tb += (has_leading_space_lut[tgt] & ~is_boundary_token_lut[prev]).to(torch.float64) + byte_count += tb.sum() + if rank == 0 and (window_idx % 200 == 0 or window_idx == len(my_windows) - 1): + elapsed = time.perf_counter() - t_slot_start + print( + f" causal_slot window [{window_idx+1}/{len(my_windows)}] " + f"elapsed={elapsed:.1f}s" + ) + # Restore requires_grad + for p in base_model.parameters(): + p.requires_grad_(True) + else: + # Standard batch path (unchanged from original) + compiled_logits = torch.compile(base_model.forward_logits, dynamic=False, fullgraph=True) + with torch.inference_mode(): + for bi in range(0, len(my_windows), batch_seqs): + batch_ws = my_windows[bi:bi + batch_seqs] + bsz = len(batch_ws) + x_batch = torch.zeros(bsz, seq_len, dtype=torch.int64, device=device) + y_batch = torch.zeros(bsz, seq_len, dtype=torch.int64, device=device) + wlens: list[int] = [] + for i, ws in enumerate(batch_ws): + end = min(ws + seq_len, total_tokens) + wlen = end - ws + wlens.append(wlen) + chunk = val_tokens[ws:end + 1].to(dtype=torch.int64, device=device) + x_batch[i, :wlen] = chunk[:-1] + y_batch[i, :wlen] = chunk[1:] + with torch.autocast(device_type="cuda", dtype=torch.bfloat16): + logits = compiled_logits(x_batch) + nll = F.cross_entropy( + logits.reshape(-1, logits.size(-1)).float(), + y_batch.reshape(-1), + reduction="none", + ).reshape(bsz, seq_len) + for i, ws in enumerate(batch_ws): + wlen = wlens[i] + s = 0 if ws == 0 else max(wlen - stride, 0) + scored_nll = nll[i, s:wlen].to(torch.float64) + loss_sum += scored_nll.sum() + token_count += float(wlen - s) + tgt = y_batch[i, s:wlen] + prev = x_batch[i, s:wlen] + tb = base_bytes_lut[tgt].to(torch.float64) + tb += (has_leading_space_lut[tgt] & ~is_boundary_token_lut[prev]).to(torch.float64) + byte_count += tb.sum() + if dist.is_available() and dist.is_initialized(): + dist.all_reduce(loss_sum, op=dist.ReduceOp.SUM) + dist.all_reduce(token_count, op=dist.ReduceOp.SUM) + dist.all_reduce(byte_count, op=dist.ReduceOp.SUM) + val_loss = (loss_sum / token_count).item() + bits_per_token = val_loss / math.log(2.0) + tokens_per_byte = token_count.item() / byte_count.item() + base_model.train() + return val_loss, bits_per_token * tokens_per_byte + +class LoRAAdapter(nn.Module): + """Low-Rank Adaptation adapter for a linear layer.""" + def __init__(self, base_linear: nn.Linear, rank: int = 8): + super().__init__() + dim_in = base_linear.in_features + dim_out = base_linear.out_features + self.A = nn.Parameter(torch.empty(dim_in, rank)) + self.B = nn.Parameter(torch.zeros(rank, dim_out)) + nn.init.kaiming_uniform_(self.A, a=math.sqrt(5)) + self.base_linear = base_linear + self._hook_handle = None + + def _lora_hook(self, module: nn.Module, input: tuple[Tensor, ...], output: Tensor) -> Tensor: + """Forward hook that adds LoRA contribution to base linear output.""" + x = input[0] + lora_delta = (x @ self.A.to(x.dtype)) @ self.B.to(x.dtype) + return output + lora_delta + + def register(self) -> None: + """Register the LoRA hook on the base linear layer.""" + if self._hook_handle is None: + self._hook_handle = self.base_linear.register_forward_hook(self._lora_hook) + + def remove(self) -> None: + """Remove the LoRA hook from the base linear layer.""" + if self._hook_handle is not None: + self._hook_handle.remove() + self._hook_handle = None + +def eval_val_sliding_lora_ttt( + args: Hyperparameters, + base_model: nn.Module, + rank: int, + world_size: int, + device: torch.device, + val_tokens: Tensor, + base_bytes_lut: Tensor, + has_leading_space_lut: Tensor, + is_boundary_token_lut: Tensor, + stride: int, + batch_seqs: int = 32, + log0=print, +) -> tuple[float, float]: + """Test-time training with LoRA adapters (only train adapter weights, not base model).""" + seq_len = args.train_seq_len + total_tokens = val_tokens.numel() - 1 + ttt_chunk = args.ttt_chunk_tokens + + window_starts = [ws for ws in range(0, total_tokens, stride) + if min(ws + seq_len, total_tokens) - ws >= stride or ws == 0] + num_chunks = (total_tokens + ttt_chunk - 1) // ttt_chunk + + chunk_windows: list[list[int]] = [[] for _ in range(num_chunks)] + for ws in window_starts: + end = min(ws + seq_len, total_tokens) + wlen = end - ws + s = 0 if ws == 0 else max(wlen - stride, 0) + scored_start = ws + s + ci = min(scored_start // ttt_chunk, num_chunks - 1) + chunk_windows[ci].append(ws) + + log0(f"lora_ttt_sliding:start chunks={num_chunks} chunk_tokens={ttt_chunk} " + f"total_windows={len(window_starts)} stride={stride} " + f"lora_rank={args.ttt_lora_rank} lora_lr={args.ttt_lora_lr} " + f"ttt_epochs={args.ttt_epochs}") + + loss_sum = torch.zeros((), device=device, dtype=torch.float64) + token_count = torch.zeros((), device=device, dtype=torch.float64) + byte_count = torch.zeros((), device=device, dtype=torch.float64) + + # Create LoRA adapters for attention projections and LM head + lora_adapters: list[LoRAAdapter] = [] + + # Add LoRA to attention blocks (c_q, c_v, proj) + if args.recurrence_repeats > 0: + # Recurrent model structure + all_blocks = list(base_model.head_blocks) + [base_model.shared_block] + list(base_model.tail_blocks) + else: + # Standard block structure + all_blocks = list(base_model.blocks) + + for block in all_blocks: + if hasattr(block, 'attn'): + attn = block.attn + # Add LoRA to Q, V, and output projection + for layer_name in ['c_q', 'c_v', 'proj']: + if hasattr(attn, layer_name): + base_linear = getattr(attn, layer_name) + adapter = LoRAAdapter(base_linear, rank=args.ttt_lora_rank).to(device) + lora_adapters.append(adapter) + + # Add LoRA to LM head (or tok_emb if tied) + if base_model.tie_embeddings: + # For tied embeddings, we need to patch tok_emb.weight usage in forward + # Create a pseudo-linear wrapper for the embedding weight + class EmbeddingAsLinear(nn.Module): + def __init__(self, weight: nn.Parameter): + super().__init__() + self.weight = weight + self.in_features = weight.size(1) + self.out_features = weight.size(0) + self.bias = None + + def forward(self, x: Tensor) -> Tensor: + return F.linear(x, self.weight) + + emb_linear = EmbeddingAsLinear(base_model.tok_emb.weight).to(device) + lm_adapter = LoRAAdapter(emb_linear, rank=args.ttt_lora_rank).to(device) + lora_adapters.append(lm_adapter) + elif base_model.lm_head is not None: + lm_adapter = LoRAAdapter(base_model.lm_head, rank=args.ttt_lora_rank).to(device) + lora_adapters.append(lm_adapter) + + # Collect all LoRA parameters + lora_params = [] + for adapter in lora_adapters: + lora_params.extend([adapter.A, adapter.B]) + + log0(f"lora_ttt_sliding:params lora={sum(p.numel() for p in lora_params)} " + f"base_frozen={sum(p.numel() for p in base_model.parameters())} " + f"adapters={len(lora_adapters)}") + + # Freeze all base model parameters + for p in base_model.parameters(): + p.requires_grad_(False) + + # Register all LoRA hooks + for adapter in lora_adapters: + adapter.register() + + # Use AdamW optimizer for LoRA params + optimizer = torch.optim.AdamW( + lora_params, + lr=args.ttt_lora_lr, + betas=(0.9, 0.999), + weight_decay=0.01, + ) + + # Compile forward paths + if not int(os.environ.get("TORCHDYNAMO_DISABLE", "0")): + compiled_logits_lora = torch.compile(base_model.forward_logits, dynamic=False, fullgraph=True) + compiled_forward_lora = torch.compile(base_model.forward, dynamic=False, fullgraph=True) + else: + compiled_logits_lora = base_model.forward_logits + compiled_forward_lora = base_model.forward + + t0 = time.perf_counter() + + for ci in range(num_chunks): + windows = chunk_windows[ci] + if not windows: + continue + + chunk_start = ci * ttt_chunk + chunk_end = min((ci + 1) * ttt_chunk, total_tokens) + + my_s = (len(windows) * rank) // world_size + my_e = (len(windows) * (rank + 1)) // world_size + my_windows = windows[my_s:my_e] + + # Score phase (inference only) + base_model.eval() + with torch.inference_mode(): + for bi in range(0, len(my_windows), batch_seqs): + batch_ws = my_windows[bi:bi + batch_seqs] + bsz = len(batch_ws) + x_batch = torch.zeros(bsz, seq_len, dtype=torch.int64, device=device) + y_batch = torch.zeros(bsz, seq_len, dtype=torch.int64, device=device) + wlens: list[int] = [] + + for i, ws in enumerate(batch_ws): + end = min(ws + seq_len, total_tokens) + wlen = end - ws + wlens.append(wlen) + chunk_tok = val_tokens[ws:end + 1].to(dtype=torch.int64, device=device) + x_batch[i, :wlen] = chunk_tok[:-1] + y_batch[i, :wlen] = chunk_tok[1:] + + with torch.autocast(device_type="cuda", dtype=torch.bfloat16): + logits = compiled_logits_lora(x_batch) + + nll = F.cross_entropy( + logits.reshape(-1, logits.size(-1)).float(), + y_batch.reshape(-1), reduction="none", + ).reshape(bsz, seq_len) + + for i, ws in enumerate(batch_ws): + wlen = wlens[i] + s = 0 if ws == 0 else max(wlen - stride, 0) + scored_nll = nll[i, s:wlen].to(torch.float64) + loss_sum += scored_nll.sum() + token_count += float(wlen - s) + + tgt, prev = y_batch[i, s:wlen], x_batch[i, s:wlen] + tb = base_bytes_lut[tgt].to(torch.float64) + tb += (has_leading_space_lut[tgt] & ~is_boundary_token_lut[prev]).to(torch.float64) + byte_count += tb.sum() + + # Train phase (update LoRA params only, skip last chunk) + is_last_chunk = (ci == num_chunks - 1) + if not is_last_chunk and args.ttt_epochs > 0: + base_model.train() + chunk_seqs = (chunk_end - chunk_start) // seq_len + + if chunk_seqs > 0: + # Cosine LR decay + cos_lr = args.ttt_lora_lr * 0.5 * (1.0 + math.cos(math.pi * ci / max(num_chunks - 1, 1))) + for pg in optimizer.param_groups: + pg['lr'] = cos_lr + + my_seq_s = (chunk_seqs * rank) // world_size + my_seq_e = (chunk_seqs * (rank + 1)) // world_size + my_chunk_seqs = my_seq_e - my_seq_s + + for _ep in range(args.ttt_epochs): + for bs in range(0, my_chunk_seqs, args.ttt_batch_seqs): + be = min(bs + args.ttt_batch_seqs, my_chunk_seqs) + actual_bs = my_seq_s + bs + start_tok = chunk_start + actual_bs * seq_len + end_tok = chunk_start + (my_seq_s + be) * seq_len + 1 + + if end_tok > val_tokens.numel(): + continue + + local = val_tokens[start_tok:end_tok].to(device=device, dtype=torch.int64) + x = local[:-1].reshape(-1, seq_len) + y = local[1:].reshape(-1, seq_len) + + optimizer.zero_grad(set_to_none=True) + + with torch.autocast(device_type="cuda", dtype=torch.bfloat16): + loss = compiled_forward_lora(x, y) + + loss.backward() + + # All-reduce LoRA gradients across GPUs + if world_size > 1: + for p in lora_params: + if p.grad is not None: + dist.all_reduce(p.grad, op=dist.ReduceOp.AVG) + + # Gradient clipping on LoRA params + torch.nn.utils.clip_grad_norm_(lora_params, args.ttt_grad_clip) + optimizer.step() + + if rank == 0 and (ci % 10 == 0 or ci == num_chunks - 1): + elapsed = time.perf_counter() - t0 + rl = loss_sum.item() / max(token_count.item(), 1) + rbpb = rl / math.log(2.0) * (token_count.item() / max(byte_count.item(), 1)) if token_count.item() > 0 else 0.0 + log0(f" lora_ttt_chunk [{ci+1}/{num_chunks}] bpb={rbpb:.6f} time={elapsed:.1f}s") + + if dist.is_available() and dist.is_initialized(): + dist.all_reduce(loss_sum, op=dist.ReduceOp.SUM) + dist.all_reduce(token_count, op=dist.ReduceOp.SUM) + dist.all_reduce(byte_count, op=dist.ReduceOp.SUM) + + val_loss = (loss_sum / token_count).item() + val_bpb = val_loss / math.log(2.0) * (token_count.item() / byte_count.item()) + + # Clean up: remove all LoRA hooks and restore base model + for adapter in lora_adapters: + adapter.remove() + + # Restore requires_grad on base model + for p in base_model.parameters(): + p.requires_grad_(True) + + base_model.eval() + + log0(f"lora_ttt_sliding:done val_loss={val_loss:.6f} val_bpb={val_bpb:.6f} " + f"elapsed={time.perf_counter() - t0:.1f}s") + + return val_loss, val_bpb + +def eval_val_sliding_ttt( + args: Hyperparameters, + base_model: nn.Module, + rank: int, + world_size: int, + device: torch.device, + val_tokens: Tensor, + base_bytes_lut: Tensor, + has_leading_space_lut: Tensor, + is_boundary_token_lut: Tensor, + stride: int, + batch_seqs: int = 32, + log0=print, +) -> tuple[float, float]: + seq_len = args.train_seq_len + total_tokens = val_tokens.numel() - 1 + ttt_chunk = args.ttt_chunk_tokens + window_starts = [ws for ws in range(0, total_tokens, stride) + if min(ws + seq_len, total_tokens) - ws >= stride or ws == 0] + num_chunks = (total_tokens + ttt_chunk - 1) // ttt_chunk + chunk_windows: list[list[int]] = [[] for _ in range(num_chunks)] + for ws in window_starts: + end = min(ws + seq_len, total_tokens) + wlen = end - ws + s = 0 if ws == 0 else max(wlen - stride, 0) + scored_start = ws + s + ci = min(scored_start // ttt_chunk, num_chunks - 1) + chunk_windows[ci].append(ws) + log0(f"ttt_sliding:start chunks={num_chunks} chunk_tokens={ttt_chunk} " + f"total_windows={len(window_starts)} stride={stride} " + f"ttt_lr={args.ttt_lr} ttt_epochs={args.ttt_epochs} " + f"freeze_blocks={args.ttt_freeze_blocks}") + loss_sum = torch.zeros((), device=device, dtype=torch.float64) + token_count = torch.zeros((), device=device, dtype=torch.float64) + byte_count = torch.zeros((), device=device, dtype=torch.float64) + frozen_block_ids = set(range(min(args.ttt_freeze_blocks, len(base_model.blocks)))) + ttt_params = [] + for name, p in base_model.named_parameters(): + freeze = False + for bi in frozen_block_ids: + if f"blocks.{bi}." in name: + freeze = True + break + if freeze: + p.requires_grad_(False) + else: + p.requires_grad_(True) + ttt_params.append(p) + log0(f"ttt_sliding:params unfrozen={sum(p.numel() for p in ttt_params)} " + f"frozen={sum(p.numel() for p in base_model.parameters() if not p.requires_grad)}") + if args.ttt_discriminative: + param_groups = [] + blocks = [b for b in base_model.blocks] if hasattr(base_model, 'blocks') else [] + num_blocks = len(blocks) + for i, block in enumerate(blocks): + if i < args.ttt_freeze_blocks: + continue + lr_scale = 0.3 + 0.7 * (i / max(num_blocks - 1, 1)) + block_params = [p for p in block.parameters() if p.requires_grad] + if block_params: + param_groups.append({"params": block_params, "lr": args.ttt_lr * lr_scale}) + block_param_ids = set() + for block in blocks: + block_param_ids.update(id(p) for p in block.parameters()) + other_params = [p for p in base_model.parameters() if p.requires_grad and id(p) not in block_param_ids] + if other_params: + param_groups.append({"params": other_params, "lr": args.ttt_lr}) + ttt_opt = torch.optim.SGD(param_groups, momentum=args.ttt_momentum) + elif args.ttt_optimizer == "adamw": + ttt_opt = torch.optim.AdamW(ttt_params, lr=args.ttt_lr, weight_decay=0.0) + else: + ttt_opt = torch.optim.SGD(ttt_params, lr=args.ttt_lr, momentum=args.ttt_momentum) + optimizer = ttt_opt + # Compile forward paths for TTT (matches eval_val_sliding behavior) + if not int(os.environ.get("TORCHDYNAMO_DISABLE", "0")): + compiled_logits_ttt = torch.compile(base_model.forward_logits, dynamic=False, fullgraph=True) + compiled_forward_ttt = torch.compile(base_model.forward, dynamic=False, fullgraph=True) + else: + compiled_logits_ttt = base_model.forward_logits + compiled_forward_ttt = base_model.forward + t0 = time.perf_counter() + effective_ttt_epochs = 10 if args.ttt_discriminative else args.ttt_epochs + for ci in range(num_chunks): + windows = chunk_windows[ci] + if not windows: + continue + chunk_start = ci * ttt_chunk + chunk_end = min((ci + 1) * ttt_chunk, total_tokens) + my_s = (len(windows) * rank) // world_size + my_e = (len(windows) * (rank + 1)) // world_size + my_windows = windows[my_s:my_e] + base_model.eval() + with torch.inference_mode(): + for bi in range(0, len(my_windows), batch_seqs): + batch_ws = my_windows[bi:bi + batch_seqs] + bsz = len(batch_ws) + x_batch = torch.zeros(bsz, seq_len, dtype=torch.int64, device=device) + y_batch = torch.zeros(bsz, seq_len, dtype=torch.int64, device=device) + wlens: list[int] = [] + for i, ws in enumerate(batch_ws): + end = min(ws + seq_len, total_tokens) + wlen = end - ws + wlens.append(wlen) + chunk_tok = val_tokens[ws:end + 1].to(dtype=torch.int64, device=device) + x_batch[i, :wlen] = chunk_tok[:-1] + y_batch[i, :wlen] = chunk_tok[1:] + with torch.autocast(device_type="cuda", dtype=torch.bfloat16): + logits = compiled_logits_ttt(x_batch) + nll = F.cross_entropy( + logits.reshape(-1, logits.size(-1)).float(), + y_batch.reshape(-1), reduction="none", + ).reshape(bsz, seq_len) + for i, ws in enumerate(batch_ws): + wlen = wlens[i] + s = 0 if ws == 0 else max(wlen - stride, 0) + scored_nll = nll[i, s:wlen].to(torch.float64) + loss_sum += scored_nll.sum() + token_count += float(wlen - s) + tgt, prev = y_batch[i, s:wlen], x_batch[i, s:wlen] + tb = base_bytes_lut[tgt].to(torch.float64) + tb += (has_leading_space_lut[tgt] & ~is_boundary_token_lut[prev]).to(torch.float64) + byte_count += tb.sum() + is_last_chunk = (ci == num_chunks - 1) + if not is_last_chunk and effective_ttt_epochs > 0: + base_model.train() + chunk_seqs = (chunk_end - chunk_start) // seq_len + if chunk_seqs > 0: + cos_lr = args.ttt_lr * 0.5 * (1.0 + math.cos(math.pi * ci / max(num_chunks - 1, 1))) + for pg in optimizer.param_groups: + pg['lr'] = cos_lr + my_seq_s = (chunk_seqs * rank) // world_size + my_seq_e = (chunk_seqs * (rank + 1)) // world_size + my_chunk_seqs = my_seq_e - my_seq_s + for _ep in range(effective_ttt_epochs): + for bs in range(0, my_chunk_seqs, args.ttt_batch_seqs): + be = min(bs + args.ttt_batch_seqs, my_chunk_seqs) + actual_bs = my_seq_s + bs + start_tok = chunk_start + actual_bs * seq_len + end_tok = chunk_start + (my_seq_s + be) * seq_len + 1 + if end_tok > val_tokens.numel(): + continue + local = val_tokens[start_tok:end_tok].to(device=device, dtype=torch.int64) + x = local[:-1].reshape(-1, seq_len) + y = local[1:].reshape(-1, seq_len) + optimizer.zero_grad(set_to_none=True) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16): + loss = compiled_forward_ttt(x, y) + loss.backward() + if world_size > 1: + for p in ttt_params: + if p.grad is not None: + dist.all_reduce(p.grad, op=dist.ReduceOp.AVG) + torch.nn.utils.clip_grad_norm_(ttt_params, args.ttt_grad_clip) + optimizer.step() + if rank == 0 and (ci % 10 == 0 or ci == num_chunks - 1): + elapsed = time.perf_counter() - t0 + rl = loss_sum.item() / max(token_count.item(), 1) + rbpb = rl / math.log(2.0) * (token_count.item() / max(byte_count.item(), 1)) if token_count.item() > 0 else 0.0 + log0(f" ttt_chunk [{ci+1}/{num_chunks}] bpb={rbpb:.6f} time={elapsed:.1f}s") + if dist.is_available() and dist.is_initialized(): + dist.all_reduce(loss_sum, op=dist.ReduceOp.SUM) + dist.all_reduce(token_count, op=dist.ReduceOp.SUM) + dist.all_reduce(byte_count, op=dist.ReduceOp.SUM) + val_loss = (loss_sum / token_count).item() + val_bpb = val_loss / math.log(2.0) * (token_count.item() / byte_count.item()) + for p in base_model.parameters(): + p.requires_grad_(True) + base_model.eval() + log0(f"ttt_sliding:done val_loss={val_loss:.6f} val_bpb={val_bpb:.6f} " + f"elapsed={time.perf_counter() - t0:.1f}s") + return val_loss, val_bpb +# ============================================================================ +# Quantization constants +# ============================================================================ +_INT6_MODE = bool(int(os.environ.get("INT6_QAT", "1"))) +_INT3_MODE = bool(int(os.environ.get("INT3_QUANT", "0"))) # Ternary quantization +_QUANT_OVERRIDE = int(os.environ.get("QUANT_BITS", "0")) # Override: 3,4,5,6,8 +if _QUANT_OVERRIDE > 0: + _QUANT_BITS = _QUANT_OVERRIDE + _QUANT_MAX_VAL = (1 << (_QUANT_BITS - 1)) - 1 # e.g. int5 -> 15, int4 -> 7 +elif _INT3_MODE: + _QUANT_MAX_VAL = 3 # Ternary: -3, -2, -1, 0, 1, 2, 3 + _QUANT_BITS = 3 +elif _INT6_MODE: + _QUANT_MAX_VAL = 31 + _QUANT_BITS = 6 +else: + _QUANT_MAX_VAL = 127 + _QUANT_BITS = 8 +CONTROL_TENSOR_NAME_PATTERNS = tuple( + pattern + for pattern in os.environ.get( + "CONTROL_TENSOR_NAME_PATTERNS", + "attn_scale,attn_scales,mlp_scale,mlp_scales,resid_mix,resid_mixes,q_gain,skip_weight,skip_weights,smear", + ).split(",") + if pattern +) +INT8_KEEP_FLOAT_FP32_NAME_PATTERNS = tuple( + pattern + for pattern in os.environ.get( + "INT8_KEEP_FLOAT_FP32_NAME_PATTERNS", + ",".join(CONTROL_TENSOR_NAME_PATTERNS), + ).split(",") + if pattern +) +INT8_KEEP_FLOAT_MAX_NUMEL = 65_536 +INT8_KEEP_FLOAT_STORE_DTYPE = torch.float16 +INT8_PER_ROW_SCALE_DTYPE = torch.float16 +INT8_CLIP_PERCENTILE = 99.99984 +INT8_CLIP_Q = INT8_CLIP_PERCENTILE / 100.0 +def tensor_nbytes(t: Tensor) -> int: + return int(t.numel()) * int(t.element_size()) +def keep_float_tensor(name: str, t: Tensor, passthrough_orig_dtypes: dict[str, str]) -> Tensor: + if any(pattern in name for pattern in INT8_KEEP_FLOAT_FP32_NAME_PATTERNS): + return t.float().contiguous() + if t.dtype in {torch.float32, torch.bfloat16}: + passthrough_orig_dtypes[name] = str(t.dtype).removeprefix("torch.") + return t.to(dtype=INT8_KEEP_FLOAT_STORE_DTYPE).contiguous() + return t +_GPTQ_LITE_PERCENTILES = [0.9999, 0.99995, 0.99999, 0.999995, 1.0] +# ============================================================================ +# GPTQ class (full Hessian) +# ============================================================================ +class GPTQ: + """Full Hessian GPTQ quantizer for a single linear layer.""" + def __init__(self, weight: Tensor): + self.rows, self.cols = weight.shape + self.H = torch.zeros((self.cols, self.cols), device=weight.device, dtype=torch.float32) + self.nsamples = 0 + + def add_batch(self, inp: Tensor): + inp = inp.reshape(-1, inp.shape[-1]).float() + n = inp.shape[0] + self.H *= self.nsamples / (self.nsamples + n) + self.nsamples += n + self.H += (2.0 / (self.nsamples)) * (inp.T @ inp) + + def quantize(self, weight: Tensor, qmax: int = 31, blocksize: int = 128, + percdamp: float = 0.005, actorder: bool = True): + W = weight.clone().float() + H = self.H.clone() + dead = torch.diag(H) == 0 + H[dead, dead] = 1 + W[:, dead] = 0 + damp = percdamp * torch.mean(torch.diag(H)) + diag_idx = torch.arange(self.cols, device=H.device) + H[diag_idx, diag_idx] += damp + if actorder: + perm = torch.argsort(torch.diag(H), descending=True) + W = W[:, perm] + H = H[perm][:, perm] + invperm = torch.argsort(perm) + try: + H_inv = torch.linalg.cholesky(H) + H_inv = torch.cholesky_inverse(H_inv) + H_inv = torch.linalg.cholesky(H_inv, upper=True) + except torch.linalg.LinAlgError: + H[diag_idx, diag_idx] += damp * 10 + H_inv = torch.linalg.cholesky(H) + H_inv = torch.cholesky_inverse(H_inv) + H_inv = torch.linalg.cholesky(H_inv, upper=True) + Q = torch.zeros_like(W) + for i1 in range(0, self.cols, blocksize): + i2 = min(i1 + blocksize, self.cols) + count = i2 - i1 + W1 = W[:, i1:i2].clone() + Q1 = torch.zeros_like(W1) + Err1 = torch.zeros_like(W1) + Hinv1 = H_inv[i1:i2, i1:i2] + for i in range(count): + w = W1[:, i] + d = Hinv1[i, i] + scale = w.abs().max().clamp(min=1e-8) / qmax + q = (w / scale).round().clamp(-qmax, qmax) * scale + Q1[:, i] = q + err = (w - q) / d + W1[:, i + 1:] -= err.unsqueeze(1) * Hinv1[i, i + 1:].unsqueeze(0) + Err1[:, i] = err + Q[:, i1:i2] = Q1 + W[:, i2:] -= Err1 @ H_inv[i1:i2, i2:] + if actorder: + Q = Q[:, invperm] + return Q + +@torch.no_grad() +def generate_ar_calibration(model: nn.Module, num_samples: int = 64, seq_len: int = 2048, + temp: float = 0.8, vocab_size: int = 1024) -> Tensor: + device = next(model.parameters()).device + all_tokens = [] + for _ in range(num_samples): + tokens = torch.randint(0, vocab_size, (1, 1), device=device) + for _ in range(seq_len - 1): + logits = model.forward_logits(tokens) + next_logits = logits[:, -1, :] / temp + probs = torch.softmax(next_logits, dim=-1) + next_token = torch.multinomial(probs, num_samples=1) + tokens = torch.cat([tokens, next_token], dim=1) + all_tokens.append(tokens.squeeze(0)) + return torch.stack(all_tokens) + +def collect_hessians(model: nn.Module, calibration_data: Tensor, gptq_modules: dict) -> None: + hooks = [] + def make_hook(name: str, gptq_obj: GPTQ): + def hook_fn(module: nn.Module, input: tuple, output: Tensor) -> None: + gptq_obj.add_batch(input[0]) + return hook_fn + for name, module in model.named_modules(): + if isinstance(module, CastedLinear) and name in gptq_modules: + h = module.register_forward_hook(make_hook(name, gptq_modules[name])) + hooks.append(h) + for i in range(calibration_data.shape[0]): + tokens = calibration_data[i:i+1] + model(tokens[:, :-1], tokens[:, 1:]) + for h in hooks: + h.remove() +# ============================================================================ +# GPTQ-lite quantization +# ============================================================================ +def quantize_float_tensor(t: Tensor, gpu_device: torch.device | None = None) -> tuple[Tensor, Tensor]: + # Run GPTQ-lite search on GPU when available (much faster for large 2D weight matrices) + dev = gpu_device if (gpu_device is not None and t.ndim == 2 and t.numel() > 0) else t.device + t32 = t.float().to(dev) + qmax = _QUANT_MAX_VAL + if t32.ndim == 2 and t32.numel() > 0: + abs_vals = t32.abs() + best_q = None + best_scale = None + best_mse = None + for pct in _GPTQ_LITE_PERCENTILES: + if pct >= 1.0: + clip_abs = abs_vals.amax(dim=1) + else: + clip_abs = torch.quantile(abs_vals, pct, dim=1) + clip_abs = clip_abs.clamp_min(1e-8) + scale = (clip_abs / qmax).clamp_min(1.0 / qmax) + clipped = torch.maximum(torch.minimum(t32, clip_abs[:, None]), -clip_abs[:, None]) + q = torch.clamp(torch.round(clipped / scale[:, None]), -qmax, qmax) + recon = q * scale[:, None] + mse = (t32 - recon).square().mean(dim=1) + if best_mse is None: + best_mse = mse + best_q = q + best_scale = scale + else: + improved = mse < best_mse + if improved.any(): + best_mse = torch.where(improved, mse, best_mse) + best_q = torch.where(improved[:, None], q, best_q) + best_scale = torch.where(improved, scale, best_scale) + return best_q.cpu().to(torch.int8).contiguous(), best_scale.cpu().to(dtype=INT8_PER_ROW_SCALE_DTYPE).contiguous() + if t32.ndim == 2: + return torch.empty_like(t32, dtype=torch.int8, device="cpu"), torch.empty((t32.shape[0],), dtype=INT8_PER_ROW_SCALE_DTYPE) + t32 = t32.cpu() + clip_abs = float(torch.quantile(t32.abs().flatten(), INT8_CLIP_Q).item()) if t32.numel() else 0.0 + scale = torch.tensor(clip_abs / qmax if clip_abs > 0 else 1.0, dtype=torch.float32) + q = torch.clamp(torch.round(torch.clamp(t32, -clip_abs, clip_abs) / scale), -qmax, qmax).to(torch.int8).contiguous() + return q, scale +def quantize_state_dict_int8(state_dict: dict[str, Tensor], gpu_device: torch.device | None = None, + rank: int = 0, world_size: int = 1): + """Quantize state dict. When world_size > 1, distributes large tensor quantization across ranks.""" + quantized: dict[str, Tensor] = {} + scales: dict[str, Tensor] = {} + dtypes: dict[str, str] = {} + passthrough: dict[str, Tensor] = {} + passthrough_orig_dtypes: dict[str, str] = {} + qmeta: dict[str, dict[str, object]] = {} + stats = dict.fromkeys( + ("param_count", "num_tensors", "num_float_tensors", "num_nonfloat_tensors", "baseline_tensor_bytes", "int8_payload_bytes"), + 0, + ) + # Separate tensors into passthrough vs needs-quantization + to_quantize: list[tuple[str, Tensor]] = [] + for name, tensor in state_dict.items(): + t = tensor.detach().to("cpu").contiguous() + stats["param_count"] += int(t.numel()) + stats["num_tensors"] += 1 + stats["baseline_tensor_bytes"] += tensor_nbytes(t) + if not t.is_floating_point(): + stats["num_nonfloat_tensors"] += 1 + passthrough[name] = t + stats["int8_payload_bytes"] += tensor_nbytes(t) + continue + if t.numel() <= INT8_KEEP_FLOAT_MAX_NUMEL: + kept = keep_float_tensor(name, t, passthrough_orig_dtypes) + passthrough[name] = kept + stats["int8_payload_bytes"] += tensor_nbytes(kept) + continue + stats["num_float_tensors"] += 1 + to_quantize.append((name, t)) + # Distribute quantization work across ranks (round-robin by tensor index) + for idx, (name, t) in enumerate(to_quantize): + if world_size > 1 and idx % world_size != rank: + continue # Another rank handles this tensor + q, s = quantize_float_tensor(t, gpu_device=gpu_device) + if s.ndim > 0: + qmeta[name] = {"scheme": "per_row", "axis": 0} + quantized[name] = q + scales[name] = s + dtypes[name] = str(t.dtype).removeprefix("torch.") + stats["int8_payload_bytes"] += tensor_nbytes(q) + tensor_nbytes(s) + # Gather results from all ranks + if world_size > 1 and dist.is_available() and dist.is_initialized(): + # Each rank broadcasts its quantized tensors to all others + for idx, (name, t) in enumerate(to_quantize): + owner = idx % world_size + if owner == rank: + # Broadcast shape info then tensor data + q_tensor = quantized[name] + s_tensor = scales[name] + # Send via broadcast + q_gpu = q_tensor.to(gpu_device) if gpu_device is not None else q_tensor.cuda() + s_gpu = s_tensor.to(gpu_device) if gpu_device is not None else s_tensor.cuda() + dist.broadcast(q_gpu, src=owner) + dist.broadcast(s_gpu, src=owner) + quantized[name] = q_gpu.cpu() + scales[name] = s_gpu.cpu() + else: + # Allocate matching tensors and receive + q_shape = (t.shape[0],) + t.shape[1:] if t.ndim == 2 else t.shape + q_gpu = torch.empty(q_shape, dtype=torch.int8, device=gpu_device if gpu_device is not None else "cuda") + if t.ndim == 2 and t.numel() > 0: + s_gpu = torch.empty(t.shape[0], dtype=INT8_PER_ROW_SCALE_DTYPE, device=gpu_device if gpu_device is not None else "cuda") + else: + s_gpu = torch.empty((), dtype=torch.float32, device=gpu_device if gpu_device is not None else "cuda") + dist.broadcast(q_gpu, src=owner) + dist.broadcast(s_gpu, src=owner) + quantized[name] = q_gpu.cpu() + scales[name] = s_gpu.cpu() + dtypes[name] = str(t.dtype).removeprefix("torch.") + if s_gpu.ndim > 0: + qmeta[name] = {"scheme": "per_row", "axis": 0} + stats["int8_payload_bytes"] += tensor_nbytes(quantized[name]) + tensor_nbytes(scales[name]) + obj: dict[str, object] = { + "__quant_format__": f"int{_QUANT_BITS}_clean_per_row_v1", + "quantized": quantized, + "scales": scales, + "dtypes": dtypes, + "passthrough": passthrough, + } + if qmeta: + obj["qmeta"] = qmeta + if passthrough_orig_dtypes: + obj["passthrough_orig_dtypes"] = passthrough_orig_dtypes + return obj, stats +def dequantize_state_dict_int8(obj: dict[str, object]) -> dict[str, Tensor]: + out: dict[str, Tensor] = {} + qmeta = obj.get("qmeta", {}) + passthrough_orig_dtypes = obj.get("passthrough_orig_dtypes", {}) + for name, q in obj["quantized"].items(): + dtype = getattr(torch, obj["dtypes"][name]) + s = obj["scales"][name] + if qmeta.get(name, {}).get("scheme") == "per_row" or s.ndim > 0: + s = s.to(dtype=torch.float32) + out[name] = (q.float() * s.view(q.shape[0], *([1] * (q.ndim - 1)))).to(dtype=dtype).contiguous() + else: + scale = float(s.item()) + out[name] = (q.float() * scale).to(dtype=dtype).contiguous() + for name, t in obj["passthrough"].items(): + out_t = t.detach().to("cpu").contiguous() + orig_dtype = passthrough_orig_dtypes.get(name) + if isinstance(orig_dtype, str): + out_t = out_t.to(dtype=getattr(torch, orig_dtype)).contiguous() + out[name] = out_t + return out +# ============================================================================ +# Data loading +# ============================================================================ +def load_data_shard(file: Path) -> Tensor: + header_bytes = 256 * np.dtype(" None: + self.file_idx = (self.file_idx + 1) % len(self.files) + self.tokens = load_data_shard(self.files[self.file_idx]) + self.pos = 0 + def take(self, n: int) -> Tensor: + chunks: list[Tensor] = [] + remaining = n + while remaining > 0: + avail = self.tokens.numel() - self.pos + if avail <= 0: + self._advance_file() + continue + k = min(remaining, avail) + chunks.append(self.tokens[self.pos : self.pos + k]) + self.pos += k + remaining -= k + return chunks[0] if len(chunks) == 1 else torch.cat(chunks) +class DistributedTokenLoader: + def __init__(self, pattern: str, rank: int, world_size: int, device: torch.device): + self.rank = rank + self.world_size = world_size + self.device = device + self.stream = TokenStream(pattern) + def next_batch(self, global_tokens: int, seq_len: int, grad_accum_steps: int) -> tuple[Tensor, Tensor]: + local_tokens = global_tokens // (self.world_size * grad_accum_steps) + per_rank_span = local_tokens + 1 + chunk = self.stream.take(per_rank_span * self.world_size) + start = self.rank * per_rank_span + local = chunk[start : start + per_rank_span].to(dtype=torch.int64) + x = local[:-1].reshape(-1, seq_len) + y = local[1:].reshape(-1, seq_len) + return x.to(self.device, non_blocking=True), y.to(self.device, non_blocking=True) +# ============================================================================ +# Model components +# ============================================================================ +class RMSNorm(nn.Module): + def __init__(self, eps: float | None = None): + super().__init__() + self.eps = eps + def forward(self, x: Tensor) -> Tensor: + return F.rms_norm(x, (x.size(-1),), eps=self.eps) +class CastedLinear(nn.Linear): + _int6_qat: bool = False + def forward(self, x: Tensor) -> Tensor: + w = self.weight.to(x.dtype) + if CastedLinear._int6_qat and self.training: + w = int6_ste(w) + bias = self.bias.to(x.dtype) if self.bias is not None else None + return F.linear(x, w, bias) +def restore_low_dim_params_to_fp32(module: nn.Module) -> None: + with torch.no_grad(): + for name, param in module.named_parameters(): + if (param.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS)) and param.dtype != torch.float32: + param.data = param.data.float() +class BigramHash(nn.Module): + def __init__(self, num_buckets: int, dim: int, proj_dim: int = 0): + super().__init__() + self.num_buckets = num_buckets + embed_dim = proj_dim if proj_dim > 0 and proj_dim < dim else dim + self.emb = nn.Embedding(num_buckets, embed_dim) + self.proj = nn.Linear(embed_dim, dim, bias=False) if embed_dim < dim else None + self.scale = nn.Parameter(torch.tensor(0.05)) + nn.init.zeros_(self.emb.weight) + def forward(self, input_ids: Tensor) -> Tensor: + prev = F.pad(input_ids[:, :-1], (1, 0), value=0) + bigram_hash = (prev * 31 + input_ids) % self.num_buckets + h = self.emb(bigram_hash) + if self.proj is not None: + h = self.proj(h) + return h * self.scale +class SmearGate(nn.Module): + def __init__(self, dim: int): + super().__init__() + self.gate = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + def forward(self, x: Tensor) -> Tensor: + g = torch.sigmoid(self.gate.to(dtype=x.dtype))[None, None, :] + positions = torch.arange(1, x.shape[1] + 1, device=x.device, dtype=x.dtype).view(1, -1, 1) + smooth = torch.cumsum(x, dim=1) / positions + return g * x + (1 - g) * smooth +class Rotary(nn.Module): + def __init__(self, dim: int, base: float = 10000.0): + super().__init__() + inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.float32) / dim)) + self.register_buffer("inv_freq", inv_freq, persistent=False) + self._seq_len_cached = 0 + self._cos_cached: Tensor | None = None + self._sin_cached: Tensor | None = None + def forward(self, seq_len: int, device: torch.device, dtype: torch.dtype) -> tuple[Tensor, Tensor]: + if ( + self._cos_cached is None + or self._sin_cached is None + or self._seq_len_cached != seq_len + or self._cos_cached.device != device + ): + t = torch.arange(seq_len, device=device, dtype=self.inv_freq.dtype) + freqs = torch.outer(t, self.inv_freq.to(device)) + self._cos_cached = freqs.cos()[None, None, :, :] + self._sin_cached = freqs.sin()[None, None, :, :] + self._seq_len_cached = seq_len + return self._cos_cached.to(dtype=dtype), self._sin_cached.to(dtype=dtype) +def apply_rotary_emb(x: Tensor, cos: Tensor, sin: Tensor) -> Tensor: + half = x.size(-1) // 2 + x1, x2 = x[..., :half], x[..., half:] + return torch.cat((x1 * cos + x2 * sin, x1 * (-sin) + x2 * cos), dim=-1) +# ============================================================================ +# Value Embedding +# ============================================================================ +class ValueEmbedding(nn.Module): + def __init__(self, vocab_size: int, ve_dim: int, kv_dim: int): + super().__init__() + self.emb = nn.Embedding(vocab_size, ve_dim) + self.proj = nn.Linear(ve_dim, kv_dim, bias=False) + nn.init.zeros_(self.emb.weight) + nn.init.zeros_(self.proj.weight) + def forward(self, token_ids: Tensor) -> Tensor: + return self.proj(self.emb(token_ids)) +# ============================================================================ +# RepeatMLP (depth recurrence) +# ============================================================================ +class RepeatMLP(nn.Module): + def __init__(self, dim: int, hidden: int): + super().__init__() + self.fc = nn.Linear(dim, hidden, bias=False) + self.proj = nn.Linear(hidden, dim, bias=False) + def forward(self, x: Tensor) -> Tensor: + return self.proj(F.leaky_relu(self.fc(x), 0.5).square()) +# ============================================================================ +# Attention +# ============================================================================ +class CausalSelfAttention(nn.Module): + def __init__( + self, + dim: int, + num_heads: int, + num_kv_heads: int, + rope_base: float, + qk_gain_init: float, + rope_partial_dims: int = 0, + use_xsa: bool = False, + ve: ValueEmbedding = None, + ): + super().__init__() + if dim % num_heads != 0: + raise ValueError("model_dim must be divisible by num_heads") + if num_heads % num_kv_heads != 0: + raise ValueError("num_heads must be divisible by num_kv_heads") + self.num_heads = num_heads + self.num_kv_heads = num_kv_heads + self.head_dim = dim // num_heads + if self.head_dim % 2 != 0: + raise ValueError("head_dim must be even for RoPE") + kv_dim = self.num_kv_heads * self.head_dim + self.c_q = CastedLinear(dim, dim, bias=False) + self.c_k = CastedLinear(dim, kv_dim, bias=False) + self.c_v = CastedLinear(dim, kv_dim, bias=False) + self.proj = CastedLinear(dim, dim, bias=False) + self.proj._zero_init = True + self.q_gain = nn.Parameter(torch.full((num_heads,), qk_gain_init, dtype=torch.float32)) + rope_dim = rope_partial_dims if 0 < rope_partial_dims < self.head_dim else self.head_dim + self.rope_dim = rope_dim + self.rotary = Rotary(rope_dim, base=rope_base) + self.use_xsa = use_xsa + self.ve = ve + def forward(self, x: Tensor, input_ids: Tensor = None) -> Tensor: + bsz, seqlen, dim = x.shape + q = self.c_q(x).reshape(bsz, seqlen, self.num_heads, self.head_dim).transpose(1, 2) + k = self.c_k(x).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim).transpose(1, 2) + v = self.c_v(x).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim).transpose(1, 2) + q = F.rms_norm(q, (q.size(-1),)) + k = F.rms_norm(k, (k.size(-1),)) + cos, sin = self.rotary(seqlen, x.device, q.dtype) + if self.rope_dim < self.head_dim: + q_rope, q_pass = q[..., :self.rope_dim], q[..., self.rope_dim:] + k_rope, k_pass = k[..., :self.rope_dim], k[..., self.rope_dim:] + q_rope = apply_rotary_emb(q_rope, cos, sin) + k_rope = apply_rotary_emb(k_rope, cos, sin) + q = torch.cat([q_rope, q_pass], dim=-1) + k = torch.cat([k_rope, k_pass], dim=-1) + else: + q = apply_rotary_emb(q, cos, sin) + k = apply_rotary_emb(k, cos, sin) + q = q * self.q_gain.to(dtype=q.dtype)[None, :, None, None] + if self.ve is not None and input_ids is not None: + ve_out = self.ve(input_ids) + ve_out = ve_out.view(bsz, seqlen, self.num_kv_heads, self.head_dim).transpose(1, 2) + v = v + ve_out + y = F.scaled_dot_product_attention( + q, + k, + v, + attn_mask=None, + is_causal=True, + enable_gqa=(self.num_kv_heads != self.num_heads), + ) + if self.use_xsa: + if self.num_kv_heads != self.num_heads: + repeats = self.num_heads // self.num_kv_heads + v_expanded = v.repeat_interleave(repeats, dim=1) + else: + v_expanded = v + y = y - v_expanded + y = y.transpose(1, 2).contiguous().reshape(bsz, seqlen, dim) + return self.proj(y) +# ============================================================================ +# MLP / MoE +# ============================================================================ +class MLP(nn.Module): + def __init__(self, dim: int, mlp_mult: float): + super().__init__() + hidden = int(mlp_mult * dim) + self.fc = CastedLinear(dim, hidden, bias=False) + self.proj = CastedLinear(hidden, dim, bias=False) + self.proj._zero_init = True + def forward(self, x: Tensor) -> Tensor: + x = F.leaky_relu(self.fc(x), negative_slope=0.5) + return self.proj(x.square()) +class MoEMLP(nn.Module): + def __init__(self, dim: int, mlp_mult: float, num_experts: int): + super().__init__() + self.num_experts = num_experts + expert_hidden = int(mlp_mult * dim) // num_experts + self.experts = nn.ModuleList([ + nn.ModuleDict({ + "fc": CastedLinear(dim, expert_hidden, bias=False), + "proj": CastedLinear(expert_hidden, dim, bias=False), + }) + for _ in range(num_experts) + ]) + for e in self.experts: + e["proj"]._zero_init = True + self.router = nn.Linear(dim, num_experts, bias=False) + def forward(self, x: Tensor) -> Tensor: + bsz, seqlen, dim = x.shape + logits = self.router(x.detach()) + weights = torch.softmax(logits, dim=-1) + indices = weights.argmax(dim=-1) + out = torch.zeros_like(x) + for i, expert in enumerate(self.experts): + mask = (indices == i) + if not mask.any(): + continue + tokens = x[mask] + h = F.leaky_relu(expert["fc"](tokens), negative_slope=0.5) + h = expert["proj"](h.square()) + out[mask] = h + return out +# ============================================================================ +# Block +# ============================================================================ +class Block(nn.Module): + def __init__( + self, + dim: int, + num_heads: int, + num_kv_heads: int, + mlp_mult: float, + rope_base: float, + qk_gain_init: float, + use_smeargate: bool = False, + rope_partial_dims: int = 0, + use_xsa: bool = False, + ln_scale_factor: float = 1.0, + moe_num_experts: int = 0, + has_repeat_mlp: bool = False, + ve: ValueEmbedding = None, + ): + super().__init__() + self.attn_norm = RMSNorm() + self.mlp_norm = RMSNorm() + self.attn = CausalSelfAttention( + dim, num_heads, num_kv_heads, rope_base, qk_gain_init, + rope_partial_dims=rope_partial_dims, + use_xsa=use_xsa, + ve=ve, + ) + self.mlp = MoEMLP(dim, mlp_mult, moe_num_experts) if moe_num_experts > 1 else MLP(dim, mlp_mult) + self.attn_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + self.mlp_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + self.resid_mix = nn.Parameter(torch.stack((torch.ones(dim), torch.zeros(dim))).float()) + self.smear = SmearGate(dim) if use_smeargate else None + self.ln_scale_factor = ln_scale_factor + if has_repeat_mlp: + mlp_hidden = int(dim * mlp_mult) + self.repeat_mlp = RepeatMLP(dim, mlp_hidden) + else: + self.repeat_mlp = None + def forward(self, x: Tensor, x0: Tensor, input_ids: Tensor = None) -> Tensor: + mix = self.resid_mix.to(dtype=x.dtype) + x = mix[0][None, None, :] * x + mix[1][None, None, :] * x0 + normed = self.attn_norm(x) + if self.ln_scale_factor != 1.0: + normed = normed * self.ln_scale_factor + attn_out = self.attn(normed, input_ids=input_ids) + x = x + self.attn_scale.to(dtype=x.dtype)[None, None, :] * attn_out + mlp_normed = self.mlp_norm(x) + if self.ln_scale_factor != 1.0: + mlp_normed = mlp_normed * self.ln_scale_factor + x = x + self.mlp_scale.to(dtype=x.dtype)[None, None, :] * self.mlp(mlp_normed) + if self.smear is not None: + x = self.smear(x) + return x + def forward_repeat(self, x: Tensor, x0: Tensor) -> Tensor: + """Same as forward but using repeat_mlp instead of mlp for the MLP step.""" + mix = self.resid_mix.to(dtype=x.dtype) + x_ln = self.attn_norm(mix[0] * x + mix[1] * x0) + attn_out = self.attn(x_ln) + x = x + self.attn_scale.to(dtype=x.dtype) * attn_out + mlp_in = self.mlp_norm(x) + mlp_out = self.repeat_mlp(mlp_in) + x = x + self.mlp_scale.to(dtype=x.dtype) * mlp_out + return x +# ============================================================================ +# GPT Model +# ============================================================================ +class GPT(nn.Module): + def __init__( + self, + vocab_size: int, + num_layers: int, + model_dim: int, + num_heads: int, + num_kv_heads: int, + mlp_mult: float, + tie_embeddings: bool, + tied_embed_init_std: float, + logit_softcap: float, + rope_base: float, + qk_gain_init: float, + bigramhash_buckets: int = 0, + bigramhash_dim: int = 0, + use_smeargate: bool = False, + unet_skips: bool = True, + rope_partial_dims: int = 0, + ln_scale: bool = False, + xsa_layers: int = 0, + recurrence_repeats: int = 0, + recurrence_unique_head: int = 2, + recurrence_unique_tail: int = 2, + moe_num_experts: int = 0, + encoder_layers: int = 0, + recur_layer_indices: list[int] | None = None, + recur_start_step: int = 3000, + parallel_start_layer: int = -1, + ve_layer_indices: list = None, + ve_dim: int = 128, + ): + super().__init__() + if logit_softcap <= 0.0: + raise ValueError(f"logit_softcap must be positive, got {logit_softcap}") + self.tie_embeddings = tie_embeddings + self.tied_embed_init_std = tied_embed_init_std + self.logit_softcap = logit_softcap + self.tok_emb = nn.Embedding(vocab_size, model_dim) + self.bigram_hash = BigramHash(bigramhash_buckets, model_dim, proj_dim=bigramhash_dim) if bigramhash_buckets > 0 else None + # Value embedding + self.ve_layer_indices = ve_layer_indices if ve_layer_indices is not None else [] + if self.ve_layer_indices: + kv_dim = num_kv_heads * (model_dim // num_heads) + self.value_embedding = ValueEmbedding(vocab_size, ve_dim, kv_dim) + else: + self.value_embedding = None + # Depth recurrence (untied MLP weights) + self.recur_layer_indices = recur_layer_indices if recur_layer_indices is not None else [] + self.recur_start_step = recur_start_step + self._recur_active = False + self.recurrence_repeats = recurrence_repeats + if recurrence_repeats > 0: + self.num_encoder_layers = 0 + self.num_decoder_layers = 0 + self.num_skip_weights = 0 + self.skip_weights = nn.Parameter(torch.zeros(0, dtype=torch.float32)) + n_head = recurrence_unique_head + n_tail = recurrence_unique_tail + effective_layers = n_head + recurrence_repeats + n_tail + def make_block(layer_idx: int, eff_total: int) -> Block: + xsa_start_idx = eff_total - xsa_layers if xsa_layers > 0 else eff_total + return Block( + model_dim, num_heads, num_kv_heads, mlp_mult, rope_base, qk_gain_init, + use_smeargate=use_smeargate, + rope_partial_dims=rope_partial_dims, + use_xsa=(layer_idx >= xsa_start_idx), + ln_scale_factor=(1.0 / math.sqrt(layer_idx + 1)) if ln_scale else 1.0, + moe_num_experts=moe_num_experts, + ) + self.head_blocks = nn.ModuleList([make_block(i, effective_layers) for i in range(n_head)]) + self.shared_block = make_block(n_head, effective_layers) + self.recurrence_gate = nn.Parameter(torch.full((recurrence_repeats,), 0.5, dtype=torch.float32)) + self.tail_blocks = nn.ModuleList([ + make_block(n_head + recurrence_repeats + i, effective_layers) for i in range(n_tail) + ]) + self.blocks = nn.ModuleList() + else: + if unet_skips: + if encoder_layers > 0: + self.num_encoder_layers = encoder_layers + elif self.recur_layer_indices: + virtual_num_layers = num_layers + len(self.recur_layer_indices) + self.num_encoder_layers = min(virtual_num_layers // 2, num_layers) + else: + self.num_encoder_layers = num_layers // 2 + self.num_decoder_layers = num_layers - self.num_encoder_layers + self.num_skip_weights = min(self.num_encoder_layers, self.num_decoder_layers) + else: + self.num_encoder_layers = num_layers + self.num_decoder_layers = 0 + self.num_skip_weights = 0 + self.skip_weights = nn.Parameter(torch.ones(self.num_skip_weights, model_dim, dtype=torch.float32)) + xsa_start = num_layers - xsa_layers if xsa_layers > 0 else num_layers + self.blocks = nn.ModuleList( + [ + Block( + model_dim, num_heads, num_kv_heads, mlp_mult, rope_base, qk_gain_init, + use_smeargate=use_smeargate, + rope_partial_dims=rope_partial_dims, + use_xsa=(i >= xsa_start), + ln_scale_factor=(1.0 / math.sqrt(i + 1)) if ln_scale else 1.0, + moe_num_experts=moe_num_experts, + has_repeat_mlp=(i in self.recur_layer_indices), + ve=self.value_embedding if i in self.ve_layer_indices else None, + ) + for i in range(num_layers) + ] + ) + self.head_blocks = nn.ModuleList() + self.shared_block = None + self.recurrence_gate = None + self.tail_blocks = nn.ModuleList() + self.final_norm = RMSNorm() + self.lm_head = None if tie_embeddings else CastedLinear(model_dim, vocab_size, bias=False) + if self.lm_head is not None: + self.lm_head._zero_init = True + # Two-lane parallel parameters + self.unet_skips = unet_skips + self.parallel_start_layer = parallel_start_layer + if self.parallel_start_layer >= 0: + num_parallel = num_layers - self.parallel_start_layer + self.parallel_post_lambdas = nn.Parameter(torch.ones(num_parallel, 2, 2, dtype=torch.float32)) + self.parallel_resid_lambdas = nn.Parameter( + torch.full((num_parallel, 2), math.sqrt(1.1), dtype=torch.float32) + ) + else: + self.parallel_post_lambdas = None + self.parallel_resid_lambdas = None + self._init_weights() + def _init_weights(self) -> None: + if self.tie_embeddings: + nn.init.normal_(self.tok_emb.weight, mean=0.0, std=self.tied_embed_init_std) + for module in self.modules(): + if isinstance(module, nn.Linear) and getattr(module, "_zero_init", False): + nn.init.zeros_(module.weight) + def _sync_repeat_mlp_from_base(self) -> None: + """Copy MLP weights to repeat_mlp weights at recurrence activation step.""" + for li in self.recur_layer_indices: + block = self.blocks[li] + if block.repeat_mlp is not None: + block.repeat_mlp.fc.weight.data.copy_(block.mlp.fc.weight.data) + block.repeat_mlp.proj.weight.data.copy_(block.mlp.proj.weight.data) + def forward(self, input_ids: Tensor, target_ids: Tensor) -> Tensor: + x = self.tok_emb(input_ids) + if self.bigram_hash is not None: + x = x + self.bigram_hash(input_ids) + x = F.rms_norm(x, (x.size(-1),)) + x0 = x + if self.recurrence_repeats > 0: + for block in self.head_blocks: + x = block(x, x0) + for r in range(self.recurrence_repeats): + gate = torch.sigmoid(self.recurrence_gate[r]).to(dtype=x.dtype) + residual = x + x = self.shared_block(x, x0) + x = gate * x + (1.0 - gate) * residual + for block in self.tail_blocks: + x = block(x, x0) + else: + skips: list[Tensor] = [] + lane0: Tensor | None = None + lane1: Tensor | None = None + + # Encoder pass + for i in range(self.num_encoder_layers): + layer_idx = i + if self.parallel_start_layer >= 0 and layer_idx >= self.parallel_start_layer: + # Two-lane mode + if lane0 is None: + lane0 = x.clone() + lane1 = x.clone() + + pi = layer_idx - self.parallel_start_layer + block = self.blocks[layer_idx] + + mix = block.resid_mix.to(dtype=lane0.dtype) + + # Attn reads from lane0 + attn_in = block.attn_norm(mix[0] * lane0 + mix[1] * x0) + attn_out = block.attn(attn_in, input_ids=input_ids if layer_idx in self.ve_layer_indices else None) * block.attn_scale.to(dtype=lane0.dtype) + + # MLP reads from lane1 + mlp_in = block.mlp_norm(mix[0] * lane1 + mix[1] * x0) + mlp_out = block.mlp(mlp_in) * block.mlp_scale.to(dtype=lane1.dtype) + + # Cross-lane routing + post = self.parallel_post_lambdas[pi].to(dtype=lane0.dtype) # [2, 2] + resid_l = self.parallel_resid_lambdas[pi].to(dtype=lane0.dtype) # [2] + + # Update lanes + new_lane0 = resid_l[0] * lane0 + post[0, 0] * attn_out + post[1, 0] * mlp_out + new_lane1 = resid_l[1] * lane1 + post[0, 1] * attn_out + post[1, 1] * mlp_out + lane0, lane1 = new_lane0, new_lane1 + else: + # Standard sequential mode + x = self.blocks[layer_idx](x, x0, input_ids=input_ids if layer_idx in self.ve_layer_indices else None) + + if self.unet_skips: + skips.append(lane0 if lane0 is not None else x) + + # Depth recurrence pass (if active) + if self._recur_active and self.recur_layer_indices: + if lane0 is not None: + x = (lane0 + lane1) * 0.5 + lane0 = None + lane1 = None + for li in self.recur_layer_indices: + x = self.blocks[li].forward_repeat(x, x0) + if self.parallel_start_layer >= 0: + lane0 = x.clone() + lane1 = x.clone() + + # Decoder pass + for i in range(self.num_decoder_layers): + layer_idx = self.num_encoder_layers + i + # Apply skip connections + if skips: + cur_dtype = lane0.dtype if lane0 is not None else x.dtype + skip_idx = self.num_encoder_layers - 1 - i + if 0 <= skip_idx < len(skips): + skip_val = self.skip_weights[min(i, len(self.skip_weights) - 1)].to(dtype=cur_dtype)[None, None, :] * skips[skip_idx] + if lane0 is not None: + lane0 = lane0 + skip_val + lane1 = lane1 + skip_val + else: + x = x + skip_val + + if self.parallel_start_layer >= 0 and layer_idx >= self.parallel_start_layer: + if lane0 is None: + lane0 = x.clone() + lane1 = x.clone() + + pi = layer_idx - self.parallel_start_layer + block = self.blocks[layer_idx] + mix = block.resid_mix.to(dtype=lane0.dtype) + + attn_in = block.attn_norm(mix[0] * lane0 + mix[1] * x0) + attn_out = block.attn(attn_in, input_ids=input_ids if layer_idx in self.ve_layer_indices else None) * block.attn_scale.to(dtype=lane0.dtype) + + mlp_in = block.mlp_norm(mix[0] * lane1 + mix[1] * x0) + mlp_out = block.mlp(mlp_in) * block.mlp_scale.to(dtype=lane1.dtype) + + post = self.parallel_post_lambdas[pi].to(dtype=lane0.dtype) + resid_l = self.parallel_resid_lambdas[pi].to(dtype=lane0.dtype) + + new_lane0 = resid_l[0] * lane0 + post[0, 0] * attn_out + post[1, 0] * mlp_out + new_lane1 = resid_l[1] * lane1 + post[0, 1] * attn_out + post[1, 1] * mlp_out + lane0, lane1 = new_lane0, new_lane1 + else: + x = self.blocks[layer_idx](x, x0, input_ids=input_ids if layer_idx in self.ve_layer_indices else None) + + # Final merge + if lane0 is not None: + x = (lane0 + lane1) * 0.5 + + x = self.final_norm(x).reshape(-1, x.size(-1)) + targets = target_ids.reshape(-1) + if self.tie_embeddings: + logits_proj = F.linear(x, self.tok_emb.weight) + else: + if self.lm_head is None: + raise RuntimeError("lm_head is required when tie_embeddings=False") + logits_proj = self.lm_head(x) + logits = self.logit_softcap * torch.tanh(logits_proj / self.logit_softcap) + if getattr(self, '_return_logits', False): + return logits.float().reshape(input_ids.shape[0], input_ids.shape[1], -1) + return F.cross_entropy(logits.float(), targets, reduction="mean") + def forward_logits(self, input_ids: Tensor) -> Tensor: + x = self.tok_emb(input_ids) + if self.bigram_hash is not None: + x = x + self.bigram_hash(input_ids) + x = F.rms_norm(x, (x.size(-1),)) + x0 = x + if self.recurrence_repeats > 0: + for block in self.head_blocks: + x = block(x, x0) + for r in range(self.recurrence_repeats): + gate = torch.sigmoid(self.recurrence_gate[r]).to(dtype=x.dtype) + residual = x + x = self.shared_block(x, x0) + x = gate * x + (1.0 - gate) * residual + for block in self.tail_blocks: + x = block(x, x0) + else: + skips: list[Tensor] = [] + lane0: Tensor | None = None + lane1: Tensor | None = None + + # Encoder pass + for i in range(self.num_encoder_layers): + layer_idx = i + if self.parallel_start_layer >= 0 and layer_idx >= self.parallel_start_layer: + if lane0 is None: + lane0 = x.clone() + lane1 = x.clone() + + pi = layer_idx - self.parallel_start_layer + block = self.blocks[layer_idx] + + mix = block.resid_mix.to(dtype=lane0.dtype) + + attn_in = block.attn_norm(mix[0] * lane0 + mix[1] * x0) + attn_out = block.attn(attn_in, input_ids=input_ids if layer_idx in self.ve_layer_indices else None) * block.attn_scale.to(dtype=lane0.dtype) + + mlp_in = block.mlp_norm(mix[0] * lane1 + mix[1] * x0) + mlp_out = block.mlp(mlp_in) * block.mlp_scale.to(dtype=lane1.dtype) + + post = self.parallel_post_lambdas[pi].to(dtype=lane0.dtype) + resid_l = self.parallel_resid_lambdas[pi].to(dtype=lane0.dtype) + + new_lane0 = resid_l[0] * lane0 + post[0, 0] * attn_out + post[1, 0] * mlp_out + new_lane1 = resid_l[1] * lane1 + post[0, 1] * attn_out + post[1, 1] * mlp_out + lane0, lane1 = new_lane0, new_lane1 + else: + x = self.blocks[layer_idx](x, x0, input_ids=input_ids if layer_idx in self.ve_layer_indices else None) + + if self.unet_skips: + skips.append(lane0 if lane0 is not None else x) + + # Depth recurrence pass (if active) + if self._recur_active and self.recur_layer_indices: + if lane0 is not None: + x = (lane0 + lane1) * 0.5 + lane0 = None + lane1 = None + for li in self.recur_layer_indices: + x = self.blocks[li].forward_repeat(x, x0) + if self.parallel_start_layer >= 0: + lane0 = x.clone() + lane1 = x.clone() + + # Decoder pass + for i in range(self.num_decoder_layers): + layer_idx = self.num_encoder_layers + i + if skips: + cur_dtype = lane0.dtype if lane0 is not None else x.dtype + skip_idx = self.num_encoder_layers - 1 - i + if 0 <= skip_idx < len(skips): + skip_val = self.skip_weights[min(i, len(self.skip_weights) - 1)].to(dtype=cur_dtype)[None, None, :] * skips[skip_idx] + if lane0 is not None: + lane0 = lane0 + skip_val + lane1 = lane1 + skip_val + else: + x = x + skip_val + + if self.parallel_start_layer >= 0 and layer_idx >= self.parallel_start_layer: + if lane0 is None: + lane0 = x.clone() + lane1 = x.clone() + + pi = layer_idx - self.parallel_start_layer + block = self.blocks[layer_idx] + mix = block.resid_mix.to(dtype=lane0.dtype) + + attn_in = block.attn_norm(mix[0] * lane0 + mix[1] * x0) + attn_out = block.attn(attn_in, input_ids=input_ids if layer_idx in self.ve_layer_indices else None) * block.attn_scale.to(dtype=lane0.dtype) + + mlp_in = block.mlp_norm(mix[0] * lane1 + mix[1] * x0) + mlp_out = block.mlp(mlp_in) * block.mlp_scale.to(dtype=lane1.dtype) + + post = self.parallel_post_lambdas[pi].to(dtype=lane0.dtype) + resid_l = self.parallel_resid_lambdas[pi].to(dtype=lane0.dtype) + + new_lane0 = resid_l[0] * lane0 + post[0, 0] * attn_out + post[1, 0] * mlp_out + new_lane1 = resid_l[1] * lane1 + post[0, 1] * attn_out + post[1, 1] * mlp_out + lane0, lane1 = new_lane0, new_lane1 + else: + x = self.blocks[layer_idx](x, x0, input_ids=input_ids if layer_idx in self.ve_layer_indices else None) + + # Final merge + if lane0 is not None: + x = (lane0 + lane1) * 0.5 + + x = self.final_norm(x) + if self.tie_embeddings: + logits_proj = F.linear(x, self.tok_emb.weight) + else: + if self.lm_head is None: + raise RuntimeError("lm_head is required when tie_embeddings=False") + logits_proj = self.lm_head(x) + return self.logit_softcap * torch.tanh(logits_proj / self.logit_softcap) +# ============================================================================ +# Main training function +# ============================================================================ +def main() -> None: + global zeropower_via_newtonschulz5 + code = Path(__file__).read_text(encoding="utf-8") + args = Hyperparameters() + if not int(os.environ.get("TORCHDYNAMO_DISABLE", "0")): + zeropower_via_newtonschulz5 = torch.compile(zeropower_via_newtonschulz5) + distributed = "RANK" in os.environ and "WORLD_SIZE" in os.environ + rank = int(os.environ.get("RANK", "0")) + world_size = int(os.environ.get("WORLD_SIZE", "1")) + local_rank = int(os.environ.get("LOCAL_RANK", "0")) + if world_size <= 0: + raise ValueError(f"WORLD_SIZE must be positive, got {world_size}") + if 8 % world_size != 0: + raise ValueError(f"WORLD_SIZE={world_size} must divide 8 so grad_accum_steps stays integral") + grad_accum_steps = 8 // world_size + grad_scale = 1.0 / grad_accum_steps + if not torch.cuda.is_available(): + raise RuntimeError("CUDA is required") + device = torch.device("cuda", local_rank) + torch.cuda.set_device(device) + if distributed: + dist.init_process_group(backend="nccl", device_id=device) + dist.barrier() + master_process = rank == 0 + torch.backends.cuda.matmul.allow_tf32 = True + torch.backends.cudnn.allow_tf32 = True + from torch.backends.cuda import enable_cudnn_sdp, enable_flash_sdp, enable_math_sdp, enable_mem_efficient_sdp + _use_flash = sys.platform != "win32" + enable_cudnn_sdp(False) + enable_flash_sdp(_use_flash) + enable_mem_efficient_sdp(False) + enable_math_sdp(not _use_flash) + # Create per-run output directory so artifacts/logs never collide + run_dir = Path(os.environ.get("RUN_OUTPUT_DIR", f"runs/{args.run_id}")) + if master_process: + run_dir.mkdir(parents=True, exist_ok=True) + # Backward-compat: also keep logs/ symlink/copy + os.makedirs("logs", exist_ok=True) + logfile = None + if master_process: + logfile = str(run_dir / "log.txt") + # Also write to legacy logs/ path for compatibility + legacy_logfile = f"logs/{args.run_id}.txt" + if not Path(legacy_logfile).exists(): + try: + os.symlink(os.path.abspath(logfile), legacy_logfile) + except OSError: + pass # Windows or cross-device; will write to both + print(logfile) + def log0(msg: str, console: bool = True) -> None: + if not master_process: + return + if console: + print(msg) + if logfile is not None: + with open(logfile, "a", encoding="utf-8") as f: + print(msg, file=f) + log0(code, console=False) + log0("=" * 100, console=False) + log0(f"Running Python {sys.version}", console=False) + log0(f"Running PyTorch {torch.__version__}", console=False) + log0( + subprocess.run(["nvidia-smi"], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True, check=False).stdout, + console=False, + ) + log0("=" * 100, console=False) + random.seed(args.seed) + np.random.seed(args.seed) + torch.manual_seed(args.seed) + torch.cuda.manual_seed_all(args.seed) + if not args.tokenizer_path.endswith(".model"): + raise ValueError(f"Script only setup for SentencePiece .model file: {args.tokenizer_path}") + sp = spm.SentencePieceProcessor(model_file=args.tokenizer_path) + if int(sp.vocab_size()) != args.vocab_size: + raise ValueError( + f"VOCAB_SIZE={args.vocab_size} does not match tokenizer vocab_size={int(sp.vocab_size())}" + ) + dataset_dir = Path(args.data_path).resolve() + actual_train_files = len(list(dataset_dir.glob("fineweb_train_*.bin"))) + val_tokens = load_validation_tokens(args.val_files, args.train_seq_len) + base_bytes_lut, has_leading_space_lut, is_boundary_token_lut = build_sentencepiece_luts( + sp, args.vocab_size, device + ) + log0(f"val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path={args.tokenizer_path}") + log0(f"train_loader:dataset:{dataset_dir.name} train_shards:{actual_train_files}") + log0(f"val_loader:shards pattern={args.val_files} tokens:{val_tokens.numel() - 1}") + CastedLinear._int6_qat = args.int6_qat + # Parse recur_layer_indices from args + recur_layer_indices = [int(x) for x in args.recur_layers.split(",") if x.strip()] if args.recur_layers else [] + # Parse ve_layer_indices from args + ve_layer_indices = [int(x) for x in args.ve_layers.split(",") if x.strip()] if args.ve_enabled else [] + base_model = GPT( + vocab_size=args.vocab_size, + num_layers=args.num_layers, + model_dim=args.model_dim, + num_heads=args.num_heads, + num_kv_heads=args.num_kv_heads, + mlp_mult=args.mlp_mult, + tie_embeddings=args.tie_embeddings, + tied_embed_init_std=args.tied_embed_init_std, + logit_softcap=args.logit_softcap, + rope_base=args.rope_base, + qk_gain_init=args.qk_gain_init, + bigramhash_buckets=args.bigramhash_buckets, + bigramhash_dim=args.bigramhash_dim, + use_smeargate=args.smeargate, + unet_skips=args.unet_skips, + rope_partial_dims=args.rope_partial_dims, + ln_scale=args.ln_scale, + xsa_layers=args.xsa_layers, + recurrence_repeats=args.recurrence_repeats, + recurrence_unique_head=args.recurrence_unique_head, + recurrence_unique_tail=args.recurrence_unique_tail, + moe_num_experts=args.moe_num_experts, + encoder_layers=args.encoder_layers, + recur_layer_indices=recur_layer_indices, + recur_start_step=args.recur_start_step, + parallel_start_layer=args.parallel_start_layer, + ve_layer_indices=ve_layer_indices, + ve_dim=args.ve_dim, + ).to(device).bfloat16() + for module in base_model.modules(): + if isinstance(module, CastedLinear): + module.float() + restore_low_dim_params_to_fp32(base_model) + # Keep parallel routing params in float32 + with torch.no_grad(): + if base_model.parallel_post_lambdas is not None: + base_model.parallel_post_lambdas.data = base_model.parallel_post_lambdas.data.float() + if base_model.parallel_resid_lambdas is not None: + base_model.parallel_resid_lambdas.data = base_model.parallel_resid_lambdas.data.float() + if int(os.environ.get("TORCHDYNAMO_DISABLE", "0")): + compiled_model = base_model + else: + compiled_model = torch.compile(base_model, dynamic=False, fullgraph=True) + model: nn.Module = DDP(compiled_model, device_ids=[local_rank], broadcast_buffers=False) if distributed else compiled_model + if args.recurrence_repeats > 0: + block_named_params = ( + list(base_model.head_blocks.named_parameters()) + + list(base_model.shared_block.named_parameters()) + + list(base_model.tail_blocks.named_parameters()) + ) + if base_model.recurrence_gate is not None: + block_named_params.append(("recurrence_gate", base_model.recurrence_gate)) + else: + block_named_params = list(base_model.blocks.named_parameters()) + matrix_params = [ + p + for name, p in block_named_params + if p.ndim == 2 and not any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) + ] + scalar_params = [ + p + for name, p in block_named_params + if p.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) + ] + if base_model.skip_weights.numel() > 0: + scalar_params.append(base_model.skip_weights) + if base_model.parallel_post_lambdas is not None: + scalar_params.append(base_model.parallel_post_lambdas) + if base_model.parallel_resid_lambdas is not None: + scalar_params.append(base_model.parallel_resid_lambdas) + token_lr = args.tied_embed_lr if args.tie_embeddings else args.embed_lr + embed_params: list[nn.Parameter] = [base_model.tok_emb.weight] + if base_model.bigram_hash is not None: + embed_params.append(base_model.bigram_hash.emb.weight) + if base_model.bigram_hash.proj is not None: + embed_params.append(base_model.bigram_hash.proj.weight) + embed_params.append(base_model.bigram_hash.scale) + if base_model.value_embedding is not None: + embed_params.append(base_model.value_embedding.emb.weight) + embed_params.append(base_model.value_embedding.proj.weight) + optimizer_tok = torch.optim.Adam( + [{"params": embed_params, "lr": token_lr, "base_lr": token_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + fused=True, + ) + optimizer_muon = Muon( + matrix_params, + lr=args.matrix_lr, + momentum=args.muon_momentum, + backend_steps=args.muon_backend_steps, + ) + for group in optimizer_muon.param_groups: + group["base_lr"] = args.matrix_lr + group["polar_express"] = args.polar_express + group["muoneq_r"] = args.muoneq_r + group["weight_decay"] = args.weight_decay + optimizer_scalar = torch.optim.Adam( + [{"params": scalar_params, "lr": args.scalar_lr, "base_lr": args.scalar_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + fused=True, + ) + optimizers: list[torch.optim.Optimizer] = [optimizer_tok, optimizer_muon, optimizer_scalar] + if base_model.lm_head is not None: + optimizer_head = torch.optim.Adam( + [{"params": [base_model.lm_head.weight], "lr": args.head_lr, "base_lr": args.head_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + fused=True, + ) + optimizers.insert(1, optimizer_head) + # Add repeat_mlp weights to the Muon optimizer param group + if base_model.recur_layer_indices: + repeat_mlp_matrix_params = [] + for li in base_model.recur_layer_indices: + block = base_model.blocks[li] + if block.repeat_mlp is not None: + repeat_mlp_matrix_params.append(block.repeat_mlp.fc.weight) + repeat_mlp_matrix_params.append(block.repeat_mlp.proj.weight) + if repeat_mlp_matrix_params: + optimizer_muon.add_param_group({ + "params": repeat_mlp_matrix_params, + "lr": args.matrix_lr, + "base_lr": args.matrix_lr, + "momentum": args.muon_momentum, + "backend_steps": args.muon_backend_steps, + "nesterov": True, + }) + n_params = sum(p.numel() for p in base_model.parameters()) + log0(f"model_params:{n_params}") + log0(f"world_size:{world_size} grad_accum_steps:{grad_accum_steps}") + log0(f"sdp_backends:cudnn=False flash={_use_flash} mem_efficient=False math={not _use_flash}") + log0(f"attention_mode:gqa num_heads:{args.num_heads} num_kv_heads:{args.num_kv_heads}") + log0( + f"tie_embeddings:{args.tie_embeddings} embed_lr:{token_lr} " + f"head_lr:{args.head_lr if base_model.lm_head is not None else 0.0} " + f"matrix_lr:{args.matrix_lr} scalar_lr:{args.scalar_lr}" + ) + log0( + f"train_batch_tokens:{args.train_batch_tokens} train_seq_len:{args.train_seq_len} " + f"iterations:{args.iterations} warmup_steps:{args.warmup_steps} " + f"max_wallclock_seconds:{args.max_wallclock_seconds:.3f}" + ) + log0(f"seed:{args.seed}") + log0( + f"meta_baseline: int6_qat:{args.int6_qat} bigramhash:{args.bigramhash_buckets} " + f"smeargate:{args.smeargate} ema_decay:{args.ema_decay} unet_skips:{args.unet_skips} " + f"warmdown_iters:{args.warmdown_iters} compression:{'brotli' if _HAS_BROTLI else 'zstd-22' if _HAS_ZSTD else 'lzma-9'}" + ) + log0( + f"strong_gains: leaky_relu_sq:True rope_partial_dims:{args.rope_partial_dims} " + f"ln_scale:{args.ln_scale} xsa_layers:{args.xsa_layers}" + ) + if args.recurrence_repeats > 0: + eff = args.recurrence_unique_head + args.recurrence_repeats + args.recurrence_unique_tail + log0( + f"depth_recurrence: head={args.recurrence_unique_head} repeats={args.recurrence_repeats} " + f"tail={args.recurrence_unique_tail} effective_layers={eff}" + ) + if base_model.recur_layer_indices: + log0( + f"depth_recurrence_v2: recur_layers={base_model.recur_layer_indices} " + f"recur_start_step={base_model.recur_start_step} " + f"(inactive until step {base_model.recur_start_step})" + ) + train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) + def zero_grad_all() -> None: + for opt in optimizers: + opt.zero_grad(set_to_none=True) + max_wallclock_ms = 1000.0 * args.max_wallclock_seconds if args.max_wallclock_seconds > 0 else None + def lr_mul(step: int, elapsed_ms: float) -> float: + if args.warmdown_iters <= 0: + return 1.0 + if max_wallclock_ms is None: + warmdown_start = max(args.iterations - args.warmdown_iters, 0) + if not (warmdown_start <= step < args.iterations): + return 1.0 + t_frac = (step - warmdown_start) / max(args.warmdown_iters, 1) + else: + step_ms = elapsed_ms / max(step, 1) + warmdown_ms = args.warmdown_iters * step_ms + remaining_ms = max(max_wallclock_ms - elapsed_ms, 0.0) + if remaining_ms > warmdown_ms: + return 1.0 + t_frac = 1.0 - remaining_ms / max(warmdown_ms, 1e-9) + if args.warmdown_schedule == "sqrt": + return max(1.0 - (t_frac ** 0.5), 0) + else: + return max(1.0 - t_frac, 0) + if args.warmup_steps > 0: + initial_model_state = {name: tensor.detach().cpu().clone() for name, tensor in base_model.state_dict().items()} + initial_optimizer_states = [copy.deepcopy(opt.state_dict()) for opt in optimizers] + model.train() + for warmup_step in range(args.warmup_steps): + zero_grad_all() + for micro_step in range(grad_accum_steps): + if distributed: + model.require_backward_grad_sync = micro_step == grad_accum_steps - 1 + x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + warmup_loss = model(x, y) + (warmup_loss * grad_scale).backward() + for opt in optimizers: + opt.step() + zero_grad_all() + if args.warmup_steps <= 20 or (warmup_step + 1) % 10 == 0 or warmup_step + 1 == args.warmup_steps: + log0(f"warmup_step:{warmup_step + 1}/{args.warmup_steps}") + base_model.load_state_dict(initial_model_state, strict=True) + for opt, state in zip(optimizers, initial_optimizer_states, strict=True): + opt.load_state_dict(state) + zero_grad_all() + if distributed: + model.require_backward_grad_sync = True + train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) + ema_state: dict[str, Tensor] | None = None + if args.ema_decay > 0: + ema_state = {name: param.data.clone() for name, param in base_model.named_parameters()} + log0(f"ema:enabled decay={args.ema_decay}") + # SWA initialization + if args.swa_enabled: + swa_state = {} + swa_count = 0 + training_time_ms = 0.0 + stop_after_step: int | None = None + torch.cuda.synchronize() + t0 = time.perf_counter() + step = 0 + while True: + last_step = step == args.iterations or (stop_after_step is not None and step >= stop_after_step) + should_validate = last_step or (args.val_loss_every > 0 and step % args.val_loss_every == 0) + if should_validate: + torch.cuda.synchronize() + training_time_ms += 1000.0 * (time.perf_counter() - t0) + ema_backup: dict[str, Tensor] | None = None + if ema_state is not None: + ema_backup = {} + with torch.no_grad(): + for name, p in base_model.named_parameters(): + ema_backup[name] = p.data.clone() + p.data.copy_(ema_state[name]) + val_loss, val_bpb = eval_val( + args, + model, + rank, + world_size, + device, + grad_accum_steps, + val_tokens, + base_bytes_lut, + has_leading_space_lut, + is_boundary_token_lut, + ) + if ema_backup is not None: + with torch.no_grad(): + for name, p in base_model.named_parameters(): + p.data.copy_(ema_backup[name]) + del ema_backup + log0( + f"step:{step}/{args.iterations} val_loss:{val_loss:.4f} val_bpb:{val_bpb:.4f} " + f"train_time:{training_time_ms:.0f}ms step_avg:{training_time_ms / max(step, 1):.2f}ms" + ) + torch.cuda.synchronize() + t0 = time.perf_counter() + if last_step: + if stop_after_step is not None and step < args.iterations: + log0( + f"stopping_early: wallclock_cap train_time:{training_time_ms:.0f}ms " + f"step:{step}/{args.iterations}" + ) + break + elapsed_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) + scale = lr_mul(step, elapsed_ms) + zero_grad_all() + train_loss = torch.zeros((), device=device) + for micro_step in range(grad_accum_steps): + if distributed: + model.require_backward_grad_sync = micro_step == grad_accum_steps - 1 + x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + loss = model(x, y) + train_loss += loss.detach() + (loss * grad_scale).backward() + train_loss /= grad_accum_steps + frac = min(step / args.muon_momentum_warmup_steps, 1.0) if args.muon_momentum_warmup_steps > 0 else 1.0 + muon_momentum = (1 - frac) * args.muon_momentum_warmup_start + frac * args.muon_momentum + for group in optimizer_muon.param_groups: + group["momentum"] = muon_momentum + for opt in optimizers: + for group in opt.param_groups: + group["lr"] = group["base_lr"] * scale + if args.grad_clip_norm > 0: + torch.nn.utils.clip_grad_norm_(base_model.parameters(), args.grad_clip_norm) + for opt in optimizers: + opt.step() + zero_grad_all() + if ema_state is not None: + with torch.no_grad(): + decay = args.ema_decay + for name, p in base_model.named_parameters(): + ema_state[name].lerp_(p.data, 1.0 - decay) + # SWA accumulation + if args.swa_enabled and scale < args.swa_start_frac and step % args.swa_every == 0: + if swa_count == 0: + swa_state = {name: p.data.clone().float() for name, p in base_model.named_parameters()} + else: + for name, p in base_model.named_parameters(): + swa_state[name].mul_(swa_count).add_(p.data.float()).div_(swa_count + 1) + swa_count += 1 + if master_process: + log0(f"SWA checkpoint {swa_count} at step {step} (lr_scale={scale:.4f})") + step += 1 + # Activate depth recurrence at the designated step + if base_model.recur_layer_indices: + if step == base_model.recur_start_step: + base_model._sync_repeat_mlp_from_base() + base_model._recur_active = True + if master_process: + log0(f"Depth recurrence activated at step {step}, layers {base_model.recur_layer_indices}") + if args.checkpoint_every > 0 and step % args.checkpoint_every == 0 and master_process: + ckpt_dir = run_dir / "checkpoints" + ckpt_dir.mkdir(exist_ok=True) + ckpt_path = ckpt_dir / f"ckpt_step{step}.pt" + ckpt_data = {"step": step, "model": base_model.state_dict()} + if ema_state is not None: + ckpt_data["ema"] = {k: v.cpu() for k, v in ema_state.items()} + torch.save(ckpt_data, ckpt_path) + ckpts = sorted(ckpt_dir.glob("ckpt_step*.pt"), key=lambda p: p.stat().st_mtime) + for old in ckpts[:-2]: + old.unlink(missing_ok=True) + log0(f"checkpoint:saved {ckpt_path} (kept {min(len(ckpts), 2)} latest)") + approx_training_time_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) + should_log_train = ( + args.train_log_every > 0 + and (step <= 10 or step % args.train_log_every == 0 or stop_after_step is not None) + ) + if should_log_train: + log0( + f"step:{step}/{args.iterations} train_loss:{train_loss.item():.4f} " + f"train_time:{approx_training_time_ms:.0f}ms step_avg:{approx_training_time_ms / step:.2f}ms" + ) + reached_cap = max_wallclock_ms is not None and approx_training_time_ms >= max_wallclock_ms + if distributed and max_wallclock_ms is not None: + reached_cap_tensor = torch.tensor(int(reached_cap), device=device) + dist.all_reduce(reached_cap_tensor, op=dist.ReduceOp.MAX) + reached_cap = bool(reached_cap_tensor.item()) + if stop_after_step is None and reached_cap: + stop_after_step = step + log0( + f"peak memory allocated: {torch.cuda.max_memory_allocated() // 1024 // 1024} MiB " + f"reserved: {torch.cuda.max_memory_reserved() // 1024 // 1024} MiB" + ) + # ======================================================================== + # EMA / SWA export + # ======================================================================== + if args.swa_enabled and swa_count > 0: + if master_process: + log0(f"Merging EMA + SWA ({swa_count} checkpoints)") + with torch.no_grad(): + for name, p in base_model.named_parameters(): + ema_val = ema_state[name] + swa_val = swa_state[name].to(dtype=ema_val.dtype) + p.data.copy_((ema_val + swa_val) * 0.5) + else: + if ema_state is not None: + with torch.no_grad(): + for name, p in base_model.named_parameters(): + p.data.copy_(ema_state[name]) + log0("ema:applied EMA weights for export") + fp32_path = str(run_dir / "final_model.pt") + if master_process: + torch.save(base_model.state_dict(), fp32_path) + model_bytes = os.path.getsize(fp32_path) + code_bytes = len(code.encode("utf-8")) + log0(f"Serialized model: {model_bytes} bytes -> {fp32_path}") + log0(f"Code size: {code_bytes} bytes") + log0(f"Total submission size: {model_bytes + code_bytes} bytes") + # ======================================================================== + # Pre-quantization TTT + # ======================================================================== + if args.ttt_enabled and args.ttt_prequant: + log0("prequant_ttt:starting TTT on FP32 EMA weights before quantization") + torch.cuda.synchronize() + t_prequant_ttt = time.perf_counter() + if args.ttt_lora_enabled: + prequant_ttt_loss, prequant_ttt_bpb = eval_val_sliding_lora_ttt( + args, base_model, rank, world_size, device, + val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, + stride=args.eval_stride if args.eval_stride > 0 else 64, + log0=log0, + ) + prequant_ttt_label = "prequant_lora_ttt" + else: + prequant_ttt_loss, prequant_ttt_bpb = eval_val_sliding_ttt( + args, base_model, rank, world_size, device, + val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, + stride=args.eval_stride if args.eval_stride > 0 else 64, + log0=log0, + ) + prequant_ttt_label = "prequant_ttt" + torch.cuda.synchronize() + log0( + f"{prequant_ttt_label} val_loss:{prequant_ttt_loss:.4f} val_bpb:{prequant_ttt_bpb:.4f} " + f"eval_time:{1000.0 * (time.perf_counter() - t_prequant_ttt):.0f}ms " + f"optimizer:{args.ttt_optimizer}" + ) + log0(f"{prequant_ttt_label}_exact val_loss:{prequant_ttt_loss:.8f} val_bpb:{prequant_ttt_bpb:.8f}") + # ======================================================================== + # Full Hessian GPTQ (before standard quantization) + # ======================================================================== + if args.gptq_full_hessian: + if master_process: + log0("Running Full Hessian GPTQ...") + calib_data: Tensor + if args.gptq_ar_selfgen: + if master_process: + log0(f"Generating {args.gptq_calib_samples} AR calibration sequences...") + base_model.eval() + calib_data = generate_ar_calibration( + base_model, args.gptq_calib_samples, args.gptq_calib_seqlen, + args.gptq_calib_temp, args.vocab_size + ) + base_model.train() + else: + raise ValueError("gptq_ar_selfgen=False requires an external calibration source (not implemented)") + gptq_modules: dict[str, GPTQ] = {} + for name, module in base_model.named_modules(): + if isinstance(module, CastedLinear): + gptq_modules[name] = GPTQ(module.weight.data) + base_model.eval() + with torch.no_grad(): + collect_hessians(base_model, calib_data, gptq_modules) + base_model.train() + with torch.no_grad(): + for name, module in base_model.named_modules(): + if isinstance(module, CastedLinear) and name in gptq_modules: + gptq_obj = gptq_modules[name] + Q = gptq_obj.quantize( + module.weight.data, + qmax=_QUANT_MAX_VAL, + blocksize=args.gptq_blocksize, + percdamp=args.gptq_damp, + actorder=args.gptq_actorder + ) + module.weight.data.copy_(Q) + if master_process: + log0("Full Hessian GPTQ complete") + # ======================================================================== + # Quantization + # ======================================================================== + quant_label = f"int{_QUANT_BITS}" + compress_label = "brotli" if _HAS_BROTLI else "zstd" if _HAS_ZSTD else "lzma" + log0(f"quantization:start {quant_label} distributing across {world_size} GPUs") + torch.cuda.synchronize() + t_quant = time.perf_counter() + quant_obj, quant_stats = quantize_state_dict_int8( + base_model.state_dict(), gpu_device=device, rank=rank, world_size=world_size, + ) + torch.cuda.synchronize() + log0(f"quantization:done {quant_label} time={1000.0 * (time.perf_counter() - t_quant):.0f}ms") + quant_buf = io.BytesIO() + torch.save(quant_obj, quant_buf) + quant_raw = quant_buf.getvalue() + if _HAS_BROTLI: + shuffled = _byte_shuffle(quant_raw) + quant_blob = _brotli.compress(shuffled, quality=11) + elif _HAS_ZSTD: + quant_blob = zstd.ZstdCompressor(level=22, threads=-1).compress(quant_raw) + else: + quant_blob = lzma.compress(quant_raw, preset=9) + quant_raw_bytes = len(quant_raw) + artifact_name = str(run_dir / f"final_model.{quant_label}.ptz") + if master_process: + with open(artifact_name, "wb") as f: + f.write(quant_blob) + quant_file_bytes = os.path.getsize(artifact_name) + code_bytes = len(code.encode("utf-8")) + ratio = quant_stats["baseline_tensor_bytes"] / max(quant_stats["int8_payload_bytes"], 1) + log0( + f"Serialized model {quant_label}+{compress_label}: {quant_file_bytes} bytes " + f"(payload:{quant_stats['int8_payload_bytes']} raw_torch:{quant_raw_bytes} payload_ratio:{ratio:.2f}x)" + ) + log0(f"Total submission size {quant_label}+{compress_label}: {quant_file_bytes + code_bytes} bytes") + log0(f"artifact_path:{artifact_name}") + # ======================================================================== + # Decompression + roundtrip eval + # ======================================================================== + if distributed: + dist.barrier() + with open(artifact_name, "rb") as f: + quant_blob_disk = f.read() + if _HAS_BROTLI: + quant_raw = _byte_unshuffle(_brotli.decompress(quant_blob_disk)) + elif _HAS_ZSTD: + quant_raw = zstd.ZstdDecompressor().decompress(quant_blob_disk) + else: + quant_raw = lzma.decompress(quant_blob_disk) + quant_state = torch.load(io.BytesIO(quant_raw), map_location="cpu") + base_model.load_state_dict(dequantize_state_dict_int8(quant_state), strict=True) + torch.cuda.synchronize() + t_qeval = time.perf_counter() + q_val_loss, q_val_bpb = eval_val( + args, + model, + rank, + world_size, + device, + grad_accum_steps, + val_tokens, + base_bytes_lut, + has_leading_space_lut, + is_boundary_token_lut, + ) + torch.cuda.synchronize() + log0( + f"final_{quant_label}_{compress_label}_roundtrip val_loss:{q_val_loss:.4f} val_bpb:{q_val_bpb:.4f} " + f"eval_time:{1000.0 * (time.perf_counter() - t_qeval):.0f}ms" + ) + log0(f"final_{quant_label}_{compress_label}_roundtrip_exact val_loss:{q_val_loss:.8f} val_bpb:{q_val_bpb:.8f}") + if args.eval_stride > 0 and args.eval_stride < args.train_seq_len: + torch.cuda.synchronize() + t_slide = time.perf_counter() + sw_val_loss, sw_val_bpb = eval_val_sliding( + args, base_model, rank, world_size, device, + val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, + stride=args.eval_stride, + ) + torch.cuda.synchronize() + log0( + f"final_sliding_window val_loss:{sw_val_loss:.4f} val_bpb:{sw_val_bpb:.4f} " + f"stride:{args.eval_stride} eval_time:{1000.0 * (time.perf_counter() - t_slide):.0f}ms" + ) + log0(f"final_sliding_window_exact val_loss:{sw_val_loss:.8f} val_bpb:{sw_val_bpb:.8f}") + # ======================================================================== + # Post-quantization TTT (only if not already done pre-quant) + # ======================================================================== + if args.ttt_enabled and not args.ttt_prequant: + torch.cuda.synchronize() + t_ttt = time.perf_counter() + + # Choose LoRA TTT or full TTT based on TTT_LORA_ENABLED flag + if args.ttt_lora_enabled: + ttt_loss, ttt_bpb = eval_val_sliding_lora_ttt( + args, base_model, rank, world_size, device, + val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, + stride=args.eval_stride if args.eval_stride > 0 else 64, + log0=log0, + ) + ttt_label = "legal_lora_ttt" + else: + ttt_loss, ttt_bpb = eval_val_sliding_ttt( + args, base_model, rank, world_size, device, + val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, + stride=args.eval_stride if args.eval_stride > 0 else 64, + log0=log0, + ) + ttt_label = "legal_ttt" + + torch.cuda.synchronize() + log0( + f"{ttt_label} val_loss:{ttt_loss:.4f} val_bpb:{ttt_bpb:.4f} " + f"eval_time:{1000.0 * (time.perf_counter() - t_ttt):.0f}ms" + ) + log0(f"{ttt_label}_exact val_loss:{ttt_loss:.8f} val_bpb:{ttt_bpb:.8f}") + if int(os.environ.get("NGRAM_EVAL", "0")): + ngram_max_order = int(os.environ.get("NGRAM_MAX_ORDER", "9")) + log0(f"ngram_eval:starting max_order={ngram_max_order}") + torch.cuda.synchronize() + t_ngram = time.perf_counter() + try: + from ngram_eval import eval_val_ngram + ng_val_loss, ng_val_bpb = eval_val_ngram( + args, base_model, model, rank, world_size, device, grad_accum_steps, + val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, + max_order=ngram_max_order, log_fn=log0, + ) + torch.cuda.synchronize() + log0( + f"ngram_eval_result val_loss:{ng_val_loss:.4f} val_bpb:{ng_val_bpb:.4f} " + f"eval_time:{1000.0 * (time.perf_counter() - t_ngram):.0f}ms" + ) + log0(f"ngram_eval_result_exact val_loss:{ng_val_loss:.8f} val_bpb:{ng_val_bpb:.8f}") + except Exception as e: + log0(f"ngram_eval:FAILED {e}") + if distributed: + dist.destroy_process_group() +if __name__ == "__main__": + main() diff --git a/train_gpt_mlx.py b/train_gpt_mlx.py deleted file mode 100644 index 7b9e935aa6..0000000000 --- a/train_gpt_mlx.py +++ /dev/null @@ -1,1104 +0,0 @@ -#!/usr/bin/env python3 -""" -The `train_gpt.py` and `train_gpt_mlx.py` scripts are intended as good launching-off points for new participants, not SOTA configs. We'll accept PRs that tune, improve, or simplify these scripts without significantly increasing complexity, but competitive submissions should stay in the `/records` folder. - -Hard stop: To keep readable for newcomers, let's make sure `train_gpt.py` and `train_gpt_mlx.py` never are longer than 1500 lines. -""" -from __future__ import annotations - -import glob -import json -import math -import os -import pickle -import sys -import time -import uuid -import zlib -from collections.abc import Callable -from pathlib import Path - -import numpy as np -import sentencepiece as spm - -import mlx.core as mx -import mlx.nn as nn -import mlx.optimizers as optim -from mlx.utils import tree_flatten, tree_unflatten - -# ============================================================================== -# SHARD FORMAT + COMPUTE DTYPE -# ============================================================================== - -COMPUTE_DTYPE = mx.bfloat16 - -# ============================================================================== -# HYPERPARAMETERS -# ============================================================================== -# Default Simple Baseline run: -# - 9 transformer blocks at width 512 -# - 8 attention heads with 4 KV heads (GQA) and 2x MLP expansion -# - vocab size 1024, sequence length 1024, tied embeddings -# - 524,288 train tokens per step for 20,000 iterations with a ~10 minute cap -class Hyperparameters: - # Data / tokenizer. - data_path: str = os.environ.get("DATA_PATH", "./data/datasets/fineweb10B_sp1024") - tokenizer_path: str = os.environ.get("TOKENIZER_PATH", "./data/tokenizers/fineweb_1024_bpe.model") - run_id: str = os.environ.get("RUN_ID", str(uuid.uuid4())) - seed: int = int(os.environ.get("SEED", 1337)) - - # Training loop. These defaults now mirror train_gpt.py on a single process. - iterations: int = int(os.environ.get("ITERATIONS", 20_000)) - val_loss_every: int = int(os.environ.get("VAL_LOSS_EVERY", 0)) - # Validation always uses the full fineweb_val split. - val_batch_size: int = int(os.environ.get("VAL_BATCH_SIZE", 524_288)) - train_log_every: int = int(os.environ.get("TRAIN_LOG_EVERY", 200)) - train_batch_tokens: int = int(os.environ.get("TRAIN_BATCH_TOKENS", 524_288)) - grad_accum_steps: int = int(os.environ.get("GRAD_ACCUM_STEPS", 8)) - train_seq_len: int = int(os.environ.get("TRAIN_SEQ_LEN", os.environ.get("TRAIN_MAX_SEQ_LEN", 1024))) - # Chunk each logical MLX microbatch into smaller sub-batches to reduce peak - # memory pressure without changing the effective optimizer batch. - mlx_max_microbatch_tokens: int = int(os.environ.get("MLX_MAX_MICROBATCH_TOKENS", 8_192)) - # Force MLX to materialize the graph after every sub-batch, preventing lazy - # graph buildup across accumulation steps. Keeps peak memory low on 16GB machines. - # Disable on 32GB+ unified memory for better throughput (MLX_EAGER_EVAL=0). - mlx_eager_eval: bool = bool(int(os.environ.get("MLX_EAGER_EVAL", "1"))) - warmup_steps: int = int(os.environ.get("WARMUP_STEPS", 20)) - warmdown_iters: int = int(os.environ.get("WARMDOWN_ITERS", 1200)) - max_wallclock_seconds: float = float(os.environ.get("MAX_WALLCLOCK_SECONDS", 600.0)) - - # Model (defaults match the current baseline setup). - vocab_size: int = int(os.environ.get("VOCAB_SIZE", 1024)) - num_layers: int = int(os.environ.get("NUM_LAYERS", 9)) - model_dim: int = int(os.environ.get("MODEL_DIM", 512)) - num_heads: int = int(os.environ.get("NUM_HEADS", 8)) - num_kv_heads: int = int(os.environ.get("NUM_KV_HEADS", 4)) - mlp_mult: int = int(os.environ.get("MLP_MULT", 2)) - tie_embeddings: bool = bool(int(os.environ.get("TIE_EMBEDDINGS", "1"))) - tied_embed_init_std: float = float(os.environ.get("TIED_EMBED_INIT_STD", 0.005)) - logit_chunk_tokens: int = int(os.environ.get("LOGIT_CHUNK_TOKENS", 0)) - logit_softcap: float = float(os.environ.get("LOGIT_SOFTCAP", 30.0)) - rope_base: float = float(os.environ.get("ROPE_BASE", 10000.0)) - qk_gain_init: float = float(os.environ.get("QK_GAIN_INIT", 1.5)) - - # Optimizer. We keep the same per-group defaults as train_gpt.py. - beta1: float = float(os.environ.get("BETA1", 0.9)) - beta2: float = float(os.environ.get("BETA2", 0.95)) - adam_eps: float = float(os.environ.get("ADAM_EPS", 1e-8)) - tied_embed_lr: float = float(os.environ.get("TIED_EMBED_LR", 0.05)) - matrix_lr: float = float(os.environ.get("MATRIX_LR", 0.04)) - scalar_lr: float = float(os.environ.get("SCALAR_LR", 0.04)) - muon_momentum: float = float(os.environ.get("MUON_MOMENTUM", 0.95)) - muon_backend_steps: int = int(os.environ.get("MUON_BACKEND_STEPS", 5)) - muon_momentum_warmup_start: float = float(os.environ.get("MUON_MOMENTUM_WARMUP_START", 0.85)) - muon_momentum_warmup_steps: int = int(os.environ.get("MUON_MOMENTUM_WARMUP_STEPS", 500)) - grad_clip_norm: float = float(os.environ.get("GRAD_CLIP_NORM", 0.0)) - - out_dir: str = os.environ.get("OUT_DIR", "logs") - - @property - def train_files(self) -> str: - return f"{self.data_path}/fineweb_train_*.bin" - - @property - def val_files(self) -> str: - return f"{self.data_path}/fineweb_val_*.bin" - - @property - def microbatch_tokens(self) -> int: - return self.train_batch_tokens // self.grad_accum_steps - - def lr_mul(self, step: int, elapsed_ms: float) -> float: - if self.warmdown_iters <= 0: - return 1.0 - if self.max_wallclock_seconds <= 0: - warmdown_start = max(self.iterations - self.warmdown_iters, 0) - return max((self.iterations - step) / max(self.warmdown_iters, 1), 0.0) if warmdown_start <= step < self.iterations else 1.0 - step_ms = elapsed_ms / max(step, 1) - warmdown_ms = self.warmdown_iters * step_ms - remaining_ms = max(1000.0 * self.max_wallclock_seconds - elapsed_ms, 0.0) - return remaining_ms / max(warmdown_ms, 1e-9) if remaining_ms <= warmdown_ms else 1.0 - - -CONTROL_TENSOR_NAME_PATTERNS = tuple( - pattern - for pattern in os.environ.get( - "CONTROL_TENSOR_NAME_PATTERNS", - "attn_scale,attn_scales,mlp_scale,mlp_scales,resid_mix,resid_mixes,q_gain,skip_weight,skip_weights", - ).split(",") - if pattern -) -INT8_KEEP_FLOAT_FP32_NAME_PATTERNS = tuple( - pattern - for pattern in os.environ.get( - "INT8_KEEP_FLOAT_FP32_NAME_PATTERNS", - ",".join(CONTROL_TENSOR_NAME_PATTERNS), - ).split(",") - if pattern -) - - -def token_chunks(total_tokens: int, seq_len: int, max_chunk_tokens: int) -> list[int]: - usable_total = (total_tokens // seq_len) * seq_len - if usable_total <= 0: - raise ValueError(f"token budget too small for seq_len={seq_len}") - usable_chunk = max((max_chunk_tokens // seq_len) * seq_len, seq_len) - chunks: list[int] = [] - remaining = usable_total - while remaining > 0: - chunk = min(remaining, usable_chunk) - chunks.append(chunk) - remaining -= chunk - return chunks - - -def accumulate_flat_grads( - accum: dict[str, mx.array] | None, - grads_tree: dict, - scale: float, -) -> dict[str, mx.array]: - flat = dict(tree_flatten(grads_tree)) - if accum is None: - return {k: g * scale for k, g in flat.items()} - for k, g in flat.items(): - accum[k] = accum[k] + g * scale - return accum - - -# ============================================================================== -# MATH HELPERS -# ============================================================================== - -def rms_norm(x: mx.array, eps: float = 1e-6) -> mx.array: - return (x * mx.rsqrt(mx.mean(x * x, axis=-1, keepdims=True) + eps)).astype(x.dtype) - - -def zeropower_newtonschulz5(g: mx.array, steps: int, eps: float = 1e-7) -> mx.array: - # Orthogonalize a 2D update matrix with a fast Newton-Schulz iteration. - # Muon uses this to normalize matrix-shaped gradients before applying them. - # Background on Muon: https://kellerjordan.github.io/posts/muon/ - a, b, c = 3.4445, -4.7750, 2.0315 - x = g.astype(mx.float32) - x = x / (mx.sqrt(mx.sum(x * x)) + eps) - transposed = x.shape[0] > x.shape[1] - if transposed: - x = x.T - for _ in range(steps): - a_mat = x @ x.T - b_mat = b * a_mat + c * (a_mat @ a_mat) - x = a * x + b_mat @ x - if transposed: - x = x.T - return x.astype(g.dtype) - - -def load_data_shard(path: Path) -> np.ndarray: - header_bytes = 256 * np.dtype(" None: - self.file_idx = (self.file_idx + 1) % len(self.files) - if self.file_idx == 0: - self.epoch += 1 - if self.log_fn is not None: - self.log_fn( - f"WARNING: starting epoch:{self.epoch} " - f"dataset:{self.dataset_name} train_shards:{len(self.files)}" - ) - self.tokens = load_data_shard(self.files[self.file_idx]) - self.pos = 0 - - def take(self, n: int) -> np.ndarray: - chunks: list[np.ndarray] = [] - left = n - while left > 0: - if self.pos >= self.tokens.size: - self.next_file() - k = min(left, int(self.tokens.size - self.pos)) - chunks.append(self.tokens[self.pos : self.pos + k]) - self.pos += k - left -= k - return chunks[0] if len(chunks) == 1 else np.concatenate(chunks, axis=0) - - -class TokenLoader: - def __init__( - self, - pattern: str, - log_fn: Callable[[str], None] | None = None, - dataset_name: str = "", - ): - self.stream = TokenStream(pattern, log_fn=log_fn, dataset_name=dataset_name) - - def next_batch(self, batch_tokens: int, seq_len: int) -> tuple[mx.array, mx.array]: - usable = (batch_tokens // seq_len) * seq_len - if usable <= 0: - raise ValueError(f"token budget too small for seq_len={seq_len}") - chunk = self.stream.take(usable + 1) - x = chunk[:-1].reshape(-1, seq_len) - y = chunk[1:].reshape(-1, seq_len) - return mx.array(x, dtype=mx.int32), mx.array(y, dtype=mx.int32) - - -# ============================================================================== -# MODEL BLOCKS -# ============================================================================== - -class CastedLinear(nn.Module): - def __init__(self, in_dim: int, out_dim: int): - super().__init__() - self.weight = nn.Linear(in_dim, out_dim, bias=False).weight.astype(mx.float32) - - def __call__(self, x: mx.array) -> mx.array: - return x @ self.weight.astype(x.dtype).T - - -class RMSNormNoWeight(nn.Module): - # MLX module wrapper around the functional RMSNorm helper so it composes nicely in blocks. - def __call__(self, x: mx.array) -> mx.array: - return rms_norm(x) - - -class CausalSelfAttention(nn.Module): - # - separate q/k/v projections - # - RMSNorm on q and k before attention - # - RoPE on q and k - # - causal masked SDPA - def __init__( - self, - dim: int, - num_heads: int, - num_kv_heads: int, - rope_base: float, - qk_gain_init: float, - ): - super().__init__() - if dim % num_heads != 0: - raise ValueError("model_dim must be divisible by num_heads") - if num_heads % num_kv_heads != 0: - raise ValueError("num_heads must be divisible by num_kv_heads") - self.num_heads = num_heads - self.num_kv_heads = num_kv_heads - self.head_dim = dim // num_heads - if self.head_dim % 2 != 0: - raise ValueError("head_dim must be even for RoPE") - kv_dim = self.num_kv_heads * self.head_dim - self.c_q = CastedLinear(dim, dim) - self.c_k = CastedLinear(dim, kv_dim) - self.c_v = CastedLinear(dim, kv_dim) - self.proj = CastedLinear(dim, dim) - self.q_gain = mx.ones((num_heads,), dtype=mx.float32) * qk_gain_init - self.rope = nn.RoPE(self.head_dim, traditional=False, base=rope_base) - self.scale = self.head_dim ** -0.5 - - def __call__(self, x: mx.array) -> mx.array: - bsz, seqlen, dim = x.shape - q = self.c_q(x).reshape(bsz, seqlen, self.num_heads, self.head_dim).transpose(0, 2, 1, 3) - k = self.c_k(x).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim).transpose(0, 2, 1, 3) - v = self.c_v(x).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim).transpose(0, 2, 1, 3) - - q = self.rope(rms_norm(q).astype(COMPUTE_DTYPE)) - k = self.rope(rms_norm(k).astype(COMPUTE_DTYPE)) - q = q * self.q_gain.astype(q.dtype)[None, :, None, None] - y = mx.fast.scaled_dot_product_attention(q, k, v, scale=self.scale, mask="causal") - y = y.transpose(0, 2, 1, 3).reshape(bsz, seqlen, dim) - return self.proj(y) - - -class MLP(nn.Module): - # Baseline MLP uses relu^2 instead of GELU/SiLU. It is cheap and works well in this setup. - def __init__(self, dim: int, mlp_mult: int): - super().__init__() - hidden = dim * mlp_mult - self.fc = CastedLinear(dim, hidden) - self.proj = CastedLinear(hidden, dim) - - def __call__(self, x: mx.array) -> mx.array: - x = nn.relu(self.fc(x)) - return self.proj(x * x) - - -class Block(nn.Module): - def __init__( - self, - dim: int, - num_heads: int, - num_kv_heads: int, - mlp_mult: int, - rope_base: float, - qk_gain_init: float, - ): - super().__init__() - self.attn_norm = RMSNormNoWeight() - self.mlp_norm = RMSNormNoWeight() - self.attn = CausalSelfAttention(dim, num_heads, num_kv_heads, rope_base, qk_gain_init) - self.mlp = MLP(dim, mlp_mult) - self.attn_scale = mx.ones((dim,), dtype=mx.float32) - self.mlp_scale = mx.ones((dim,), dtype=mx.float32) - self.resid_mix = mx.array(np.stack((np.ones((dim,), dtype=np.float32), np.zeros((dim,), dtype=np.float32)))) - - def __call__(self, x: mx.array, x0: mx.array) -> mx.array: - mix = self.resid_mix.astype(x.dtype) - x = mix[0][None, None, :] * x + mix[1][None, None, :] * x0 - attn_out = self.attn(self.attn_norm(x)) - x = x + self.attn_scale.astype(x.dtype)[None, None, :] * attn_out - x = x + self.mlp_scale.astype(x.dtype)[None, None, :] * self.mlp(self.mlp_norm(x)) - return x - - -class GPT(nn.Module): - # - token embedding + RMSNorm - # - encoder half accumulates skip tensors - # - decoder half consumes reversed skips with learned skip_weights - # - tied embeddings for the LM head (the baseline default setup) - def __init__(self, vocab_size: int, num_layers: int, dim: int, num_heads: int, num_kv_heads: int, mlp_mult: int, - logit_chunk_tokens: int, logit_softcap: float, rope_base: float, tied_embed_init_std: float, - qk_gain_init: float): - super().__init__() - if logit_softcap <= 0.0: - raise ValueError(f"logit_softcap must be positive, got {logit_softcap}") - self.logit_chunk_tokens = logit_chunk_tokens - self.logit_softcap = logit_softcap - - self.tok_emb = nn.Embedding(vocab_size, dim) - self.num_encoder_layers = num_layers // 2 - self.num_decoder_layers = num_layers - self.num_encoder_layers - self.num_skip_weights = min(self.num_encoder_layers, self.num_decoder_layers) - self.skip_weights = mx.ones((self.num_skip_weights, dim), dtype=mx.float32) - self.blocks = [ - Block(dim, num_heads, num_kv_heads, mlp_mult, rope_base, qk_gain_init) - for i in range(num_layers) - ] - self.final_norm = RMSNormNoWeight() - - for b in self.blocks: - b.attn.proj.weight = mx.zeros_like(b.attn.proj.weight) - b.mlp.proj.weight = mx.zeros_like(b.mlp.proj.weight) - self.tok_emb.weight = ( - mx.random.normal(self.tok_emb.weight.shape, dtype=mx.float32) * tied_embed_init_std - ).astype(COMPUTE_DTYPE) - - def softcap(self, logits: mx.array) -> mx.array: - c = self.logit_softcap - return c * mx.tanh(logits / c) - - def __call__(self, input_ids: mx.array) -> mx.array: - x = rms_norm(self.tok_emb(input_ids).astype(COMPUTE_DTYPE)) - x0 = x - skips: list[mx.array] = [] - - for i in range(self.num_encoder_layers): - x = self.blocks[i](x, x0) - skips.append(x) - for i in range(self.num_decoder_layers): - # Odd layer counts have one more decoder block than encoder block. The baseline only - # applies a skip connection when one exists, then runs the remaining decoder block(s) - # without an added skip. - if skips: - x = x + self.skip_weights[i].astype(x.dtype)[None, None, :] * skips.pop() - x = self.blocks[self.num_encoder_layers + i](x, x0) - return self.final_norm(x) - - def loss(self, input_ids: mx.array, target_ids: mx.array) -> mx.array: - # Cross-entropy over flattened tokens. We keep optional logit chunking because it is a useful - # memory knob on Macs, but the common path is chunk_tokens=0 (single matmul + CE). - x = self(input_ids).reshape(-1, self.tok_emb.weight.shape[1]) - y = target_ids.reshape(-1) - if self.logit_chunk_tokens <= 0 or x.shape[0] <= self.logit_chunk_tokens: - logits_proj = x @ self.tok_emb.weight.astype(x.dtype).T - logits = self.softcap(logits_proj) - return nn.losses.cross_entropy(logits.astype(mx.float32), y, reduction="mean") - - loss_sum = mx.array(0.0, dtype=mx.float32) - n = int(x.shape[0]) - for s in range(0, n, self.logit_chunk_tokens): - e = min(s + self.logit_chunk_tokens, n) - logits_proj = x[s:e] @ self.tok_emb.weight.astype(x.dtype).T - logits = self.softcap(logits_proj) - loss_sum = loss_sum + nn.losses.cross_entropy(logits.astype(mx.float32), y[s:e], reduction="sum") - return loss_sum / float(n) - -# ============================================================================== -# OPTIMIZERS (MUON + ADAM SPLIT) -# ============================================================================== -class Muon: - # Muon applies SGD-momentum to matrix gradients, then orthogonalizes the result before the - # parameter update. - def __init__(self, keys: list[str], params: dict[str, mx.array], args: Hyperparameters): - self.keys = keys - self.args = args - self.buffers = {k: mx.zeros_like(params[k]) for k in keys} - - def step(self, params: dict[str, mx.array], grads: dict[str, mx.array], step: int, lr_mul: float) -> dict[str, mx.array]: - if self.args.muon_momentum_warmup_steps: - t = min(step / self.args.muon_momentum_warmup_steps, 1.0) - momentum = (1.0 - t) * self.args.muon_momentum_warmup_start + t * self.args.muon_momentum - else: - momentum = self.args.muon_momentum - lr = self.args.matrix_lr * lr_mul - out: dict[str, mx.array] = {} - for k in self.keys: - p = params[k] - g = grads[k] - buf = momentum * self.buffers[k] + g - self.buffers[k] = buf - g_eff = g + momentum * buf - g_ortho = zeropower_newtonschulz5(g_eff, self.args.muon_backend_steps) - scale = math.sqrt(max(1.0, float(p.shape[0]) / float(p.shape[1]))) - out[k] = p - lr * (g_ortho * scale).astype(p.dtype) - return out - - -class SplitOptimizers: - # - embeddings: Adam with the tied-embedding LR - # - block matrices (2D): Muon - # - block scalars + skip weights: Adam - # This preserves the high-level optimization behavior even though MLX internals differ. - def __init__(self, model: GPT, args: Hyperparameters): - self.args = args - params = dict(tree_flatten(model.parameters())) - self.embed_key = "tok_emb.weight" - self.matrix_keys = [ - k - for k, p in params.items() - if k.startswith("blocks.") and p.ndim == 2 and not any(pattern in k for pattern in CONTROL_TENSOR_NAME_PATTERNS) - ] - self.scalar_keys = [ - k - for k, p in params.items() - if k == "skip_weights" or (k.startswith("blocks.") and (p.ndim < 2 or any(pattern in k for pattern in CONTROL_TENSOR_NAME_PATTERNS))) - ] - - self.muon = Muon(self.matrix_keys, params, args) - self.adam_embed = optim.Adam( - learning_rate=args.tied_embed_lr, - betas=[args.beta1, args.beta2], - eps=args.adam_eps, - bias_correction=True, - ) - self.adam_scalar = optim.Adam( - learning_rate=args.scalar_lr, - betas=[args.beta1, args.beta2], - eps=args.adam_eps, - bias_correction=True, - ) - - def step(self, model: GPT, grads_tree: dict, step: int, lr_mul: float) -> None: - params = dict(tree_flatten(model.parameters())) - grads = dict(tree_flatten(grads_tree)) - updated = dict(params) - - updated.update(self.muon.step(params, grads, step=step, lr_mul=lr_mul)) - - self.adam_embed.learning_rate = self.args.tied_embed_lr * lr_mul - updated.update( - self.adam_embed.apply_gradients( - {self.embed_key: grads[self.embed_key]}, - {self.embed_key: params[self.embed_key]}, - ) - ) - - self.adam_scalar.learning_rate = self.args.scalar_lr * lr_mul - scalar_grads = {k: grads[k] for k in self.scalar_keys} - scalar_params = {k: params[k] for k in self.scalar_keys} - updated.update(self.adam_scalar.apply_gradients(scalar_grads, scalar_params)) - - model.update(tree_unflatten(list(updated.items()))) - -# ============================================================================== -# QUANTIZATION (INT8 + ZLIB) -# ============================================================================== -# - per-row int8 for 2D float tensors -# - per-tensor int8 for other float tensors -# - fp16 passthrough for small float tensors -# - exact passthrough for non-floats - -MX_DTYPE_FROM_NAME = { - "float32": mx.float32, - "float16": mx.float16, - "bfloat16": mx.bfloat16, -} - -INT8_KEEP_FLOAT_MAX_NUMEL = 65_536 -INT8_KEEP_FLOAT_STORE_DTYPE = np.float16 -INT8_PER_ROW_SCALE_DTYPE = np.float16 -INT8_CLIP_PERCENTILE = 99.99984 -INT8_CLIP_Q = INT8_CLIP_PERCENTILE / 100.0 - - -def _np_float32(arr: mx.array) -> np.ndarray: - return np.array(arr.astype(mx.float32), dtype=np.float32, copy=False) - - -def keep_float_array(name: str, arr: mx.array, passthrough_orig_dtypes: dict[str, str]) -> np.ndarray: - if any(pattern in name for pattern in INT8_KEEP_FLOAT_FP32_NAME_PATTERNS): - return np.ascontiguousarray(_np_float32(arr)) - if arr.dtype in {mx.float32, mx.bfloat16}: - passthrough_orig_dtypes[name] = str(arr.dtype).split(".")[-1] - return np.ascontiguousarray(np.array(arr.astype(mx.float16), dtype=INT8_KEEP_FLOAT_STORE_DTYPE, copy=False)) - return np.ascontiguousarray(np.array(arr, copy=True)) - - -def quantize_float_array(arr: mx.array) -> tuple[np.ndarray, np.ndarray]: - f32 = _np_float32(arr) - if f32.ndim == 2: - # Matrices get one scale per row, which usually tracks output-channel - # ranges much better than a single tensor-wide scale. - clip_abs = np.quantile(np.abs(f32), INT8_CLIP_Q, axis=1) if f32.size else np.empty((f32.shape[0],), dtype=np.float32) - clipped = np.clip(f32, -clip_abs[:, None], clip_abs[:, None]) - scale = np.maximum(clip_abs / 127.0, 1.0 / 127.0).astype(np.float32, copy=False) - q = np.clip(np.round(clipped / scale[:, None]), -127, 127).astype(np.int8, copy=False) - return np.ascontiguousarray(q), np.ascontiguousarray(scale.astype(INT8_PER_ROW_SCALE_DTYPE, copy=False)) - - # Vectors / scalars use a simpler per-tensor scale. - clip_abs = float(np.quantile(np.abs(f32).reshape(-1), INT8_CLIP_Q)) if f32.size else 0.0 - scale = np.array(clip_abs / 127.0 if clip_abs > 0.0 else 1.0, dtype=np.float32) - q = np.clip(np.round(np.clip(f32, -clip_abs, clip_abs) / scale), -127, 127).astype(np.int8, copy=False) - return np.ascontiguousarray(q), scale - - -def quantize_state_dict_int8(flat_state: dict[str, mx.array]) -> tuple[dict[str, object], dict[str, int]]: - quantized: dict[str, np.ndarray] = {} - scales: dict[str, np.ndarray] = {} - dtypes: dict[str, str] = {} - passthrough: dict[str, np.ndarray] = {} - passthrough_orig_dtypes: dict[str, str] = {} - qmeta: dict[str, dict[str, object]] = {} - stats = dict.fromkeys( - ("param_count", "num_tensors", "num_float_tensors", "num_nonfloat_tensors", "baseline_tensor_bytes", "int8_payload_bytes"), - 0, - ) - for name, arr in flat_state.items(): - stats["param_count"] += int(arr.size) - stats["num_tensors"] += 1 - stats["baseline_tensor_bytes"] += int(arr.nbytes) - if not mx.issubdtype(arr.dtype, mx.floating): - stats["num_nonfloat_tensors"] += 1 - passthrough[name] = np.ascontiguousarray(np.array(arr)) - stats["int8_payload_bytes"] += int(passthrough[name].nbytes) - continue - - # Small float tensors are cheap enough to keep directly. We still downcast - # fp32/bf16 passthrough tensors to fp16 so metadata does not dominate size. - if int(arr.size) <= INT8_KEEP_FLOAT_MAX_NUMEL: - kept = keep_float_array(name, arr, passthrough_orig_dtypes) - passthrough[name] = kept - stats["int8_payload_bytes"] += int(kept.nbytes) - continue - - stats["num_float_tensors"] += 1 - q, s = quantize_float_array(arr) - if s.ndim > 0: - qmeta[name] = {"scheme": "per_row", "axis": 0} - quantized[name] = q - scales[name] = s - dtypes[name] = str(arr.dtype).split(".")[-1] - stats["int8_payload_bytes"] += int(q.nbytes + s.nbytes) - obj: dict[str, object] = { - "__quant_format__": "int8_clean_per_row_v1", - "quantized": quantized, - "scales": scales, - "dtypes": dtypes, - "passthrough": passthrough, - } - if qmeta: - obj["qmeta"] = qmeta - if passthrough_orig_dtypes: - obj["passthrough_orig_dtypes"] = passthrough_orig_dtypes - return obj, stats - - -def dequantize_state_dict_int8(quant_obj: dict[str, object]) -> dict[str, mx.array]: - out: dict[str, mx.array] = {} - qmeta = quant_obj.get("qmeta", {}) - passthrough_orig_dtypes = quant_obj.get("passthrough_orig_dtypes", {}) - for name, q in quant_obj["quantized"].items(): - q_np = np.asarray(q, dtype=np.int8) - dtype_name = quant_obj["dtypes"][name] - scale = np.asarray(quant_obj["scales"][name], dtype=np.float32) - if qmeta.get(name, {}).get("scheme") == "per_row" or scale.ndim > 0: - # Broadcast the saved row scale back across trailing dimensions. - out_arr = q_np.astype(np.float32) * scale.reshape((q_np.shape[0],) + (1,) * (q_np.ndim - 1)) - else: - out_arr = q_np.astype(np.float32) * float(scale) - out[name] = mx.array(out_arr, dtype=MX_DTYPE_FROM_NAME[dtype_name]) - for name, arr in quant_obj["passthrough"].items(): - # Restore small tensors, undoing the temporary fp16 storage cast if needed. - out_arr = np.array(arr, copy=True) - orig_dtype = passthrough_orig_dtypes.get(name) - if isinstance(orig_dtype, str): - out[name] = mx.array(out_arr, dtype=MX_DTYPE_FROM_NAME[orig_dtype]) - else: - out[name] = mx.array(out_arr) - return out - - -def build_sentencepiece_luts( - sp: spm.SentencePieceProcessor, vocab_size: int -) -> tuple[np.ndarray, np.ndarray, np.ndarray]: - sp_vocab_size = int(sp.vocab_size()) - table_size = max(sp_vocab_size, vocab_size) - base_bytes_lut = np.zeros((table_size,), dtype=np.int16) - has_leading_space_lut = np.zeros((table_size,), dtype=np.bool_) - is_boundary_token_lut = np.ones((table_size,), dtype=np.bool_) - for token_id in range(sp_vocab_size): - if sp.is_control(token_id) or sp.is_unknown(token_id) or sp.is_unused(token_id): - continue - is_boundary_token_lut[token_id] = False - if sp.is_byte(token_id): - base_bytes_lut[token_id] = 1 - continue - piece = sp.id_to_piece(token_id) - if piece.startswith("▁"): - has_leading_space_lut[token_id] = True - piece = piece[1:] - base_bytes_lut[token_id] = len(piece.encode("utf-8")) - return base_bytes_lut, has_leading_space_lut, is_boundary_token_lut - - -def validate_dataset_tokenizer_pair(data_path: str, tokenizer_path: str) -> tuple[str, int, int | None]: - # The shard directory and tokenizer are coupled: val_bpb is only meaningful if we - # decode bytes with the exact tokenizer that produced the shards. The manifest - # lets the training script fail fast on accidental dataset/tokenizer mismatches. - dataset_dir = Path(data_path).resolve() - actual_train_files = len(list(dataset_dir.glob("fineweb_train_*.bin"))) - if len(dataset_dir.parents) < 2: - return dataset_dir.name, actual_train_files, None - manifest_path = dataset_dir.parents[1] / "manifest.json" - if not manifest_path.is_file(): - return dataset_dir.name, actual_train_files, None - - manifest = json.loads(manifest_path.read_text(encoding="utf-8")) - dataset_entry = next((x for x in manifest.get("datasets", []) if x.get("name") == dataset_dir.name), None) - if dataset_entry is None: - return dataset_dir.name, actual_train_files, None - - tokenizer_name = dataset_entry.get("tokenizer_name") - tokenizer_entry = ( - next((x for x in manifest.get("tokenizers", []) if x.get("name") == tokenizer_name), None) - if tokenizer_name - else None - ) - expected_name = Path((tokenizer_entry or {}).get("model_path") or (tokenizer_entry or {}).get("path") or "").name - if expected_name and Path(tokenizer_path).name != expected_name: - raise ValueError(f"{dataset_dir.name} expects tokenizer {expected_name}, got {Path(tokenizer_path).name}") - expected_train_files = (dataset_entry.get("stats") or {}).get("files_train") - if expected_train_files is not None: - expected_train_files = int(expected_train_files) - if actual_train_files > expected_train_files: - raise ValueError( - f"{dataset_dir.name} has more train shards than expected: found {actual_train_files}, " - f"manifest says {expected_train_files}" - ) - return dataset_dir.name, actual_train_files, expected_train_files - - -def load_validation_tokens(pattern: str, seq_len: int) -> np.ndarray: - files = [Path(p) for p in sorted(glob.glob(pattern))] - if not files: - raise FileNotFoundError(f"No files found for pattern: {pattern}") - # The export pipeline writes the fixed first-50k-doc validation set to fineweb_val_*. - tokens = np.ascontiguousarray(np.concatenate([load_data_shard(file) for file in files], axis=0)) - usable = ((tokens.size - 1) // seq_len) * seq_len - if usable <= 0: - raise ValueError(f"Validation split is too short for TRAIN_SEQ_LEN={seq_len}") - return tokens[: usable + 1] - - -def loss_and_grad_chunked( - args: Hyperparameters, - train_loader: TokenLoader, - compiled_loss_and_grad, -) -> tuple[mx.array, dict]: - chunk_sizes = token_chunks(args.microbatch_tokens, args.train_seq_len, args.mlx_max_microbatch_tokens) - total_tokens = float(sum(chunk_sizes)) - loss_value = mx.array(0.0, dtype=mx.float32) - grad_accum: dict[str, mx.array] | None = None - for chunk_tokens in chunk_sizes: - x, y = train_loader.next_batch(chunk_tokens, args.train_seq_len) - loss, grads = compiled_loss_and_grad(x, y) - scale = float(y.size) / total_tokens - loss_value = loss_value + loss.astype(mx.float32) * scale - grad_accum = accumulate_flat_grads(grad_accum, grads, scale) - if args.mlx_eager_eval: - mx.eval(loss_value, grad_accum) # materialize each chunk to cap peak memory - return loss_value, tree_unflatten(list(grad_accum.items())) - - -def eval_val( - args: Hyperparameters, - compiled_loss, - val_tokens: np.ndarray, - base_bytes_lut: np.ndarray, - has_leading_space_lut: np.ndarray, - is_boundary_token_lut: np.ndarray, - log_fn: Callable[[str], None] | None = None, -) -> tuple[float, float]: - # Validation computes two metrics: - # - val_loss: token cross-entropy (natural log) - # - val_bpb: tokenizer-agnostic compression metric used by the challenge - val_batch_tokens = args.val_batch_size // args.grad_accum_steps - if val_batch_tokens < args.train_seq_len: - raise ValueError( - "VAL_BATCH_SIZE must provide at least one sequence; " - f"got VAL_BATCH_SIZE={args.val_batch_size}, GRAD_ACCUM_STEPS={args.grad_accum_steps}, " - f"TRAIN_SEQ_LEN={args.train_seq_len}" - ) - val_batch_seqs = val_batch_tokens // args.train_seq_len - total_seqs = (val_tokens.size - 1) // args.train_seq_len - total_batches = max((total_seqs + val_batch_seqs - 1) // val_batch_seqs, 1) - total_loss_sum = 0.0 - total_tokens = 0.0 - total_bytes = 0.0 - for batch_idx, batch_seq_start in enumerate(range(0, total_seqs, val_batch_seqs), start=1): - batch_seq_end = min(batch_seq_start + val_batch_seqs, total_seqs) - raw_start = batch_seq_start * args.train_seq_len - raw_end = batch_seq_end * args.train_seq_len + 1 - chunk = val_tokens[raw_start:raw_end] - x_np = chunk[:-1].reshape(-1, args.train_seq_len) - y_np = chunk[1:].reshape(-1, args.train_seq_len) - x = mx.array(x_np, dtype=mx.int32) - y = mx.array(y_np, dtype=mx.int32) - chunk_token_count = float(y.size) - batch_loss = compiled_loss(x, y).astype(mx.float32) - mx.eval(batch_loss) - total_loss_sum += float(batch_loss.item()) * chunk_token_count - prev_ids = x_np.reshape(-1) - tgt_ids = y_np.reshape(-1) - bytes_np = base_bytes_lut[tgt_ids].astype(np.int16, copy=True) - bytes_np += ( - has_leading_space_lut[tgt_ids] & ~is_boundary_token_lut[prev_ids] - ).astype(np.int16, copy=False) - total_tokens += chunk_token_count - total_bytes += float(bytes_np.astype(np.float64).sum()) - if log_fn is not None and total_batches > 1 and ( - batch_idx == 1 or batch_idx == total_batches or batch_idx % 25 == 0 - ): - log_fn(f"val_progress:{batch_idx}/{total_batches}") - val_loss = total_loss_sum / total_tokens - bits_per_token = val_loss / math.log(2.0) - val_bpb = bits_per_token * (total_tokens / total_bytes) - return val_loss, val_bpb - -# ----------------------------- -# TRAINING -# ----------------------------- - -def clip_grad_tree(grads_tree: dict, max_norm: float) -> dict: - if max_norm <= 0: - return grads_tree - flat = dict(tree_flatten(grads_tree)) - total_sq = 0.0 - for grad in flat.values(): - total_sq += float(np.sum(np.square(_np_float32(grad)), dtype=np.float64)) - if total_sq <= 0.0: - return grads_tree - total_norm = math.sqrt(total_sq) - if total_norm <= max_norm: - return grads_tree - scale = max_norm / (total_norm + 1e-12) - return tree_unflatten([(k, g * scale) for k, g in flat.items()]) - - -def main() -> None: - # ============================================================================== - # TOKENIZER + VALIDATION METRIC SETUP - # ============================================================================== - args = Hyperparameters() - out_dir = Path(args.out_dir) - out_dir.mkdir(parents=True, exist_ok=True) - logfile = out_dir / f"{args.run_id}.txt" - print(logfile) - - def log(msg: str, console: bool = True) -> None: - if console: - print(msg) - with logfile.open("a", encoding="utf-8") as f: - print(msg, file=f) - - code = Path(__file__).read_text(encoding="utf-8") - log(code, console=False) - log("=" * 100, console=False) - log(f"Running Python {sys.version}", console=False) - log(f"Running MLX {mx.__version__}", console=False) - log("=" * 100, console=False) - - if not args.tie_embeddings: - raise NotImplementedError("train_gpt_mlx.py only supports tied embeddings") - if not args.tokenizer_path.endswith(".model"): - raise ValueError(f"TOKENIZER_PATH must point to a SentencePiece .model file: {args.tokenizer_path}") - sp = spm.SentencePieceProcessor(model_file=args.tokenizer_path) - if int(sp.vocab_size()) != args.vocab_size: - raise ValueError( - f"VOCAB_SIZE={args.vocab_size} does not match tokenizer vocab_size={int(sp.vocab_size())}" - ) - dataset_name, actual_train_files, expected_train_files = validate_dataset_tokenizer_pair( - args.data_path, - args.tokenizer_path, - ) - val_tokens = load_validation_tokens(args.val_files, args.train_seq_len) - - base_bytes_lut, has_leading_space_lut, is_boundary_token_lut = build_sentencepiece_luts( - sp, args.vocab_size - ) - - # ============================================================================== - # TRAINING SETUP - # ============================================================================== - mx.random.seed(args.seed) - - train_loader = TokenLoader(args.train_files, log_fn=log, dataset_name=dataset_name) - - # ============================================================================== - # MODEL + OPTIMIZER SETUP - # ============================================================================== - model = GPT( - vocab_size=args.vocab_size, - num_layers=args.num_layers, - dim=args.model_dim, - num_heads=args.num_heads, - num_kv_heads=args.num_kv_heads, - mlp_mult=args.mlp_mult, - logit_chunk_tokens=args.logit_chunk_tokens, - logit_softcap=args.logit_softcap, - rope_base=args.rope_base, - tied_embed_init_std=args.tied_embed_init_std, - qk_gain_init=args.qk_gain_init, - ) - opt = SplitOptimizers(model, args) - - # ============================================================================== - # COMPILED TRAIN / EVAL FUNCTIONS (MLX) - # ============================================================================== - # The crucial MLX detail is capture scope: this model contains non-trainable arrays too (for example - # inside RoPE modules), so compiling only against trainable parameters throws "uncaptured inputs". - # Compiling the model-bound functions and capturing the full model state fixes that while still - # returning gradients only for trainable parameters via nn.value_and_grad(...). - compiled_loss = mx.compile(lambda x, y: model.loss(x, y), inputs=model.state, outputs=model.state) - compiled_loss_and_grad = mx.compile( - nn.value_and_grad(model, lambda x, y: model.loss(x, y)), - inputs=model.state, - outputs=model.state, - ) - - # Print config once so logs are self-describing. - n_params = sum(int(np.prod(p.shape)) for _, p in tree_flatten(model.parameters())) - log(f"run_id:{args.run_id}") - log(f"mlx_version:{mx.__version__}") - log(f"train_loader:shards pattern={args.train_files}") - log(f"val_loader:shards pattern={args.val_files} tokens:{val_tokens.size - 1}") - if expected_train_files is None: - log(f"train_loader:dataset:{dataset_name} train_shards:{actual_train_files}") - elif actual_train_files < expected_train_files: - log( - f"WARNING: train_loader:subset dataset:{dataset_name} " - f"train_shards:{actual_train_files}/{expected_train_files} " - f"new epochs will arrive sooner than the full dataset" - ) - else: - log(f"train_loader:dataset:{dataset_name} train_shards:{actual_train_files}/{expected_train_files}") - log(f"tokenizer_path:{args.tokenizer_path}") - log( - f"model_params:{n_params} vocab_size:{args.vocab_size} layers:{args.num_layers} " - f"dim:{args.model_dim} heads:{args.num_heads} kv_heads:{args.num_kv_heads} " - f"seq_len:{args.train_seq_len} tie_embeddings:{args.tie_embeddings}" - ) - log( - f"iterations:{args.iterations} train_batch_tokens:{args.train_batch_tokens} grad_accum_steps:{args.grad_accum_steps} " - f"microbatch_tokens:{args.microbatch_tokens} microbatch_batch_size:{args.microbatch_tokens // args.train_seq_len} " - f"val_batch_size:{args.val_batch_size} " - f"warmup_steps:{args.warmup_steps} max_wallclock_seconds:{args.max_wallclock_seconds:.3f}" - ) - log(f"mlx_max_microbatch_tokens:{args.mlx_max_microbatch_tokens}") - log( - f"optimizer:muon+adam muon_matrix_params:{len(opt.matrix_keys)} scalar_params:{len(opt.scalar_keys)} " - f"embed_lr:{args.tied_embed_lr} " - f"matrix_lr:{args.matrix_lr} scalar_lr:{args.scalar_lr} " - f"muon_momentum:{args.muon_momentum} muon_steps:{args.muon_backend_steps}" - ) - log(f"val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path={args.tokenizer_path}") - log(f"compute_dtype:{COMPUTE_DTYPE} compile:True") - log( - f"dtypes tok_emb:{model.tok_emb.weight.dtype} " - f"linear_weight:{model.blocks[0].attn.c_q.weight.dtype} " - f"skip_weights:{model.skip_weights.dtype}" - ) - - # ============================================================================== - # TRAINING LOOP - # ============================================================================== - if args.warmup_steps > 0: - # Warmup should only prime MLX compile/allocation paths. Updating parameters here forces us - # to snapshot and restore model/optimizer state, which is expensive on unified-memory Macs. - # Instead we run the real train shapes, force the loss/grads to materialize, and then reset - # the loader so measured training still starts from the true init and token window. - for warmup_step in range(args.warmup_steps): - accum: dict[str, mx.array] | None = None - warmup_loss = mx.array(0.0, dtype=mx.float32) - grad_scale = 1.0 / args.grad_accum_steps - for _ in range(args.grad_accum_steps): - warmup_loss, grads = loss_and_grad_chunked(args, train_loader, compiled_loss_and_grad) - accum = accumulate_flat_grads(accum, grads, grad_scale) - mx.eval(warmup_loss, accum) - mx.synchronize() - if args.warmup_steps <= 20 or (warmup_step + 1) % 10 == 0 or warmup_step + 1 == args.warmup_steps: - log(f"warmup_step:{warmup_step + 1}/{args.warmup_steps}") - - # Prime the standalone eval graph once too. It is compiled separately from value_and_grad. - val_batch_tokens = args.val_batch_size // args.grad_accum_steps - if val_batch_tokens < args.train_seq_len: - raise ValueError( - "VAL_BATCH_SIZE must provide at least one sequence; " - f"got VAL_BATCH_SIZE={args.val_batch_size}, GRAD_ACCUM_STEPS={args.grad_accum_steps}, " - f"TRAIN_SEQ_LEN={args.train_seq_len}" - ) - warm_val_seqs = min(val_batch_tokens // args.train_seq_len, (val_tokens.size - 1) // args.train_seq_len) - warm_chunk = val_tokens[: warm_val_seqs * args.train_seq_len + 1] - x_val = mx.array(warm_chunk[:-1].reshape(-1, args.train_seq_len), dtype=mx.int32) - y_val = mx.array(warm_chunk[1:].reshape(-1, args.train_seq_len), dtype=mx.int32) - warm_val_loss = compiled_loss(x_val, y_val) - mx.eval(warm_val_loss) - mx.synchronize() - - train_loader = TokenLoader(args.train_files, log_fn=log, dataset_name=dataset_name) - - train_time_ms = 0.0 - max_wallclock_ms = 1000.0 * args.max_wallclock_seconds if args.max_wallclock_seconds > 0 else None - stop_after_step: int | None = None - t0 = time.perf_counter() - step = 0 - while True: - last_step = step == args.iterations or (stop_after_step is not None and step >= stop_after_step) - if last_step or (args.val_loss_every > 0 and step % args.val_loss_every == 0): - train_time_ms += 1000.0 * (time.perf_counter() - t0) - # Validation always scans the same fixed full validation split. - val_loss, val_bpb = eval_val( - args, - compiled_loss, - val_tokens, - base_bytes_lut, - has_leading_space_lut, - is_boundary_token_lut, - log_fn=log, - ) - if step % 25 == 0 or last_step: - log( - f"step:{step}/{args.iterations} val_loss:{val_loss:.4f} val_bpb:{val_bpb:.4f} " - f"train_time:{train_time_ms:.0f}ms step_avg:{train_time_ms / max(step, 1):.2f}ms" - ) - t0 = time.perf_counter() - if last_step: - if stop_after_step is not None and step < args.iterations: - log(f"stopping_early: wallclock_cap train_time:{train_time_ms:.0f}ms step:{step}/{args.iterations}") - break - - lr_mul = args.lr_mul(step, train_time_ms + 1000.0 * (time.perf_counter() - t0)) - step_t0 = time.perf_counter() - - accum: dict[str, mx.array] | None = None - train_loss = mx.array(0.0, dtype=mx.float32) - grad_scale = 1.0 / args.grad_accum_steps - for _ in range(args.grad_accum_steps): - loss, grads = loss_and_grad_chunked(args, train_loader, compiled_loss_and_grad) - accum = accumulate_flat_grads(accum, grads, grad_scale) - train_loss = train_loss + loss.astype(mx.float32) * grad_scale - if args.mlx_eager_eval: - mx.eval(train_loss, accum) # materialize each microbatch to cap peak memory - - grads = tree_unflatten(list(accum.items())) - grads = clip_grad_tree(grads, args.grad_clip_norm) - train_loss_value = float(train_loss.item()) - opt.step(model, grads, step=step, lr_mul=lr_mul) - mx.synchronize() - - step_ms = 1000.0 * (time.perf_counter() - step_t0) - approx_train_time_ms = train_time_ms + 1000.0 * (time.perf_counter() - t0) - tok_s = args.train_batch_tokens / (step_ms / 1000.0) - step += 1 - if args.train_log_every > 0 and (step <= 10 or step % args.train_log_every == 0 or stop_after_step is not None): - log( - f"step:{step}/{args.iterations} train_loss:{train_loss_value:.4f} " - f"train_time:{approx_train_time_ms:.0f}ms step_avg:{approx_train_time_ms / step:.2f}ms tok_s:{tok_s:.0f}" - ) - if max_wallclock_ms is not None and stop_after_step is None and approx_train_time_ms >= max_wallclock_ms: - stop_after_step = step - - # ============================================================================== - # FINAL SERIALIZATION + QUANTIZED ROUNDTRIP EVAL - # ============================================================================== - # We always write a raw artifact and a quantized artifact, then validate the - # quantized roundtrip directly by loading the dequantized tensors back into the - # model and running one final validation pass. - out_path = out_dir / f"{args.run_id}_mlx_model.npz" - flat_state = {k: v for k, v in tree_flatten(model.state)} - mx.savez(str(out_path), **flat_state) - log(f"saved_model:{out_path} bytes:{out_path.stat().st_size}") - - quant_obj, quant_stats = quantize_state_dict_int8(flat_state) - quant_raw = pickle.dumps(quant_obj, protocol=pickle.HIGHEST_PROTOCOL) - quant_blob = zlib.compress(quant_raw, level=9) - quant_serialized_bytes = len(quant_raw) - quant_path = out_dir / f"{args.run_id}_mlx_model.int8.ptz" - with quant_path.open("wb") as f: - f.write(quant_blob) - quant_file_bytes = quant_path.stat().st_size - ratio = quant_stats["baseline_tensor_bytes"] / max(quant_stats["int8_payload_bytes"], 1) - log( - f"serialized_model_int8_zlib:{quant_file_bytes} bytes " - f"(payload:{quant_stats['int8_payload_bytes']} raw_pickle:{quant_serialized_bytes} payload_ratio:{ratio:.2f}x)" - ) - - with quant_path.open("rb") as f: - quant_blob_disk = f.read() - quant_flat = dequantize_state_dict_int8(pickle.loads(zlib.decompress(quant_blob_disk))) - model.update(tree_unflatten(list(quant_flat.items()))) - q_t0 = time.perf_counter() - q_val_loss, q_val_bpb = eval_val( - args, - compiled_loss, - val_tokens, - base_bytes_lut, - has_leading_space_lut, - is_boundary_token_lut, - log_fn=log, - ) - q_eval_ms = 1000.0 * (time.perf_counter() - q_t0) - log(f"final_int8_zlib_roundtrip val_loss:{q_val_loss:.4f} val_bpb:{q_val_bpb:.4f} eval_time:{q_eval_ms:.0f}ms") - log(f"final_int8_zlib_roundtrip_exact val_loss:{q_val_loss:.8f} val_bpb:{q_val_bpb:.8f}") - - -if __name__ == "__main__": - main() diff --git a/train_gpt_phase0_defaults.py b/train_gpt_phase0_defaults.py new file mode 100644 index 0000000000..e3efb21bbf --- /dev/null +++ b/train_gpt_phase0_defaults.py @@ -0,0 +1,1874 @@ +from __future__ import annotations +import copy +import glob +import io +import math +import os +import random +import subprocess +import sys +import time +import uuid +import zlib +from pathlib import Path +try: + import zstandard as zstd + _HAS_ZSTD = True +except ImportError: + _HAS_ZSTD = False +import numpy as np +import sentencepiece as spm +import torch +import torch.distributed as dist +import torch.nn.functional as F +from torch import Tensor, nn +from torch.nn.parallel import DistributedDataParallel as DDP +class Hyperparameters: + data_path = os.environ.get("DATA_PATH", "./data/datasets/fineweb10B_sp1024") + train_files = os.path.join(data_path, "fineweb_train_*.bin") + val_files = os.path.join(data_path, "fineweb_val_*.bin") + tokenizer_path = os.environ.get("TOKENIZER_PATH", "./data/tokenizers/fineweb_1024_bpe.model") + run_id = os.environ.get("RUN_ID", str(uuid.uuid4())) + seed = int(os.environ.get("SEED", 1337)) + val_batch_size = int(os.environ.get("VAL_BATCH_SIZE", 524_288)) + val_loss_every = int(os.environ.get("VAL_LOSS_EVERY", 1000)) + train_log_every = int(os.environ.get("TRAIN_LOG_EVERY", 200)) + val_max_tokens = int(os.environ.get("VAL_MAX_TOKENS", 0)) + checkpoint_every = int(os.environ.get("CHECKPOINT_EVERY", 0)) + iterations = int(os.environ.get("ITERATIONS", 20000)) + warmdown_iters = int(os.environ.get("WARMDOWN_ITERS", 3500)) + warmup_steps = int(os.environ.get("WARMUP_STEPS", 20)) + train_batch_tokens = int(os.environ.get("TRAIN_BATCH_TOKENS", 786_432)) + train_seq_len = int(os.environ.get("TRAIN_SEQ_LEN", 2048)) + max_wallclock_seconds = float(os.environ.get("MAX_WALLCLOCK_SECONDS", 600.0)) + qk_gain_init = float(os.environ.get("QK_GAIN_INIT", 5.0)) + vocab_size = int(os.environ.get("VOCAB_SIZE", 1024)) + num_layers = int(os.environ.get("NUM_LAYERS", 11)) + num_kv_heads = int(os.environ.get("NUM_KV_HEADS", 4)) + model_dim = int(os.environ.get("MODEL_DIM", 512)) + num_heads = int(os.environ.get("NUM_HEADS", 8)) + mlp_mult = int(os.environ.get("MLP_MULT", 3)) + tie_embeddings = bool(int(os.environ.get("TIE_EMBEDDINGS", "1"))) + rope_base = float(os.environ.get("ROPE_BASE", 10000.0)) + logit_softcap = float(os.environ.get("LOGIT_SOFTCAP", 30.0)) + embed_lr = float(os.environ.get("EMBED_LR", 0.6)) + head_lr = float(os.environ.get("HEAD_LR", 0.008)) + tied_embed_lr = float(os.environ.get("TIED_EMBED_LR", 0.05)) + tied_embed_init_std = float(os.environ.get("TIED_EMBED_INIT_STD", 0.005)) + matrix_lr = float(os.environ.get("MATRIX_LR", 0.025)) + scalar_lr = float(os.environ.get("SCALAR_LR", 0.04)) + muon_momentum = float(os.environ.get("MUON_MOMENTUM", 0.99)) + muon_backend_steps = int(os.environ.get("MUON_BACKEND_STEPS", 5)) + muon_momentum_warmup_start = float(os.environ.get("MUON_MOMENTUM_WARMUP_START", 0.92)) + muon_momentum_warmup_steps = int(os.environ.get("MUON_MOMENTUM_WARMUP_STEPS", 1500)) + beta1 = float(os.environ.get("BETA1", 0.9)) + beta2 = float(os.environ.get("BETA2", 0.95)) + adam_eps = float(os.environ.get("ADAM_EPS", 1e-8)) + grad_clip_norm = float(os.environ.get("GRAD_CLIP_NORM", 0.3)) + bigramhash_buckets = int(os.environ.get("BIGRAMHASH_BUCKETS", 4096)) + ema_decay = float(os.environ.get("EMA_DECAY", 0.997)) + smeargate = bool(int(os.environ.get("SMEARGATE", "1"))) + unet_skips = bool(int(os.environ.get("UNET_SKIPS", "1"))) + int6_qat = bool(int(os.environ.get("INT6_QAT", "1"))) + rope_partial_dims = int(os.environ.get("ROPE_PARTIAL_DIMS", 16)) + ln_scale = bool(int(os.environ.get("LN_SCALE", "1"))) + xsa_layers = int(os.environ.get("XSA_LAYERS", 4)) + eval_stride = int(os.environ.get("EVAL_STRIDE", 0)) + ttt_enabled = bool(int(os.environ.get("TTT_ENABLED", "0"))) + ttt_lr = float(os.environ.get("TTT_LR", 0.002)) + ttt_epochs = int(os.environ.get("TTT_EPOCHS", 3)) + ttt_chunk_tokens = int(os.environ.get("TTT_CHUNK_TOKENS", 32768)) + ttt_freeze_blocks = int(os.environ.get("TTT_FREEZE_BLOCKS", 0)) + ttt_momentum = float(os.environ.get("TTT_MOMENTUM", 0.9)) + ttt_batch_seqs = int(os.environ.get("TTT_BATCH_SEQS", 32)) + ttt_grad_clip = float(os.environ.get("TTT_GRAD_CLIP", 1.0)) + ttt_lora_enabled = bool(int(os.environ.get("TTT_LORA_ENABLED", "0"))) + ttt_lora_rank = int(os.environ.get("TTT_LORA_RANK", 8)) + ttt_lora_lr = float(os.environ.get("TTT_LORA_LR", 0.001)) + recurrence_repeats = int(os.environ.get("RECURRENCE_REPEATS", 0)) + recurrence_unique_head = int(os.environ.get("RECURRENCE_UNIQUE_HEAD", 2)) + recurrence_unique_tail = int(os.environ.get("RECURRENCE_UNIQUE_TAIL", 2)) + moe_num_experts = int(os.environ.get("MOE_NUM_EXPERTS", 0)) + warmdown_schedule = os.environ.get("WARMDOWN_SCHEDULE", "sqrt") + muoneq_r = bool(int(os.environ.get("MUONEQ_R", "1"))) + weight_decay = float(os.environ.get("WEIGHT_DECAY", "0.04")) +def zeropower_via_newtonschulz5(G: Tensor, steps: int = 10, eps: float = 1e-7) -> Tensor: + a, b, c = (3.4445, -4.7750, 2.0315) + X = G.bfloat16() + X /= X.norm() + eps + transposed = G.size(0) > G.size(1) + if transposed: + X = X.T + for _ in range(steps): + A = X @ X.T + B = b * A + c * A @ A + X = a * X + B @ X + return X.T if transposed else X +class Muon(torch.optim.Optimizer): + def __init__(self, params, lr: float, momentum: float, backend_steps: int, nesterov: bool = True): + super().__init__( + params, + dict(lr=lr, momentum=momentum, backend_steps=backend_steps, nesterov=nesterov), + ) + @torch.no_grad() + def step(self, closure=None): + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + distributed = dist.is_available() and dist.is_initialized() + world_size = dist.get_world_size() if distributed else 1 + rank = dist.get_rank() if distributed else 0 + for group in self.param_groups: + params = group["params"] + if not params: + continue + lr = group["lr"] + momentum = group["momentum"] + backend_steps = group["backend_steps"] + nesterov = group["nesterov"] + total_params = sum(int(p.numel()) for p in params) + updates_flat = torch.zeros(total_params, device=params[0].device, dtype=torch.bfloat16) + curr = 0 + for i, p in enumerate(params): + if i % world_size == rank and p.grad is not None: + g = p.grad + state = self.state[p] + if "momentum_buffer" not in state: + state["momentum_buffer"] = torch.zeros_like(g) + buf = state["momentum_buffer"] + buf.mul_(momentum).add_(g) + if nesterov: + g = g.add(buf, alpha=momentum) + g = zeropower_via_newtonschulz5(g, steps=backend_steps) + g *= max(1, g.size(0) / g.size(1)) ** 0.5 + updates_flat[curr : curr + p.numel()] = g.reshape(-1) + curr += p.numel() + if distributed: + dist.all_reduce(updates_flat, op=dist.ReduceOp.SUM) + curr = 0 + for p in params: + g = updates_flat[curr : curr + p.numel()].view_as(p).to(dtype=p.dtype) + p.add_(g, alpha=-lr) + curr += p.numel() + return loss +def int6_ste(w: Tensor) -> Tensor: + qmax = _QUANT_MAX_VAL # Uses QUANT_BITS override if set (e.g. 15 for int5, 7 for int4) + if w.ndim == 2: + abs_max = w.abs().amax(dim=1, keepdim=True).clamp_min(1e-8) + else: + abs_max = w.abs().amax().clamp_min(1e-8) + scale = abs_max / float(qmax) + w_q = (w / scale).round().clamp(-qmax, qmax) * scale + return w + (w_q - w).detach() +def build_sentencepiece_luts( + sp: spm.SentencePieceProcessor, vocab_size: int, device: torch.device +) -> tuple[Tensor, Tensor, Tensor]: + sp_vocab_size = int(sp.vocab_size()) + table_size = max(sp_vocab_size, vocab_size) + base_bytes_np = np.zeros((table_size,), dtype=np.int16) + has_leading_space_np = np.zeros((table_size,), dtype=np.bool_) + is_boundary_token_np = np.ones((table_size,), dtype=np.bool_) + for token_id in range(sp_vocab_size): + if sp.is_control(token_id) or sp.is_unknown(token_id) or sp.is_unused(token_id): + continue + is_boundary_token_np[token_id] = False + if sp.is_byte(token_id): + base_bytes_np[token_id] = 1 + continue + piece = sp.id_to_piece(token_id) + if piece.startswith("▁"): + has_leading_space_np[token_id] = True + piece = piece[1:] + base_bytes_np[token_id] = len(piece.encode("utf-8")) + return ( + torch.tensor(base_bytes_np, dtype=torch.int16, device=device), + torch.tensor(has_leading_space_np, dtype=torch.bool, device=device), + torch.tensor(is_boundary_token_np, dtype=torch.bool, device=device), + ) +def load_validation_tokens(pattern: str, seq_len: int) -> Tensor: + files = [Path(p) for p in sorted(glob.glob(pattern))] + if not files: + raise FileNotFoundError(f"No files found for pattern: {pattern}") + tokens = torch.cat([load_data_shard(file) for file in files]).contiguous() + max_tokens = int(os.environ.get("VAL_MAX_TOKENS", 0)) + if max_tokens > 0 and tokens.numel() > max_tokens: + tokens = tokens[:max_tokens].contiguous() + usable = ((tokens.numel() - 1) // seq_len) * seq_len + if usable <= 0: + raise ValueError(f"Validation split is too short for TRAIN_SEQ_LEN={seq_len}") + return tokens[: usable + 1] +def eval_val( + args: Hyperparameters, + model: nn.Module, + rank: int, + world_size: int, + device: torch.device, + grad_accum_steps: int, + val_tokens: Tensor, + base_bytes_lut: Tensor, + has_leading_space_lut: Tensor, + is_boundary_token_lut: Tensor, +) -> tuple[float, float]: + local_batch_tokens = args.val_batch_size // (world_size * grad_accum_steps) + if local_batch_tokens < args.train_seq_len: + raise ValueError( + "VAL_BATCH_SIZE must provide at least one sequence per rank; " + f"got VAL_BATCH_SIZE={args.val_batch_size}, WORLD_SIZE={world_size}, " + f"GRAD_ACCUM_STEPS={grad_accum_steps}, TRAIN_SEQ_LEN={args.train_seq_len}" + ) + local_batch_seqs = local_batch_tokens // args.train_seq_len + total_seqs = (val_tokens.numel() - 1) // args.train_seq_len + seq_start = (total_seqs * rank) // world_size + seq_end = (total_seqs * (rank + 1)) // world_size + val_loss_sum = torch.zeros((), device=device, dtype=torch.float64) + val_token_count = torch.zeros((), device=device, dtype=torch.float64) + val_byte_count = torch.zeros((), device=device, dtype=torch.float64) + model.eval() + with torch.inference_mode(): + for batch_seq_start in range(seq_start, seq_end, local_batch_seqs): + batch_seq_end = min(batch_seq_start + local_batch_seqs, seq_end) + raw_start = batch_seq_start * args.train_seq_len + raw_end = batch_seq_end * args.train_seq_len + 1 + local = val_tokens[raw_start:raw_end].to(device=device, dtype=torch.int64, non_blocking=True) + x = local[:-1].reshape(-1, args.train_seq_len) + y = local[1:].reshape(-1, args.train_seq_len) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + batch_loss = model(x, y).detach() + batch_token_count = float(y.numel()) + val_loss_sum += batch_loss.to(torch.float64) * batch_token_count + val_token_count += batch_token_count + prev_ids = x.reshape(-1) + tgt_ids = y.reshape(-1) + token_bytes = base_bytes_lut[tgt_ids].to(dtype=torch.int16) + token_bytes += (has_leading_space_lut[tgt_ids] & ~is_boundary_token_lut[prev_ids]).to(dtype=torch.int16) + val_byte_count += token_bytes.to(torch.float64).sum() + if dist.is_available() and dist.is_initialized(): + dist.all_reduce(val_loss_sum, op=dist.ReduceOp.SUM) + dist.all_reduce(val_token_count, op=dist.ReduceOp.SUM) + dist.all_reduce(val_byte_count, op=dist.ReduceOp.SUM) + val_loss = val_loss_sum / val_token_count + bits_per_token = val_loss.item() / math.log(2.0) + tokens_per_byte = val_token_count.item() / val_byte_count.item() + model.train() + return float(val_loss.item()), float(bits_per_token * tokens_per_byte) +def eval_val_sliding( + args: Hyperparameters, + base_model: nn.Module, + rank: int, + world_size: int, + device: torch.device, + val_tokens: Tensor, + base_bytes_lut: Tensor, + has_leading_space_lut: Tensor, + is_boundary_token_lut: Tensor, + stride: int, + batch_seqs: int = 32, +) -> tuple[float, float]: + seq_len = args.train_seq_len + total_tokens = val_tokens.numel() - 1 + window_starts = [ws for ws in range(0, total_tokens, stride) + if min(ws + seq_len, total_tokens) - ws >= 1] + total_windows = len(window_starts) + my_s = (total_windows * rank) // world_size + my_e = (total_windows * (rank + 1)) // world_size + my_windows = window_starts[my_s:my_e] + loss_sum = torch.zeros((), device=device, dtype=torch.float64) + token_count = torch.zeros((), device=device, dtype=torch.float64) + byte_count = torch.zeros((), device=device, dtype=torch.float64) + base_model.eval() + compiled_logits = torch.compile(base_model.forward_logits, dynamic=False, fullgraph=True) + with torch.inference_mode(): + for bi in range(0, len(my_windows), batch_seqs): + batch_ws = my_windows[bi:bi + batch_seqs] + bsz = len(batch_ws) + x_batch = torch.zeros(bsz, seq_len, dtype=torch.int64, device=device) + y_batch = torch.zeros(bsz, seq_len, dtype=torch.int64, device=device) + wlens: list[int] = [] + for i, ws in enumerate(batch_ws): + end = min(ws + seq_len, total_tokens) + wlen = end - ws + wlens.append(wlen) + chunk = val_tokens[ws:end + 1].to(dtype=torch.int64, device=device) + x_batch[i, :wlen] = chunk[:-1] + y_batch[i, :wlen] = chunk[1:] + with torch.autocast(device_type="cuda", dtype=torch.bfloat16): + logits = compiled_logits(x_batch) + nll = F.cross_entropy( + logits.reshape(-1, logits.size(-1)).float(), + y_batch.reshape(-1), + reduction="none", + ).reshape(bsz, seq_len) + for i, ws in enumerate(batch_ws): + wlen = wlens[i] + s = 0 if ws == 0 else max(wlen - stride, 0) + scored_nll = nll[i, s:wlen].to(torch.float64) + loss_sum += scored_nll.sum() + token_count += float(wlen - s) + tgt = y_batch[i, s:wlen] + prev = x_batch[i, s:wlen] + tb = base_bytes_lut[tgt].to(torch.float64) + tb += (has_leading_space_lut[tgt] & ~is_boundary_token_lut[prev]).to(torch.float64) + byte_count += tb.sum() + if dist.is_available() and dist.is_initialized(): + dist.all_reduce(loss_sum, op=dist.ReduceOp.SUM) + dist.all_reduce(token_count, op=dist.ReduceOp.SUM) + dist.all_reduce(byte_count, op=dist.ReduceOp.SUM) + val_loss = (loss_sum / token_count).item() + bits_per_token = val_loss / math.log(2.0) + tokens_per_byte = token_count.item() / byte_count.item() + base_model.train() + return val_loss, bits_per_token * tokens_per_byte +class LoRAAdapter(nn.Module): + """Low-Rank Adaptation adapter for a linear layer.""" + def __init__(self, base_linear: nn.Linear, rank: int = 8): + super().__init__() + dim_in = base_linear.in_features + dim_out = base_linear.out_features + self.A = nn.Parameter(torch.empty(dim_in, rank)) + self.B = nn.Parameter(torch.zeros(rank, dim_out)) + nn.init.kaiming_uniform_(self.A, a=math.sqrt(5)) + self.base_linear = base_linear + self._hook_handle = None + + def _lora_hook(self, module: nn.Module, input: tuple[Tensor, ...], output: Tensor) -> Tensor: + """Forward hook that adds LoRA contribution to base linear output.""" + x = input[0] + lora_delta = (x @ self.A.to(x.dtype)) @ self.B.to(x.dtype) + return output + lora_delta + + def register(self) -> None: + """Register the LoRA hook on the base linear layer.""" + if self._hook_handle is None: + self._hook_handle = self.base_linear.register_forward_hook(self._lora_hook) + + def remove(self) -> None: + """Remove the LoRA hook from the base linear layer.""" + if self._hook_handle is not None: + self._hook_handle.remove() + self._hook_handle = None + +def eval_val_sliding_lora_ttt( + args: Hyperparameters, + base_model: nn.Module, + rank: int, + world_size: int, + device: torch.device, + val_tokens: Tensor, + base_bytes_lut: Tensor, + has_leading_space_lut: Tensor, + is_boundary_token_lut: Tensor, + stride: int, + batch_seqs: int = 32, + log0=print, +) -> tuple[float, float]: + """Test-time training with LoRA adapters (only train adapter weights, not base model).""" + seq_len = args.train_seq_len + total_tokens = val_tokens.numel() - 1 + ttt_chunk = args.ttt_chunk_tokens + + window_starts = [ws for ws in range(0, total_tokens, stride) + if min(ws + seq_len, total_tokens) - ws >= stride or ws == 0] + num_chunks = (total_tokens + ttt_chunk - 1) // ttt_chunk + + chunk_windows: list[list[int]] = [[] for _ in range(num_chunks)] + for ws in window_starts: + end = min(ws + seq_len, total_tokens) + wlen = end - ws + s = 0 if ws == 0 else max(wlen - stride, 0) + scored_start = ws + s + ci = min(scored_start // ttt_chunk, num_chunks - 1) + chunk_windows[ci].append(ws) + + log0(f"lora_ttt_sliding:start chunks={num_chunks} chunk_tokens={ttt_chunk} " + f"total_windows={len(window_starts)} stride={stride} " + f"lora_rank={args.ttt_lora_rank} lora_lr={args.ttt_lora_lr} " + f"ttt_epochs={args.ttt_epochs}") + + loss_sum = torch.zeros((), device=device, dtype=torch.float64) + token_count = torch.zeros((), device=device, dtype=torch.float64) + byte_count = torch.zeros((), device=device, dtype=torch.float64) + + # Create LoRA adapters for attention projections and LM head + lora_adapters: list[LoRAAdapter] = [] + + # Add LoRA to attention blocks (c_q, c_v, proj) + if args.recurrence_repeats > 0: + # Recurrent model structure + all_blocks = list(base_model.head_blocks) + [base_model.shared_block] + list(base_model.tail_blocks) + else: + # Standard block structure + all_blocks = list(base_model.blocks) + + for block in all_blocks: + if hasattr(block, 'attn'): + attn = block.attn + # Add LoRA to Q, V, and output projection + for layer_name in ['c_q', 'c_v', 'proj']: + if hasattr(attn, layer_name): + base_linear = getattr(attn, layer_name) + adapter = LoRAAdapter(base_linear, rank=args.ttt_lora_rank).to(device) + lora_adapters.append(adapter) + + # Add LoRA to LM head (or tok_emb if tied) + if base_model.tie_embeddings: + # For tied embeddings, we need to patch tok_emb.weight usage in forward + # Create a pseudo-linear wrapper for the embedding weight + class EmbeddingAsLinear(nn.Module): + def __init__(self, weight: nn.Parameter): + super().__init__() + self.weight = weight + self.in_features = weight.size(1) + self.out_features = weight.size(0) + self.bias = None + + def forward(self, x: Tensor) -> Tensor: + return F.linear(x, self.weight) + + emb_linear = EmbeddingAsLinear(base_model.tok_emb.weight).to(device) + lm_adapter = LoRAAdapter(emb_linear, rank=args.ttt_lora_rank).to(device) + lora_adapters.append(lm_adapter) + elif base_model.lm_head is not None: + lm_adapter = LoRAAdapter(base_model.lm_head, rank=args.ttt_lora_rank).to(device) + lora_adapters.append(lm_adapter) + + # Collect all LoRA parameters + lora_params = [] + for adapter in lora_adapters: + lora_params.extend([adapter.A, adapter.B]) + + log0(f"lora_ttt_sliding:params lora={sum(p.numel() for p in lora_params)} " + f"base_frozen={sum(p.numel() for p in base_model.parameters())} " + f"adapters={len(lora_adapters)}") + + # Freeze all base model parameters + for p in base_model.parameters(): + p.requires_grad_(False) + + # Register all LoRA hooks + for adapter in lora_adapters: + adapter.register() + + # Use AdamW optimizer for LoRA params + optimizer = torch.optim.AdamW( + lora_params, + lr=args.ttt_lora_lr, + betas=(0.9, 0.999), + weight_decay=0.01, + ) + + # Compile forward paths + if not int(os.environ.get("TORCHDYNAMO_DISABLE", "0")): + compiled_logits_lora = torch.compile(base_model.forward_logits, dynamic=False, fullgraph=True) + compiled_forward_lora = torch.compile(base_model.forward, dynamic=False, fullgraph=True) + else: + compiled_logits_lora = base_model.forward_logits + compiled_forward_lora = base_model.forward + + t0 = time.perf_counter() + + for ci in range(num_chunks): + windows = chunk_windows[ci] + if not windows: + continue + + chunk_start = ci * ttt_chunk + chunk_end = min((ci + 1) * ttt_chunk, total_tokens) + + my_s = (len(windows) * rank) // world_size + my_e = (len(windows) * (rank + 1)) // world_size + my_windows = windows[my_s:my_e] + + # Score phase (inference only) + base_model.eval() + with torch.inference_mode(): + for bi in range(0, len(my_windows), batch_seqs): + batch_ws = my_windows[bi:bi + batch_seqs] + bsz = len(batch_ws) + x_batch = torch.zeros(bsz, seq_len, dtype=torch.int64, device=device) + y_batch = torch.zeros(bsz, seq_len, dtype=torch.int64, device=device) + wlens: list[int] = [] + + for i, ws in enumerate(batch_ws): + end = min(ws + seq_len, total_tokens) + wlen = end - ws + wlens.append(wlen) + chunk_tok = val_tokens[ws:end + 1].to(dtype=torch.int64, device=device) + x_batch[i, :wlen] = chunk_tok[:-1] + y_batch[i, :wlen] = chunk_tok[1:] + + with torch.autocast(device_type="cuda", dtype=torch.bfloat16): + logits = compiled_logits_lora(x_batch) + + nll = F.cross_entropy( + logits.reshape(-1, logits.size(-1)).float(), + y_batch.reshape(-1), reduction="none", + ).reshape(bsz, seq_len) + + for i, ws in enumerate(batch_ws): + wlen = wlens[i] + s = 0 if ws == 0 else max(wlen - stride, 0) + scored_nll = nll[i, s:wlen].to(torch.float64) + loss_sum += scored_nll.sum() + token_count += float(wlen - s) + + tgt, prev = y_batch[i, s:wlen], x_batch[i, s:wlen] + tb = base_bytes_lut[tgt].to(torch.float64) + tb += (has_leading_space_lut[tgt] & ~is_boundary_token_lut[prev]).to(torch.float64) + byte_count += tb.sum() + + # Train phase (update LoRA params only, skip last chunk) + is_last_chunk = (ci == num_chunks - 1) + if not is_last_chunk and args.ttt_epochs > 0: + base_model.train() + chunk_seqs = (chunk_end - chunk_start) // seq_len + + if chunk_seqs > 0: + # Cosine LR decay + cos_lr = args.ttt_lora_lr * 0.5 * (1.0 + math.cos(math.pi * ci / max(num_chunks - 1, 1))) + for pg in optimizer.param_groups: + pg['lr'] = cos_lr + + my_seq_s = (chunk_seqs * rank) // world_size + my_seq_e = (chunk_seqs * (rank + 1)) // world_size + my_chunk_seqs = my_seq_e - my_seq_s + + for _ep in range(args.ttt_epochs): + for bs in range(0, my_chunk_seqs, args.ttt_batch_seqs): + be = min(bs + args.ttt_batch_seqs, my_chunk_seqs) + actual_bs = my_seq_s + bs + start_tok = chunk_start + actual_bs * seq_len + end_tok = chunk_start + (my_seq_s + be) * seq_len + 1 + + if end_tok > val_tokens.numel(): + continue + + local = val_tokens[start_tok:end_tok].to(device=device, dtype=torch.int64) + x = local[:-1].reshape(-1, seq_len) + y = local[1:].reshape(-1, seq_len) + + optimizer.zero_grad(set_to_none=True) + + with torch.autocast(device_type="cuda", dtype=torch.bfloat16): + loss = compiled_forward_lora(x, y) + + loss.backward() + + # All-reduce LoRA gradients across GPUs + if world_size > 1: + for p in lora_params: + if p.grad is not None: + dist.all_reduce(p.grad, op=dist.ReduceOp.AVG) + + # Gradient clipping on LoRA params + torch.nn.utils.clip_grad_norm_(lora_params, args.ttt_grad_clip) + optimizer.step() + + if rank == 0 and (ci % 10 == 0 or ci == num_chunks - 1): + elapsed = time.perf_counter() - t0 + rl = loss_sum.item() / max(token_count.item(), 1) + rbpb = rl / math.log(2.0) * (token_count.item() / max(byte_count.item(), 1)) if token_count.item() > 0 else 0.0 + log0(f" lora_ttt_chunk [{ci+1}/{num_chunks}] bpb={rbpb:.6f} time={elapsed:.1f}s") + + if dist.is_available() and dist.is_initialized(): + dist.all_reduce(loss_sum, op=dist.ReduceOp.SUM) + dist.all_reduce(token_count, op=dist.ReduceOp.SUM) + dist.all_reduce(byte_count, op=dist.ReduceOp.SUM) + + val_loss = (loss_sum / token_count).item() + val_bpb = val_loss / math.log(2.0) * (token_count.item() / byte_count.item()) + + # Clean up: remove all LoRA hooks and restore base model + for adapter in lora_adapters: + adapter.remove() + + # Restore requires_grad on base model + for p in base_model.parameters(): + p.requires_grad_(True) + + base_model.eval() + + log0(f"lora_ttt_sliding:done val_loss={val_loss:.6f} val_bpb={val_bpb:.6f} " + f"elapsed={time.perf_counter() - t0:.1f}s") + + return val_loss, val_bpb + +def eval_val_sliding_ttt( + args: Hyperparameters, + base_model: nn.Module, + rank: int, + world_size: int, + device: torch.device, + val_tokens: Tensor, + base_bytes_lut: Tensor, + has_leading_space_lut: Tensor, + is_boundary_token_lut: Tensor, + stride: int, + batch_seqs: int = 32, + log0=print, +) -> tuple[float, float]: + seq_len = args.train_seq_len + total_tokens = val_tokens.numel() - 1 + ttt_chunk = args.ttt_chunk_tokens + window_starts = [ws for ws in range(0, total_tokens, stride) + if min(ws + seq_len, total_tokens) - ws >= stride or ws == 0] + num_chunks = (total_tokens + ttt_chunk - 1) // ttt_chunk + chunk_windows: list[list[int]] = [[] for _ in range(num_chunks)] + for ws in window_starts: + end = min(ws + seq_len, total_tokens) + wlen = end - ws + s = 0 if ws == 0 else max(wlen - stride, 0) + scored_start = ws + s + ci = min(scored_start // ttt_chunk, num_chunks - 1) + chunk_windows[ci].append(ws) + log0(f"ttt_sliding:start chunks={num_chunks} chunk_tokens={ttt_chunk} " + f"total_windows={len(window_starts)} stride={stride} " + f"ttt_lr={args.ttt_lr} ttt_epochs={args.ttt_epochs} " + f"freeze_blocks={args.ttt_freeze_blocks}") + loss_sum = torch.zeros((), device=device, dtype=torch.float64) + token_count = torch.zeros((), device=device, dtype=torch.float64) + byte_count = torch.zeros((), device=device, dtype=torch.float64) + frozen_block_ids = set(range(min(args.ttt_freeze_blocks, len(base_model.blocks)))) + ttt_params = [] + for name, p in base_model.named_parameters(): + freeze = False + for bi in frozen_block_ids: + if f"blocks.{bi}." in name: + freeze = True + break + if freeze: + p.requires_grad_(False) + else: + p.requires_grad_(True) + ttt_params.append(p) + log0(f"ttt_sliding:params unfrozen={sum(p.numel() for p in ttt_params)} " + f"frozen={sum(p.numel() for p in base_model.parameters() if not p.requires_grad)}") + optimizer = torch.optim.SGD(ttt_params, lr=args.ttt_lr, momentum=args.ttt_momentum) + # Compile forward paths for TTT (matches eval_val_sliding behavior) + if not int(os.environ.get("TORCHDYNAMO_DISABLE", "0")): + compiled_logits_ttt = torch.compile(base_model.forward_logits, dynamic=False, fullgraph=True) + compiled_forward_ttt = torch.compile(base_model.forward, dynamic=False, fullgraph=True) + else: + compiled_logits_ttt = base_model.forward_logits + compiled_forward_ttt = base_model.forward + t0 = time.perf_counter() + for ci in range(num_chunks): + windows = chunk_windows[ci] + if not windows: + continue + chunk_start = ci * ttt_chunk + chunk_end = min((ci + 1) * ttt_chunk, total_tokens) + my_s = (len(windows) * rank) // world_size + my_e = (len(windows) * (rank + 1)) // world_size + my_windows = windows[my_s:my_e] + base_model.eval() + with torch.inference_mode(): + for bi in range(0, len(my_windows), batch_seqs): + batch_ws = my_windows[bi:bi + batch_seqs] + bsz = len(batch_ws) + x_batch = torch.zeros(bsz, seq_len, dtype=torch.int64, device=device) + y_batch = torch.zeros(bsz, seq_len, dtype=torch.int64, device=device) + wlens: list[int] = [] + for i, ws in enumerate(batch_ws): + end = min(ws + seq_len, total_tokens) + wlen = end - ws + wlens.append(wlen) + chunk_tok = val_tokens[ws:end + 1].to(dtype=torch.int64, device=device) + x_batch[i, :wlen] = chunk_tok[:-1] + y_batch[i, :wlen] = chunk_tok[1:] + with torch.autocast(device_type="cuda", dtype=torch.bfloat16): + logits = compiled_logits_ttt(x_batch) + nll = F.cross_entropy( + logits.reshape(-1, logits.size(-1)).float(), + y_batch.reshape(-1), reduction="none", + ).reshape(bsz, seq_len) + for i, ws in enumerate(batch_ws): + wlen = wlens[i] + s = 0 if ws == 0 else max(wlen - stride, 0) + scored_nll = nll[i, s:wlen].to(torch.float64) + loss_sum += scored_nll.sum() + token_count += float(wlen - s) + tgt, prev = y_batch[i, s:wlen], x_batch[i, s:wlen] + tb = base_bytes_lut[tgt].to(torch.float64) + tb += (has_leading_space_lut[tgt] & ~is_boundary_token_lut[prev]).to(torch.float64) + byte_count += tb.sum() + is_last_chunk = (ci == num_chunks - 1) + if not is_last_chunk and args.ttt_epochs > 0: + base_model.train() + chunk_seqs = (chunk_end - chunk_start) // seq_len + if chunk_seqs > 0: + cos_lr = args.ttt_lr * 0.5 * (1.0 + math.cos(math.pi * ci / max(num_chunks - 1, 1))) + for pg in optimizer.param_groups: + pg['lr'] = cos_lr + my_seq_s = (chunk_seqs * rank) // world_size + my_seq_e = (chunk_seqs * (rank + 1)) // world_size + my_chunk_seqs = my_seq_e - my_seq_s + for _ep in range(args.ttt_epochs): + for bs in range(0, my_chunk_seqs, args.ttt_batch_seqs): + be = min(bs + args.ttt_batch_seqs, my_chunk_seqs) + actual_bs = my_seq_s + bs + start_tok = chunk_start + actual_bs * seq_len + end_tok = chunk_start + (my_seq_s + be) * seq_len + 1 + if end_tok > val_tokens.numel(): + continue + local = val_tokens[start_tok:end_tok].to(device=device, dtype=torch.int64) + x = local[:-1].reshape(-1, seq_len) + y = local[1:].reshape(-1, seq_len) + optimizer.zero_grad(set_to_none=True) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16): + loss = compiled_forward_ttt(x, y) + loss.backward() + if world_size > 1: + for p in ttt_params: + if p.grad is not None: + dist.all_reduce(p.grad, op=dist.ReduceOp.AVG) + torch.nn.utils.clip_grad_norm_(ttt_params, args.ttt_grad_clip) + optimizer.step() + if rank == 0 and (ci % 10 == 0 or ci == num_chunks - 1): + elapsed = time.perf_counter() - t0 + rl = loss_sum.item() / max(token_count.item(), 1) + rbpb = rl / math.log(2.0) * (token_count.item() / max(byte_count.item(), 1)) if token_count.item() > 0 else 0.0 + log0(f" ttt_chunk [{ci+1}/{num_chunks}] bpb={rbpb:.6f} time={elapsed:.1f}s") + if dist.is_available() and dist.is_initialized(): + dist.all_reduce(loss_sum, op=dist.ReduceOp.SUM) + dist.all_reduce(token_count, op=dist.ReduceOp.SUM) + dist.all_reduce(byte_count, op=dist.ReduceOp.SUM) + val_loss = (loss_sum / token_count).item() + val_bpb = val_loss / math.log(2.0) * (token_count.item() / byte_count.item()) + for p in base_model.parameters(): + p.requires_grad_(True) + base_model.eval() + log0(f"ttt_sliding:done val_loss={val_loss:.6f} val_bpb={val_bpb:.6f} " + f"elapsed={time.perf_counter() - t0:.1f}s") + return val_loss, val_bpb +_INT6_MODE = bool(int(os.environ.get("INT6_QAT", "1"))) +_INT3_MODE = bool(int(os.environ.get("INT3_QUANT", "0"))) # Ternary quantization +_QUANT_OVERRIDE = int(os.environ.get("QUANT_BITS", "0")) # Override: 3,4,5,6,8 +if _QUANT_OVERRIDE > 0: + _QUANT_BITS = _QUANT_OVERRIDE + _QUANT_MAX_VAL = (1 << (_QUANT_BITS - 1)) - 1 # e.g. int5 -> 15, int4 -> 7 +elif _INT3_MODE: + _QUANT_MAX_VAL = 3 # Ternary: -3, -2, -1, 0, 1, 2, 3 + _QUANT_BITS = 3 +elif _INT6_MODE: + _QUANT_MAX_VAL = 31 + _QUANT_BITS = 6 +else: + _QUANT_MAX_VAL = 127 + _QUANT_BITS = 8 +CONTROL_TENSOR_NAME_PATTERNS = tuple( + pattern + for pattern in os.environ.get( + "CONTROL_TENSOR_NAME_PATTERNS", + "attn_scale,attn_scales,mlp_scale,mlp_scales,resid_mix,resid_mixes,q_gain,skip_weight,skip_weights,smear", + ).split(",") + if pattern +) +INT8_KEEP_FLOAT_FP32_NAME_PATTERNS = tuple( + pattern + for pattern in os.environ.get( + "INT8_KEEP_FLOAT_FP32_NAME_PATTERNS", + ",".join(CONTROL_TENSOR_NAME_PATTERNS), + ).split(",") + if pattern +) +INT8_KEEP_FLOAT_MAX_NUMEL = 65_536 +INT8_KEEP_FLOAT_STORE_DTYPE = torch.float16 +INT8_PER_ROW_SCALE_DTYPE = torch.float16 +INT8_CLIP_PERCENTILE = 99.99984 +INT8_CLIP_Q = INT8_CLIP_PERCENTILE / 100.0 +def tensor_nbytes(t: Tensor) -> int: + return int(t.numel()) * int(t.element_size()) +def keep_float_tensor(name: str, t: Tensor, passthrough_orig_dtypes: dict[str, str]) -> Tensor: + if any(pattern in name for pattern in INT8_KEEP_FLOAT_FP32_NAME_PATTERNS): + return t.float().contiguous() + if t.dtype in {torch.float32, torch.bfloat16}: + passthrough_orig_dtypes[name] = str(t.dtype).removeprefix("torch.") + return t.to(dtype=INT8_KEEP_FLOAT_STORE_DTYPE).contiguous() + return t +_GPTQ_LITE_PERCENTILES = [0.9999, 0.99995, 0.99999, 0.999995, 1.0] +def quantize_float_tensor(t: Tensor, gpu_device: torch.device | None = None) -> tuple[Tensor, Tensor]: + # Run GPTQ-lite search on GPU when available (much faster for large 2D weight matrices) + dev = gpu_device if (gpu_device is not None and t.ndim == 2 and t.numel() > 0) else t.device + t32 = t.float().to(dev) + qmax = _QUANT_MAX_VAL + if t32.ndim == 2 and t32.numel() > 0: + abs_vals = t32.abs() + best_q = None + best_scale = None + best_mse = None + for pct in _GPTQ_LITE_PERCENTILES: + if pct >= 1.0: + clip_abs = abs_vals.amax(dim=1) + else: + clip_abs = torch.quantile(abs_vals, pct, dim=1) + clip_abs = clip_abs.clamp_min(1e-8) + scale = (clip_abs / qmax).clamp_min(1.0 / qmax) + clipped = torch.maximum(torch.minimum(t32, clip_abs[:, None]), -clip_abs[:, None]) + q = torch.clamp(torch.round(clipped / scale[:, None]), -qmax, qmax) + recon = q * scale[:, None] + mse = (t32 - recon).square().mean(dim=1) + if best_mse is None: + best_mse = mse + best_q = q + best_scale = scale + else: + improved = mse < best_mse + if improved.any(): + best_mse = torch.where(improved, mse, best_mse) + best_q = torch.where(improved[:, None], q, best_q) + best_scale = torch.where(improved, scale, best_scale) + return best_q.cpu().to(torch.int8).contiguous(), best_scale.cpu().to(dtype=INT8_PER_ROW_SCALE_DTYPE).contiguous() + if t32.ndim == 2: + return torch.empty_like(t32, dtype=torch.int8, device="cpu"), torch.empty((t32.shape[0],), dtype=INT8_PER_ROW_SCALE_DTYPE) + t32 = t32.cpu() + clip_abs = float(torch.quantile(t32.abs().flatten(), INT8_CLIP_Q).item()) if t32.numel() else 0.0 + scale = torch.tensor(clip_abs / qmax if clip_abs > 0 else 1.0, dtype=torch.float32) + q = torch.clamp(torch.round(torch.clamp(t32, -clip_abs, clip_abs) / scale), -qmax, qmax).to(torch.int8).contiguous() + return q, scale +def quantize_state_dict_int8(state_dict: dict[str, Tensor], gpu_device: torch.device | None = None, + rank: int = 0, world_size: int = 1): + """Quantize state dict. When world_size > 1, distributes large tensor quantization across ranks.""" + quantized: dict[str, Tensor] = {} + scales: dict[str, Tensor] = {} + dtypes: dict[str, str] = {} + passthrough: dict[str, Tensor] = {} + passthrough_orig_dtypes: dict[str, str] = {} + qmeta: dict[str, dict[str, object]] = {} + stats = dict.fromkeys( + ("param_count", "num_tensors", "num_float_tensors", "num_nonfloat_tensors", "baseline_tensor_bytes", "int8_payload_bytes"), + 0, + ) + # Separate tensors into passthrough vs needs-quantization + to_quantize: list[tuple[str, Tensor]] = [] + for name, tensor in state_dict.items(): + t = tensor.detach().to("cpu").contiguous() + stats["param_count"] += int(t.numel()) + stats["num_tensors"] += 1 + stats["baseline_tensor_bytes"] += tensor_nbytes(t) + if not t.is_floating_point(): + stats["num_nonfloat_tensors"] += 1 + passthrough[name] = t + stats["int8_payload_bytes"] += tensor_nbytes(t) + continue + if t.numel() <= INT8_KEEP_FLOAT_MAX_NUMEL: + kept = keep_float_tensor(name, t, passthrough_orig_dtypes) + passthrough[name] = kept + stats["int8_payload_bytes"] += tensor_nbytes(kept) + continue + stats["num_float_tensors"] += 1 + to_quantize.append((name, t)) + # Distribute quantization work across ranks (round-robin by tensor index) + for idx, (name, t) in enumerate(to_quantize): + if world_size > 1 and idx % world_size != rank: + continue # Another rank handles this tensor + q, s = quantize_float_tensor(t, gpu_device=gpu_device) + if s.ndim > 0: + qmeta[name] = {"scheme": "per_row", "axis": 0} + quantized[name] = q + scales[name] = s + dtypes[name] = str(t.dtype).removeprefix("torch.") + stats["int8_payload_bytes"] += tensor_nbytes(q) + tensor_nbytes(s) + # Gather results from all ranks + if world_size > 1 and dist.is_available() and dist.is_initialized(): + # Each rank broadcasts its quantized tensors to all others + for idx, (name, t) in enumerate(to_quantize): + owner = idx % world_size + if owner == rank: + # Broadcast shape info then tensor data + q_tensor = quantized[name] + s_tensor = scales[name] + # Send via broadcast + q_gpu = q_tensor.to(gpu_device) if gpu_device is not None else q_tensor.cuda() + s_gpu = s_tensor.to(gpu_device) if gpu_device is not None else s_tensor.cuda() + dist.broadcast(q_gpu, src=owner) + dist.broadcast(s_gpu, src=owner) + quantized[name] = q_gpu.cpu() + scales[name] = s_gpu.cpu() + else: + # Allocate matching tensors and receive + q_shape = (t.shape[0],) + t.shape[1:] if t.ndim == 2 else t.shape + q_gpu = torch.empty(q_shape, dtype=torch.int8, device=gpu_device if gpu_device is not None else "cuda") + if t.ndim == 2 and t.numel() > 0: + s_gpu = torch.empty(t.shape[0], dtype=INT8_PER_ROW_SCALE_DTYPE, device=gpu_device if gpu_device is not None else "cuda") + else: + s_gpu = torch.empty((), dtype=torch.float32, device=gpu_device if gpu_device is not None else "cuda") + dist.broadcast(q_gpu, src=owner) + dist.broadcast(s_gpu, src=owner) + quantized[name] = q_gpu.cpu() + scales[name] = s_gpu.cpu() + dtypes[name] = str(t.dtype).removeprefix("torch.") + if s_gpu.ndim > 0: + qmeta[name] = {"scheme": "per_row", "axis": 0} + stats["int8_payload_bytes"] += tensor_nbytes(quantized[name]) + tensor_nbytes(scales[name]) + obj: dict[str, object] = { + "__quant_format__": f"int{_QUANT_BITS}_clean_per_row_v1", + "quantized": quantized, + "scales": scales, + "dtypes": dtypes, + "passthrough": passthrough, + } + if qmeta: + obj["qmeta"] = qmeta + if passthrough_orig_dtypes: + obj["passthrough_orig_dtypes"] = passthrough_orig_dtypes + return obj, stats +def dequantize_state_dict_int8(obj: dict[str, object]) -> dict[str, Tensor]: + out: dict[str, Tensor] = {} + qmeta = obj.get("qmeta", {}) + passthrough_orig_dtypes = obj.get("passthrough_orig_dtypes", {}) + for name, q in obj["quantized"].items(): + dtype = getattr(torch, obj["dtypes"][name]) + s = obj["scales"][name] + if qmeta.get(name, {}).get("scheme") == "per_row" or s.ndim > 0: + s = s.to(dtype=torch.float32) + out[name] = (q.float() * s.view(q.shape[0], *([1] * (q.ndim - 1)))).to(dtype=dtype).contiguous() + else: + scale = float(s.item()) + out[name] = (q.float() * scale).to(dtype=dtype).contiguous() + for name, t in obj["passthrough"].items(): + out_t = t.detach().to("cpu").contiguous() + orig_dtype = passthrough_orig_dtypes.get(name) + if isinstance(orig_dtype, str): + out_t = out_t.to(dtype=getattr(torch, orig_dtype)).contiguous() + out[name] = out_t + return out +def load_data_shard(file: Path) -> Tensor: + header_bytes = 256 * np.dtype(" None: + self.file_idx = (self.file_idx + 1) % len(self.files) + self.tokens = load_data_shard(self.files[self.file_idx]) + self.pos = 0 + def take(self, n: int) -> Tensor: + chunks: list[Tensor] = [] + remaining = n + while remaining > 0: + avail = self.tokens.numel() - self.pos + if avail <= 0: + self._advance_file() + continue + k = min(remaining, avail) + chunks.append(self.tokens[self.pos : self.pos + k]) + self.pos += k + remaining -= k + return chunks[0] if len(chunks) == 1 else torch.cat(chunks) +class DistributedTokenLoader: + def __init__(self, pattern: str, rank: int, world_size: int, device: torch.device): + self.rank = rank + self.world_size = world_size + self.device = device + self.stream = TokenStream(pattern) + def next_batch(self, global_tokens: int, seq_len: int, grad_accum_steps: int) -> tuple[Tensor, Tensor]: + local_tokens = global_tokens // (self.world_size * grad_accum_steps) + per_rank_span = local_tokens + 1 + chunk = self.stream.take(per_rank_span * self.world_size) + start = self.rank * per_rank_span + local = chunk[start : start + per_rank_span].to(dtype=torch.int64) + x = local[:-1].reshape(-1, seq_len) + y = local[1:].reshape(-1, seq_len) + return x.to(self.device, non_blocking=True), y.to(self.device, non_blocking=True) +class RMSNorm(nn.Module): + def __init__(self, eps: float | None = None): + super().__init__() + self.eps = eps + def forward(self, x: Tensor) -> Tensor: + return F.rms_norm(x, (x.size(-1),), eps=self.eps) +class CastedLinear(nn.Linear): + _int6_qat: bool = False + def forward(self, x: Tensor) -> Tensor: + w = self.weight.to(x.dtype) + if CastedLinear._int6_qat and self.training: + w = int6_ste(w) + bias = self.bias.to(x.dtype) if self.bias is not None else None + return F.linear(x, w, bias) +def restore_low_dim_params_to_fp32(module: nn.Module) -> None: + with torch.no_grad(): + for name, param in module.named_parameters(): + if (param.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS)) and param.dtype != torch.float32: + param.data = param.data.float() +class BigramHash(nn.Module): + def __init__(self, num_buckets: int, dim: int): + super().__init__() + self.num_buckets = num_buckets + self.emb = nn.Embedding(num_buckets, dim) + nn.init.zeros_(self.emb.weight) + def forward(self, input_ids: Tensor) -> Tensor: + prev = F.pad(input_ids[:, :-1], (1, 0), value=0) + bigram_hash = (prev * 31 + input_ids) % self.num_buckets + return self.emb(bigram_hash) +class SmearGate(nn.Module): + def __init__(self, dim: int): + super().__init__() + self.gate = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + def forward(self, x: Tensor) -> Tensor: + g = torch.sigmoid(self.gate.to(dtype=x.dtype))[None, None, :] + positions = torch.arange(1, x.shape[1] + 1, device=x.device, dtype=x.dtype).view(1, -1, 1) + smooth = torch.cumsum(x, dim=1) / positions + return g * x + (1 - g) * smooth +class Rotary(nn.Module): + def __init__(self, dim: int, base: float = 10000.0): + super().__init__() + inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.float32) / dim)) + self.register_buffer("inv_freq", inv_freq, persistent=False) + self._seq_len_cached = 0 + self._cos_cached: Tensor | None = None + self._sin_cached: Tensor | None = None + def forward(self, seq_len: int, device: torch.device, dtype: torch.dtype) -> tuple[Tensor, Tensor]: + if ( + self._cos_cached is None + or self._sin_cached is None + or self._seq_len_cached != seq_len + or self._cos_cached.device != device + ): + t = torch.arange(seq_len, device=device, dtype=self.inv_freq.dtype) + freqs = torch.outer(t, self.inv_freq.to(device)) + self._cos_cached = freqs.cos()[None, None, :, :] + self._sin_cached = freqs.sin()[None, None, :, :] + self._seq_len_cached = seq_len + return self._cos_cached.to(dtype=dtype), self._sin_cached.to(dtype=dtype) +def apply_rotary_emb(x: Tensor, cos: Tensor, sin: Tensor) -> Tensor: + half = x.size(-1) // 2 + x1, x2 = x[..., :half], x[..., half:] + return torch.cat((x1 * cos + x2 * sin, x1 * (-sin) + x2 * cos), dim=-1) +class CausalSelfAttention(nn.Module): + def __init__( + self, + dim: int, + num_heads: int, + num_kv_heads: int, + rope_base: float, + qk_gain_init: float, + rope_partial_dims: int = 0, + use_xsa: bool = False, + ): + super().__init__() + if dim % num_heads != 0: + raise ValueError("model_dim must be divisible by num_heads") + if num_heads % num_kv_heads != 0: + raise ValueError("num_heads must be divisible by num_kv_heads") + self.num_heads = num_heads + self.num_kv_heads = num_kv_heads + self.head_dim = dim // num_heads + if self.head_dim % 2 != 0: + raise ValueError("head_dim must be even for RoPE") + kv_dim = self.num_kv_heads * self.head_dim + self.c_q = CastedLinear(dim, dim, bias=False) + self.c_k = CastedLinear(dim, kv_dim, bias=False) + self.c_v = CastedLinear(dim, kv_dim, bias=False) + self.proj = CastedLinear(dim, dim, bias=False) + self.proj._zero_init = True + self.q_gain = nn.Parameter(torch.full((num_heads,), qk_gain_init, dtype=torch.float32)) + rope_dim = rope_partial_dims if 0 < rope_partial_dims < self.head_dim else self.head_dim + self.rope_dim = rope_dim + self.rotary = Rotary(rope_dim, base=rope_base) + self.use_xsa = use_xsa + def forward(self, x: Tensor) -> Tensor: + bsz, seqlen, dim = x.shape + q = self.c_q(x).reshape(bsz, seqlen, self.num_heads, self.head_dim).transpose(1, 2) + k = self.c_k(x).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim).transpose(1, 2) + v = self.c_v(x).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim).transpose(1, 2) + q = F.rms_norm(q, (q.size(-1),)) + k = F.rms_norm(k, (k.size(-1),)) + cos, sin = self.rotary(seqlen, x.device, q.dtype) + if self.rope_dim < self.head_dim: + q_rope, q_pass = q[..., :self.rope_dim], q[..., self.rope_dim:] + k_rope, k_pass = k[..., :self.rope_dim], k[..., self.rope_dim:] + q_rope = apply_rotary_emb(q_rope, cos, sin) + k_rope = apply_rotary_emb(k_rope, cos, sin) + q = torch.cat([q_rope, q_pass], dim=-1) + k = torch.cat([k_rope, k_pass], dim=-1) + else: + q = apply_rotary_emb(q, cos, sin) + k = apply_rotary_emb(k, cos, sin) + q = q * self.q_gain.to(dtype=q.dtype)[None, :, None, None] + y = F.scaled_dot_product_attention( + q, + k, + v, + attn_mask=None, + is_causal=True, + enable_gqa=(self.num_kv_heads != self.num_heads), + ) + if self.use_xsa: + if self.num_kv_heads != self.num_heads: + repeats = self.num_heads // self.num_kv_heads + v_expanded = v.repeat_interleave(repeats, dim=1) + else: + v_expanded = v + y = y - v_expanded + y = y.transpose(1, 2).contiguous().reshape(bsz, seqlen, dim) + return self.proj(y) +class MLP(nn.Module): + def __init__(self, dim: int, mlp_mult: int): + super().__init__() + hidden = mlp_mult * dim + self.fc = CastedLinear(dim, hidden, bias=False) + self.proj = CastedLinear(hidden, dim, bias=False) + self.proj._zero_init = True + def forward(self, x: Tensor) -> Tensor: + x = F.leaky_relu(self.fc(x), negative_slope=0.5) + return self.proj(x.square()) +class MoEMLP(nn.Module): + def __init__(self, dim: int, mlp_mult: int, num_experts: int): + super().__init__() + self.num_experts = num_experts + expert_hidden = (mlp_mult * dim) // num_experts + self.experts = nn.ModuleList([ + nn.ModuleDict({ + "fc": CastedLinear(dim, expert_hidden, bias=False), + "proj": CastedLinear(expert_hidden, dim, bias=False), + }) + for _ in range(num_experts) + ]) + for e in self.experts: + e["proj"]._zero_init = True + self.router = nn.Linear(dim, num_experts, bias=False) + def forward(self, x: Tensor) -> Tensor: + bsz, seqlen, dim = x.shape + logits = self.router(x.detach()) + weights = torch.softmax(logits, dim=-1) + indices = weights.argmax(dim=-1) + out = torch.zeros_like(x) + for i, expert in enumerate(self.experts): + mask = (indices == i) + if not mask.any(): + continue + tokens = x[mask] + h = F.leaky_relu(expert["fc"](tokens), negative_slope=0.5) + h = expert["proj"](h.square()) + out[mask] = h + return out +class Block(nn.Module): + def __init__( + self, + dim: int, + num_heads: int, + num_kv_heads: int, + mlp_mult: int, + rope_base: float, + qk_gain_init: float, + use_smeargate: bool = False, + rope_partial_dims: int = 0, + use_xsa: bool = False, + ln_scale_factor: float = 1.0, + moe_num_experts: int = 0, + ): + super().__init__() + self.attn_norm = RMSNorm() + self.mlp_norm = RMSNorm() + self.attn = CausalSelfAttention( + dim, num_heads, num_kv_heads, rope_base, qk_gain_init, + rope_partial_dims=rope_partial_dims, + use_xsa=use_xsa, + ) + self.mlp = MoEMLP(dim, mlp_mult, moe_num_experts) if moe_num_experts > 1 else MLP(dim, mlp_mult) + self.attn_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + self.mlp_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + self.resid_mix = nn.Parameter(torch.stack((torch.ones(dim), torch.zeros(dim))).float()) + self.smear = SmearGate(dim) if use_smeargate else None + self.ln_scale_factor = ln_scale_factor + def forward(self, x: Tensor, x0: Tensor) -> Tensor: + mix = self.resid_mix.to(dtype=x.dtype) + x = mix[0][None, None, :] * x + mix[1][None, None, :] * x0 + normed = self.attn_norm(x) + if self.ln_scale_factor != 1.0: + normed = normed * self.ln_scale_factor + attn_out = self.attn(normed) + x = x + self.attn_scale.to(dtype=x.dtype)[None, None, :] * attn_out + mlp_normed = self.mlp_norm(x) + if self.ln_scale_factor != 1.0: + mlp_normed = mlp_normed * self.ln_scale_factor + x = x + self.mlp_scale.to(dtype=x.dtype)[None, None, :] * self.mlp(mlp_normed) + if self.smear is not None: + x = self.smear(x) + return x +class GPT(nn.Module): + def __init__( + self, + vocab_size: int, + num_layers: int, + model_dim: int, + num_heads: int, + num_kv_heads: int, + mlp_mult: int, + tie_embeddings: bool, + tied_embed_init_std: float, + logit_softcap: float, + rope_base: float, + qk_gain_init: float, + bigramhash_buckets: int = 0, + use_smeargate: bool = False, + unet_skips: bool = True, + rope_partial_dims: int = 0, + ln_scale: bool = False, + xsa_layers: int = 0, + recurrence_repeats: int = 0, + recurrence_unique_head: int = 2, + recurrence_unique_tail: int = 2, + moe_num_experts: int = 0, + ): + super().__init__() + if logit_softcap <= 0.0: + raise ValueError(f"logit_softcap must be positive, got {logit_softcap}") + self.tie_embeddings = tie_embeddings + self.tied_embed_init_std = tied_embed_init_std + self.logit_softcap = logit_softcap + self.tok_emb = nn.Embedding(vocab_size, model_dim) + self.bigram_hash = BigramHash(bigramhash_buckets, model_dim) if bigramhash_buckets > 0 else None + self.recurrence_repeats = recurrence_repeats + if recurrence_repeats > 0: + self.num_encoder_layers = 0 + self.num_decoder_layers = 0 + self.num_skip_weights = 0 + self.skip_weights = nn.Parameter(torch.zeros(0, dtype=torch.float32)) + n_head = recurrence_unique_head + n_tail = recurrence_unique_tail + effective_layers = n_head + recurrence_repeats + n_tail + def make_block(layer_idx: int, eff_total: int) -> Block: + xsa_start_idx = eff_total - xsa_layers if xsa_layers > 0 else eff_total + return Block( + model_dim, num_heads, num_kv_heads, mlp_mult, rope_base, qk_gain_init, + use_smeargate=use_smeargate, + rope_partial_dims=rope_partial_dims, + use_xsa=(layer_idx >= xsa_start_idx), + ln_scale_factor=(1.0 / math.sqrt(layer_idx + 1)) if ln_scale else 1.0, + moe_num_experts=moe_num_experts, + ) + self.head_blocks = nn.ModuleList([make_block(i, effective_layers) for i in range(n_head)]) + self.shared_block = make_block(n_head, effective_layers) + self.recurrence_gate = nn.Parameter(torch.full((recurrence_repeats,), 0.5, dtype=torch.float32)) + self.tail_blocks = nn.ModuleList([ + make_block(n_head + recurrence_repeats + i, effective_layers) for i in range(n_tail) + ]) + self.blocks = nn.ModuleList() + else: + if unet_skips: + self.num_encoder_layers = num_layers // 2 + self.num_decoder_layers = num_layers - self.num_encoder_layers + self.num_skip_weights = min(self.num_encoder_layers, self.num_decoder_layers) + else: + self.num_encoder_layers = num_layers + self.num_decoder_layers = 0 + self.num_skip_weights = 0 + self.skip_weights = nn.Parameter(torch.ones(self.num_skip_weights, model_dim, dtype=torch.float32)) + xsa_start = num_layers - xsa_layers if xsa_layers > 0 else num_layers + self.blocks = nn.ModuleList( + [ + Block( + model_dim, num_heads, num_kv_heads, mlp_mult, rope_base, qk_gain_init, + use_smeargate=use_smeargate, + rope_partial_dims=rope_partial_dims, + use_xsa=(i >= xsa_start), + ln_scale_factor=(1.0 / math.sqrt(i + 1)) if ln_scale else 1.0, + moe_num_experts=moe_num_experts, + ) + for i in range(num_layers) + ] + ) + self.head_blocks = nn.ModuleList() + self.shared_block = None + self.recurrence_gate = None + self.tail_blocks = nn.ModuleList() + self.final_norm = RMSNorm() + self.lm_head = None if tie_embeddings else CastedLinear(model_dim, vocab_size, bias=False) + if self.lm_head is not None: + self.lm_head._zero_init = True + self._init_weights() + def _init_weights(self) -> None: + if self.tie_embeddings: + nn.init.normal_(self.tok_emb.weight, mean=0.0, std=self.tied_embed_init_std) + for module in self.modules(): + if isinstance(module, nn.Linear) and getattr(module, "_zero_init", False): + nn.init.zeros_(module.weight) + def forward(self, input_ids: Tensor, target_ids: Tensor) -> Tensor: + x = self.tok_emb(input_ids) + if self.bigram_hash is not None: + x = x + self.bigram_hash(input_ids) + x = F.rms_norm(x, (x.size(-1),)) + x0 = x + if self.recurrence_repeats > 0: + for block in self.head_blocks: + x = block(x, x0) + for r in range(self.recurrence_repeats): + gate = torch.sigmoid(self.recurrence_gate[r]).to(dtype=x.dtype) + residual = x + x = self.shared_block(x, x0) + x = gate * x + (1.0 - gate) * residual + for block in self.tail_blocks: + x = block(x, x0) + else: + skips: list[Tensor] = [] + for i in range(self.num_encoder_layers): + x = self.blocks[i](x, x0) + skips.append(x) + for i in range(self.num_decoder_layers): + if skips: + x = x + self.skip_weights[i].to(dtype=x.dtype)[None, None, :] * skips.pop() + x = self.blocks[self.num_encoder_layers + i](x, x0) + x = self.final_norm(x).reshape(-1, x.size(-1)) + targets = target_ids.reshape(-1) + if self.tie_embeddings: + logits_proj = F.linear(x, self.tok_emb.weight) + else: + if self.lm_head is None: + raise RuntimeError("lm_head is required when tie_embeddings=False") + logits_proj = self.lm_head(x) + logits = self.logit_softcap * torch.tanh(logits_proj / self.logit_softcap) + if getattr(self, '_return_logits', False): + return logits.float().reshape(input_ids.shape[0], input_ids.shape[1], -1) + return F.cross_entropy(logits.float(), targets, reduction="mean") + def forward_logits(self, input_ids: Tensor) -> Tensor: + x = self.tok_emb(input_ids) + if self.bigram_hash is not None: + x = x + self.bigram_hash(input_ids) + x = F.rms_norm(x, (x.size(-1),)) + x0 = x + if self.recurrence_repeats > 0: + for block in self.head_blocks: + x = block(x, x0) + for r in range(self.recurrence_repeats): + gate = torch.sigmoid(self.recurrence_gate[r]).to(dtype=x.dtype) + residual = x + x = self.shared_block(x, x0) + x = gate * x + (1.0 - gate) * residual + for block in self.tail_blocks: + x = block(x, x0) + else: + skips: list[Tensor] = [] + for i in range(self.num_encoder_layers): + x = self.blocks[i](x, x0) + skips.append(x) + for i in range(self.num_decoder_layers): + if skips: + x = x + self.skip_weights[i].to(dtype=x.dtype)[None, None, :] * skips.pop() + x = self.blocks[self.num_encoder_layers + i](x, x0) + x = self.final_norm(x) + if self.tie_embeddings: + logits_proj = F.linear(x, self.tok_emb.weight) + else: + if self.lm_head is None: + raise RuntimeError("lm_head is required when tie_embeddings=False") + logits_proj = self.lm_head(x) + return self.logit_softcap * torch.tanh(logits_proj / self.logit_softcap) +def main() -> None: + global zeropower_via_newtonschulz5 + code = Path(__file__).read_text(encoding="utf-8") + args = Hyperparameters() + if not int(os.environ.get("TORCHDYNAMO_DISABLE", "0")): + zeropower_via_newtonschulz5 = torch.compile(zeropower_via_newtonschulz5) + distributed = "RANK" in os.environ and "WORLD_SIZE" in os.environ + rank = int(os.environ.get("RANK", "0")) + world_size = int(os.environ.get("WORLD_SIZE", "1")) + local_rank = int(os.environ.get("LOCAL_RANK", "0")) + if world_size <= 0: + raise ValueError(f"WORLD_SIZE must be positive, got {world_size}") + if 8 % world_size != 0: + raise ValueError(f"WORLD_SIZE={world_size} must divide 8 so grad_accum_steps stays integral") + grad_accum_steps = 8 // world_size + grad_scale = 1.0 / grad_accum_steps + if not torch.cuda.is_available(): + raise RuntimeError("CUDA is required") + device = torch.device("cuda", local_rank) + torch.cuda.set_device(device) + if distributed: + dist.init_process_group(backend="nccl", device_id=device) + dist.barrier() + master_process = rank == 0 + torch.backends.cuda.matmul.allow_tf32 = True + torch.backends.cudnn.allow_tf32 = True + from torch.backends.cuda import enable_cudnn_sdp, enable_flash_sdp, enable_math_sdp, enable_mem_efficient_sdp + _use_flash = sys.platform != "win32" + enable_cudnn_sdp(False) + enable_flash_sdp(_use_flash) + enable_mem_efficient_sdp(False) + enable_math_sdp(not _use_flash) + # Create per-run output directory so artifacts/logs never collide + run_dir = Path(os.environ.get("RUN_OUTPUT_DIR", f"runs/{args.run_id}")) + if master_process: + run_dir.mkdir(parents=True, exist_ok=True) + # Backward-compat: also keep logs/ symlink/copy + os.makedirs("logs", exist_ok=True) + logfile = None + if master_process: + logfile = str(run_dir / "log.txt") + # Also write to legacy logs/ path for compatibility + legacy_logfile = f"logs/{args.run_id}.txt" + if not Path(legacy_logfile).exists(): + try: + os.symlink(os.path.abspath(logfile), legacy_logfile) + except OSError: + pass # Windows or cross-device; will write to both + print(logfile) + def log0(msg: str, console: bool = True) -> None: + if not master_process: + return + if console: + print(msg) + if logfile is not None: + with open(logfile, "a", encoding="utf-8") as f: + print(msg, file=f) + log0(code, console=False) + log0("=" * 100, console=False) + log0(f"Running Python {sys.version}", console=False) + log0(f"Running PyTorch {torch.__version__}", console=False) + log0( + subprocess.run(["nvidia-smi"], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True, check=False).stdout, + console=False, + ) + log0("=" * 100, console=False) + random.seed(args.seed) + np.random.seed(args.seed) + torch.manual_seed(args.seed) + torch.cuda.manual_seed_all(args.seed) + if not args.tokenizer_path.endswith(".model"): + raise ValueError(f"Script only setup for SentencePiece .model file: {args.tokenizer_path}") + sp = spm.SentencePieceProcessor(model_file=args.tokenizer_path) + if int(sp.vocab_size()) != args.vocab_size: + raise ValueError( + f"VOCAB_SIZE={args.vocab_size} does not match tokenizer vocab_size={int(sp.vocab_size())}" + ) + dataset_dir = Path(args.data_path).resolve() + actual_train_files = len(list(dataset_dir.glob("fineweb_train_*.bin"))) + val_tokens = load_validation_tokens(args.val_files, args.train_seq_len) + base_bytes_lut, has_leading_space_lut, is_boundary_token_lut = build_sentencepiece_luts( + sp, args.vocab_size, device + ) + log0(f"val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path={args.tokenizer_path}") + log0(f"train_loader:dataset:{dataset_dir.name} train_shards:{actual_train_files}") + log0(f"val_loader:shards pattern={args.val_files} tokens:{val_tokens.numel() - 1}") + CastedLinear._int6_qat = args.int6_qat + base_model = GPT( + vocab_size=args.vocab_size, + num_layers=args.num_layers, + model_dim=args.model_dim, + num_heads=args.num_heads, + num_kv_heads=args.num_kv_heads, + mlp_mult=args.mlp_mult, + tie_embeddings=args.tie_embeddings, + tied_embed_init_std=args.tied_embed_init_std, + logit_softcap=args.logit_softcap, + rope_base=args.rope_base, + qk_gain_init=args.qk_gain_init, + bigramhash_buckets=args.bigramhash_buckets, + use_smeargate=args.smeargate, + unet_skips=args.unet_skips, + rope_partial_dims=args.rope_partial_dims, + ln_scale=args.ln_scale, + xsa_layers=args.xsa_layers, + recurrence_repeats=args.recurrence_repeats, + recurrence_unique_head=args.recurrence_unique_head, + recurrence_unique_tail=args.recurrence_unique_tail, + moe_num_experts=args.moe_num_experts, + ).to(device).bfloat16() + for module in base_model.modules(): + if isinstance(module, CastedLinear): + module.float() + restore_low_dim_params_to_fp32(base_model) + if int(os.environ.get("TORCHDYNAMO_DISABLE", "0")): + compiled_model = base_model + else: + compiled_model = torch.compile(base_model, dynamic=False, fullgraph=True) + model: nn.Module = DDP(compiled_model, device_ids=[local_rank], broadcast_buffers=False) if distributed else compiled_model + if args.recurrence_repeats > 0: + block_named_params = ( + list(base_model.head_blocks.named_parameters()) + + list(base_model.shared_block.named_parameters()) + + list(base_model.tail_blocks.named_parameters()) + ) + if base_model.recurrence_gate is not None: + block_named_params.append(("recurrence_gate", base_model.recurrence_gate)) + else: + block_named_params = list(base_model.blocks.named_parameters()) + matrix_params = [ + p + for name, p in block_named_params + if p.ndim == 2 and not any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) + ] + scalar_params = [ + p + for name, p in block_named_params + if p.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) + ] + if base_model.skip_weights.numel() > 0: + scalar_params.append(base_model.skip_weights) + token_lr = args.tied_embed_lr if args.tie_embeddings else args.embed_lr + embed_params: list[nn.Parameter] = [base_model.tok_emb.weight] + if base_model.bigram_hash is not None: + embed_params.append(base_model.bigram_hash.emb.weight) + optimizer_tok = torch.optim.Adam( + [{"params": embed_params, "lr": token_lr, "base_lr": token_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + fused=True, + ) + optimizer_muon = Muon( + matrix_params, + lr=args.matrix_lr, + momentum=args.muon_momentum, + backend_steps=args.muon_backend_steps, + ) + for group in optimizer_muon.param_groups: + group["base_lr"] = args.matrix_lr + optimizer_scalar = torch.optim.Adam( + [{"params": scalar_params, "lr": args.scalar_lr, "base_lr": args.scalar_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + fused=True, + ) + optimizers: list[torch.optim.Optimizer] = [optimizer_tok, optimizer_muon, optimizer_scalar] + if base_model.lm_head is not None: + optimizer_head = torch.optim.Adam( + [{"params": [base_model.lm_head.weight], "lr": args.head_lr, "base_lr": args.head_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + fused=True, + ) + optimizers.insert(1, optimizer_head) + n_params = sum(p.numel() for p in base_model.parameters()) + log0(f"model_params:{n_params}") + log0(f"world_size:{world_size} grad_accum_steps:{grad_accum_steps}") + log0(f"sdp_backends:cudnn=False flash={_use_flash} mem_efficient=False math={not _use_flash}") + log0(f"attention_mode:gqa num_heads:{args.num_heads} num_kv_heads:{args.num_kv_heads}") + log0( + f"tie_embeddings:{args.tie_embeddings} embed_lr:{token_lr} " + f"head_lr:{args.head_lr if base_model.lm_head is not None else 0.0} " + f"matrix_lr:{args.matrix_lr} scalar_lr:{args.scalar_lr}" + ) + log0( + f"train_batch_tokens:{args.train_batch_tokens} train_seq_len:{args.train_seq_len} " + f"iterations:{args.iterations} warmup_steps:{args.warmup_steps} " + f"max_wallclock_seconds:{args.max_wallclock_seconds:.3f}" + ) + log0(f"seed:{args.seed}") + log0( + f"meta_baseline: int6_qat:{args.int6_qat} bigramhash:{args.bigramhash_buckets} " + f"smeargate:{args.smeargate} ema_decay:{args.ema_decay} unet_skips:{args.unet_skips} " + f"warmdown_iters:{args.warmdown_iters} compression:{'zstd-22' if _HAS_ZSTD else 'zlib-9'}" + ) + log0( + f"strong_gains: leaky_relu_sq:True rope_partial_dims:{args.rope_partial_dims} " + f"ln_scale:{args.ln_scale} xsa_layers:{args.xsa_layers}" + ) + if args.recurrence_repeats > 0: + eff = args.recurrence_unique_head + args.recurrence_repeats + args.recurrence_unique_tail + log0( + f"depth_recurrence: head={args.recurrence_unique_head} repeats={args.recurrence_repeats} " + f"tail={args.recurrence_unique_tail} effective_layers={eff}" + ) + train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) + def zero_grad_all() -> None: + for opt in optimizers: + opt.zero_grad(set_to_none=True) + max_wallclock_ms = 1000.0 * args.max_wallclock_seconds if args.max_wallclock_seconds > 0 else None + def lr_mul(step: int, elapsed_ms: float) -> float: + if args.warmdown_iters <= 0: + return 1.0 + if max_wallclock_ms is None: + warmdown_start = max(args.iterations - args.warmdown_iters, 0) + return max((args.iterations - step) / max(args.warmdown_iters, 1), 0.0) if warmdown_start <= step < args.iterations else 1.0 + step_ms = elapsed_ms / max(step, 1) + warmdown_ms = args.warmdown_iters * step_ms + remaining_ms = max(max_wallclock_ms - elapsed_ms, 0.0) + return remaining_ms / max(warmdown_ms, 1e-9) if remaining_ms <= warmdown_ms else 1.0 + if args.warmup_steps > 0: + initial_model_state = {name: tensor.detach().cpu().clone() for name, tensor in base_model.state_dict().items()} + initial_optimizer_states = [copy.deepcopy(opt.state_dict()) for opt in optimizers] + model.train() + for warmup_step in range(args.warmup_steps): + zero_grad_all() + for micro_step in range(grad_accum_steps): + if distributed: + model.require_backward_grad_sync = micro_step == grad_accum_steps - 1 + x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + warmup_loss = model(x, y) + (warmup_loss * grad_scale).backward() + for opt in optimizers: + opt.step() + zero_grad_all() + if args.warmup_steps <= 20 or (warmup_step + 1) % 10 == 0 or warmup_step + 1 == args.warmup_steps: + log0(f"warmup_step:{warmup_step + 1}/{args.warmup_steps}") + base_model.load_state_dict(initial_model_state, strict=True) + for opt, state in zip(optimizers, initial_optimizer_states, strict=True): + opt.load_state_dict(state) + zero_grad_all() + if distributed: + model.require_backward_grad_sync = True + train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) + ema_state: dict[str, Tensor] | None = None + if args.ema_decay > 0: + ema_state = {name: param.data.clone() for name, param in base_model.named_parameters()} + log0(f"ema:enabled decay={args.ema_decay}") + training_time_ms = 0.0 + stop_after_step: int | None = None + torch.cuda.synchronize() + t0 = time.perf_counter() + step = 0 + while True: + last_step = step == args.iterations or (stop_after_step is not None and step >= stop_after_step) + should_validate = last_step or (args.val_loss_every > 0 and step % args.val_loss_every == 0) + if should_validate: + torch.cuda.synchronize() + training_time_ms += 1000.0 * (time.perf_counter() - t0) + ema_backup: dict[str, Tensor] | None = None + if ema_state is not None: + ema_backup = {} + with torch.no_grad(): + for name, p in base_model.named_parameters(): + ema_backup[name] = p.data.clone() + p.data.copy_(ema_state[name]) + val_loss, val_bpb = eval_val( + args, + model, + rank, + world_size, + device, + grad_accum_steps, + val_tokens, + base_bytes_lut, + has_leading_space_lut, + is_boundary_token_lut, + ) + if ema_backup is not None: + with torch.no_grad(): + for name, p in base_model.named_parameters(): + p.data.copy_(ema_backup[name]) + del ema_backup + log0( + f"step:{step}/{args.iterations} val_loss:{val_loss:.4f} val_bpb:{val_bpb:.4f} " + f"train_time:{training_time_ms:.0f}ms step_avg:{training_time_ms / max(step, 1):.2f}ms" + ) + torch.cuda.synchronize() + t0 = time.perf_counter() + if last_step: + if stop_after_step is not None and step < args.iterations: + log0( + f"stopping_early: wallclock_cap train_time:{training_time_ms:.0f}ms " + f"step:{step}/{args.iterations}" + ) + break + elapsed_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) + scale = lr_mul(step, elapsed_ms) + zero_grad_all() + train_loss = torch.zeros((), device=device) + for micro_step in range(grad_accum_steps): + if distributed: + model.require_backward_grad_sync = micro_step == grad_accum_steps - 1 + x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + loss = model(x, y) + train_loss += loss.detach() + (loss * grad_scale).backward() + train_loss /= grad_accum_steps + frac = min(step / args.muon_momentum_warmup_steps, 1.0) if args.muon_momentum_warmup_steps > 0 else 1.0 + muon_momentum = (1 - frac) * args.muon_momentum_warmup_start + frac * args.muon_momentum + for group in optimizer_muon.param_groups: + group["momentum"] = muon_momentum + for opt in optimizers: + for group in opt.param_groups: + group["lr"] = group["base_lr"] * scale + if args.grad_clip_norm > 0: + torch.nn.utils.clip_grad_norm_(base_model.parameters(), args.grad_clip_norm) + for opt in optimizers: + opt.step() + zero_grad_all() + if ema_state is not None: + with torch.no_grad(): + decay = args.ema_decay + for name, p in base_model.named_parameters(): + ema_state[name].mul_(decay).add_(p.data, alpha=1.0 - decay) + step += 1 + if args.checkpoint_every > 0 and step % args.checkpoint_every == 0 and master_process: + ckpt_dir = run_dir / "checkpoints" + ckpt_dir.mkdir(exist_ok=True) + ckpt_path = ckpt_dir / f"ckpt_step{step}.pt" + ckpt_data = {"step": step, "model": base_model.state_dict()} + if ema_state is not None: + ckpt_data["ema"] = {k: v.cpu() for k, v in ema_state.items()} + torch.save(ckpt_data, ckpt_path) + ckpts = sorted(ckpt_dir.glob("ckpt_step*.pt"), key=lambda p: p.stat().st_mtime) + for old in ckpts[:-2]: + old.unlink(missing_ok=True) + log0(f"checkpoint:saved {ckpt_path} (kept {min(len(ckpts), 2)} latest)") + approx_training_time_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) + should_log_train = ( + args.train_log_every > 0 + and (step <= 10 or step % args.train_log_every == 0 or stop_after_step is not None) + ) + if should_log_train: + log0( + f"step:{step}/{args.iterations} train_loss:{train_loss.item():.4f} " + f"train_time:{approx_training_time_ms:.0f}ms step_avg:{approx_training_time_ms / step:.2f}ms" + ) + reached_cap = max_wallclock_ms is not None and approx_training_time_ms >= max_wallclock_ms + if distributed and max_wallclock_ms is not None: + reached_cap_tensor = torch.tensor(int(reached_cap), device=device) + dist.all_reduce(reached_cap_tensor, op=dist.ReduceOp.MAX) + reached_cap = bool(reached_cap_tensor.item()) + if stop_after_step is None and reached_cap: + stop_after_step = step + log0( + f"peak memory allocated: {torch.cuda.max_memory_allocated() // 1024 // 1024} MiB " + f"reserved: {torch.cuda.max_memory_reserved() // 1024 // 1024} MiB" + ) + if ema_state is not None: + with torch.no_grad(): + for name, p in base_model.named_parameters(): + p.data.copy_(ema_state[name]) + log0("ema:applied EMA weights for export") + fp32_path = str(run_dir / "final_model.pt") + if master_process: + torch.save(base_model.state_dict(), fp32_path) + model_bytes = os.path.getsize(fp32_path) + code_bytes = len(code.encode("utf-8")) + log0(f"Serialized model: {model_bytes} bytes → {fp32_path}") + log0(f"Code size: {code_bytes} bytes") + log0(f"Total submission size: {model_bytes + code_bytes} bytes") + quant_label = f"int{_QUANT_BITS}" + compress_label = "zstd" if _HAS_ZSTD else "zlib" + log0(f"quantization:start {quant_label} distributing across {world_size} GPUs") + torch.cuda.synchronize() + t_quant = time.perf_counter() + quant_obj, quant_stats = quantize_state_dict_int8( + base_model.state_dict(), gpu_device=device, rank=rank, world_size=world_size, + ) + torch.cuda.synchronize() + log0(f"quantization:done {quant_label} time={1000.0 * (time.perf_counter() - t_quant):.0f}ms") + quant_buf = io.BytesIO() + torch.save(quant_obj, quant_buf) + quant_raw = quant_buf.getvalue() + if _HAS_ZSTD: + compressor = zstd.ZstdCompressor(level=22, threads=-1) # -1 = use all CPU cores + quant_blob = compressor.compress(quant_raw) + else: + quant_blob = zlib.compress(quant_raw, level=9) + quant_raw_bytes = len(quant_raw) + artifact_name = str(run_dir / f"final_model.{quant_label}.ptz") + if master_process: + with open(artifact_name, "wb") as f: + f.write(quant_blob) + quant_file_bytes = os.path.getsize(artifact_name) + code_bytes = len(code.encode("utf-8")) + ratio = quant_stats["baseline_tensor_bytes"] / max(quant_stats["int8_payload_bytes"], 1) + log0( + f"Serialized model {quant_label}+{compress_label}: {quant_file_bytes} bytes " + f"(payload:{quant_stats['int8_payload_bytes']} raw_torch:{quant_raw_bytes} payload_ratio:{ratio:.2f}x)" + ) + log0(f"Total submission size {quant_label}+{compress_label}: {quant_file_bytes + code_bytes} bytes") + log0(f"artifact_path:{artifact_name}") + if distributed: + dist.barrier() + with open(artifact_name, "rb") as f: + quant_blob_disk = f.read() + if _HAS_ZSTD: + decompressor = zstd.ZstdDecompressor() + quant_state = torch.load(io.BytesIO(decompressor.decompress(quant_blob_disk)), map_location="cpu") + else: + quant_state = torch.load(io.BytesIO(zlib.decompress(quant_blob_disk)), map_location="cpu") + base_model.load_state_dict(dequantize_state_dict_int8(quant_state), strict=True) + torch.cuda.synchronize() + t_qeval = time.perf_counter() + q_val_loss, q_val_bpb = eval_val( + args, + model, + rank, + world_size, + device, + grad_accum_steps, + val_tokens, + base_bytes_lut, + has_leading_space_lut, + is_boundary_token_lut, + ) + torch.cuda.synchronize() + log0( + f"final_{quant_label}_{compress_label}_roundtrip val_loss:{q_val_loss:.4f} val_bpb:{q_val_bpb:.4f} " + f"eval_time:{1000.0 * (time.perf_counter() - t_qeval):.0f}ms" + ) + log0(f"final_{quant_label}_{compress_label}_roundtrip_exact val_loss:{q_val_loss:.8f} val_bpb:{q_val_bpb:.8f}") + if args.eval_stride > 0 and args.eval_stride < args.train_seq_len: + torch.cuda.synchronize() + t_slide = time.perf_counter() + sw_val_loss, sw_val_bpb = eval_val_sliding( + args, base_model, rank, world_size, device, + val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, + stride=args.eval_stride, + ) + torch.cuda.synchronize() + log0( + f"final_sliding_window val_loss:{sw_val_loss:.4f} val_bpb:{sw_val_bpb:.4f} " + f"stride:{args.eval_stride} eval_time:{1000.0 * (time.perf_counter() - t_slide):.0f}ms" + ) + log0(f"final_sliding_window_exact val_loss:{sw_val_loss:.8f} val_bpb:{sw_val_bpb:.8f}") + if args.ttt_enabled: + torch.cuda.synchronize() + t_ttt = time.perf_counter() + + # Choose LoRA TTT or full TTT based on TTT_LORA_ENABLED flag + if args.ttt_lora_enabled: + ttt_loss, ttt_bpb = eval_val_sliding_lora_ttt( + args, base_model, rank, world_size, device, + val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, + stride=args.eval_stride if args.eval_stride > 0 else 64, + log0=log0, + ) + ttt_label = "legal_lora_ttt" + else: + ttt_loss, ttt_bpb = eval_val_sliding_ttt( + args, base_model, rank, world_size, device, + val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, + stride=args.eval_stride if args.eval_stride > 0 else 64, + log0=log0, + ) + ttt_label = "legal_ttt" + + torch.cuda.synchronize() + log0( + f"{ttt_label} val_loss:{ttt_loss:.4f} val_bpb:{ttt_bpb:.4f} " + f"eval_time:{1000.0 * (time.perf_counter() - t_ttt):.0f}ms" + ) + log0(f"{ttt_label}_exact val_loss:{ttt_loss:.8f} val_bpb:{ttt_bpb:.8f}") + if int(os.environ.get("NGRAM_EVAL", "0")): + ngram_max_order = int(os.environ.get("NGRAM_MAX_ORDER", "9")) + log0(f"ngram_eval:starting max_order={ngram_max_order}") + torch.cuda.synchronize() + t_ngram = time.perf_counter() + try: + from ngram_eval import eval_val_ngram + ng_val_loss, ng_val_bpb = eval_val_ngram( + args, base_model, model, rank, world_size, device, grad_accum_steps, + val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, + max_order=ngram_max_order, log_fn=log0, + ) + torch.cuda.synchronize() + log0( + f"ngram_eval_result val_loss:{ng_val_loss:.4f} val_bpb:{ng_val_bpb:.4f} " + f"eval_time:{1000.0 * (time.perf_counter() - t_ngram):.0f}ms" + ) + log0(f"ngram_eval_result_exact val_loss:{ng_val_loss:.8f} val_bpb:{ng_val_bpb:.8f}") + except Exception as e: + log0(f"ngram_eval:FAILED {e}") + if distributed: + dist.destroy_process_group() +if __name__ == "__main__": + main() diff --git a/train_gpt_phase0_misc.py b/train_gpt_phase0_misc.py new file mode 100644 index 0000000000..7021994705 --- /dev/null +++ b/train_gpt_phase0_misc.py @@ -0,0 +1,1920 @@ +from __future__ import annotations +import copy +import glob +import io +import math +import os +import random +import subprocess +import sys +import time +import uuid +import zlib +from pathlib import Path +try: + import zstandard as zstd + _HAS_ZSTD = True +except ImportError: + _HAS_ZSTD = False +try: + import brotli as _brotli + _HAS_BROTLI = True +except ImportError: + _HAS_BROTLI = False +import lzma +import numpy as np +import sentencepiece as spm +import torch +import torch.distributed as dist +import torch.nn.functional as F +from torch import Tensor, nn +from torch.nn.parallel import DistributedDataParallel as DDP +def _byte_shuffle(data: bytes, stride: int = 2) -> bytes: + arr = bytearray(data) + n = len(arr) + out = bytearray(n) + for s in range(stride): + src_start = s + dst_start = s * (n // stride) + min(s, n % stride) + count = (n - s + stride - 1) // stride + for i in range(count): + out[dst_start + i] = arr[src_start + i * stride] + return bytes(out) + +def _byte_unshuffle(data: bytes, stride: int = 2) -> bytes: + arr = bytearray(data) + n = len(arr) + out = bytearray(n) + for s in range(stride): + src_start = s * (n // stride) + min(s, n % stride) + dst_start = s + count = (n - s + stride - 1) // stride + for i in range(count): + out[dst_start + i * stride] = arr[src_start + i] + return bytes(out) + +class Hyperparameters: + data_path = os.environ.get("DATA_PATH", "./data/datasets/fineweb10B_sp1024") + train_files = os.path.join(data_path, "fineweb_train_*.bin") + val_files = os.path.join(data_path, "fineweb_val_*.bin") + tokenizer_path = os.environ.get("TOKENIZER_PATH", "./data/tokenizers/fineweb_1024_bpe.model") + run_id = os.environ.get("RUN_ID", str(uuid.uuid4())) + seed = int(os.environ.get("SEED", 1337)) + val_batch_size = int(os.environ.get("VAL_BATCH_SIZE", 524_288)) + val_loss_every = int(os.environ.get("VAL_LOSS_EVERY", 1000)) + train_log_every = int(os.environ.get("TRAIN_LOG_EVERY", 200)) + val_max_tokens = int(os.environ.get("VAL_MAX_TOKENS", 0)) + checkpoint_every = int(os.environ.get("CHECKPOINT_EVERY", 0)) + iterations = int(os.environ.get("ITERATIONS", 20000)) + warmdown_iters = int(os.environ.get("WARMDOWN_ITERS", 3500)) + warmup_steps = int(os.environ.get("WARMUP_STEPS", 20)) + train_batch_tokens = int(os.environ.get("TRAIN_BATCH_TOKENS", 524_288)) + train_seq_len = int(os.environ.get("TRAIN_SEQ_LEN", 1024)) + max_wallclock_seconds = float(os.environ.get("MAX_WALLCLOCK_SECONDS", 600.0)) + qk_gain_init = float(os.environ.get("QK_GAIN_INIT", 1.5)) + vocab_size = int(os.environ.get("VOCAB_SIZE", 1024)) + num_layers = int(os.environ.get("NUM_LAYERS", 11)) + num_kv_heads = int(os.environ.get("NUM_KV_HEADS", 4)) + model_dim = int(os.environ.get("MODEL_DIM", 512)) + num_heads = int(os.environ.get("NUM_HEADS", 8)) + mlp_mult = int(os.environ.get("MLP_MULT", 3)) + tie_embeddings = bool(int(os.environ.get("TIE_EMBEDDINGS", "1"))) + rope_base = float(os.environ.get("ROPE_BASE", 10000.0)) + logit_softcap = float(os.environ.get("LOGIT_SOFTCAP", 30.0)) + embed_lr = float(os.environ.get("EMBED_LR", 0.6)) + head_lr = float(os.environ.get("HEAD_LR", 0.008)) + tied_embed_lr = float(os.environ.get("TIED_EMBED_LR", 0.05)) + tied_embed_init_std = float(os.environ.get("TIED_EMBED_INIT_STD", 0.005)) + matrix_lr = float(os.environ.get("MATRIX_LR", 0.04)) + scalar_lr = float(os.environ.get("SCALAR_LR", 0.04)) + muon_momentum = float(os.environ.get("MUON_MOMENTUM", 0.95)) + muon_backend_steps = int(os.environ.get("MUON_BACKEND_STEPS", 5)) + muon_momentum_warmup_start = float(os.environ.get("MUON_MOMENTUM_WARMUP_START", 0.85)) + muon_momentum_warmup_steps = int(os.environ.get("MUON_MOMENTUM_WARMUP_STEPS", 500)) + beta1 = float(os.environ.get("BETA1", 0.9)) + beta2 = float(os.environ.get("BETA2", 0.95)) + adam_eps = float(os.environ.get("ADAM_EPS", 1e-8)) + grad_clip_norm = float(os.environ.get("GRAD_CLIP_NORM", 0.0)) + bigramhash_buckets = int(os.environ.get("BIGRAMHASH_BUCKETS", 4096)) + ema_decay = float(os.environ.get("EMA_DECAY", 0.997)) + smeargate = bool(int(os.environ.get("SMEARGATE", "1"))) + unet_skips = bool(int(os.environ.get("UNET_SKIPS", "1"))) + int6_qat = bool(int(os.environ.get("INT6_QAT", "1"))) + rope_partial_dims = int(os.environ.get("ROPE_PARTIAL_DIMS", 16)) + ln_scale = bool(int(os.environ.get("LN_SCALE", "1"))) + xsa_layers = int(os.environ.get("XSA_LAYERS", 4)) + eval_stride = int(os.environ.get("EVAL_STRIDE", 0)) + ttt_enabled = bool(int(os.environ.get("TTT_ENABLED", "0"))) + ttt_lr = float(os.environ.get("TTT_LR", 0.002)) + ttt_epochs = int(os.environ.get("TTT_EPOCHS", 3)) + ttt_chunk_tokens = int(os.environ.get("TTT_CHUNK_TOKENS", 32768)) + ttt_freeze_blocks = int(os.environ.get("TTT_FREEZE_BLOCKS", 0)) + ttt_momentum = float(os.environ.get("TTT_MOMENTUM", 0.9)) + ttt_batch_seqs = int(os.environ.get("TTT_BATCH_SEQS", 32)) + ttt_grad_clip = float(os.environ.get("TTT_GRAD_CLIP", 1.0)) + ttt_lora_enabled = bool(int(os.environ.get("TTT_LORA_ENABLED", "0"))) + ttt_lora_rank = int(os.environ.get("TTT_LORA_RANK", 8)) + ttt_lora_lr = float(os.environ.get("TTT_LORA_LR", 0.001)) + recurrence_repeats = int(os.environ.get("RECURRENCE_REPEATS", 0)) + recurrence_unique_head = int(os.environ.get("RECURRENCE_UNIQUE_HEAD", 2)) + recurrence_unique_tail = int(os.environ.get("RECURRENCE_UNIQUE_TAIL", 2)) + moe_num_experts = int(os.environ.get("MOE_NUM_EXPERTS", 0)) + warmdown_schedule = os.environ.get("WARMDOWN_SCHEDULE", "sqrt") + weight_decay = float(os.environ.get("WEIGHT_DECAY", "0.04")) +def zeropower_via_newtonschulz5(G: Tensor, steps: int = 10, eps: float = 1e-7) -> Tensor: + a, b, c = (3.4445, -4.7750, 2.0315) + X = G.bfloat16() + X /= X.norm() + eps + transposed = G.size(0) > G.size(1) + if transposed: + X = X.T + for _ in range(steps): + A = X @ X.T + B = b * A + c * A @ A + X = a * X + B @ X + return X.T if transposed else X +class Muon(torch.optim.Optimizer): + def __init__(self, params, lr: float, momentum: float, backend_steps: int, nesterov: bool = True): + super().__init__( + params, + dict(lr=lr, momentum=momentum, backend_steps=backend_steps, nesterov=nesterov), + ) + @torch.no_grad() + def step(self, closure=None): + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + distributed = dist.is_available() and dist.is_initialized() + world_size = dist.get_world_size() if distributed else 1 + rank = dist.get_rank() if distributed else 0 + for group in self.param_groups: + params = group["params"] + if not params: + continue + lr = group["lr"] + momentum = group["momentum"] + backend_steps = group["backend_steps"] + nesterov = group["nesterov"] + total_params = sum(int(p.numel()) for p in params) + updates_flat = torch.zeros(total_params, device=params[0].device, dtype=torch.bfloat16) + curr = 0 + for i, p in enumerate(params): + if i % world_size == rank and p.grad is not None: + g = p.grad + state = self.state[p] + if "momentum_buffer" not in state: + state["momentum_buffer"] = torch.zeros_like(g) + buf = state["momentum_buffer"] + buf.mul_(momentum).add_(g) + if nesterov: + g = g.add(buf, alpha=momentum) + g = zeropower_via_newtonschulz5(g, steps=backend_steps) + g *= max(1, g.size(0) / g.size(1)) ** 0.5 + updates_flat[curr : curr + p.numel()] = g.reshape(-1) + curr += p.numel() + if distributed: + dist.all_reduce(updates_flat, op=dist.ReduceOp.SUM) + curr = 0 + for p in params: + g = updates_flat[curr : curr + p.numel()].view_as(p).to(dtype=p.dtype) + p.add_(g, alpha=-lr) + wd = group.get("weight_decay", 0.0) + if wd > 0: + p.add_(p, alpha=-wd * lr) + curr += p.numel() + return loss +def int6_ste(w: Tensor) -> Tensor: + qmax = _QUANT_MAX_VAL # Uses QUANT_BITS override if set (e.g. 15 for int5, 7 for int4) + if w.ndim == 2: + abs_max = w.abs().amax(dim=1, keepdim=True).clamp_min(1e-8) + else: + abs_max = w.abs().amax().clamp_min(1e-8) + scale = abs_max / float(qmax) + w_q = (w / scale).round().clamp(-qmax, qmax) * scale + return w + (w_q - w).detach() +def build_sentencepiece_luts( + sp: spm.SentencePieceProcessor, vocab_size: int, device: torch.device +) -> tuple[Tensor, Tensor, Tensor]: + sp_vocab_size = int(sp.vocab_size()) + table_size = max(sp_vocab_size, vocab_size) + base_bytes_np = np.zeros((table_size,), dtype=np.int16) + has_leading_space_np = np.zeros((table_size,), dtype=np.bool_) + is_boundary_token_np = np.ones((table_size,), dtype=np.bool_) + for token_id in range(sp_vocab_size): + if sp.is_control(token_id) or sp.is_unknown(token_id) or sp.is_unused(token_id): + continue + is_boundary_token_np[token_id] = False + if sp.is_byte(token_id): + base_bytes_np[token_id] = 1 + continue + piece = sp.id_to_piece(token_id) + if piece.startswith("▁"): + has_leading_space_np[token_id] = True + piece = piece[1:] + base_bytes_np[token_id] = len(piece.encode("utf-8")) + return ( + torch.tensor(base_bytes_np, dtype=torch.int16, device=device), + torch.tensor(has_leading_space_np, dtype=torch.bool, device=device), + torch.tensor(is_boundary_token_np, dtype=torch.bool, device=device), + ) +def load_validation_tokens(pattern: str, seq_len: int) -> Tensor: + files = [Path(p) for p in sorted(glob.glob(pattern))] + if not files: + raise FileNotFoundError(f"No files found for pattern: {pattern}") + tokens = torch.cat([load_data_shard(file) for file in files]).contiguous() + max_tokens = int(os.environ.get("VAL_MAX_TOKENS", 0)) + if max_tokens > 0 and tokens.numel() > max_tokens: + tokens = tokens[:max_tokens].contiguous() + usable = ((tokens.numel() - 1) // seq_len) * seq_len + if usable <= 0: + raise ValueError(f"Validation split is too short for TRAIN_SEQ_LEN={seq_len}") + return tokens[: usable + 1] +def eval_val( + args: Hyperparameters, + model: nn.Module, + rank: int, + world_size: int, + device: torch.device, + grad_accum_steps: int, + val_tokens: Tensor, + base_bytes_lut: Tensor, + has_leading_space_lut: Tensor, + is_boundary_token_lut: Tensor, +) -> tuple[float, float]: + local_batch_tokens = args.val_batch_size // (world_size * grad_accum_steps) + if local_batch_tokens < args.train_seq_len: + raise ValueError( + "VAL_BATCH_SIZE must provide at least one sequence per rank; " + f"got VAL_BATCH_SIZE={args.val_batch_size}, WORLD_SIZE={world_size}, " + f"GRAD_ACCUM_STEPS={grad_accum_steps}, TRAIN_SEQ_LEN={args.train_seq_len}" + ) + local_batch_seqs = local_batch_tokens // args.train_seq_len + total_seqs = (val_tokens.numel() - 1) // args.train_seq_len + seq_start = (total_seqs * rank) // world_size + seq_end = (total_seqs * (rank + 1)) // world_size + val_loss_sum = torch.zeros((), device=device, dtype=torch.float64) + val_token_count = torch.zeros((), device=device, dtype=torch.float64) + val_byte_count = torch.zeros((), device=device, dtype=torch.float64) + model.eval() + with torch.inference_mode(): + for batch_seq_start in range(seq_start, seq_end, local_batch_seqs): + batch_seq_end = min(batch_seq_start + local_batch_seqs, seq_end) + raw_start = batch_seq_start * args.train_seq_len + raw_end = batch_seq_end * args.train_seq_len + 1 + local = val_tokens[raw_start:raw_end].to(device=device, dtype=torch.int64, non_blocking=True) + x = local[:-1].reshape(-1, args.train_seq_len) + y = local[1:].reshape(-1, args.train_seq_len) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + batch_loss = model(x, y).detach() + batch_token_count = float(y.numel()) + val_loss_sum += batch_loss.to(torch.float64) * batch_token_count + val_token_count += batch_token_count + prev_ids = x.reshape(-1) + tgt_ids = y.reshape(-1) + token_bytes = base_bytes_lut[tgt_ids].to(dtype=torch.int16) + token_bytes += (has_leading_space_lut[tgt_ids] & ~is_boundary_token_lut[prev_ids]).to(dtype=torch.int16) + val_byte_count += token_bytes.to(torch.float64).sum() + if dist.is_available() and dist.is_initialized(): + dist.all_reduce(val_loss_sum, op=dist.ReduceOp.SUM) + dist.all_reduce(val_token_count, op=dist.ReduceOp.SUM) + dist.all_reduce(val_byte_count, op=dist.ReduceOp.SUM) + val_loss = val_loss_sum / val_token_count + bits_per_token = val_loss.item() / math.log(2.0) + tokens_per_byte = val_token_count.item() / val_byte_count.item() + model.train() + return float(val_loss.item()), float(bits_per_token * tokens_per_byte) +def eval_val_sliding( + args: Hyperparameters, + base_model: nn.Module, + rank: int, + world_size: int, + device: torch.device, + val_tokens: Tensor, + base_bytes_lut: Tensor, + has_leading_space_lut: Tensor, + is_boundary_token_lut: Tensor, + stride: int, + batch_seqs: int = 32, +) -> tuple[float, float]: + seq_len = args.train_seq_len + total_tokens = val_tokens.numel() - 1 + window_starts = [ws for ws in range(0, total_tokens, stride) + if min(ws + seq_len, total_tokens) - ws >= 1] + total_windows = len(window_starts) + my_s = (total_windows * rank) // world_size + my_e = (total_windows * (rank + 1)) // world_size + my_windows = window_starts[my_s:my_e] + loss_sum = torch.zeros((), device=device, dtype=torch.float64) + token_count = torch.zeros((), device=device, dtype=torch.float64) + byte_count = torch.zeros((), device=device, dtype=torch.float64) + base_model.eval() + compiled_logits = torch.compile(base_model.forward_logits, dynamic=False, fullgraph=True) + with torch.inference_mode(): + for bi in range(0, len(my_windows), batch_seqs): + batch_ws = my_windows[bi:bi + batch_seqs] + bsz = len(batch_ws) + x_batch = torch.zeros(bsz, seq_len, dtype=torch.int64, device=device) + y_batch = torch.zeros(bsz, seq_len, dtype=torch.int64, device=device) + wlens: list[int] = [] + for i, ws in enumerate(batch_ws): + end = min(ws + seq_len, total_tokens) + wlen = end - ws + wlens.append(wlen) + chunk = val_tokens[ws:end + 1].to(dtype=torch.int64, device=device) + x_batch[i, :wlen] = chunk[:-1] + y_batch[i, :wlen] = chunk[1:] + with torch.autocast(device_type="cuda", dtype=torch.bfloat16): + logits = compiled_logits(x_batch) + nll = F.cross_entropy( + logits.reshape(-1, logits.size(-1)).float(), + y_batch.reshape(-1), + reduction="none", + ).reshape(bsz, seq_len) + for i, ws in enumerate(batch_ws): + wlen = wlens[i] + s = 0 if ws == 0 else max(wlen - stride, 0) + scored_nll = nll[i, s:wlen].to(torch.float64) + loss_sum += scored_nll.sum() + token_count += float(wlen - s) + tgt = y_batch[i, s:wlen] + prev = x_batch[i, s:wlen] + tb = base_bytes_lut[tgt].to(torch.float64) + tb += (has_leading_space_lut[tgt] & ~is_boundary_token_lut[prev]).to(torch.float64) + byte_count += tb.sum() + if dist.is_available() and dist.is_initialized(): + dist.all_reduce(loss_sum, op=dist.ReduceOp.SUM) + dist.all_reduce(token_count, op=dist.ReduceOp.SUM) + dist.all_reduce(byte_count, op=dist.ReduceOp.SUM) + val_loss = (loss_sum / token_count).item() + bits_per_token = val_loss / math.log(2.0) + tokens_per_byte = token_count.item() / byte_count.item() + base_model.train() + return val_loss, bits_per_token * tokens_per_byte +class LoRAAdapter(nn.Module): + """Low-Rank Adaptation adapter for a linear layer.""" + def __init__(self, base_linear: nn.Linear, rank: int = 8): + super().__init__() + dim_in = base_linear.in_features + dim_out = base_linear.out_features + self.A = nn.Parameter(torch.empty(dim_in, rank)) + self.B = nn.Parameter(torch.zeros(rank, dim_out)) + nn.init.kaiming_uniform_(self.A, a=math.sqrt(5)) + self.base_linear = base_linear + self._hook_handle = None + + def _lora_hook(self, module: nn.Module, input: tuple[Tensor, ...], output: Tensor) -> Tensor: + """Forward hook that adds LoRA contribution to base linear output.""" + x = input[0] + lora_delta = (x @ self.A.to(x.dtype)) @ self.B.to(x.dtype) + return output + lora_delta + + def register(self) -> None: + """Register the LoRA hook on the base linear layer.""" + if self._hook_handle is None: + self._hook_handle = self.base_linear.register_forward_hook(self._lora_hook) + + def remove(self) -> None: + """Remove the LoRA hook from the base linear layer.""" + if self._hook_handle is not None: + self._hook_handle.remove() + self._hook_handle = None + +def eval_val_sliding_lora_ttt( + args: Hyperparameters, + base_model: nn.Module, + rank: int, + world_size: int, + device: torch.device, + val_tokens: Tensor, + base_bytes_lut: Tensor, + has_leading_space_lut: Tensor, + is_boundary_token_lut: Tensor, + stride: int, + batch_seqs: int = 32, + log0=print, +) -> tuple[float, float]: + """Test-time training with LoRA adapters (only train adapter weights, not base model).""" + seq_len = args.train_seq_len + total_tokens = val_tokens.numel() - 1 + ttt_chunk = args.ttt_chunk_tokens + + window_starts = [ws for ws in range(0, total_tokens, stride) + if min(ws + seq_len, total_tokens) - ws >= stride or ws == 0] + num_chunks = (total_tokens + ttt_chunk - 1) // ttt_chunk + + chunk_windows: list[list[int]] = [[] for _ in range(num_chunks)] + for ws in window_starts: + end = min(ws + seq_len, total_tokens) + wlen = end - ws + s = 0 if ws == 0 else max(wlen - stride, 0) + scored_start = ws + s + ci = min(scored_start // ttt_chunk, num_chunks - 1) + chunk_windows[ci].append(ws) + + log0(f"lora_ttt_sliding:start chunks={num_chunks} chunk_tokens={ttt_chunk} " + f"total_windows={len(window_starts)} stride={stride} " + f"lora_rank={args.ttt_lora_rank} lora_lr={args.ttt_lora_lr} " + f"ttt_epochs={args.ttt_epochs}") + + loss_sum = torch.zeros((), device=device, dtype=torch.float64) + token_count = torch.zeros((), device=device, dtype=torch.float64) + byte_count = torch.zeros((), device=device, dtype=torch.float64) + + # Create LoRA adapters for attention projections and LM head + lora_adapters: list[LoRAAdapter] = [] + + # Add LoRA to attention blocks (c_q, c_v, proj) + if args.recurrence_repeats > 0: + # Recurrent model structure + all_blocks = list(base_model.head_blocks) + [base_model.shared_block] + list(base_model.tail_blocks) + else: + # Standard block structure + all_blocks = list(base_model.blocks) + + for block in all_blocks: + if hasattr(block, 'attn'): + attn = block.attn + # Add LoRA to Q, V, and output projection + for layer_name in ['c_q', 'c_v', 'proj']: + if hasattr(attn, layer_name): + base_linear = getattr(attn, layer_name) + adapter = LoRAAdapter(base_linear, rank=args.ttt_lora_rank).to(device) + lora_adapters.append(adapter) + + # Add LoRA to LM head (or tok_emb if tied) + if base_model.tie_embeddings: + # For tied embeddings, we need to patch tok_emb.weight usage in forward + # Create a pseudo-linear wrapper for the embedding weight + class EmbeddingAsLinear(nn.Module): + def __init__(self, weight: nn.Parameter): + super().__init__() + self.weight = weight + self.in_features = weight.size(1) + self.out_features = weight.size(0) + self.bias = None + + def forward(self, x: Tensor) -> Tensor: + return F.linear(x, self.weight) + + emb_linear = EmbeddingAsLinear(base_model.tok_emb.weight).to(device) + lm_adapter = LoRAAdapter(emb_linear, rank=args.ttt_lora_rank).to(device) + lora_adapters.append(lm_adapter) + elif base_model.lm_head is not None: + lm_adapter = LoRAAdapter(base_model.lm_head, rank=args.ttt_lora_rank).to(device) + lora_adapters.append(lm_adapter) + + # Collect all LoRA parameters + lora_params = [] + for adapter in lora_adapters: + lora_params.extend([adapter.A, adapter.B]) + + log0(f"lora_ttt_sliding:params lora={sum(p.numel() for p in lora_params)} " + f"base_frozen={sum(p.numel() for p in base_model.parameters())} " + f"adapters={len(lora_adapters)}") + + # Freeze all base model parameters + for p in base_model.parameters(): + p.requires_grad_(False) + + # Register all LoRA hooks + for adapter in lora_adapters: + adapter.register() + + # Use AdamW optimizer for LoRA params + optimizer = torch.optim.AdamW( + lora_params, + lr=args.ttt_lora_lr, + betas=(0.9, 0.999), + weight_decay=0.01, + ) + + # Compile forward paths + if not int(os.environ.get("TORCHDYNAMO_DISABLE", "0")): + compiled_logits_lora = torch.compile(base_model.forward_logits, dynamic=False, fullgraph=True) + compiled_forward_lora = torch.compile(base_model.forward, dynamic=False, fullgraph=True) + else: + compiled_logits_lora = base_model.forward_logits + compiled_forward_lora = base_model.forward + + t0 = time.perf_counter() + + for ci in range(num_chunks): + windows = chunk_windows[ci] + if not windows: + continue + + chunk_start = ci * ttt_chunk + chunk_end = min((ci + 1) * ttt_chunk, total_tokens) + + my_s = (len(windows) * rank) // world_size + my_e = (len(windows) * (rank + 1)) // world_size + my_windows = windows[my_s:my_e] + + # Score phase (inference only) + base_model.eval() + with torch.inference_mode(): + for bi in range(0, len(my_windows), batch_seqs): + batch_ws = my_windows[bi:bi + batch_seqs] + bsz = len(batch_ws) + x_batch = torch.zeros(bsz, seq_len, dtype=torch.int64, device=device) + y_batch = torch.zeros(bsz, seq_len, dtype=torch.int64, device=device) + wlens: list[int] = [] + + for i, ws in enumerate(batch_ws): + end = min(ws + seq_len, total_tokens) + wlen = end - ws + wlens.append(wlen) + chunk_tok = val_tokens[ws:end + 1].to(dtype=torch.int64, device=device) + x_batch[i, :wlen] = chunk_tok[:-1] + y_batch[i, :wlen] = chunk_tok[1:] + + with torch.autocast(device_type="cuda", dtype=torch.bfloat16): + logits = compiled_logits_lora(x_batch) + + nll = F.cross_entropy( + logits.reshape(-1, logits.size(-1)).float(), + y_batch.reshape(-1), reduction="none", + ).reshape(bsz, seq_len) + + for i, ws in enumerate(batch_ws): + wlen = wlens[i] + s = 0 if ws == 0 else max(wlen - stride, 0) + scored_nll = nll[i, s:wlen].to(torch.float64) + loss_sum += scored_nll.sum() + token_count += float(wlen - s) + + tgt, prev = y_batch[i, s:wlen], x_batch[i, s:wlen] + tb = base_bytes_lut[tgt].to(torch.float64) + tb += (has_leading_space_lut[tgt] & ~is_boundary_token_lut[prev]).to(torch.float64) + byte_count += tb.sum() + + # Train phase (update LoRA params only, skip last chunk) + is_last_chunk = (ci == num_chunks - 1) + if not is_last_chunk and args.ttt_epochs > 0: + base_model.train() + chunk_seqs = (chunk_end - chunk_start) // seq_len + + if chunk_seqs > 0: + # Cosine LR decay + cos_lr = args.ttt_lora_lr * 0.5 * (1.0 + math.cos(math.pi * ci / max(num_chunks - 1, 1))) + for pg in optimizer.param_groups: + pg['lr'] = cos_lr + + my_seq_s = (chunk_seqs * rank) // world_size + my_seq_e = (chunk_seqs * (rank + 1)) // world_size + my_chunk_seqs = my_seq_e - my_seq_s + + for _ep in range(args.ttt_epochs): + for bs in range(0, my_chunk_seqs, args.ttt_batch_seqs): + be = min(bs + args.ttt_batch_seqs, my_chunk_seqs) + actual_bs = my_seq_s + bs + start_tok = chunk_start + actual_bs * seq_len + end_tok = chunk_start + (my_seq_s + be) * seq_len + 1 + + if end_tok > val_tokens.numel(): + continue + + local = val_tokens[start_tok:end_tok].to(device=device, dtype=torch.int64) + x = local[:-1].reshape(-1, seq_len) + y = local[1:].reshape(-1, seq_len) + + optimizer.zero_grad(set_to_none=True) + + with torch.autocast(device_type="cuda", dtype=torch.bfloat16): + loss = compiled_forward_lora(x, y) + + loss.backward() + + # All-reduce LoRA gradients across GPUs + if world_size > 1: + for p in lora_params: + if p.grad is not None: + dist.all_reduce(p.grad, op=dist.ReduceOp.AVG) + + # Gradient clipping on LoRA params + torch.nn.utils.clip_grad_norm_(lora_params, args.ttt_grad_clip) + optimizer.step() + + if rank == 0 and (ci % 10 == 0 or ci == num_chunks - 1): + elapsed = time.perf_counter() - t0 + rl = loss_sum.item() / max(token_count.item(), 1) + rbpb = rl / math.log(2.0) * (token_count.item() / max(byte_count.item(), 1)) if token_count.item() > 0 else 0.0 + log0(f" lora_ttt_chunk [{ci+1}/{num_chunks}] bpb={rbpb:.6f} time={elapsed:.1f}s") + + if dist.is_available() and dist.is_initialized(): + dist.all_reduce(loss_sum, op=dist.ReduceOp.SUM) + dist.all_reduce(token_count, op=dist.ReduceOp.SUM) + dist.all_reduce(byte_count, op=dist.ReduceOp.SUM) + + val_loss = (loss_sum / token_count).item() + val_bpb = val_loss / math.log(2.0) * (token_count.item() / byte_count.item()) + + # Clean up: remove all LoRA hooks and restore base model + for adapter in lora_adapters: + adapter.remove() + + # Restore requires_grad on base model + for p in base_model.parameters(): + p.requires_grad_(True) + + base_model.eval() + + log0(f"lora_ttt_sliding:done val_loss={val_loss:.6f} val_bpb={val_bpb:.6f} " + f"elapsed={time.perf_counter() - t0:.1f}s") + + return val_loss, val_bpb + +def eval_val_sliding_ttt( + args: Hyperparameters, + base_model: nn.Module, + rank: int, + world_size: int, + device: torch.device, + val_tokens: Tensor, + base_bytes_lut: Tensor, + has_leading_space_lut: Tensor, + is_boundary_token_lut: Tensor, + stride: int, + batch_seqs: int = 32, + log0=print, +) -> tuple[float, float]: + seq_len = args.train_seq_len + total_tokens = val_tokens.numel() - 1 + ttt_chunk = args.ttt_chunk_tokens + window_starts = [ws for ws in range(0, total_tokens, stride) + if min(ws + seq_len, total_tokens) - ws >= stride or ws == 0] + num_chunks = (total_tokens + ttt_chunk - 1) // ttt_chunk + chunk_windows: list[list[int]] = [[] for _ in range(num_chunks)] + for ws in window_starts: + end = min(ws + seq_len, total_tokens) + wlen = end - ws + s = 0 if ws == 0 else max(wlen - stride, 0) + scored_start = ws + s + ci = min(scored_start // ttt_chunk, num_chunks - 1) + chunk_windows[ci].append(ws) + log0(f"ttt_sliding:start chunks={num_chunks} chunk_tokens={ttt_chunk} " + f"total_windows={len(window_starts)} stride={stride} " + f"ttt_lr={args.ttt_lr} ttt_epochs={args.ttt_epochs} " + f"freeze_blocks={args.ttt_freeze_blocks}") + loss_sum = torch.zeros((), device=device, dtype=torch.float64) + token_count = torch.zeros((), device=device, dtype=torch.float64) + byte_count = torch.zeros((), device=device, dtype=torch.float64) + frozen_block_ids = set(range(min(args.ttt_freeze_blocks, len(base_model.blocks)))) + ttt_params = [] + for name, p in base_model.named_parameters(): + freeze = False + for bi in frozen_block_ids: + if f"blocks.{bi}." in name: + freeze = True + break + if freeze: + p.requires_grad_(False) + else: + p.requires_grad_(True) + ttt_params.append(p) + log0(f"ttt_sliding:params unfrozen={sum(p.numel() for p in ttt_params)} " + f"frozen={sum(p.numel() for p in base_model.parameters() if not p.requires_grad)}") + optimizer = torch.optim.SGD(ttt_params, lr=args.ttt_lr, momentum=args.ttt_momentum) + # Compile forward paths for TTT (matches eval_val_sliding behavior) + if not int(os.environ.get("TORCHDYNAMO_DISABLE", "0")): + compiled_logits_ttt = torch.compile(base_model.forward_logits, dynamic=False, fullgraph=True) + compiled_forward_ttt = torch.compile(base_model.forward, dynamic=False, fullgraph=True) + else: + compiled_logits_ttt = base_model.forward_logits + compiled_forward_ttt = base_model.forward + t0 = time.perf_counter() + for ci in range(num_chunks): + windows = chunk_windows[ci] + if not windows: + continue + chunk_start = ci * ttt_chunk + chunk_end = min((ci + 1) * ttt_chunk, total_tokens) + my_s = (len(windows) * rank) // world_size + my_e = (len(windows) * (rank + 1)) // world_size + my_windows = windows[my_s:my_e] + base_model.eval() + with torch.inference_mode(): + for bi in range(0, len(my_windows), batch_seqs): + batch_ws = my_windows[bi:bi + batch_seqs] + bsz = len(batch_ws) + x_batch = torch.zeros(bsz, seq_len, dtype=torch.int64, device=device) + y_batch = torch.zeros(bsz, seq_len, dtype=torch.int64, device=device) + wlens: list[int] = [] + for i, ws in enumerate(batch_ws): + end = min(ws + seq_len, total_tokens) + wlen = end - ws + wlens.append(wlen) + chunk_tok = val_tokens[ws:end + 1].to(dtype=torch.int64, device=device) + x_batch[i, :wlen] = chunk_tok[:-1] + y_batch[i, :wlen] = chunk_tok[1:] + with torch.autocast(device_type="cuda", dtype=torch.bfloat16): + logits = compiled_logits_ttt(x_batch) + nll = F.cross_entropy( + logits.reshape(-1, logits.size(-1)).float(), + y_batch.reshape(-1), reduction="none", + ).reshape(bsz, seq_len) + for i, ws in enumerate(batch_ws): + wlen = wlens[i] + s = 0 if ws == 0 else max(wlen - stride, 0) + scored_nll = nll[i, s:wlen].to(torch.float64) + loss_sum += scored_nll.sum() + token_count += float(wlen - s) + tgt, prev = y_batch[i, s:wlen], x_batch[i, s:wlen] + tb = base_bytes_lut[tgt].to(torch.float64) + tb += (has_leading_space_lut[tgt] & ~is_boundary_token_lut[prev]).to(torch.float64) + byte_count += tb.sum() + is_last_chunk = (ci == num_chunks - 1) + if not is_last_chunk and args.ttt_epochs > 0: + base_model.train() + chunk_seqs = (chunk_end - chunk_start) // seq_len + if chunk_seqs > 0: + cos_lr = args.ttt_lr * 0.5 * (1.0 + math.cos(math.pi * ci / max(num_chunks - 1, 1))) + for pg in optimizer.param_groups: + pg['lr'] = cos_lr + my_seq_s = (chunk_seqs * rank) // world_size + my_seq_e = (chunk_seqs * (rank + 1)) // world_size + my_chunk_seqs = my_seq_e - my_seq_s + for _ep in range(args.ttt_epochs): + for bs in range(0, my_chunk_seqs, args.ttt_batch_seqs): + be = min(bs + args.ttt_batch_seqs, my_chunk_seqs) + actual_bs = my_seq_s + bs + start_tok = chunk_start + actual_bs * seq_len + end_tok = chunk_start + (my_seq_s + be) * seq_len + 1 + if end_tok > val_tokens.numel(): + continue + local = val_tokens[start_tok:end_tok].to(device=device, dtype=torch.int64) + x = local[:-1].reshape(-1, seq_len) + y = local[1:].reshape(-1, seq_len) + optimizer.zero_grad(set_to_none=True) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16): + loss = compiled_forward_ttt(x, y) + loss.backward() + if world_size > 1: + for p in ttt_params: + if p.grad is not None: + dist.all_reduce(p.grad, op=dist.ReduceOp.AVG) + torch.nn.utils.clip_grad_norm_(ttt_params, args.ttt_grad_clip) + optimizer.step() + if rank == 0 and (ci % 10 == 0 or ci == num_chunks - 1): + elapsed = time.perf_counter() - t0 + rl = loss_sum.item() / max(token_count.item(), 1) + rbpb = rl / math.log(2.0) * (token_count.item() / max(byte_count.item(), 1)) if token_count.item() > 0 else 0.0 + log0(f" ttt_chunk [{ci+1}/{num_chunks}] bpb={rbpb:.6f} time={elapsed:.1f}s") + if dist.is_available() and dist.is_initialized(): + dist.all_reduce(loss_sum, op=dist.ReduceOp.SUM) + dist.all_reduce(token_count, op=dist.ReduceOp.SUM) + dist.all_reduce(byte_count, op=dist.ReduceOp.SUM) + val_loss = (loss_sum / token_count).item() + val_bpb = val_loss / math.log(2.0) * (token_count.item() / byte_count.item()) + for p in base_model.parameters(): + p.requires_grad_(True) + base_model.eval() + log0(f"ttt_sliding:done val_loss={val_loss:.6f} val_bpb={val_bpb:.6f} " + f"elapsed={time.perf_counter() - t0:.1f}s") + return val_loss, val_bpb +_INT6_MODE = bool(int(os.environ.get("INT6_QAT", "1"))) +_INT3_MODE = bool(int(os.environ.get("INT3_QUANT", "0"))) # Ternary quantization +_QUANT_OVERRIDE = int(os.environ.get("QUANT_BITS", "0")) # Override: 3,4,5,6,8 +if _QUANT_OVERRIDE > 0: + _QUANT_BITS = _QUANT_OVERRIDE + _QUANT_MAX_VAL = (1 << (_QUANT_BITS - 1)) - 1 # e.g. int5 -> 15, int4 -> 7 +elif _INT3_MODE: + _QUANT_MAX_VAL = 3 # Ternary: -3, -2, -1, 0, 1, 2, 3 + _QUANT_BITS = 3 +elif _INT6_MODE: + _QUANT_MAX_VAL = 31 + _QUANT_BITS = 6 +else: + _QUANT_MAX_VAL = 127 + _QUANT_BITS = 8 +CONTROL_TENSOR_NAME_PATTERNS = tuple( + pattern + for pattern in os.environ.get( + "CONTROL_TENSOR_NAME_PATTERNS", + "attn_scale,attn_scales,mlp_scale,mlp_scales,resid_mix,resid_mixes,q_gain,skip_weight,skip_weights,smear", + ).split(",") + if pattern +) +INT8_KEEP_FLOAT_FP32_NAME_PATTERNS = tuple( + pattern + for pattern in os.environ.get( + "INT8_KEEP_FLOAT_FP32_NAME_PATTERNS", + ",".join(CONTROL_TENSOR_NAME_PATTERNS), + ).split(",") + if pattern +) +INT8_KEEP_FLOAT_MAX_NUMEL = 65_536 +INT8_KEEP_FLOAT_STORE_DTYPE = torch.float16 +INT8_PER_ROW_SCALE_DTYPE = torch.float16 +INT8_CLIP_PERCENTILE = 99.99984 +INT8_CLIP_Q = INT8_CLIP_PERCENTILE / 100.0 +def tensor_nbytes(t: Tensor) -> int: + return int(t.numel()) * int(t.element_size()) +def keep_float_tensor(name: str, t: Tensor, passthrough_orig_dtypes: dict[str, str]) -> Tensor: + if any(pattern in name for pattern in INT8_KEEP_FLOAT_FP32_NAME_PATTERNS): + return t.float().contiguous() + if t.dtype in {torch.float32, torch.bfloat16}: + passthrough_orig_dtypes[name] = str(t.dtype).removeprefix("torch.") + return t.to(dtype=INT8_KEEP_FLOAT_STORE_DTYPE).contiguous() + return t +_GPTQ_LITE_PERCENTILES = [0.9999, 0.99995, 0.99999, 0.999995, 1.0] +def quantize_float_tensor(t: Tensor, gpu_device: torch.device | None = None) -> tuple[Tensor, Tensor]: + # Run GPTQ-lite search on GPU when available (much faster for large 2D weight matrices) + dev = gpu_device if (gpu_device is not None and t.ndim == 2 and t.numel() > 0) else t.device + t32 = t.float().to(dev) + qmax = _QUANT_MAX_VAL + if t32.ndim == 2 and t32.numel() > 0: + abs_vals = t32.abs() + best_q = None + best_scale = None + best_mse = None + for pct in _GPTQ_LITE_PERCENTILES: + if pct >= 1.0: + clip_abs = abs_vals.amax(dim=1) + else: + clip_abs = torch.quantile(abs_vals, pct, dim=1) + clip_abs = clip_abs.clamp_min(1e-8) + scale = (clip_abs / qmax).clamp_min(1.0 / qmax) + clipped = torch.maximum(torch.minimum(t32, clip_abs[:, None]), -clip_abs[:, None]) + q = torch.clamp(torch.round(clipped / scale[:, None]), -qmax, qmax) + recon = q * scale[:, None] + mse = (t32 - recon).square().mean(dim=1) + if best_mse is None: + best_mse = mse + best_q = q + best_scale = scale + else: + improved = mse < best_mse + if improved.any(): + best_mse = torch.where(improved, mse, best_mse) + best_q = torch.where(improved[:, None], q, best_q) + best_scale = torch.where(improved, scale, best_scale) + return best_q.cpu().to(torch.int8).contiguous(), best_scale.cpu().to(dtype=INT8_PER_ROW_SCALE_DTYPE).contiguous() + if t32.ndim == 2: + return torch.empty_like(t32, dtype=torch.int8, device="cpu"), torch.empty((t32.shape[0],), dtype=INT8_PER_ROW_SCALE_DTYPE) + t32 = t32.cpu() + clip_abs = float(torch.quantile(t32.abs().flatten(), INT8_CLIP_Q).item()) if t32.numel() else 0.0 + scale = torch.tensor(clip_abs / qmax if clip_abs > 0 else 1.0, dtype=torch.float32) + q = torch.clamp(torch.round(torch.clamp(t32, -clip_abs, clip_abs) / scale), -qmax, qmax).to(torch.int8).contiguous() + return q, scale +def quantize_state_dict_int8(state_dict: dict[str, Tensor], gpu_device: torch.device | None = None, + rank: int = 0, world_size: int = 1): + """Quantize state dict. When world_size > 1, distributes large tensor quantization across ranks.""" + quantized: dict[str, Tensor] = {} + scales: dict[str, Tensor] = {} + dtypes: dict[str, str] = {} + passthrough: dict[str, Tensor] = {} + passthrough_orig_dtypes: dict[str, str] = {} + qmeta: dict[str, dict[str, object]] = {} + stats = dict.fromkeys( + ("param_count", "num_tensors", "num_float_tensors", "num_nonfloat_tensors", "baseline_tensor_bytes", "int8_payload_bytes"), + 0, + ) + # Separate tensors into passthrough vs needs-quantization + to_quantize: list[tuple[str, Tensor]] = [] + for name, tensor in state_dict.items(): + t = tensor.detach().to("cpu").contiguous() + stats["param_count"] += int(t.numel()) + stats["num_tensors"] += 1 + stats["baseline_tensor_bytes"] += tensor_nbytes(t) + if not t.is_floating_point(): + stats["num_nonfloat_tensors"] += 1 + passthrough[name] = t + stats["int8_payload_bytes"] += tensor_nbytes(t) + continue + if t.numel() <= INT8_KEEP_FLOAT_MAX_NUMEL: + kept = keep_float_tensor(name, t, passthrough_orig_dtypes) + passthrough[name] = kept + stats["int8_payload_bytes"] += tensor_nbytes(kept) + continue + stats["num_float_tensors"] += 1 + to_quantize.append((name, t)) + # Distribute quantization work across ranks (round-robin by tensor index) + for idx, (name, t) in enumerate(to_quantize): + if world_size > 1 and idx % world_size != rank: + continue # Another rank handles this tensor + q, s = quantize_float_tensor(t, gpu_device=gpu_device) + if s.ndim > 0: + qmeta[name] = {"scheme": "per_row", "axis": 0} + quantized[name] = q + scales[name] = s + dtypes[name] = str(t.dtype).removeprefix("torch.") + stats["int8_payload_bytes"] += tensor_nbytes(q) + tensor_nbytes(s) + # Gather results from all ranks + if world_size > 1 and dist.is_available() and dist.is_initialized(): + # Each rank broadcasts its quantized tensors to all others + for idx, (name, t) in enumerate(to_quantize): + owner = idx % world_size + if owner == rank: + # Broadcast shape info then tensor data + q_tensor = quantized[name] + s_tensor = scales[name] + # Send via broadcast + q_gpu = q_tensor.to(gpu_device) if gpu_device is not None else q_tensor.cuda() + s_gpu = s_tensor.to(gpu_device) if gpu_device is not None else s_tensor.cuda() + dist.broadcast(q_gpu, src=owner) + dist.broadcast(s_gpu, src=owner) + quantized[name] = q_gpu.cpu() + scales[name] = s_gpu.cpu() + else: + # Allocate matching tensors and receive + q_shape = (t.shape[0],) + t.shape[1:] if t.ndim == 2 else t.shape + q_gpu = torch.empty(q_shape, dtype=torch.int8, device=gpu_device if gpu_device is not None else "cuda") + if t.ndim == 2 and t.numel() > 0: + s_gpu = torch.empty(t.shape[0], dtype=INT8_PER_ROW_SCALE_DTYPE, device=gpu_device if gpu_device is not None else "cuda") + else: + s_gpu = torch.empty((), dtype=torch.float32, device=gpu_device if gpu_device is not None else "cuda") + dist.broadcast(q_gpu, src=owner) + dist.broadcast(s_gpu, src=owner) + quantized[name] = q_gpu.cpu() + scales[name] = s_gpu.cpu() + dtypes[name] = str(t.dtype).removeprefix("torch.") + if s_gpu.ndim > 0: + qmeta[name] = {"scheme": "per_row", "axis": 0} + stats["int8_payload_bytes"] += tensor_nbytes(quantized[name]) + tensor_nbytes(scales[name]) + obj: dict[str, object] = { + "__quant_format__": f"int{_QUANT_BITS}_clean_per_row_v1", + "quantized": quantized, + "scales": scales, + "dtypes": dtypes, + "passthrough": passthrough, + } + if qmeta: + obj["qmeta"] = qmeta + if passthrough_orig_dtypes: + obj["passthrough_orig_dtypes"] = passthrough_orig_dtypes + return obj, stats +def dequantize_state_dict_int8(obj: dict[str, object]) -> dict[str, Tensor]: + out: dict[str, Tensor] = {} + qmeta = obj.get("qmeta", {}) + passthrough_orig_dtypes = obj.get("passthrough_orig_dtypes", {}) + for name, q in obj["quantized"].items(): + dtype = getattr(torch, obj["dtypes"][name]) + s = obj["scales"][name] + if qmeta.get(name, {}).get("scheme") == "per_row" or s.ndim > 0: + s = s.to(dtype=torch.float32) + out[name] = (q.float() * s.view(q.shape[0], *([1] * (q.ndim - 1)))).to(dtype=dtype).contiguous() + else: + scale = float(s.item()) + out[name] = (q.float() * scale).to(dtype=dtype).contiguous() + for name, t in obj["passthrough"].items(): + out_t = t.detach().to("cpu").contiguous() + orig_dtype = passthrough_orig_dtypes.get(name) + if isinstance(orig_dtype, str): + out_t = out_t.to(dtype=getattr(torch, orig_dtype)).contiguous() + out[name] = out_t + return out +def load_data_shard(file: Path) -> Tensor: + header_bytes = 256 * np.dtype(" None: + self.file_idx = (self.file_idx + 1) % len(self.files) + self.tokens = load_data_shard(self.files[self.file_idx]) + self.pos = 0 + def take(self, n: int) -> Tensor: + chunks: list[Tensor] = [] + remaining = n + while remaining > 0: + avail = self.tokens.numel() - self.pos + if avail <= 0: + self._advance_file() + continue + k = min(remaining, avail) + chunks.append(self.tokens[self.pos : self.pos + k]) + self.pos += k + remaining -= k + return chunks[0] if len(chunks) == 1 else torch.cat(chunks) +class DistributedTokenLoader: + def __init__(self, pattern: str, rank: int, world_size: int, device: torch.device): + self.rank = rank + self.world_size = world_size + self.device = device + self.stream = TokenStream(pattern) + def next_batch(self, global_tokens: int, seq_len: int, grad_accum_steps: int) -> tuple[Tensor, Tensor]: + local_tokens = global_tokens // (self.world_size * grad_accum_steps) + per_rank_span = local_tokens + 1 + chunk = self.stream.take(per_rank_span * self.world_size) + start = self.rank * per_rank_span + local = chunk[start : start + per_rank_span].to(dtype=torch.int64) + x = local[:-1].reshape(-1, seq_len) + y = local[1:].reshape(-1, seq_len) + return x.to(self.device, non_blocking=True), y.to(self.device, non_blocking=True) +class RMSNorm(nn.Module): + def __init__(self, eps: float | None = None): + super().__init__() + self.eps = eps + def forward(self, x: Tensor) -> Tensor: + return F.rms_norm(x, (x.size(-1),), eps=self.eps) +class CastedLinear(nn.Linear): + _int6_qat: bool = False + def forward(self, x: Tensor) -> Tensor: + w = self.weight.to(x.dtype) + if CastedLinear._int6_qat and self.training: + w = int6_ste(w) + bias = self.bias.to(x.dtype) if self.bias is not None else None + return F.linear(x, w, bias) +def restore_low_dim_params_to_fp32(module: nn.Module) -> None: + with torch.no_grad(): + for name, param in module.named_parameters(): + if (param.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS)) and param.dtype != torch.float32: + param.data = param.data.float() +class BigramHash(nn.Module): + def __init__(self, num_buckets: int, dim: int): + super().__init__() + self.num_buckets = num_buckets + self.emb = nn.Embedding(num_buckets, dim) + nn.init.zeros_(self.emb.weight) + def forward(self, input_ids: Tensor) -> Tensor: + prev = F.pad(input_ids[:, :-1], (1, 0), value=0) + bigram_hash = (prev * 31 + input_ids) % self.num_buckets + return self.emb(bigram_hash) +class SmearGate(nn.Module): + def __init__(self, dim: int): + super().__init__() + self.gate = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + def forward(self, x: Tensor) -> Tensor: + g = torch.sigmoid(self.gate.to(dtype=x.dtype))[None, None, :] + positions = torch.arange(1, x.shape[1] + 1, device=x.device, dtype=x.dtype).view(1, -1, 1) + smooth = torch.cumsum(x, dim=1) / positions + return g * x + (1 - g) * smooth +class Rotary(nn.Module): + def __init__(self, dim: int, base: float = 10000.0): + super().__init__() + inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.float32) / dim)) + self.register_buffer("inv_freq", inv_freq, persistent=False) + self._seq_len_cached = 0 + self._cos_cached: Tensor | None = None + self._sin_cached: Tensor | None = None + def forward(self, seq_len: int, device: torch.device, dtype: torch.dtype) -> tuple[Tensor, Tensor]: + if ( + self._cos_cached is None + or self._sin_cached is None + or self._seq_len_cached != seq_len + or self._cos_cached.device != device + ): + t = torch.arange(seq_len, device=device, dtype=self.inv_freq.dtype) + freqs = torch.outer(t, self.inv_freq.to(device)) + self._cos_cached = freqs.cos()[None, None, :, :] + self._sin_cached = freqs.sin()[None, None, :, :] + self._seq_len_cached = seq_len + return self._cos_cached.to(dtype=dtype), self._sin_cached.to(dtype=dtype) +def apply_rotary_emb(x: Tensor, cos: Tensor, sin: Tensor) -> Tensor: + half = x.size(-1) // 2 + x1, x2 = x[..., :half], x[..., half:] + return torch.cat((x1 * cos + x2 * sin, x1 * (-sin) + x2 * cos), dim=-1) +class CausalSelfAttention(nn.Module): + def __init__( + self, + dim: int, + num_heads: int, + num_kv_heads: int, + rope_base: float, + qk_gain_init: float, + rope_partial_dims: int = 0, + use_xsa: bool = False, + ): + super().__init__() + if dim % num_heads != 0: + raise ValueError("model_dim must be divisible by num_heads") + if num_heads % num_kv_heads != 0: + raise ValueError("num_heads must be divisible by num_kv_heads") + self.num_heads = num_heads + self.num_kv_heads = num_kv_heads + self.head_dim = dim // num_heads + if self.head_dim % 2 != 0: + raise ValueError("head_dim must be even for RoPE") + kv_dim = self.num_kv_heads * self.head_dim + self.c_q = CastedLinear(dim, dim, bias=False) + self.c_k = CastedLinear(dim, kv_dim, bias=False) + self.c_v = CastedLinear(dim, kv_dim, bias=False) + self.proj = CastedLinear(dim, dim, bias=False) + self.proj._zero_init = True + self.q_gain = nn.Parameter(torch.full((num_heads,), qk_gain_init, dtype=torch.float32)) + rope_dim = rope_partial_dims if 0 < rope_partial_dims < self.head_dim else self.head_dim + self.rope_dim = rope_dim + self.rotary = Rotary(rope_dim, base=rope_base) + self.use_xsa = use_xsa + def forward(self, x: Tensor) -> Tensor: + bsz, seqlen, dim = x.shape + q = self.c_q(x).reshape(bsz, seqlen, self.num_heads, self.head_dim).transpose(1, 2) + k = self.c_k(x).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim).transpose(1, 2) + v = self.c_v(x).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim).transpose(1, 2) + q = F.rms_norm(q, (q.size(-1),)) + k = F.rms_norm(k, (k.size(-1),)) + cos, sin = self.rotary(seqlen, x.device, q.dtype) + if self.rope_dim < self.head_dim: + q_rope, q_pass = q[..., :self.rope_dim], q[..., self.rope_dim:] + k_rope, k_pass = k[..., :self.rope_dim], k[..., self.rope_dim:] + q_rope = apply_rotary_emb(q_rope, cos, sin) + k_rope = apply_rotary_emb(k_rope, cos, sin) + q = torch.cat([q_rope, q_pass], dim=-1) + k = torch.cat([k_rope, k_pass], dim=-1) + else: + q = apply_rotary_emb(q, cos, sin) + k = apply_rotary_emb(k, cos, sin) + q = q * self.q_gain.to(dtype=q.dtype)[None, :, None, None] + y = F.scaled_dot_product_attention( + q, + k, + v, + attn_mask=None, + is_causal=True, + enable_gqa=(self.num_kv_heads != self.num_heads), + ) + if self.use_xsa: + if self.num_kv_heads != self.num_heads: + repeats = self.num_heads // self.num_kv_heads + v_expanded = v.repeat_interleave(repeats, dim=1) + else: + v_expanded = v + y = y - v_expanded + y = y.transpose(1, 2).contiguous().reshape(bsz, seqlen, dim) + return self.proj(y) +class MLP(nn.Module): + def __init__(self, dim: int, mlp_mult: int): + super().__init__() + hidden = mlp_mult * dim + self.fc = CastedLinear(dim, hidden, bias=False) + self.proj = CastedLinear(hidden, dim, bias=False) + self.proj._zero_init = True + def forward(self, x: Tensor) -> Tensor: + x = F.leaky_relu(self.fc(x), negative_slope=0.5) + return self.proj(x.square()) +class MoEMLP(nn.Module): + def __init__(self, dim: int, mlp_mult: int, num_experts: int): + super().__init__() + self.num_experts = num_experts + expert_hidden = (mlp_mult * dim) // num_experts + self.experts = nn.ModuleList([ + nn.ModuleDict({ + "fc": CastedLinear(dim, expert_hidden, bias=False), + "proj": CastedLinear(expert_hidden, dim, bias=False), + }) + for _ in range(num_experts) + ]) + for e in self.experts: + e["proj"]._zero_init = True + self.router = nn.Linear(dim, num_experts, bias=False) + def forward(self, x: Tensor) -> Tensor: + bsz, seqlen, dim = x.shape + logits = self.router(x.detach()) + weights = torch.softmax(logits, dim=-1) + indices = weights.argmax(dim=-1) + out = torch.zeros_like(x) + for i, expert in enumerate(self.experts): + mask = (indices == i) + if not mask.any(): + continue + tokens = x[mask] + h = F.leaky_relu(expert["fc"](tokens), negative_slope=0.5) + h = expert["proj"](h.square()) + out[mask] = h + return out +class Block(nn.Module): + def __init__( + self, + dim: int, + num_heads: int, + num_kv_heads: int, + mlp_mult: int, + rope_base: float, + qk_gain_init: float, + use_smeargate: bool = False, + rope_partial_dims: int = 0, + use_xsa: bool = False, + ln_scale_factor: float = 1.0, + moe_num_experts: int = 0, + ): + super().__init__() + self.attn_norm = RMSNorm() + self.mlp_norm = RMSNorm() + self.attn = CausalSelfAttention( + dim, num_heads, num_kv_heads, rope_base, qk_gain_init, + rope_partial_dims=rope_partial_dims, + use_xsa=use_xsa, + ) + self.mlp = MoEMLP(dim, mlp_mult, moe_num_experts) if moe_num_experts > 1 else MLP(dim, mlp_mult) + self.attn_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + self.mlp_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + self.resid_mix = nn.Parameter(torch.stack((torch.ones(dim), torch.zeros(dim))).float()) + self.smear = SmearGate(dim) if use_smeargate else None + self.ln_scale_factor = ln_scale_factor + def forward(self, x: Tensor, x0: Tensor) -> Tensor: + mix = self.resid_mix.to(dtype=x.dtype) + x = mix[0][None, None, :] * x + mix[1][None, None, :] * x0 + normed = self.attn_norm(x) + if self.ln_scale_factor != 1.0: + normed = normed * self.ln_scale_factor + attn_out = self.attn(normed) + x = x + self.attn_scale.to(dtype=x.dtype)[None, None, :] * attn_out + mlp_normed = self.mlp_norm(x) + if self.ln_scale_factor != 1.0: + mlp_normed = mlp_normed * self.ln_scale_factor + x = x + self.mlp_scale.to(dtype=x.dtype)[None, None, :] * self.mlp(mlp_normed) + if self.smear is not None: + x = self.smear(x) + return x +class GPT(nn.Module): + def __init__( + self, + vocab_size: int, + num_layers: int, + model_dim: int, + num_heads: int, + num_kv_heads: int, + mlp_mult: int, + tie_embeddings: bool, + tied_embed_init_std: float, + logit_softcap: float, + rope_base: float, + qk_gain_init: float, + bigramhash_buckets: int = 0, + use_smeargate: bool = False, + unet_skips: bool = True, + rope_partial_dims: int = 0, + ln_scale: bool = False, + xsa_layers: int = 0, + recurrence_repeats: int = 0, + recurrence_unique_head: int = 2, + recurrence_unique_tail: int = 2, + moe_num_experts: int = 0, + ): + super().__init__() + if logit_softcap <= 0.0: + raise ValueError(f"logit_softcap must be positive, got {logit_softcap}") + self.tie_embeddings = tie_embeddings + self.tied_embed_init_std = tied_embed_init_std + self.logit_softcap = logit_softcap + self.tok_emb = nn.Embedding(vocab_size, model_dim) + self.bigram_hash = BigramHash(bigramhash_buckets, model_dim) if bigramhash_buckets > 0 else None + self.recurrence_repeats = recurrence_repeats + if recurrence_repeats > 0: + self.num_encoder_layers = 0 + self.num_decoder_layers = 0 + self.num_skip_weights = 0 + self.skip_weights = nn.Parameter(torch.zeros(0, dtype=torch.float32)) + n_head = recurrence_unique_head + n_tail = recurrence_unique_tail + effective_layers = n_head + recurrence_repeats + n_tail + def make_block(layer_idx: int, eff_total: int) -> Block: + xsa_start_idx = eff_total - xsa_layers if xsa_layers > 0 else eff_total + return Block( + model_dim, num_heads, num_kv_heads, mlp_mult, rope_base, qk_gain_init, + use_smeargate=use_smeargate, + rope_partial_dims=rope_partial_dims, + use_xsa=(layer_idx >= xsa_start_idx), + ln_scale_factor=(1.0 / math.sqrt(layer_idx + 1)) if ln_scale else 1.0, + moe_num_experts=moe_num_experts, + ) + self.head_blocks = nn.ModuleList([make_block(i, effective_layers) for i in range(n_head)]) + self.shared_block = make_block(n_head, effective_layers) + self.recurrence_gate = nn.Parameter(torch.full((recurrence_repeats,), 0.5, dtype=torch.float32)) + self.tail_blocks = nn.ModuleList([ + make_block(n_head + recurrence_repeats + i, effective_layers) for i in range(n_tail) + ]) + self.blocks = nn.ModuleList() + else: + if unet_skips: + self.num_encoder_layers = num_layers // 2 + self.num_decoder_layers = num_layers - self.num_encoder_layers + self.num_skip_weights = min(self.num_encoder_layers, self.num_decoder_layers) + else: + self.num_encoder_layers = num_layers + self.num_decoder_layers = 0 + self.num_skip_weights = 0 + self.skip_weights = nn.Parameter(torch.ones(self.num_skip_weights, model_dim, dtype=torch.float32)) + xsa_start = num_layers - xsa_layers if xsa_layers > 0 else num_layers + self.blocks = nn.ModuleList( + [ + Block( + model_dim, num_heads, num_kv_heads, mlp_mult, rope_base, qk_gain_init, + use_smeargate=use_smeargate, + rope_partial_dims=rope_partial_dims, + use_xsa=(i >= xsa_start), + ln_scale_factor=(1.0 / math.sqrt(i + 1)) if ln_scale else 1.0, + moe_num_experts=moe_num_experts, + ) + for i in range(num_layers) + ] + ) + self.head_blocks = nn.ModuleList() + self.shared_block = None + self.recurrence_gate = None + self.tail_blocks = nn.ModuleList() + self.final_norm = RMSNorm() + self.lm_head = None if tie_embeddings else CastedLinear(model_dim, vocab_size, bias=False) + if self.lm_head is not None: + self.lm_head._zero_init = True + self._init_weights() + def _init_weights(self) -> None: + if self.tie_embeddings: + nn.init.normal_(self.tok_emb.weight, mean=0.0, std=self.tied_embed_init_std) + for module in self.modules(): + if isinstance(module, nn.Linear) and getattr(module, "_zero_init", False): + nn.init.zeros_(module.weight) + def forward(self, input_ids: Tensor, target_ids: Tensor) -> Tensor: + x = self.tok_emb(input_ids) + if self.bigram_hash is not None: + x = x + self.bigram_hash(input_ids) + x = F.rms_norm(x, (x.size(-1),)) + x0 = x + if self.recurrence_repeats > 0: + for block in self.head_blocks: + x = block(x, x0) + for r in range(self.recurrence_repeats): + gate = torch.sigmoid(self.recurrence_gate[r]).to(dtype=x.dtype) + residual = x + x = self.shared_block(x, x0) + x = gate * x + (1.0 - gate) * residual + for block in self.tail_blocks: + x = block(x, x0) + else: + skips: list[Tensor] = [] + for i in range(self.num_encoder_layers): + x = self.blocks[i](x, x0) + skips.append(x) + for i in range(self.num_decoder_layers): + if skips: + x = x + self.skip_weights[i].to(dtype=x.dtype)[None, None, :] * skips.pop() + x = self.blocks[self.num_encoder_layers + i](x, x0) + x = self.final_norm(x).reshape(-1, x.size(-1)) + targets = target_ids.reshape(-1) + if self.tie_embeddings: + logits_proj = F.linear(x, self.tok_emb.weight) + else: + if self.lm_head is None: + raise RuntimeError("lm_head is required when tie_embeddings=False") + logits_proj = self.lm_head(x) + logits = self.logit_softcap * torch.tanh(logits_proj / self.logit_softcap) + if getattr(self, '_return_logits', False): + return logits.float().reshape(input_ids.shape[0], input_ids.shape[1], -1) + return F.cross_entropy(logits.float(), targets, reduction="mean") + def forward_logits(self, input_ids: Tensor) -> Tensor: + x = self.tok_emb(input_ids) + if self.bigram_hash is not None: + x = x + self.bigram_hash(input_ids) + x = F.rms_norm(x, (x.size(-1),)) + x0 = x + if self.recurrence_repeats > 0: + for block in self.head_blocks: + x = block(x, x0) + for r in range(self.recurrence_repeats): + gate = torch.sigmoid(self.recurrence_gate[r]).to(dtype=x.dtype) + residual = x + x = self.shared_block(x, x0) + x = gate * x + (1.0 - gate) * residual + for block in self.tail_blocks: + x = block(x, x0) + else: + skips: list[Tensor] = [] + for i in range(self.num_encoder_layers): + x = self.blocks[i](x, x0) + skips.append(x) + for i in range(self.num_decoder_layers): + if skips: + x = x + self.skip_weights[i].to(dtype=x.dtype)[None, None, :] * skips.pop() + x = self.blocks[self.num_encoder_layers + i](x, x0) + x = self.final_norm(x) + if self.tie_embeddings: + logits_proj = F.linear(x, self.tok_emb.weight) + else: + if self.lm_head is None: + raise RuntimeError("lm_head is required when tie_embeddings=False") + logits_proj = self.lm_head(x) + return self.logit_softcap * torch.tanh(logits_proj / self.logit_softcap) +def main() -> None: + global zeropower_via_newtonschulz5 + code = Path(__file__).read_text(encoding="utf-8") + args = Hyperparameters() + if not int(os.environ.get("TORCHDYNAMO_DISABLE", "0")): + zeropower_via_newtonschulz5 = torch.compile(zeropower_via_newtonschulz5) + distributed = "RANK" in os.environ and "WORLD_SIZE" in os.environ + rank = int(os.environ.get("RANK", "0")) + world_size = int(os.environ.get("WORLD_SIZE", "1")) + local_rank = int(os.environ.get("LOCAL_RANK", "0")) + if world_size <= 0: + raise ValueError(f"WORLD_SIZE must be positive, got {world_size}") + if 8 % world_size != 0: + raise ValueError(f"WORLD_SIZE={world_size} must divide 8 so grad_accum_steps stays integral") + grad_accum_steps = 8 // world_size + grad_scale = 1.0 / grad_accum_steps + if not torch.cuda.is_available(): + raise RuntimeError("CUDA is required") + device = torch.device("cuda", local_rank) + torch.cuda.set_device(device) + if distributed: + dist.init_process_group(backend="nccl", device_id=device) + dist.barrier() + master_process = rank == 0 + torch.backends.cuda.matmul.allow_tf32 = True + torch.backends.cudnn.allow_tf32 = True + from torch.backends.cuda import enable_cudnn_sdp, enable_flash_sdp, enable_math_sdp, enable_mem_efficient_sdp + _use_flash = sys.platform != "win32" + enable_cudnn_sdp(False) + enable_flash_sdp(_use_flash) + enable_mem_efficient_sdp(False) + enable_math_sdp(not _use_flash) + # Create per-run output directory so artifacts/logs never collide + run_dir = Path(os.environ.get("RUN_OUTPUT_DIR", f"runs/{args.run_id}")) + if master_process: + run_dir.mkdir(parents=True, exist_ok=True) + # Backward-compat: also keep logs/ symlink/copy + os.makedirs("logs", exist_ok=True) + logfile = None + if master_process: + logfile = str(run_dir / "log.txt") + # Also write to legacy logs/ path for compatibility + legacy_logfile = f"logs/{args.run_id}.txt" + if not Path(legacy_logfile).exists(): + try: + os.symlink(os.path.abspath(logfile), legacy_logfile) + except OSError: + pass # Windows or cross-device; will write to both + print(logfile) + def log0(msg: str, console: bool = True) -> None: + if not master_process: + return + if console: + print(msg) + if logfile is not None: + with open(logfile, "a", encoding="utf-8") as f: + print(msg, file=f) + log0(code, console=False) + log0("=" * 100, console=False) + log0(f"Running Python {sys.version}", console=False) + log0(f"Running PyTorch {torch.__version__}", console=False) + log0( + subprocess.run(["nvidia-smi"], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True, check=False).stdout, + console=False, + ) + log0("=" * 100, console=False) + random.seed(args.seed) + np.random.seed(args.seed) + torch.manual_seed(args.seed) + torch.cuda.manual_seed_all(args.seed) + if not args.tokenizer_path.endswith(".model"): + raise ValueError(f"Script only setup for SentencePiece .model file: {args.tokenizer_path}") + sp = spm.SentencePieceProcessor(model_file=args.tokenizer_path) + if int(sp.vocab_size()) != args.vocab_size: + raise ValueError( + f"VOCAB_SIZE={args.vocab_size} does not match tokenizer vocab_size={int(sp.vocab_size())}" + ) + dataset_dir = Path(args.data_path).resolve() + actual_train_files = len(list(dataset_dir.glob("fineweb_train_*.bin"))) + val_tokens = load_validation_tokens(args.val_files, args.train_seq_len) + base_bytes_lut, has_leading_space_lut, is_boundary_token_lut = build_sentencepiece_luts( + sp, args.vocab_size, device + ) + log0(f"val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path={args.tokenizer_path}") + log0(f"train_loader:dataset:{dataset_dir.name} train_shards:{actual_train_files}") + log0(f"val_loader:shards pattern={args.val_files} tokens:{val_tokens.numel() - 1}") + CastedLinear._int6_qat = args.int6_qat + base_model = GPT( + vocab_size=args.vocab_size, + num_layers=args.num_layers, + model_dim=args.model_dim, + num_heads=args.num_heads, + num_kv_heads=args.num_kv_heads, + mlp_mult=args.mlp_mult, + tie_embeddings=args.tie_embeddings, + tied_embed_init_std=args.tied_embed_init_std, + logit_softcap=args.logit_softcap, + rope_base=args.rope_base, + qk_gain_init=args.qk_gain_init, + bigramhash_buckets=args.bigramhash_buckets, + use_smeargate=args.smeargate, + unet_skips=args.unet_skips, + rope_partial_dims=args.rope_partial_dims, + ln_scale=args.ln_scale, + xsa_layers=args.xsa_layers, + recurrence_repeats=args.recurrence_repeats, + recurrence_unique_head=args.recurrence_unique_head, + recurrence_unique_tail=args.recurrence_unique_tail, + moe_num_experts=args.moe_num_experts, + ).to(device).bfloat16() + for module in base_model.modules(): + if isinstance(module, CastedLinear): + module.float() + restore_low_dim_params_to_fp32(base_model) + if int(os.environ.get("TORCHDYNAMO_DISABLE", "0")): + compiled_model = base_model + else: + compiled_model = torch.compile(base_model, dynamic=False, fullgraph=True, mode="reduce-overhead") + model: nn.Module = DDP(compiled_model, device_ids=[local_rank], broadcast_buffers=False) if distributed else compiled_model + if args.recurrence_repeats > 0: + block_named_params = ( + list(base_model.head_blocks.named_parameters()) + + list(base_model.shared_block.named_parameters()) + + list(base_model.tail_blocks.named_parameters()) + ) + if base_model.recurrence_gate is not None: + block_named_params.append(("recurrence_gate", base_model.recurrence_gate)) + else: + block_named_params = list(base_model.blocks.named_parameters()) + matrix_params = [ + p + for name, p in block_named_params + if p.ndim == 2 and not any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) + ] + scalar_params = [ + p + for name, p in block_named_params + if p.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) + ] + if base_model.skip_weights.numel() > 0: + scalar_params.append(base_model.skip_weights) + token_lr = args.tied_embed_lr if args.tie_embeddings else args.embed_lr + embed_params: list[nn.Parameter] = [base_model.tok_emb.weight] + if base_model.bigram_hash is not None: + embed_params.append(base_model.bigram_hash.emb.weight) + optimizer_tok = torch.optim.Adam( + [{"params": embed_params, "lr": token_lr, "base_lr": token_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + fused=True, + ) + optimizer_muon = Muon( + matrix_params, + lr=args.matrix_lr, + momentum=args.muon_momentum, + backend_steps=args.muon_backend_steps, + ) + for group in optimizer_muon.param_groups: + group["base_lr"] = args.matrix_lr + group["weight_decay"] = args.weight_decay + optimizer_scalar = torch.optim.Adam( + [{"params": scalar_params, "lr": args.scalar_lr, "base_lr": args.scalar_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + fused=True, + ) + optimizers: list[torch.optim.Optimizer] = [optimizer_tok, optimizer_muon, optimizer_scalar] + if base_model.lm_head is not None: + optimizer_head = torch.optim.Adam( + [{"params": [base_model.lm_head.weight], "lr": args.head_lr, "base_lr": args.head_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + fused=True, + ) + optimizers.insert(1, optimizer_head) + n_params = sum(p.numel() for p in base_model.parameters()) + log0(f"model_params:{n_params}") + log0(f"world_size:{world_size} grad_accum_steps:{grad_accum_steps}") + log0(f"sdp_backends:cudnn=False flash={_use_flash} mem_efficient=False math={not _use_flash}") + log0(f"attention_mode:gqa num_heads:{args.num_heads} num_kv_heads:{args.num_kv_heads}") + log0( + f"tie_embeddings:{args.tie_embeddings} embed_lr:{token_lr} " + f"head_lr:{args.head_lr if base_model.lm_head is not None else 0.0} " + f"matrix_lr:{args.matrix_lr} scalar_lr:{args.scalar_lr}" + ) + log0( + f"train_batch_tokens:{args.train_batch_tokens} train_seq_len:{args.train_seq_len} " + f"iterations:{args.iterations} warmup_steps:{args.warmup_steps} " + f"max_wallclock_seconds:{args.max_wallclock_seconds:.3f}" + ) + log0(f"seed:{args.seed}") + log0( + f"meta_baseline: int6_qat:{args.int6_qat} bigramhash:{args.bigramhash_buckets} " + f"smeargate:{args.smeargate} ema_decay:{args.ema_decay} unet_skips:{args.unet_skips} " + f"warmdown_iters:{args.warmdown_iters} compression:{'zstd-22' if _HAS_ZSTD else 'zlib-9'}" + ) + log0( + f"strong_gains: leaky_relu_sq:True rope_partial_dims:{args.rope_partial_dims} " + f"ln_scale:{args.ln_scale} xsa_layers:{args.xsa_layers}" + ) + if args.recurrence_repeats > 0: + eff = args.recurrence_unique_head + args.recurrence_repeats + args.recurrence_unique_tail + log0( + f"depth_recurrence: head={args.recurrence_unique_head} repeats={args.recurrence_repeats} " + f"tail={args.recurrence_unique_tail} effective_layers={eff}" + ) + train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) + def zero_grad_all() -> None: + for opt in optimizers: + opt.zero_grad(set_to_none=True) + max_wallclock_ms = 1000.0 * args.max_wallclock_seconds if args.max_wallclock_seconds > 0 else None + def lr_mul(step: int, elapsed_ms: float) -> float: + if args.warmdown_iters <= 0: + return 1.0 + if max_wallclock_ms is None: + warmdown_start = max(args.iterations - args.warmdown_iters, 0) + if not (warmdown_start <= step < args.iterations): + return 1.0 + t_frac = (step - warmdown_start) / max(args.warmdown_iters, 1) + else: + step_ms = elapsed_ms / max(step, 1) + warmdown_ms = args.warmdown_iters * step_ms + remaining_ms = max(max_wallclock_ms - elapsed_ms, 0.0) + if remaining_ms > warmdown_ms: + return 1.0 + t_frac = 1.0 - remaining_ms / max(warmdown_ms, 1e-9) + if args.warmdown_schedule == "sqrt": + return max(1.0 - (t_frac ** 0.5), 0) + else: + return max(1.0 - t_frac, 0) + if args.warmup_steps > 0: + initial_model_state = {name: tensor.detach().cpu().clone() for name, tensor in base_model.state_dict().items()} + initial_optimizer_states = [copy.deepcopy(opt.state_dict()) for opt in optimizers] + model.train() + for warmup_step in range(args.warmup_steps): + zero_grad_all() + for micro_step in range(grad_accum_steps): + if distributed: + model.require_backward_grad_sync = micro_step == grad_accum_steps - 1 + x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + warmup_loss = model(x, y) + (warmup_loss * grad_scale).backward() + for opt in optimizers: + opt.step() + zero_grad_all() + if args.warmup_steps <= 20 or (warmup_step + 1) % 10 == 0 or warmup_step + 1 == args.warmup_steps: + log0(f"warmup_step:{warmup_step + 1}/{args.warmup_steps}") + base_model.load_state_dict(initial_model_state, strict=True) + for opt, state in zip(optimizers, initial_optimizer_states, strict=True): + opt.load_state_dict(state) + zero_grad_all() + if distributed: + model.require_backward_grad_sync = True + train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) + ema_state: dict[str, Tensor] | None = None + if args.ema_decay > 0: + ema_state = {name: param.data.clone() for name, param in base_model.named_parameters()} + log0(f"ema:enabled decay={args.ema_decay}") + training_time_ms = 0.0 + stop_after_step: int | None = None + torch.cuda.synchronize() + t0 = time.perf_counter() + step = 0 + while True: + last_step = step == args.iterations or (stop_after_step is not None and step >= stop_after_step) + should_validate = last_step or (args.val_loss_every > 0 and step % args.val_loss_every == 0) + if should_validate: + torch.cuda.synchronize() + training_time_ms += 1000.0 * (time.perf_counter() - t0) + ema_backup: dict[str, Tensor] | None = None + if ema_state is not None: + ema_backup = {} + with torch.no_grad(): + for name, p in base_model.named_parameters(): + ema_backup[name] = p.data.clone() + p.data.copy_(ema_state[name]) + val_loss, val_bpb = eval_val( + args, + model, + rank, + world_size, + device, + grad_accum_steps, + val_tokens, + base_bytes_lut, + has_leading_space_lut, + is_boundary_token_lut, + ) + if ema_backup is not None: + with torch.no_grad(): + for name, p in base_model.named_parameters(): + p.data.copy_(ema_backup[name]) + del ema_backup + log0( + f"step:{step}/{args.iterations} val_loss:{val_loss:.4f} val_bpb:{val_bpb:.4f} " + f"train_time:{training_time_ms:.0f}ms step_avg:{training_time_ms / max(step, 1):.2f}ms" + ) + torch.cuda.synchronize() + t0 = time.perf_counter() + if last_step: + if stop_after_step is not None and step < args.iterations: + log0( + f"stopping_early: wallclock_cap train_time:{training_time_ms:.0f}ms " + f"step:{step}/{args.iterations}" + ) + break + elapsed_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) + scale = lr_mul(step, elapsed_ms) + zero_grad_all() + train_loss = torch.zeros((), device=device) + for micro_step in range(grad_accum_steps): + if distributed: + model.require_backward_grad_sync = micro_step == grad_accum_steps - 1 + x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + loss = model(x, y) + train_loss += loss.detach() + (loss * grad_scale).backward() + train_loss /= grad_accum_steps + frac = min(step / args.muon_momentum_warmup_steps, 1.0) if args.muon_momentum_warmup_steps > 0 else 1.0 + muon_momentum = (1 - frac) * args.muon_momentum_warmup_start + frac * args.muon_momentum + for group in optimizer_muon.param_groups: + group["momentum"] = muon_momentum + for opt in optimizers: + for group in opt.param_groups: + group["lr"] = group["base_lr"] * scale + if args.grad_clip_norm > 0: + torch.nn.utils.clip_grad_norm_(base_model.parameters(), args.grad_clip_norm) + for opt in optimizers: + opt.step() + zero_grad_all() + if ema_state is not None: + with torch.no_grad(): + decay = args.ema_decay + for name, p in base_model.named_parameters(): + ema_state[name].lerp_(p.data, 1.0 - decay) + step += 1 + if args.checkpoint_every > 0 and step % args.checkpoint_every == 0 and master_process: + ckpt_dir = run_dir / "checkpoints" + ckpt_dir.mkdir(exist_ok=True) + ckpt_path = ckpt_dir / f"ckpt_step{step}.pt" + ckpt_data = {"step": step, "model": base_model.state_dict()} + if ema_state is not None: + ckpt_data["ema"] = {k: v.cpu() for k, v in ema_state.items()} + torch.save(ckpt_data, ckpt_path) + ckpts = sorted(ckpt_dir.glob("ckpt_step*.pt"), key=lambda p: p.stat().st_mtime) + for old in ckpts[:-2]: + old.unlink(missing_ok=True) + log0(f"checkpoint:saved {ckpt_path} (kept {min(len(ckpts), 2)} latest)") + approx_training_time_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) + should_log_train = ( + args.train_log_every > 0 + and (step <= 10 or step % args.train_log_every == 0 or stop_after_step is not None) + ) + if should_log_train: + log0( + f"step:{step}/{args.iterations} train_loss:{train_loss.item():.4f} " + f"train_time:{approx_training_time_ms:.0f}ms step_avg:{approx_training_time_ms / step:.2f}ms" + ) + reached_cap = max_wallclock_ms is not None and approx_training_time_ms >= max_wallclock_ms + if distributed and max_wallclock_ms is not None: + reached_cap_tensor = torch.tensor(int(reached_cap), device=device) + dist.all_reduce(reached_cap_tensor, op=dist.ReduceOp.MAX) + reached_cap = bool(reached_cap_tensor.item()) + if stop_after_step is None and reached_cap: + stop_after_step = step + log0( + f"peak memory allocated: {torch.cuda.max_memory_allocated() // 1024 // 1024} MiB " + f"reserved: {torch.cuda.max_memory_reserved() // 1024 // 1024} MiB" + ) + if ema_state is not None: + with torch.no_grad(): + for name, p in base_model.named_parameters(): + p.data.copy_(ema_state[name]) + log0("ema:applied EMA weights for export") + fp32_path = str(run_dir / "final_model.pt") + if master_process: + torch.save(base_model.state_dict(), fp32_path) + model_bytes = os.path.getsize(fp32_path) + code_bytes = len(code.encode("utf-8")) + log0(f"Serialized model: {model_bytes} bytes → {fp32_path}") + log0(f"Code size: {code_bytes} bytes") + log0(f"Total submission size: {model_bytes + code_bytes} bytes") + quant_label = f"int{_QUANT_BITS}" + compress_label = "zstd" if _HAS_ZSTD else "zlib" + log0(f"quantization:start {quant_label} distributing across {world_size} GPUs") + torch.cuda.synchronize() + t_quant = time.perf_counter() + quant_obj, quant_stats = quantize_state_dict_int8( + base_model.state_dict(), gpu_device=device, rank=rank, world_size=world_size, + ) + torch.cuda.synchronize() + log0(f"quantization:done {quant_label} time={1000.0 * (time.perf_counter() - t_quant):.0f}ms") + quant_buf = io.BytesIO() + torch.save(quant_obj, quant_buf) + quant_raw = quant_buf.getvalue() + if _HAS_BROTLI: + shuffled = _byte_shuffle(quant_raw) + quant_blob = _brotli.compress(shuffled, quality=11) + elif _HAS_ZSTD: + quant_blob = zstd.ZstdCompressor(level=22, threads=-1).compress(quant_raw) + else: + quant_blob = lzma.compress(quant_raw, preset=9) + quant_raw_bytes = len(quant_raw) + artifact_name = str(run_dir / f"final_model.{quant_label}.ptz") + if master_process: + with open(artifact_name, "wb") as f: + f.write(quant_blob) + quant_file_bytes = os.path.getsize(artifact_name) + code_bytes = len(code.encode("utf-8")) + ratio = quant_stats["baseline_tensor_bytes"] / max(quant_stats["int8_payload_bytes"], 1) + log0( + f"Serialized model {quant_label}+{compress_label}: {quant_file_bytes} bytes " + f"(payload:{quant_stats['int8_payload_bytes']} raw_torch:{quant_raw_bytes} payload_ratio:{ratio:.2f}x)" + ) + log0(f"Total submission size {quant_label}+{compress_label}: {quant_file_bytes + code_bytes} bytes") + log0(f"artifact_path:{artifact_name}") + if distributed: + dist.barrier() + with open(artifact_name, "rb") as f: + quant_blob_disk = f.read() + if _HAS_BROTLI: + quant_raw = _byte_unshuffle(_brotli.decompress(quant_blob_disk)) + elif _HAS_ZSTD: + quant_raw = zstd.ZstdDecompressor().decompress(quant_blob_disk) + else: + quant_raw = lzma.decompress(quant_blob_disk) + quant_state = torch.load(io.BytesIO(quant_raw), map_location="cpu") + base_model.load_state_dict(dequantize_state_dict_int8(quant_state), strict=True) + torch.cuda.synchronize() + t_qeval = time.perf_counter() + q_val_loss, q_val_bpb = eval_val( + args, + model, + rank, + world_size, + device, + grad_accum_steps, + val_tokens, + base_bytes_lut, + has_leading_space_lut, + is_boundary_token_lut, + ) + torch.cuda.synchronize() + log0( + f"final_{quant_label}_{compress_label}_roundtrip val_loss:{q_val_loss:.4f} val_bpb:{q_val_bpb:.4f} " + f"eval_time:{1000.0 * (time.perf_counter() - t_qeval):.0f}ms" + ) + log0(f"final_{quant_label}_{compress_label}_roundtrip_exact val_loss:{q_val_loss:.8f} val_bpb:{q_val_bpb:.8f}") + if args.eval_stride > 0 and args.eval_stride < args.train_seq_len: + torch.cuda.synchronize() + t_slide = time.perf_counter() + sw_val_loss, sw_val_bpb = eval_val_sliding( + args, base_model, rank, world_size, device, + val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, + stride=args.eval_stride, + ) + torch.cuda.synchronize() + log0( + f"final_sliding_window val_loss:{sw_val_loss:.4f} val_bpb:{sw_val_bpb:.4f} " + f"stride:{args.eval_stride} eval_time:{1000.0 * (time.perf_counter() - t_slide):.0f}ms" + ) + log0(f"final_sliding_window_exact val_loss:{sw_val_loss:.8f} val_bpb:{sw_val_bpb:.8f}") + if args.ttt_enabled: + torch.cuda.synchronize() + t_ttt = time.perf_counter() + + # Choose LoRA TTT or full TTT based on TTT_LORA_ENABLED flag + if args.ttt_lora_enabled: + ttt_loss, ttt_bpb = eval_val_sliding_lora_ttt( + args, base_model, rank, world_size, device, + val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, + stride=args.eval_stride if args.eval_stride > 0 else 64, + log0=log0, + ) + ttt_label = "legal_lora_ttt" + else: + ttt_loss, ttt_bpb = eval_val_sliding_ttt( + args, base_model, rank, world_size, device, + val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, + stride=args.eval_stride if args.eval_stride > 0 else 64, + log0=log0, + ) + ttt_label = "legal_ttt" + + torch.cuda.synchronize() + log0( + f"{ttt_label} val_loss:{ttt_loss:.4f} val_bpb:{ttt_bpb:.4f} " + f"eval_time:{1000.0 * (time.perf_counter() - t_ttt):.0f}ms" + ) + log0(f"{ttt_label}_exact val_loss:{ttt_loss:.8f} val_bpb:{ttt_bpb:.8f}") + if int(os.environ.get("NGRAM_EVAL", "0")): + ngram_max_order = int(os.environ.get("NGRAM_MAX_ORDER", "9")) + log0(f"ngram_eval:starting max_order={ngram_max_order}") + torch.cuda.synchronize() + t_ngram = time.perf_counter() + try: + from ngram_eval import eval_val_ngram + ng_val_loss, ng_val_bpb = eval_val_ngram( + args, base_model, model, rank, world_size, device, grad_accum_steps, + val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, + max_order=ngram_max_order, log_fn=log0, + ) + torch.cuda.synchronize() + log0( + f"ngram_eval_result val_loss:{ng_val_loss:.4f} val_bpb:{ng_val_bpb:.4f} " + f"eval_time:{1000.0 * (time.perf_counter() - t_ngram):.0f}ms" + ) + log0(f"ngram_eval_result_exact val_loss:{ng_val_loss:.8f} val_bpb:{ng_val_bpb:.8f}") + except Exception as e: + log0(f"ngram_eval:FAILED {e}") + if distributed: + dist.destroy_process_group() +if __name__ == "__main__": + main() diff --git a/train_gpt_phase0_muoneqr.py b/train_gpt_phase0_muoneqr.py new file mode 100644 index 0000000000..249365af84 --- /dev/null +++ b/train_gpt_phase0_muoneqr.py @@ -0,0 +1,1875 @@ +from __future__ import annotations +import copy +import glob +import io +import math +import os +import random +import subprocess +import sys +import time +import uuid +import zlib +from pathlib import Path +try: + import zstandard as zstd + _HAS_ZSTD = True +except ImportError: + _HAS_ZSTD = False +import numpy as np +import sentencepiece as spm +import torch +import torch.distributed as dist +import torch.nn.functional as F +from torch import Tensor, nn +from torch.nn.parallel import DistributedDataParallel as DDP +class Hyperparameters: + data_path = os.environ.get("DATA_PATH", "./data/datasets/fineweb10B_sp1024") + train_files = os.path.join(data_path, "fineweb_train_*.bin") + val_files = os.path.join(data_path, "fineweb_val_*.bin") + tokenizer_path = os.environ.get("TOKENIZER_PATH", "./data/tokenizers/fineweb_1024_bpe.model") + run_id = os.environ.get("RUN_ID", str(uuid.uuid4())) + seed = int(os.environ.get("SEED", 1337)) + val_batch_size = int(os.environ.get("VAL_BATCH_SIZE", 524_288)) + val_loss_every = int(os.environ.get("VAL_LOSS_EVERY", 1000)) + train_log_every = int(os.environ.get("TRAIN_LOG_EVERY", 200)) + val_max_tokens = int(os.environ.get("VAL_MAX_TOKENS", 0)) + checkpoint_every = int(os.environ.get("CHECKPOINT_EVERY", 0)) + iterations = int(os.environ.get("ITERATIONS", 20000)) + warmdown_iters = int(os.environ.get("WARMDOWN_ITERS", 3500)) + warmup_steps = int(os.environ.get("WARMUP_STEPS", 20)) + train_batch_tokens = int(os.environ.get("TRAIN_BATCH_TOKENS", 524_288)) + train_seq_len = int(os.environ.get("TRAIN_SEQ_LEN", 1024)) + max_wallclock_seconds = float(os.environ.get("MAX_WALLCLOCK_SECONDS", 600.0)) + qk_gain_init = float(os.environ.get("QK_GAIN_INIT", 1.5)) + vocab_size = int(os.environ.get("VOCAB_SIZE", 1024)) + num_layers = int(os.environ.get("NUM_LAYERS", 11)) + num_kv_heads = int(os.environ.get("NUM_KV_HEADS", 4)) + model_dim = int(os.environ.get("MODEL_DIM", 512)) + num_heads = int(os.environ.get("NUM_HEADS", 8)) + mlp_mult = int(os.environ.get("MLP_MULT", 3)) + tie_embeddings = bool(int(os.environ.get("TIE_EMBEDDINGS", "1"))) + rope_base = float(os.environ.get("ROPE_BASE", 10000.0)) + logit_softcap = float(os.environ.get("LOGIT_SOFTCAP", 30.0)) + embed_lr = float(os.environ.get("EMBED_LR", 0.6)) + head_lr = float(os.environ.get("HEAD_LR", 0.008)) + tied_embed_lr = float(os.environ.get("TIED_EMBED_LR", 0.05)) + tied_embed_init_std = float(os.environ.get("TIED_EMBED_INIT_STD", 0.005)) + matrix_lr = float(os.environ.get("MATRIX_LR", 0.04)) + scalar_lr = float(os.environ.get("SCALAR_LR", 0.04)) + muon_momentum = float(os.environ.get("MUON_MOMENTUM", 0.95)) + muon_backend_steps = int(os.environ.get("MUON_BACKEND_STEPS", 5)) + muon_momentum_warmup_start = float(os.environ.get("MUON_MOMENTUM_WARMUP_START", 0.85)) + muon_momentum_warmup_steps = int(os.environ.get("MUON_MOMENTUM_WARMUP_STEPS", 500)) + beta1 = float(os.environ.get("BETA1", 0.9)) + beta2 = float(os.environ.get("BETA2", 0.95)) + adam_eps = float(os.environ.get("ADAM_EPS", 1e-8)) + grad_clip_norm = float(os.environ.get("GRAD_CLIP_NORM", 0.0)) + bigramhash_buckets = int(os.environ.get("BIGRAMHASH_BUCKETS", 4096)) + ema_decay = float(os.environ.get("EMA_DECAY", 0.997)) + smeargate = bool(int(os.environ.get("SMEARGATE", "1"))) + unet_skips = bool(int(os.environ.get("UNET_SKIPS", "1"))) + int6_qat = bool(int(os.environ.get("INT6_QAT", "1"))) + rope_partial_dims = int(os.environ.get("ROPE_PARTIAL_DIMS", 16)) + ln_scale = bool(int(os.environ.get("LN_SCALE", "1"))) + xsa_layers = int(os.environ.get("XSA_LAYERS", 4)) + eval_stride = int(os.environ.get("EVAL_STRIDE", 0)) + ttt_enabled = bool(int(os.environ.get("TTT_ENABLED", "0"))) + ttt_lr = float(os.environ.get("TTT_LR", 0.002)) + ttt_epochs = int(os.environ.get("TTT_EPOCHS", 3)) + ttt_chunk_tokens = int(os.environ.get("TTT_CHUNK_TOKENS", 32768)) + ttt_freeze_blocks = int(os.environ.get("TTT_FREEZE_BLOCKS", 0)) + ttt_momentum = float(os.environ.get("TTT_MOMENTUM", 0.9)) + ttt_batch_seqs = int(os.environ.get("TTT_BATCH_SEQS", 32)) + ttt_grad_clip = float(os.environ.get("TTT_GRAD_CLIP", 1.0)) + ttt_lora_enabled = bool(int(os.environ.get("TTT_LORA_ENABLED", "0"))) + ttt_lora_rank = int(os.environ.get("TTT_LORA_RANK", 8)) + ttt_lora_lr = float(os.environ.get("TTT_LORA_LR", 0.001)) + recurrence_repeats = int(os.environ.get("RECURRENCE_REPEATS", 0)) + recurrence_unique_head = int(os.environ.get("RECURRENCE_UNIQUE_HEAD", 2)) + recurrence_unique_tail = int(os.environ.get("RECURRENCE_UNIQUE_TAIL", 2)) + moe_num_experts = int(os.environ.get("MOE_NUM_EXPERTS", 0)) + muoneq_r = bool(int(os.environ.get("MUONEQ_R", "1"))) +def zeropower_via_newtonschulz5(G: Tensor, steps: int = 10, eps: float = 1e-7) -> Tensor: + a, b, c = (3.4445, -4.7750, 2.0315) + X = G.bfloat16() + X /= X.norm() + eps + transposed = G.size(0) > G.size(1) + if transposed: + X = X.T + for _ in range(steps): + A = X @ X.T + B = b * A + c * A @ A + X = a * X + B @ X + return X.T if transposed else X +class Muon(torch.optim.Optimizer): + def __init__(self, params, lr: float, momentum: float, backend_steps: int, nesterov: bool = True): + super().__init__( + params, + dict(lr=lr, momentum=momentum, backend_steps=backend_steps, nesterov=nesterov), + ) + @torch.no_grad() + def step(self, closure=None): + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + distributed = dist.is_available() and dist.is_initialized() + world_size = dist.get_world_size() if distributed else 1 + rank = dist.get_rank() if distributed else 0 + for group in self.param_groups: + params = group["params"] + if not params: + continue + lr = group["lr"] + momentum = group["momentum"] + backend_steps = group["backend_steps"] + nesterov = group["nesterov"] + total_params = sum(int(p.numel()) for p in params) + updates_flat = torch.zeros(total_params, device=params[0].device, dtype=torch.bfloat16) + curr = 0 + for i, p in enumerate(params): + if i % world_size == rank and p.grad is not None: + g = p.grad + state = self.state[p] + if "momentum_buffer" not in state: + state["momentum_buffer"] = torch.zeros_like(g) + buf = state["momentum_buffer"] + buf.mul_(momentum).add_(g) + if nesterov: + g = g.add(buf, alpha=momentum) + if group.get("muoneq_r", False): + g = g / g.norm(dim=-1, keepdim=True).clamp_min(1e-7) + g = zeropower_via_newtonschulz5(g, steps=backend_steps) + g *= max(1, g.size(0) / g.size(1)) ** 0.5 + updates_flat[curr : curr + p.numel()] = g.reshape(-1) + curr += p.numel() + if distributed: + dist.all_reduce(updates_flat, op=dist.ReduceOp.SUM) + curr = 0 + for p in params: + g = updates_flat[curr : curr + p.numel()].view_as(p).to(dtype=p.dtype) + p.add_(g, alpha=-lr) + curr += p.numel() + return loss +def int6_ste(w: Tensor) -> Tensor: + qmax = _QUANT_MAX_VAL # Uses QUANT_BITS override if set (e.g. 15 for int5, 7 for int4) + if w.ndim == 2: + abs_max = w.abs().amax(dim=1, keepdim=True).clamp_min(1e-8) + else: + abs_max = w.abs().amax().clamp_min(1e-8) + scale = abs_max / float(qmax) + w_q = (w / scale).round().clamp(-qmax, qmax) * scale + return w + (w_q - w).detach() +def build_sentencepiece_luts( + sp: spm.SentencePieceProcessor, vocab_size: int, device: torch.device +) -> tuple[Tensor, Tensor, Tensor]: + sp_vocab_size = int(sp.vocab_size()) + table_size = max(sp_vocab_size, vocab_size) + base_bytes_np = np.zeros((table_size,), dtype=np.int16) + has_leading_space_np = np.zeros((table_size,), dtype=np.bool_) + is_boundary_token_np = np.ones((table_size,), dtype=np.bool_) + for token_id in range(sp_vocab_size): + if sp.is_control(token_id) or sp.is_unknown(token_id) or sp.is_unused(token_id): + continue + is_boundary_token_np[token_id] = False + if sp.is_byte(token_id): + base_bytes_np[token_id] = 1 + continue + piece = sp.id_to_piece(token_id) + if piece.startswith("▁"): + has_leading_space_np[token_id] = True + piece = piece[1:] + base_bytes_np[token_id] = len(piece.encode("utf-8")) + return ( + torch.tensor(base_bytes_np, dtype=torch.int16, device=device), + torch.tensor(has_leading_space_np, dtype=torch.bool, device=device), + torch.tensor(is_boundary_token_np, dtype=torch.bool, device=device), + ) +def load_validation_tokens(pattern: str, seq_len: int) -> Tensor: + files = [Path(p) for p in sorted(glob.glob(pattern))] + if not files: + raise FileNotFoundError(f"No files found for pattern: {pattern}") + tokens = torch.cat([load_data_shard(file) for file in files]).contiguous() + max_tokens = int(os.environ.get("VAL_MAX_TOKENS", 0)) + if max_tokens > 0 and tokens.numel() > max_tokens: + tokens = tokens[:max_tokens].contiguous() + usable = ((tokens.numel() - 1) // seq_len) * seq_len + if usable <= 0: + raise ValueError(f"Validation split is too short for TRAIN_SEQ_LEN={seq_len}") + return tokens[: usable + 1] +def eval_val( + args: Hyperparameters, + model: nn.Module, + rank: int, + world_size: int, + device: torch.device, + grad_accum_steps: int, + val_tokens: Tensor, + base_bytes_lut: Tensor, + has_leading_space_lut: Tensor, + is_boundary_token_lut: Tensor, +) -> tuple[float, float]: + local_batch_tokens = args.val_batch_size // (world_size * grad_accum_steps) + if local_batch_tokens < args.train_seq_len: + raise ValueError( + "VAL_BATCH_SIZE must provide at least one sequence per rank; " + f"got VAL_BATCH_SIZE={args.val_batch_size}, WORLD_SIZE={world_size}, " + f"GRAD_ACCUM_STEPS={grad_accum_steps}, TRAIN_SEQ_LEN={args.train_seq_len}" + ) + local_batch_seqs = local_batch_tokens // args.train_seq_len + total_seqs = (val_tokens.numel() - 1) // args.train_seq_len + seq_start = (total_seqs * rank) // world_size + seq_end = (total_seqs * (rank + 1)) // world_size + val_loss_sum = torch.zeros((), device=device, dtype=torch.float64) + val_token_count = torch.zeros((), device=device, dtype=torch.float64) + val_byte_count = torch.zeros((), device=device, dtype=torch.float64) + model.eval() + with torch.inference_mode(): + for batch_seq_start in range(seq_start, seq_end, local_batch_seqs): + batch_seq_end = min(batch_seq_start + local_batch_seqs, seq_end) + raw_start = batch_seq_start * args.train_seq_len + raw_end = batch_seq_end * args.train_seq_len + 1 + local = val_tokens[raw_start:raw_end].to(device=device, dtype=torch.int64, non_blocking=True) + x = local[:-1].reshape(-1, args.train_seq_len) + y = local[1:].reshape(-1, args.train_seq_len) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + batch_loss = model(x, y).detach() + batch_token_count = float(y.numel()) + val_loss_sum += batch_loss.to(torch.float64) * batch_token_count + val_token_count += batch_token_count + prev_ids = x.reshape(-1) + tgt_ids = y.reshape(-1) + token_bytes = base_bytes_lut[tgt_ids].to(dtype=torch.int16) + token_bytes += (has_leading_space_lut[tgt_ids] & ~is_boundary_token_lut[prev_ids]).to(dtype=torch.int16) + val_byte_count += token_bytes.to(torch.float64).sum() + if dist.is_available() and dist.is_initialized(): + dist.all_reduce(val_loss_sum, op=dist.ReduceOp.SUM) + dist.all_reduce(val_token_count, op=dist.ReduceOp.SUM) + dist.all_reduce(val_byte_count, op=dist.ReduceOp.SUM) + val_loss = val_loss_sum / val_token_count + bits_per_token = val_loss.item() / math.log(2.0) + tokens_per_byte = val_token_count.item() / val_byte_count.item() + model.train() + return float(val_loss.item()), float(bits_per_token * tokens_per_byte) +def eval_val_sliding( + args: Hyperparameters, + base_model: nn.Module, + rank: int, + world_size: int, + device: torch.device, + val_tokens: Tensor, + base_bytes_lut: Tensor, + has_leading_space_lut: Tensor, + is_boundary_token_lut: Tensor, + stride: int, + batch_seqs: int = 32, +) -> tuple[float, float]: + seq_len = args.train_seq_len + total_tokens = val_tokens.numel() - 1 + window_starts = [ws for ws in range(0, total_tokens, stride) + if min(ws + seq_len, total_tokens) - ws >= 1] + total_windows = len(window_starts) + my_s = (total_windows * rank) // world_size + my_e = (total_windows * (rank + 1)) // world_size + my_windows = window_starts[my_s:my_e] + loss_sum = torch.zeros((), device=device, dtype=torch.float64) + token_count = torch.zeros((), device=device, dtype=torch.float64) + byte_count = torch.zeros((), device=device, dtype=torch.float64) + base_model.eval() + compiled_logits = torch.compile(base_model.forward_logits, dynamic=False, fullgraph=True) + with torch.inference_mode(): + for bi in range(0, len(my_windows), batch_seqs): + batch_ws = my_windows[bi:bi + batch_seqs] + bsz = len(batch_ws) + x_batch = torch.zeros(bsz, seq_len, dtype=torch.int64, device=device) + y_batch = torch.zeros(bsz, seq_len, dtype=torch.int64, device=device) + wlens: list[int] = [] + for i, ws in enumerate(batch_ws): + end = min(ws + seq_len, total_tokens) + wlen = end - ws + wlens.append(wlen) + chunk = val_tokens[ws:end + 1].to(dtype=torch.int64, device=device) + x_batch[i, :wlen] = chunk[:-1] + y_batch[i, :wlen] = chunk[1:] + with torch.autocast(device_type="cuda", dtype=torch.bfloat16): + logits = compiled_logits(x_batch) + nll = F.cross_entropy( + logits.reshape(-1, logits.size(-1)).float(), + y_batch.reshape(-1), + reduction="none", + ).reshape(bsz, seq_len) + for i, ws in enumerate(batch_ws): + wlen = wlens[i] + s = 0 if ws == 0 else max(wlen - stride, 0) + scored_nll = nll[i, s:wlen].to(torch.float64) + loss_sum += scored_nll.sum() + token_count += float(wlen - s) + tgt = y_batch[i, s:wlen] + prev = x_batch[i, s:wlen] + tb = base_bytes_lut[tgt].to(torch.float64) + tb += (has_leading_space_lut[tgt] & ~is_boundary_token_lut[prev]).to(torch.float64) + byte_count += tb.sum() + if dist.is_available() and dist.is_initialized(): + dist.all_reduce(loss_sum, op=dist.ReduceOp.SUM) + dist.all_reduce(token_count, op=dist.ReduceOp.SUM) + dist.all_reduce(byte_count, op=dist.ReduceOp.SUM) + val_loss = (loss_sum / token_count).item() + bits_per_token = val_loss / math.log(2.0) + tokens_per_byte = token_count.item() / byte_count.item() + base_model.train() + return val_loss, bits_per_token * tokens_per_byte +class LoRAAdapter(nn.Module): + """Low-Rank Adaptation adapter for a linear layer.""" + def __init__(self, base_linear: nn.Linear, rank: int = 8): + super().__init__() + dim_in = base_linear.in_features + dim_out = base_linear.out_features + self.A = nn.Parameter(torch.empty(dim_in, rank)) + self.B = nn.Parameter(torch.zeros(rank, dim_out)) + nn.init.kaiming_uniform_(self.A, a=math.sqrt(5)) + self.base_linear = base_linear + self._hook_handle = None + + def _lora_hook(self, module: nn.Module, input: tuple[Tensor, ...], output: Tensor) -> Tensor: + """Forward hook that adds LoRA contribution to base linear output.""" + x = input[0] + lora_delta = (x @ self.A.to(x.dtype)) @ self.B.to(x.dtype) + return output + lora_delta + + def register(self) -> None: + """Register the LoRA hook on the base linear layer.""" + if self._hook_handle is None: + self._hook_handle = self.base_linear.register_forward_hook(self._lora_hook) + + def remove(self) -> None: + """Remove the LoRA hook from the base linear layer.""" + if self._hook_handle is not None: + self._hook_handle.remove() + self._hook_handle = None + +def eval_val_sliding_lora_ttt( + args: Hyperparameters, + base_model: nn.Module, + rank: int, + world_size: int, + device: torch.device, + val_tokens: Tensor, + base_bytes_lut: Tensor, + has_leading_space_lut: Tensor, + is_boundary_token_lut: Tensor, + stride: int, + batch_seqs: int = 32, + log0=print, +) -> tuple[float, float]: + """Test-time training with LoRA adapters (only train adapter weights, not base model).""" + seq_len = args.train_seq_len + total_tokens = val_tokens.numel() - 1 + ttt_chunk = args.ttt_chunk_tokens + + window_starts = [ws for ws in range(0, total_tokens, stride) + if min(ws + seq_len, total_tokens) - ws >= stride or ws == 0] + num_chunks = (total_tokens + ttt_chunk - 1) // ttt_chunk + + chunk_windows: list[list[int]] = [[] for _ in range(num_chunks)] + for ws in window_starts: + end = min(ws + seq_len, total_tokens) + wlen = end - ws + s = 0 if ws == 0 else max(wlen - stride, 0) + scored_start = ws + s + ci = min(scored_start // ttt_chunk, num_chunks - 1) + chunk_windows[ci].append(ws) + + log0(f"lora_ttt_sliding:start chunks={num_chunks} chunk_tokens={ttt_chunk} " + f"total_windows={len(window_starts)} stride={stride} " + f"lora_rank={args.ttt_lora_rank} lora_lr={args.ttt_lora_lr} " + f"ttt_epochs={args.ttt_epochs}") + + loss_sum = torch.zeros((), device=device, dtype=torch.float64) + token_count = torch.zeros((), device=device, dtype=torch.float64) + byte_count = torch.zeros((), device=device, dtype=torch.float64) + + # Create LoRA adapters for attention projections and LM head + lora_adapters: list[LoRAAdapter] = [] + + # Add LoRA to attention blocks (c_q, c_v, proj) + if args.recurrence_repeats > 0: + # Recurrent model structure + all_blocks = list(base_model.head_blocks) + [base_model.shared_block] + list(base_model.tail_blocks) + else: + # Standard block structure + all_blocks = list(base_model.blocks) + + for block in all_blocks: + if hasattr(block, 'attn'): + attn = block.attn + # Add LoRA to Q, V, and output projection + for layer_name in ['c_q', 'c_v', 'proj']: + if hasattr(attn, layer_name): + base_linear = getattr(attn, layer_name) + adapter = LoRAAdapter(base_linear, rank=args.ttt_lora_rank).to(device) + lora_adapters.append(adapter) + + # Add LoRA to LM head (or tok_emb if tied) + if base_model.tie_embeddings: + # For tied embeddings, we need to patch tok_emb.weight usage in forward + # Create a pseudo-linear wrapper for the embedding weight + class EmbeddingAsLinear(nn.Module): + def __init__(self, weight: nn.Parameter): + super().__init__() + self.weight = weight + self.in_features = weight.size(1) + self.out_features = weight.size(0) + self.bias = None + + def forward(self, x: Tensor) -> Tensor: + return F.linear(x, self.weight) + + emb_linear = EmbeddingAsLinear(base_model.tok_emb.weight).to(device) + lm_adapter = LoRAAdapter(emb_linear, rank=args.ttt_lora_rank).to(device) + lora_adapters.append(lm_adapter) + elif base_model.lm_head is not None: + lm_adapter = LoRAAdapter(base_model.lm_head, rank=args.ttt_lora_rank).to(device) + lora_adapters.append(lm_adapter) + + # Collect all LoRA parameters + lora_params = [] + for adapter in lora_adapters: + lora_params.extend([adapter.A, adapter.B]) + + log0(f"lora_ttt_sliding:params lora={sum(p.numel() for p in lora_params)} " + f"base_frozen={sum(p.numel() for p in base_model.parameters())} " + f"adapters={len(lora_adapters)}") + + # Freeze all base model parameters + for p in base_model.parameters(): + p.requires_grad_(False) + + # Register all LoRA hooks + for adapter in lora_adapters: + adapter.register() + + # Use AdamW optimizer for LoRA params + optimizer = torch.optim.AdamW( + lora_params, + lr=args.ttt_lora_lr, + betas=(0.9, 0.999), + weight_decay=0.01, + ) + + # Compile forward paths + if not int(os.environ.get("TORCHDYNAMO_DISABLE", "0")): + compiled_logits_lora = torch.compile(base_model.forward_logits, dynamic=False, fullgraph=True) + compiled_forward_lora = torch.compile(base_model.forward, dynamic=False, fullgraph=True) + else: + compiled_logits_lora = base_model.forward_logits + compiled_forward_lora = base_model.forward + + t0 = time.perf_counter() + + for ci in range(num_chunks): + windows = chunk_windows[ci] + if not windows: + continue + + chunk_start = ci * ttt_chunk + chunk_end = min((ci + 1) * ttt_chunk, total_tokens) + + my_s = (len(windows) * rank) // world_size + my_e = (len(windows) * (rank + 1)) // world_size + my_windows = windows[my_s:my_e] + + # Score phase (inference only) + base_model.eval() + with torch.inference_mode(): + for bi in range(0, len(my_windows), batch_seqs): + batch_ws = my_windows[bi:bi + batch_seqs] + bsz = len(batch_ws) + x_batch = torch.zeros(bsz, seq_len, dtype=torch.int64, device=device) + y_batch = torch.zeros(bsz, seq_len, dtype=torch.int64, device=device) + wlens: list[int] = [] + + for i, ws in enumerate(batch_ws): + end = min(ws + seq_len, total_tokens) + wlen = end - ws + wlens.append(wlen) + chunk_tok = val_tokens[ws:end + 1].to(dtype=torch.int64, device=device) + x_batch[i, :wlen] = chunk_tok[:-1] + y_batch[i, :wlen] = chunk_tok[1:] + + with torch.autocast(device_type="cuda", dtype=torch.bfloat16): + logits = compiled_logits_lora(x_batch) + + nll = F.cross_entropy( + logits.reshape(-1, logits.size(-1)).float(), + y_batch.reshape(-1), reduction="none", + ).reshape(bsz, seq_len) + + for i, ws in enumerate(batch_ws): + wlen = wlens[i] + s = 0 if ws == 0 else max(wlen - stride, 0) + scored_nll = nll[i, s:wlen].to(torch.float64) + loss_sum += scored_nll.sum() + token_count += float(wlen - s) + + tgt, prev = y_batch[i, s:wlen], x_batch[i, s:wlen] + tb = base_bytes_lut[tgt].to(torch.float64) + tb += (has_leading_space_lut[tgt] & ~is_boundary_token_lut[prev]).to(torch.float64) + byte_count += tb.sum() + + # Train phase (update LoRA params only, skip last chunk) + is_last_chunk = (ci == num_chunks - 1) + if not is_last_chunk and args.ttt_epochs > 0: + base_model.train() + chunk_seqs = (chunk_end - chunk_start) // seq_len + + if chunk_seqs > 0: + # Cosine LR decay + cos_lr = args.ttt_lora_lr * 0.5 * (1.0 + math.cos(math.pi * ci / max(num_chunks - 1, 1))) + for pg in optimizer.param_groups: + pg['lr'] = cos_lr + + my_seq_s = (chunk_seqs * rank) // world_size + my_seq_e = (chunk_seqs * (rank + 1)) // world_size + my_chunk_seqs = my_seq_e - my_seq_s + + for _ep in range(args.ttt_epochs): + for bs in range(0, my_chunk_seqs, args.ttt_batch_seqs): + be = min(bs + args.ttt_batch_seqs, my_chunk_seqs) + actual_bs = my_seq_s + bs + start_tok = chunk_start + actual_bs * seq_len + end_tok = chunk_start + (my_seq_s + be) * seq_len + 1 + + if end_tok > val_tokens.numel(): + continue + + local = val_tokens[start_tok:end_tok].to(device=device, dtype=torch.int64) + x = local[:-1].reshape(-1, seq_len) + y = local[1:].reshape(-1, seq_len) + + optimizer.zero_grad(set_to_none=True) + + with torch.autocast(device_type="cuda", dtype=torch.bfloat16): + loss = compiled_forward_lora(x, y) + + loss.backward() + + # All-reduce LoRA gradients across GPUs + if world_size > 1: + for p in lora_params: + if p.grad is not None: + dist.all_reduce(p.grad, op=dist.ReduceOp.AVG) + + # Gradient clipping on LoRA params + torch.nn.utils.clip_grad_norm_(lora_params, args.ttt_grad_clip) + optimizer.step() + + if rank == 0 and (ci % 10 == 0 or ci == num_chunks - 1): + elapsed = time.perf_counter() - t0 + rl = loss_sum.item() / max(token_count.item(), 1) + rbpb = rl / math.log(2.0) * (token_count.item() / max(byte_count.item(), 1)) if token_count.item() > 0 else 0.0 + log0(f" lora_ttt_chunk [{ci+1}/{num_chunks}] bpb={rbpb:.6f} time={elapsed:.1f}s") + + if dist.is_available() and dist.is_initialized(): + dist.all_reduce(loss_sum, op=dist.ReduceOp.SUM) + dist.all_reduce(token_count, op=dist.ReduceOp.SUM) + dist.all_reduce(byte_count, op=dist.ReduceOp.SUM) + + val_loss = (loss_sum / token_count).item() + val_bpb = val_loss / math.log(2.0) * (token_count.item() / byte_count.item()) + + # Clean up: remove all LoRA hooks and restore base model + for adapter in lora_adapters: + adapter.remove() + + # Restore requires_grad on base model + for p in base_model.parameters(): + p.requires_grad_(True) + + base_model.eval() + + log0(f"lora_ttt_sliding:done val_loss={val_loss:.6f} val_bpb={val_bpb:.6f} " + f"elapsed={time.perf_counter() - t0:.1f}s") + + return val_loss, val_bpb + +def eval_val_sliding_ttt( + args: Hyperparameters, + base_model: nn.Module, + rank: int, + world_size: int, + device: torch.device, + val_tokens: Tensor, + base_bytes_lut: Tensor, + has_leading_space_lut: Tensor, + is_boundary_token_lut: Tensor, + stride: int, + batch_seqs: int = 32, + log0=print, +) -> tuple[float, float]: + seq_len = args.train_seq_len + total_tokens = val_tokens.numel() - 1 + ttt_chunk = args.ttt_chunk_tokens + window_starts = [ws for ws in range(0, total_tokens, stride) + if min(ws + seq_len, total_tokens) - ws >= stride or ws == 0] + num_chunks = (total_tokens + ttt_chunk - 1) // ttt_chunk + chunk_windows: list[list[int]] = [[] for _ in range(num_chunks)] + for ws in window_starts: + end = min(ws + seq_len, total_tokens) + wlen = end - ws + s = 0 if ws == 0 else max(wlen - stride, 0) + scored_start = ws + s + ci = min(scored_start // ttt_chunk, num_chunks - 1) + chunk_windows[ci].append(ws) + log0(f"ttt_sliding:start chunks={num_chunks} chunk_tokens={ttt_chunk} " + f"total_windows={len(window_starts)} stride={stride} " + f"ttt_lr={args.ttt_lr} ttt_epochs={args.ttt_epochs} " + f"freeze_blocks={args.ttt_freeze_blocks}") + loss_sum = torch.zeros((), device=device, dtype=torch.float64) + token_count = torch.zeros((), device=device, dtype=torch.float64) + byte_count = torch.zeros((), device=device, dtype=torch.float64) + frozen_block_ids = set(range(min(args.ttt_freeze_blocks, len(base_model.blocks)))) + ttt_params = [] + for name, p in base_model.named_parameters(): + freeze = False + for bi in frozen_block_ids: + if f"blocks.{bi}." in name: + freeze = True + break + if freeze: + p.requires_grad_(False) + else: + p.requires_grad_(True) + ttt_params.append(p) + log0(f"ttt_sliding:params unfrozen={sum(p.numel() for p in ttt_params)} " + f"frozen={sum(p.numel() for p in base_model.parameters() if not p.requires_grad)}") + optimizer = torch.optim.SGD(ttt_params, lr=args.ttt_lr, momentum=args.ttt_momentum) + # Compile forward paths for TTT (matches eval_val_sliding behavior) + if not int(os.environ.get("TORCHDYNAMO_DISABLE", "0")): + compiled_logits_ttt = torch.compile(base_model.forward_logits, dynamic=False, fullgraph=True) + compiled_forward_ttt = torch.compile(base_model.forward, dynamic=False, fullgraph=True) + else: + compiled_logits_ttt = base_model.forward_logits + compiled_forward_ttt = base_model.forward + t0 = time.perf_counter() + for ci in range(num_chunks): + windows = chunk_windows[ci] + if not windows: + continue + chunk_start = ci * ttt_chunk + chunk_end = min((ci + 1) * ttt_chunk, total_tokens) + my_s = (len(windows) * rank) // world_size + my_e = (len(windows) * (rank + 1)) // world_size + my_windows = windows[my_s:my_e] + base_model.eval() + with torch.inference_mode(): + for bi in range(0, len(my_windows), batch_seqs): + batch_ws = my_windows[bi:bi + batch_seqs] + bsz = len(batch_ws) + x_batch = torch.zeros(bsz, seq_len, dtype=torch.int64, device=device) + y_batch = torch.zeros(bsz, seq_len, dtype=torch.int64, device=device) + wlens: list[int] = [] + for i, ws in enumerate(batch_ws): + end = min(ws + seq_len, total_tokens) + wlen = end - ws + wlens.append(wlen) + chunk_tok = val_tokens[ws:end + 1].to(dtype=torch.int64, device=device) + x_batch[i, :wlen] = chunk_tok[:-1] + y_batch[i, :wlen] = chunk_tok[1:] + with torch.autocast(device_type="cuda", dtype=torch.bfloat16): + logits = compiled_logits_ttt(x_batch) + nll = F.cross_entropy( + logits.reshape(-1, logits.size(-1)).float(), + y_batch.reshape(-1), reduction="none", + ).reshape(bsz, seq_len) + for i, ws in enumerate(batch_ws): + wlen = wlens[i] + s = 0 if ws == 0 else max(wlen - stride, 0) + scored_nll = nll[i, s:wlen].to(torch.float64) + loss_sum += scored_nll.sum() + token_count += float(wlen - s) + tgt, prev = y_batch[i, s:wlen], x_batch[i, s:wlen] + tb = base_bytes_lut[tgt].to(torch.float64) + tb += (has_leading_space_lut[tgt] & ~is_boundary_token_lut[prev]).to(torch.float64) + byte_count += tb.sum() + is_last_chunk = (ci == num_chunks - 1) + if not is_last_chunk and args.ttt_epochs > 0: + base_model.train() + chunk_seqs = (chunk_end - chunk_start) // seq_len + if chunk_seqs > 0: + cos_lr = args.ttt_lr * 0.5 * (1.0 + math.cos(math.pi * ci / max(num_chunks - 1, 1))) + for pg in optimizer.param_groups: + pg['lr'] = cos_lr + my_seq_s = (chunk_seqs * rank) // world_size + my_seq_e = (chunk_seqs * (rank + 1)) // world_size + my_chunk_seqs = my_seq_e - my_seq_s + for _ep in range(args.ttt_epochs): + for bs in range(0, my_chunk_seqs, args.ttt_batch_seqs): + be = min(bs + args.ttt_batch_seqs, my_chunk_seqs) + actual_bs = my_seq_s + bs + start_tok = chunk_start + actual_bs * seq_len + end_tok = chunk_start + (my_seq_s + be) * seq_len + 1 + if end_tok > val_tokens.numel(): + continue + local = val_tokens[start_tok:end_tok].to(device=device, dtype=torch.int64) + x = local[:-1].reshape(-1, seq_len) + y = local[1:].reshape(-1, seq_len) + optimizer.zero_grad(set_to_none=True) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16): + loss = compiled_forward_ttt(x, y) + loss.backward() + if world_size > 1: + for p in ttt_params: + if p.grad is not None: + dist.all_reduce(p.grad, op=dist.ReduceOp.AVG) + torch.nn.utils.clip_grad_norm_(ttt_params, args.ttt_grad_clip) + optimizer.step() + if rank == 0 and (ci % 10 == 0 or ci == num_chunks - 1): + elapsed = time.perf_counter() - t0 + rl = loss_sum.item() / max(token_count.item(), 1) + rbpb = rl / math.log(2.0) * (token_count.item() / max(byte_count.item(), 1)) if token_count.item() > 0 else 0.0 + log0(f" ttt_chunk [{ci+1}/{num_chunks}] bpb={rbpb:.6f} time={elapsed:.1f}s") + if dist.is_available() and dist.is_initialized(): + dist.all_reduce(loss_sum, op=dist.ReduceOp.SUM) + dist.all_reduce(token_count, op=dist.ReduceOp.SUM) + dist.all_reduce(byte_count, op=dist.ReduceOp.SUM) + val_loss = (loss_sum / token_count).item() + val_bpb = val_loss / math.log(2.0) * (token_count.item() / byte_count.item()) + for p in base_model.parameters(): + p.requires_grad_(True) + base_model.eval() + log0(f"ttt_sliding:done val_loss={val_loss:.6f} val_bpb={val_bpb:.6f} " + f"elapsed={time.perf_counter() - t0:.1f}s") + return val_loss, val_bpb +_INT6_MODE = bool(int(os.environ.get("INT6_QAT", "1"))) +_INT3_MODE = bool(int(os.environ.get("INT3_QUANT", "0"))) # Ternary quantization +_QUANT_OVERRIDE = int(os.environ.get("QUANT_BITS", "0")) # Override: 3,4,5,6,8 +if _QUANT_OVERRIDE > 0: + _QUANT_BITS = _QUANT_OVERRIDE + _QUANT_MAX_VAL = (1 << (_QUANT_BITS - 1)) - 1 # e.g. int5 -> 15, int4 -> 7 +elif _INT3_MODE: + _QUANT_MAX_VAL = 3 # Ternary: -3, -2, -1, 0, 1, 2, 3 + _QUANT_BITS = 3 +elif _INT6_MODE: + _QUANT_MAX_VAL = 31 + _QUANT_BITS = 6 +else: + _QUANT_MAX_VAL = 127 + _QUANT_BITS = 8 +CONTROL_TENSOR_NAME_PATTERNS = tuple( + pattern + for pattern in os.environ.get( + "CONTROL_TENSOR_NAME_PATTERNS", + "attn_scale,attn_scales,mlp_scale,mlp_scales,resid_mix,resid_mixes,q_gain,skip_weight,skip_weights,smear", + ).split(",") + if pattern +) +INT8_KEEP_FLOAT_FP32_NAME_PATTERNS = tuple( + pattern + for pattern in os.environ.get( + "INT8_KEEP_FLOAT_FP32_NAME_PATTERNS", + ",".join(CONTROL_TENSOR_NAME_PATTERNS), + ).split(",") + if pattern +) +INT8_KEEP_FLOAT_MAX_NUMEL = 65_536 +INT8_KEEP_FLOAT_STORE_DTYPE = torch.float16 +INT8_PER_ROW_SCALE_DTYPE = torch.float16 +INT8_CLIP_PERCENTILE = 99.99984 +INT8_CLIP_Q = INT8_CLIP_PERCENTILE / 100.0 +def tensor_nbytes(t: Tensor) -> int: + return int(t.numel()) * int(t.element_size()) +def keep_float_tensor(name: str, t: Tensor, passthrough_orig_dtypes: dict[str, str]) -> Tensor: + if any(pattern in name for pattern in INT8_KEEP_FLOAT_FP32_NAME_PATTERNS): + return t.float().contiguous() + if t.dtype in {torch.float32, torch.bfloat16}: + passthrough_orig_dtypes[name] = str(t.dtype).removeprefix("torch.") + return t.to(dtype=INT8_KEEP_FLOAT_STORE_DTYPE).contiguous() + return t +_GPTQ_LITE_PERCENTILES = [0.9999, 0.99995, 0.99999, 0.999995, 1.0] +def quantize_float_tensor(t: Tensor, gpu_device: torch.device | None = None) -> tuple[Tensor, Tensor]: + # Run GPTQ-lite search on GPU when available (much faster for large 2D weight matrices) + dev = gpu_device if (gpu_device is not None and t.ndim == 2 and t.numel() > 0) else t.device + t32 = t.float().to(dev) + qmax = _QUANT_MAX_VAL + if t32.ndim == 2 and t32.numel() > 0: + abs_vals = t32.abs() + best_q = None + best_scale = None + best_mse = None + for pct in _GPTQ_LITE_PERCENTILES: + if pct >= 1.0: + clip_abs = abs_vals.amax(dim=1) + else: + clip_abs = torch.quantile(abs_vals, pct, dim=1) + clip_abs = clip_abs.clamp_min(1e-8) + scale = (clip_abs / qmax).clamp_min(1.0 / qmax) + clipped = torch.maximum(torch.minimum(t32, clip_abs[:, None]), -clip_abs[:, None]) + q = torch.clamp(torch.round(clipped / scale[:, None]), -qmax, qmax) + recon = q * scale[:, None] + mse = (t32 - recon).square().mean(dim=1) + if best_mse is None: + best_mse = mse + best_q = q + best_scale = scale + else: + improved = mse < best_mse + if improved.any(): + best_mse = torch.where(improved, mse, best_mse) + best_q = torch.where(improved[:, None], q, best_q) + best_scale = torch.where(improved, scale, best_scale) + return best_q.cpu().to(torch.int8).contiguous(), best_scale.cpu().to(dtype=INT8_PER_ROW_SCALE_DTYPE).contiguous() + if t32.ndim == 2: + return torch.empty_like(t32, dtype=torch.int8, device="cpu"), torch.empty((t32.shape[0],), dtype=INT8_PER_ROW_SCALE_DTYPE) + t32 = t32.cpu() + clip_abs = float(torch.quantile(t32.abs().flatten(), INT8_CLIP_Q).item()) if t32.numel() else 0.0 + scale = torch.tensor(clip_abs / qmax if clip_abs > 0 else 1.0, dtype=torch.float32) + q = torch.clamp(torch.round(torch.clamp(t32, -clip_abs, clip_abs) / scale), -qmax, qmax).to(torch.int8).contiguous() + return q, scale +def quantize_state_dict_int8(state_dict: dict[str, Tensor], gpu_device: torch.device | None = None, + rank: int = 0, world_size: int = 1): + """Quantize state dict. When world_size > 1, distributes large tensor quantization across ranks.""" + quantized: dict[str, Tensor] = {} + scales: dict[str, Tensor] = {} + dtypes: dict[str, str] = {} + passthrough: dict[str, Tensor] = {} + passthrough_orig_dtypes: dict[str, str] = {} + qmeta: dict[str, dict[str, object]] = {} + stats = dict.fromkeys( + ("param_count", "num_tensors", "num_float_tensors", "num_nonfloat_tensors", "baseline_tensor_bytes", "int8_payload_bytes"), + 0, + ) + # Separate tensors into passthrough vs needs-quantization + to_quantize: list[tuple[str, Tensor]] = [] + for name, tensor in state_dict.items(): + t = tensor.detach().to("cpu").contiguous() + stats["param_count"] += int(t.numel()) + stats["num_tensors"] += 1 + stats["baseline_tensor_bytes"] += tensor_nbytes(t) + if not t.is_floating_point(): + stats["num_nonfloat_tensors"] += 1 + passthrough[name] = t + stats["int8_payload_bytes"] += tensor_nbytes(t) + continue + if t.numel() <= INT8_KEEP_FLOAT_MAX_NUMEL: + kept = keep_float_tensor(name, t, passthrough_orig_dtypes) + passthrough[name] = kept + stats["int8_payload_bytes"] += tensor_nbytes(kept) + continue + stats["num_float_tensors"] += 1 + to_quantize.append((name, t)) + # Distribute quantization work across ranks (round-robin by tensor index) + for idx, (name, t) in enumerate(to_quantize): + if world_size > 1 and idx % world_size != rank: + continue # Another rank handles this tensor + q, s = quantize_float_tensor(t, gpu_device=gpu_device) + if s.ndim > 0: + qmeta[name] = {"scheme": "per_row", "axis": 0} + quantized[name] = q + scales[name] = s + dtypes[name] = str(t.dtype).removeprefix("torch.") + stats["int8_payload_bytes"] += tensor_nbytes(q) + tensor_nbytes(s) + # Gather results from all ranks + if world_size > 1 and dist.is_available() and dist.is_initialized(): + # Each rank broadcasts its quantized tensors to all others + for idx, (name, t) in enumerate(to_quantize): + owner = idx % world_size + if owner == rank: + # Broadcast shape info then tensor data + q_tensor = quantized[name] + s_tensor = scales[name] + # Send via broadcast + q_gpu = q_tensor.to(gpu_device) if gpu_device is not None else q_tensor.cuda() + s_gpu = s_tensor.to(gpu_device) if gpu_device is not None else s_tensor.cuda() + dist.broadcast(q_gpu, src=owner) + dist.broadcast(s_gpu, src=owner) + quantized[name] = q_gpu.cpu() + scales[name] = s_gpu.cpu() + else: + # Allocate matching tensors and receive + q_shape = (t.shape[0],) + t.shape[1:] if t.ndim == 2 else t.shape + q_gpu = torch.empty(q_shape, dtype=torch.int8, device=gpu_device if gpu_device is not None else "cuda") + if t.ndim == 2 and t.numel() > 0: + s_gpu = torch.empty(t.shape[0], dtype=INT8_PER_ROW_SCALE_DTYPE, device=gpu_device if gpu_device is not None else "cuda") + else: + s_gpu = torch.empty((), dtype=torch.float32, device=gpu_device if gpu_device is not None else "cuda") + dist.broadcast(q_gpu, src=owner) + dist.broadcast(s_gpu, src=owner) + quantized[name] = q_gpu.cpu() + scales[name] = s_gpu.cpu() + dtypes[name] = str(t.dtype).removeprefix("torch.") + if s_gpu.ndim > 0: + qmeta[name] = {"scheme": "per_row", "axis": 0} + stats["int8_payload_bytes"] += tensor_nbytes(quantized[name]) + tensor_nbytes(scales[name]) + obj: dict[str, object] = { + "__quant_format__": f"int{_QUANT_BITS}_clean_per_row_v1", + "quantized": quantized, + "scales": scales, + "dtypes": dtypes, + "passthrough": passthrough, + } + if qmeta: + obj["qmeta"] = qmeta + if passthrough_orig_dtypes: + obj["passthrough_orig_dtypes"] = passthrough_orig_dtypes + return obj, stats +def dequantize_state_dict_int8(obj: dict[str, object]) -> dict[str, Tensor]: + out: dict[str, Tensor] = {} + qmeta = obj.get("qmeta", {}) + passthrough_orig_dtypes = obj.get("passthrough_orig_dtypes", {}) + for name, q in obj["quantized"].items(): + dtype = getattr(torch, obj["dtypes"][name]) + s = obj["scales"][name] + if qmeta.get(name, {}).get("scheme") == "per_row" or s.ndim > 0: + s = s.to(dtype=torch.float32) + out[name] = (q.float() * s.view(q.shape[0], *([1] * (q.ndim - 1)))).to(dtype=dtype).contiguous() + else: + scale = float(s.item()) + out[name] = (q.float() * scale).to(dtype=dtype).contiguous() + for name, t in obj["passthrough"].items(): + out_t = t.detach().to("cpu").contiguous() + orig_dtype = passthrough_orig_dtypes.get(name) + if isinstance(orig_dtype, str): + out_t = out_t.to(dtype=getattr(torch, orig_dtype)).contiguous() + out[name] = out_t + return out +def load_data_shard(file: Path) -> Tensor: + header_bytes = 256 * np.dtype(" None: + self.file_idx = (self.file_idx + 1) % len(self.files) + self.tokens = load_data_shard(self.files[self.file_idx]) + self.pos = 0 + def take(self, n: int) -> Tensor: + chunks: list[Tensor] = [] + remaining = n + while remaining > 0: + avail = self.tokens.numel() - self.pos + if avail <= 0: + self._advance_file() + continue + k = min(remaining, avail) + chunks.append(self.tokens[self.pos : self.pos + k]) + self.pos += k + remaining -= k + return chunks[0] if len(chunks) == 1 else torch.cat(chunks) +class DistributedTokenLoader: + def __init__(self, pattern: str, rank: int, world_size: int, device: torch.device): + self.rank = rank + self.world_size = world_size + self.device = device + self.stream = TokenStream(pattern) + def next_batch(self, global_tokens: int, seq_len: int, grad_accum_steps: int) -> tuple[Tensor, Tensor]: + local_tokens = global_tokens // (self.world_size * grad_accum_steps) + per_rank_span = local_tokens + 1 + chunk = self.stream.take(per_rank_span * self.world_size) + start = self.rank * per_rank_span + local = chunk[start : start + per_rank_span].to(dtype=torch.int64) + x = local[:-1].reshape(-1, seq_len) + y = local[1:].reshape(-1, seq_len) + return x.to(self.device, non_blocking=True), y.to(self.device, non_blocking=True) +class RMSNorm(nn.Module): + def __init__(self, eps: float | None = None): + super().__init__() + self.eps = eps + def forward(self, x: Tensor) -> Tensor: + return F.rms_norm(x, (x.size(-1),), eps=self.eps) +class CastedLinear(nn.Linear): + _int6_qat: bool = False + def forward(self, x: Tensor) -> Tensor: + w = self.weight.to(x.dtype) + if CastedLinear._int6_qat and self.training: + w = int6_ste(w) + bias = self.bias.to(x.dtype) if self.bias is not None else None + return F.linear(x, w, bias) +def restore_low_dim_params_to_fp32(module: nn.Module) -> None: + with torch.no_grad(): + for name, param in module.named_parameters(): + if (param.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS)) and param.dtype != torch.float32: + param.data = param.data.float() +class BigramHash(nn.Module): + def __init__(self, num_buckets: int, dim: int): + super().__init__() + self.num_buckets = num_buckets + self.emb = nn.Embedding(num_buckets, dim) + nn.init.zeros_(self.emb.weight) + def forward(self, input_ids: Tensor) -> Tensor: + prev = F.pad(input_ids[:, :-1], (1, 0), value=0) + bigram_hash = (prev * 31 + input_ids) % self.num_buckets + return self.emb(bigram_hash) +class SmearGate(nn.Module): + def __init__(self, dim: int): + super().__init__() + self.gate = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + def forward(self, x: Tensor) -> Tensor: + g = torch.sigmoid(self.gate.to(dtype=x.dtype))[None, None, :] + positions = torch.arange(1, x.shape[1] + 1, device=x.device, dtype=x.dtype).view(1, -1, 1) + smooth = torch.cumsum(x, dim=1) / positions + return g * x + (1 - g) * smooth +class Rotary(nn.Module): + def __init__(self, dim: int, base: float = 10000.0): + super().__init__() + inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.float32) / dim)) + self.register_buffer("inv_freq", inv_freq, persistent=False) + self._seq_len_cached = 0 + self._cos_cached: Tensor | None = None + self._sin_cached: Tensor | None = None + def forward(self, seq_len: int, device: torch.device, dtype: torch.dtype) -> tuple[Tensor, Tensor]: + if ( + self._cos_cached is None + or self._sin_cached is None + or self._seq_len_cached != seq_len + or self._cos_cached.device != device + ): + t = torch.arange(seq_len, device=device, dtype=self.inv_freq.dtype) + freqs = torch.outer(t, self.inv_freq.to(device)) + self._cos_cached = freqs.cos()[None, None, :, :] + self._sin_cached = freqs.sin()[None, None, :, :] + self._seq_len_cached = seq_len + return self._cos_cached.to(dtype=dtype), self._sin_cached.to(dtype=dtype) +def apply_rotary_emb(x: Tensor, cos: Tensor, sin: Tensor) -> Tensor: + half = x.size(-1) // 2 + x1, x2 = x[..., :half], x[..., half:] + return torch.cat((x1 * cos + x2 * sin, x1 * (-sin) + x2 * cos), dim=-1) +class CausalSelfAttention(nn.Module): + def __init__( + self, + dim: int, + num_heads: int, + num_kv_heads: int, + rope_base: float, + qk_gain_init: float, + rope_partial_dims: int = 0, + use_xsa: bool = False, + ): + super().__init__() + if dim % num_heads != 0: + raise ValueError("model_dim must be divisible by num_heads") + if num_heads % num_kv_heads != 0: + raise ValueError("num_heads must be divisible by num_kv_heads") + self.num_heads = num_heads + self.num_kv_heads = num_kv_heads + self.head_dim = dim // num_heads + if self.head_dim % 2 != 0: + raise ValueError("head_dim must be even for RoPE") + kv_dim = self.num_kv_heads * self.head_dim + self.c_q = CastedLinear(dim, dim, bias=False) + self.c_k = CastedLinear(dim, kv_dim, bias=False) + self.c_v = CastedLinear(dim, kv_dim, bias=False) + self.proj = CastedLinear(dim, dim, bias=False) + self.proj._zero_init = True + self.q_gain = nn.Parameter(torch.full((num_heads,), qk_gain_init, dtype=torch.float32)) + rope_dim = rope_partial_dims if 0 < rope_partial_dims < self.head_dim else self.head_dim + self.rope_dim = rope_dim + self.rotary = Rotary(rope_dim, base=rope_base) + self.use_xsa = use_xsa + def forward(self, x: Tensor) -> Tensor: + bsz, seqlen, dim = x.shape + q = self.c_q(x).reshape(bsz, seqlen, self.num_heads, self.head_dim).transpose(1, 2) + k = self.c_k(x).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim).transpose(1, 2) + v = self.c_v(x).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim).transpose(1, 2) + q = F.rms_norm(q, (q.size(-1),)) + k = F.rms_norm(k, (k.size(-1),)) + cos, sin = self.rotary(seqlen, x.device, q.dtype) + if self.rope_dim < self.head_dim: + q_rope, q_pass = q[..., :self.rope_dim], q[..., self.rope_dim:] + k_rope, k_pass = k[..., :self.rope_dim], k[..., self.rope_dim:] + q_rope = apply_rotary_emb(q_rope, cos, sin) + k_rope = apply_rotary_emb(k_rope, cos, sin) + q = torch.cat([q_rope, q_pass], dim=-1) + k = torch.cat([k_rope, k_pass], dim=-1) + else: + q = apply_rotary_emb(q, cos, sin) + k = apply_rotary_emb(k, cos, sin) + q = q * self.q_gain.to(dtype=q.dtype)[None, :, None, None] + y = F.scaled_dot_product_attention( + q, + k, + v, + attn_mask=None, + is_causal=True, + enable_gqa=(self.num_kv_heads != self.num_heads), + ) + if self.use_xsa: + if self.num_kv_heads != self.num_heads: + repeats = self.num_heads // self.num_kv_heads + v_expanded = v.repeat_interleave(repeats, dim=1) + else: + v_expanded = v + y = y - v_expanded + y = y.transpose(1, 2).contiguous().reshape(bsz, seqlen, dim) + return self.proj(y) +class MLP(nn.Module): + def __init__(self, dim: int, mlp_mult: int): + super().__init__() + hidden = mlp_mult * dim + self.fc = CastedLinear(dim, hidden, bias=False) + self.proj = CastedLinear(hidden, dim, bias=False) + self.proj._zero_init = True + def forward(self, x: Tensor) -> Tensor: + x = F.leaky_relu(self.fc(x), negative_slope=0.5) + return self.proj(x.square()) +class MoEMLP(nn.Module): + def __init__(self, dim: int, mlp_mult: int, num_experts: int): + super().__init__() + self.num_experts = num_experts + expert_hidden = (mlp_mult * dim) // num_experts + self.experts = nn.ModuleList([ + nn.ModuleDict({ + "fc": CastedLinear(dim, expert_hidden, bias=False), + "proj": CastedLinear(expert_hidden, dim, bias=False), + }) + for _ in range(num_experts) + ]) + for e in self.experts: + e["proj"]._zero_init = True + self.router = nn.Linear(dim, num_experts, bias=False) + def forward(self, x: Tensor) -> Tensor: + bsz, seqlen, dim = x.shape + logits = self.router(x.detach()) + weights = torch.softmax(logits, dim=-1) + indices = weights.argmax(dim=-1) + out = torch.zeros_like(x) + for i, expert in enumerate(self.experts): + mask = (indices == i) + if not mask.any(): + continue + tokens = x[mask] + h = F.leaky_relu(expert["fc"](tokens), negative_slope=0.5) + h = expert["proj"](h.square()) + out[mask] = h + return out +class Block(nn.Module): + def __init__( + self, + dim: int, + num_heads: int, + num_kv_heads: int, + mlp_mult: int, + rope_base: float, + qk_gain_init: float, + use_smeargate: bool = False, + rope_partial_dims: int = 0, + use_xsa: bool = False, + ln_scale_factor: float = 1.0, + moe_num_experts: int = 0, + ): + super().__init__() + self.attn_norm = RMSNorm() + self.mlp_norm = RMSNorm() + self.attn = CausalSelfAttention( + dim, num_heads, num_kv_heads, rope_base, qk_gain_init, + rope_partial_dims=rope_partial_dims, + use_xsa=use_xsa, + ) + self.mlp = MoEMLP(dim, mlp_mult, moe_num_experts) if moe_num_experts > 1 else MLP(dim, mlp_mult) + self.attn_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + self.mlp_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + self.resid_mix = nn.Parameter(torch.stack((torch.ones(dim), torch.zeros(dim))).float()) + self.smear = SmearGate(dim) if use_smeargate else None + self.ln_scale_factor = ln_scale_factor + def forward(self, x: Tensor, x0: Tensor) -> Tensor: + mix = self.resid_mix.to(dtype=x.dtype) + x = mix[0][None, None, :] * x + mix[1][None, None, :] * x0 + normed = self.attn_norm(x) + if self.ln_scale_factor != 1.0: + normed = normed * self.ln_scale_factor + attn_out = self.attn(normed) + x = x + self.attn_scale.to(dtype=x.dtype)[None, None, :] * attn_out + mlp_normed = self.mlp_norm(x) + if self.ln_scale_factor != 1.0: + mlp_normed = mlp_normed * self.ln_scale_factor + x = x + self.mlp_scale.to(dtype=x.dtype)[None, None, :] * self.mlp(mlp_normed) + if self.smear is not None: + x = self.smear(x) + return x +class GPT(nn.Module): + def __init__( + self, + vocab_size: int, + num_layers: int, + model_dim: int, + num_heads: int, + num_kv_heads: int, + mlp_mult: int, + tie_embeddings: bool, + tied_embed_init_std: float, + logit_softcap: float, + rope_base: float, + qk_gain_init: float, + bigramhash_buckets: int = 0, + use_smeargate: bool = False, + unet_skips: bool = True, + rope_partial_dims: int = 0, + ln_scale: bool = False, + xsa_layers: int = 0, + recurrence_repeats: int = 0, + recurrence_unique_head: int = 2, + recurrence_unique_tail: int = 2, + moe_num_experts: int = 0, + ): + super().__init__() + if logit_softcap <= 0.0: + raise ValueError(f"logit_softcap must be positive, got {logit_softcap}") + self.tie_embeddings = tie_embeddings + self.tied_embed_init_std = tied_embed_init_std + self.logit_softcap = logit_softcap + self.tok_emb = nn.Embedding(vocab_size, model_dim) + self.bigram_hash = BigramHash(bigramhash_buckets, model_dim) if bigramhash_buckets > 0 else None + self.recurrence_repeats = recurrence_repeats + if recurrence_repeats > 0: + self.num_encoder_layers = 0 + self.num_decoder_layers = 0 + self.num_skip_weights = 0 + self.skip_weights = nn.Parameter(torch.zeros(0, dtype=torch.float32)) + n_head = recurrence_unique_head + n_tail = recurrence_unique_tail + effective_layers = n_head + recurrence_repeats + n_tail + def make_block(layer_idx: int, eff_total: int) -> Block: + xsa_start_idx = eff_total - xsa_layers if xsa_layers > 0 else eff_total + return Block( + model_dim, num_heads, num_kv_heads, mlp_mult, rope_base, qk_gain_init, + use_smeargate=use_smeargate, + rope_partial_dims=rope_partial_dims, + use_xsa=(layer_idx >= xsa_start_idx), + ln_scale_factor=(1.0 / math.sqrt(layer_idx + 1)) if ln_scale else 1.0, + moe_num_experts=moe_num_experts, + ) + self.head_blocks = nn.ModuleList([make_block(i, effective_layers) for i in range(n_head)]) + self.shared_block = make_block(n_head, effective_layers) + self.recurrence_gate = nn.Parameter(torch.full((recurrence_repeats,), 0.5, dtype=torch.float32)) + self.tail_blocks = nn.ModuleList([ + make_block(n_head + recurrence_repeats + i, effective_layers) for i in range(n_tail) + ]) + self.blocks = nn.ModuleList() + else: + if unet_skips: + self.num_encoder_layers = num_layers // 2 + self.num_decoder_layers = num_layers - self.num_encoder_layers + self.num_skip_weights = min(self.num_encoder_layers, self.num_decoder_layers) + else: + self.num_encoder_layers = num_layers + self.num_decoder_layers = 0 + self.num_skip_weights = 0 + self.skip_weights = nn.Parameter(torch.ones(self.num_skip_weights, model_dim, dtype=torch.float32)) + xsa_start = num_layers - xsa_layers if xsa_layers > 0 else num_layers + self.blocks = nn.ModuleList( + [ + Block( + model_dim, num_heads, num_kv_heads, mlp_mult, rope_base, qk_gain_init, + use_smeargate=use_smeargate, + rope_partial_dims=rope_partial_dims, + use_xsa=(i >= xsa_start), + ln_scale_factor=(1.0 / math.sqrt(i + 1)) if ln_scale else 1.0, + moe_num_experts=moe_num_experts, + ) + for i in range(num_layers) + ] + ) + self.head_blocks = nn.ModuleList() + self.shared_block = None + self.recurrence_gate = None + self.tail_blocks = nn.ModuleList() + self.final_norm = RMSNorm() + self.lm_head = None if tie_embeddings else CastedLinear(model_dim, vocab_size, bias=False) + if self.lm_head is not None: + self.lm_head._zero_init = True + self._init_weights() + def _init_weights(self) -> None: + if self.tie_embeddings: + nn.init.normal_(self.tok_emb.weight, mean=0.0, std=self.tied_embed_init_std) + for module in self.modules(): + if isinstance(module, nn.Linear) and getattr(module, "_zero_init", False): + nn.init.zeros_(module.weight) + def forward(self, input_ids: Tensor, target_ids: Tensor) -> Tensor: + x = self.tok_emb(input_ids) + if self.bigram_hash is not None: + x = x + self.bigram_hash(input_ids) + x = F.rms_norm(x, (x.size(-1),)) + x0 = x + if self.recurrence_repeats > 0: + for block in self.head_blocks: + x = block(x, x0) + for r in range(self.recurrence_repeats): + gate = torch.sigmoid(self.recurrence_gate[r]).to(dtype=x.dtype) + residual = x + x = self.shared_block(x, x0) + x = gate * x + (1.0 - gate) * residual + for block in self.tail_blocks: + x = block(x, x0) + else: + skips: list[Tensor] = [] + for i in range(self.num_encoder_layers): + x = self.blocks[i](x, x0) + skips.append(x) + for i in range(self.num_decoder_layers): + if skips: + x = x + self.skip_weights[i].to(dtype=x.dtype)[None, None, :] * skips.pop() + x = self.blocks[self.num_encoder_layers + i](x, x0) + x = self.final_norm(x).reshape(-1, x.size(-1)) + targets = target_ids.reshape(-1) + if self.tie_embeddings: + logits_proj = F.linear(x, self.tok_emb.weight) + else: + if self.lm_head is None: + raise RuntimeError("lm_head is required when tie_embeddings=False") + logits_proj = self.lm_head(x) + logits = self.logit_softcap * torch.tanh(logits_proj / self.logit_softcap) + if getattr(self, '_return_logits', False): + return logits.float().reshape(input_ids.shape[0], input_ids.shape[1], -1) + return F.cross_entropy(logits.float(), targets, reduction="mean") + def forward_logits(self, input_ids: Tensor) -> Tensor: + x = self.tok_emb(input_ids) + if self.bigram_hash is not None: + x = x + self.bigram_hash(input_ids) + x = F.rms_norm(x, (x.size(-1),)) + x0 = x + if self.recurrence_repeats > 0: + for block in self.head_blocks: + x = block(x, x0) + for r in range(self.recurrence_repeats): + gate = torch.sigmoid(self.recurrence_gate[r]).to(dtype=x.dtype) + residual = x + x = self.shared_block(x, x0) + x = gate * x + (1.0 - gate) * residual + for block in self.tail_blocks: + x = block(x, x0) + else: + skips: list[Tensor] = [] + for i in range(self.num_encoder_layers): + x = self.blocks[i](x, x0) + skips.append(x) + for i in range(self.num_decoder_layers): + if skips: + x = x + self.skip_weights[i].to(dtype=x.dtype)[None, None, :] * skips.pop() + x = self.blocks[self.num_encoder_layers + i](x, x0) + x = self.final_norm(x) + if self.tie_embeddings: + logits_proj = F.linear(x, self.tok_emb.weight) + else: + if self.lm_head is None: + raise RuntimeError("lm_head is required when tie_embeddings=False") + logits_proj = self.lm_head(x) + return self.logit_softcap * torch.tanh(logits_proj / self.logit_softcap) +def main() -> None: + global zeropower_via_newtonschulz5 + code = Path(__file__).read_text(encoding="utf-8") + args = Hyperparameters() + if not int(os.environ.get("TORCHDYNAMO_DISABLE", "0")): + zeropower_via_newtonschulz5 = torch.compile(zeropower_via_newtonschulz5) + distributed = "RANK" in os.environ and "WORLD_SIZE" in os.environ + rank = int(os.environ.get("RANK", "0")) + world_size = int(os.environ.get("WORLD_SIZE", "1")) + local_rank = int(os.environ.get("LOCAL_RANK", "0")) + if world_size <= 0: + raise ValueError(f"WORLD_SIZE must be positive, got {world_size}") + if 8 % world_size != 0: + raise ValueError(f"WORLD_SIZE={world_size} must divide 8 so grad_accum_steps stays integral") + grad_accum_steps = 8 // world_size + grad_scale = 1.0 / grad_accum_steps + if not torch.cuda.is_available(): + raise RuntimeError("CUDA is required") + device = torch.device("cuda", local_rank) + torch.cuda.set_device(device) + if distributed: + dist.init_process_group(backend="nccl", device_id=device) + dist.barrier() + master_process = rank == 0 + torch.backends.cuda.matmul.allow_tf32 = True + torch.backends.cudnn.allow_tf32 = True + from torch.backends.cuda import enable_cudnn_sdp, enable_flash_sdp, enable_math_sdp, enable_mem_efficient_sdp + _use_flash = sys.platform != "win32" + enable_cudnn_sdp(False) + enable_flash_sdp(_use_flash) + enable_mem_efficient_sdp(False) + enable_math_sdp(not _use_flash) + # Create per-run output directory so artifacts/logs never collide + run_dir = Path(os.environ.get("RUN_OUTPUT_DIR", f"runs/{args.run_id}")) + if master_process: + run_dir.mkdir(parents=True, exist_ok=True) + # Backward-compat: also keep logs/ symlink/copy + os.makedirs("logs", exist_ok=True) + logfile = None + if master_process: + logfile = str(run_dir / "log.txt") + # Also write to legacy logs/ path for compatibility + legacy_logfile = f"logs/{args.run_id}.txt" + if not Path(legacy_logfile).exists(): + try: + os.symlink(os.path.abspath(logfile), legacy_logfile) + except OSError: + pass # Windows or cross-device; will write to both + print(logfile) + def log0(msg: str, console: bool = True) -> None: + if not master_process: + return + if console: + print(msg) + if logfile is not None: + with open(logfile, "a", encoding="utf-8") as f: + print(msg, file=f) + log0(code, console=False) + log0("=" * 100, console=False) + log0(f"Running Python {sys.version}", console=False) + log0(f"Running PyTorch {torch.__version__}", console=False) + log0( + subprocess.run(["nvidia-smi"], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True, check=False).stdout, + console=False, + ) + log0("=" * 100, console=False) + random.seed(args.seed) + np.random.seed(args.seed) + torch.manual_seed(args.seed) + torch.cuda.manual_seed_all(args.seed) + if not args.tokenizer_path.endswith(".model"): + raise ValueError(f"Script only setup for SentencePiece .model file: {args.tokenizer_path}") + sp = spm.SentencePieceProcessor(model_file=args.tokenizer_path) + if int(sp.vocab_size()) != args.vocab_size: + raise ValueError( + f"VOCAB_SIZE={args.vocab_size} does not match tokenizer vocab_size={int(sp.vocab_size())}" + ) + dataset_dir = Path(args.data_path).resolve() + actual_train_files = len(list(dataset_dir.glob("fineweb_train_*.bin"))) + val_tokens = load_validation_tokens(args.val_files, args.train_seq_len) + base_bytes_lut, has_leading_space_lut, is_boundary_token_lut = build_sentencepiece_luts( + sp, args.vocab_size, device + ) + log0(f"val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path={args.tokenizer_path}") + log0(f"train_loader:dataset:{dataset_dir.name} train_shards:{actual_train_files}") + log0(f"val_loader:shards pattern={args.val_files} tokens:{val_tokens.numel() - 1}") + CastedLinear._int6_qat = args.int6_qat + base_model = GPT( + vocab_size=args.vocab_size, + num_layers=args.num_layers, + model_dim=args.model_dim, + num_heads=args.num_heads, + num_kv_heads=args.num_kv_heads, + mlp_mult=args.mlp_mult, + tie_embeddings=args.tie_embeddings, + tied_embed_init_std=args.tied_embed_init_std, + logit_softcap=args.logit_softcap, + rope_base=args.rope_base, + qk_gain_init=args.qk_gain_init, + bigramhash_buckets=args.bigramhash_buckets, + use_smeargate=args.smeargate, + unet_skips=args.unet_skips, + rope_partial_dims=args.rope_partial_dims, + ln_scale=args.ln_scale, + xsa_layers=args.xsa_layers, + recurrence_repeats=args.recurrence_repeats, + recurrence_unique_head=args.recurrence_unique_head, + recurrence_unique_tail=args.recurrence_unique_tail, + moe_num_experts=args.moe_num_experts, + ).to(device).bfloat16() + for module in base_model.modules(): + if isinstance(module, CastedLinear): + module.float() + restore_low_dim_params_to_fp32(base_model) + if int(os.environ.get("TORCHDYNAMO_DISABLE", "0")): + compiled_model = base_model + else: + compiled_model = torch.compile(base_model, dynamic=False, fullgraph=True) + model: nn.Module = DDP(compiled_model, device_ids=[local_rank], broadcast_buffers=False) if distributed else compiled_model + if args.recurrence_repeats > 0: + block_named_params = ( + list(base_model.head_blocks.named_parameters()) + + list(base_model.shared_block.named_parameters()) + + list(base_model.tail_blocks.named_parameters()) + ) + if base_model.recurrence_gate is not None: + block_named_params.append(("recurrence_gate", base_model.recurrence_gate)) + else: + block_named_params = list(base_model.blocks.named_parameters()) + matrix_params = [ + p + for name, p in block_named_params + if p.ndim == 2 and not any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) + ] + scalar_params = [ + p + for name, p in block_named_params + if p.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) + ] + if base_model.skip_weights.numel() > 0: + scalar_params.append(base_model.skip_weights) + token_lr = args.tied_embed_lr if args.tie_embeddings else args.embed_lr + embed_params: list[nn.Parameter] = [base_model.tok_emb.weight] + if base_model.bigram_hash is not None: + embed_params.append(base_model.bigram_hash.emb.weight) + optimizer_tok = torch.optim.Adam( + [{"params": embed_params, "lr": token_lr, "base_lr": token_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + fused=True, + ) + optimizer_muon = Muon( + matrix_params, + lr=args.matrix_lr, + momentum=args.muon_momentum, + backend_steps=args.muon_backend_steps, + ) + for group in optimizer_muon.param_groups: + group["base_lr"] = args.matrix_lr + group["muoneq_r"] = args.muoneq_r + optimizer_scalar = torch.optim.Adam( + [{"params": scalar_params, "lr": args.scalar_lr, "base_lr": args.scalar_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + fused=True, + ) + optimizers: list[torch.optim.Optimizer] = [optimizer_tok, optimizer_muon, optimizer_scalar] + if base_model.lm_head is not None: + optimizer_head = torch.optim.Adam( + [{"params": [base_model.lm_head.weight], "lr": args.head_lr, "base_lr": args.head_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + fused=True, + ) + optimizers.insert(1, optimizer_head) + n_params = sum(p.numel() for p in base_model.parameters()) + log0(f"model_params:{n_params}") + log0(f"world_size:{world_size} grad_accum_steps:{grad_accum_steps}") + log0(f"sdp_backends:cudnn=False flash={_use_flash} mem_efficient=False math={not _use_flash}") + log0(f"attention_mode:gqa num_heads:{args.num_heads} num_kv_heads:{args.num_kv_heads}") + log0( + f"tie_embeddings:{args.tie_embeddings} embed_lr:{token_lr} " + f"head_lr:{args.head_lr if base_model.lm_head is not None else 0.0} " + f"matrix_lr:{args.matrix_lr} scalar_lr:{args.scalar_lr}" + ) + log0( + f"train_batch_tokens:{args.train_batch_tokens} train_seq_len:{args.train_seq_len} " + f"iterations:{args.iterations} warmup_steps:{args.warmup_steps} " + f"max_wallclock_seconds:{args.max_wallclock_seconds:.3f}" + ) + log0(f"seed:{args.seed}") + log0( + f"meta_baseline: int6_qat:{args.int6_qat} bigramhash:{args.bigramhash_buckets} " + f"smeargate:{args.smeargate} ema_decay:{args.ema_decay} unet_skips:{args.unet_skips} " + f"warmdown_iters:{args.warmdown_iters} compression:{'zstd-22' if _HAS_ZSTD else 'zlib-9'}" + ) + log0( + f"strong_gains: leaky_relu_sq:True rope_partial_dims:{args.rope_partial_dims} " + f"ln_scale:{args.ln_scale} xsa_layers:{args.xsa_layers}" + ) + if args.recurrence_repeats > 0: + eff = args.recurrence_unique_head + args.recurrence_repeats + args.recurrence_unique_tail + log0( + f"depth_recurrence: head={args.recurrence_unique_head} repeats={args.recurrence_repeats} " + f"tail={args.recurrence_unique_tail} effective_layers={eff}" + ) + train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) + def zero_grad_all() -> None: + for opt in optimizers: + opt.zero_grad(set_to_none=True) + max_wallclock_ms = 1000.0 * args.max_wallclock_seconds if args.max_wallclock_seconds > 0 else None + def lr_mul(step: int, elapsed_ms: float) -> float: + if args.warmdown_iters <= 0: + return 1.0 + if max_wallclock_ms is None: + warmdown_start = max(args.iterations - args.warmdown_iters, 0) + return max((args.iterations - step) / max(args.warmdown_iters, 1), 0.0) if warmdown_start <= step < args.iterations else 1.0 + step_ms = elapsed_ms / max(step, 1) + warmdown_ms = args.warmdown_iters * step_ms + remaining_ms = max(max_wallclock_ms - elapsed_ms, 0.0) + return remaining_ms / max(warmdown_ms, 1e-9) if remaining_ms <= warmdown_ms else 1.0 + if args.warmup_steps > 0: + initial_model_state = {name: tensor.detach().cpu().clone() for name, tensor in base_model.state_dict().items()} + initial_optimizer_states = [copy.deepcopy(opt.state_dict()) for opt in optimizers] + model.train() + for warmup_step in range(args.warmup_steps): + zero_grad_all() + for micro_step in range(grad_accum_steps): + if distributed: + model.require_backward_grad_sync = micro_step == grad_accum_steps - 1 + x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + warmup_loss = model(x, y) + (warmup_loss * grad_scale).backward() + for opt in optimizers: + opt.step() + zero_grad_all() + if args.warmup_steps <= 20 or (warmup_step + 1) % 10 == 0 or warmup_step + 1 == args.warmup_steps: + log0(f"warmup_step:{warmup_step + 1}/{args.warmup_steps}") + base_model.load_state_dict(initial_model_state, strict=True) + for opt, state in zip(optimizers, initial_optimizer_states, strict=True): + opt.load_state_dict(state) + zero_grad_all() + if distributed: + model.require_backward_grad_sync = True + train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) + ema_state: dict[str, Tensor] | None = None + if args.ema_decay > 0: + ema_state = {name: param.data.clone() for name, param in base_model.named_parameters()} + log0(f"ema:enabled decay={args.ema_decay}") + training_time_ms = 0.0 + stop_after_step: int | None = None + torch.cuda.synchronize() + t0 = time.perf_counter() + step = 0 + while True: + last_step = step == args.iterations or (stop_after_step is not None and step >= stop_after_step) + should_validate = last_step or (args.val_loss_every > 0 and step % args.val_loss_every == 0) + if should_validate: + torch.cuda.synchronize() + training_time_ms += 1000.0 * (time.perf_counter() - t0) + ema_backup: dict[str, Tensor] | None = None + if ema_state is not None: + ema_backup = {} + with torch.no_grad(): + for name, p in base_model.named_parameters(): + ema_backup[name] = p.data.clone() + p.data.copy_(ema_state[name]) + val_loss, val_bpb = eval_val( + args, + model, + rank, + world_size, + device, + grad_accum_steps, + val_tokens, + base_bytes_lut, + has_leading_space_lut, + is_boundary_token_lut, + ) + if ema_backup is not None: + with torch.no_grad(): + for name, p in base_model.named_parameters(): + p.data.copy_(ema_backup[name]) + del ema_backup + log0( + f"step:{step}/{args.iterations} val_loss:{val_loss:.4f} val_bpb:{val_bpb:.4f} " + f"train_time:{training_time_ms:.0f}ms step_avg:{training_time_ms / max(step, 1):.2f}ms" + ) + torch.cuda.synchronize() + t0 = time.perf_counter() + if last_step: + if stop_after_step is not None and step < args.iterations: + log0( + f"stopping_early: wallclock_cap train_time:{training_time_ms:.0f}ms " + f"step:{step}/{args.iterations}" + ) + break + elapsed_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) + scale = lr_mul(step, elapsed_ms) + zero_grad_all() + train_loss = torch.zeros((), device=device) + for micro_step in range(grad_accum_steps): + if distributed: + model.require_backward_grad_sync = micro_step == grad_accum_steps - 1 + x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + loss = model(x, y) + train_loss += loss.detach() + (loss * grad_scale).backward() + train_loss /= grad_accum_steps + frac = min(step / args.muon_momentum_warmup_steps, 1.0) if args.muon_momentum_warmup_steps > 0 else 1.0 + muon_momentum = (1 - frac) * args.muon_momentum_warmup_start + frac * args.muon_momentum + for group in optimizer_muon.param_groups: + group["momentum"] = muon_momentum + for opt in optimizers: + for group in opt.param_groups: + group["lr"] = group["base_lr"] * scale + if args.grad_clip_norm > 0: + torch.nn.utils.clip_grad_norm_(base_model.parameters(), args.grad_clip_norm) + for opt in optimizers: + opt.step() + zero_grad_all() + if ema_state is not None: + with torch.no_grad(): + decay = args.ema_decay + for name, p in base_model.named_parameters(): + ema_state[name].mul_(decay).add_(p.data, alpha=1.0 - decay) + step += 1 + if args.checkpoint_every > 0 and step % args.checkpoint_every == 0 and master_process: + ckpt_dir = run_dir / "checkpoints" + ckpt_dir.mkdir(exist_ok=True) + ckpt_path = ckpt_dir / f"ckpt_step{step}.pt" + ckpt_data = {"step": step, "model": base_model.state_dict()} + if ema_state is not None: + ckpt_data["ema"] = {k: v.cpu() for k, v in ema_state.items()} + torch.save(ckpt_data, ckpt_path) + ckpts = sorted(ckpt_dir.glob("ckpt_step*.pt"), key=lambda p: p.stat().st_mtime) + for old in ckpts[:-2]: + old.unlink(missing_ok=True) + log0(f"checkpoint:saved {ckpt_path} (kept {min(len(ckpts), 2)} latest)") + approx_training_time_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) + should_log_train = ( + args.train_log_every > 0 + and (step <= 10 or step % args.train_log_every == 0 or stop_after_step is not None) + ) + if should_log_train: + log0( + f"step:{step}/{args.iterations} train_loss:{train_loss.item():.4f} " + f"train_time:{approx_training_time_ms:.0f}ms step_avg:{approx_training_time_ms / step:.2f}ms" + ) + reached_cap = max_wallclock_ms is not None and approx_training_time_ms >= max_wallclock_ms + if distributed and max_wallclock_ms is not None: + reached_cap_tensor = torch.tensor(int(reached_cap), device=device) + dist.all_reduce(reached_cap_tensor, op=dist.ReduceOp.MAX) + reached_cap = bool(reached_cap_tensor.item()) + if stop_after_step is None and reached_cap: + stop_after_step = step + log0( + f"peak memory allocated: {torch.cuda.max_memory_allocated() // 1024 // 1024} MiB " + f"reserved: {torch.cuda.max_memory_reserved() // 1024 // 1024} MiB" + ) + if ema_state is not None: + with torch.no_grad(): + for name, p in base_model.named_parameters(): + p.data.copy_(ema_state[name]) + log0("ema:applied EMA weights for export") + fp32_path = str(run_dir / "final_model.pt") + if master_process: + torch.save(base_model.state_dict(), fp32_path) + model_bytes = os.path.getsize(fp32_path) + code_bytes = len(code.encode("utf-8")) + log0(f"Serialized model: {model_bytes} bytes → {fp32_path}") + log0(f"Code size: {code_bytes} bytes") + log0(f"Total submission size: {model_bytes + code_bytes} bytes") + quant_label = f"int{_QUANT_BITS}" + compress_label = "zstd" if _HAS_ZSTD else "zlib" + log0(f"quantization:start {quant_label} distributing across {world_size} GPUs") + torch.cuda.synchronize() + t_quant = time.perf_counter() + quant_obj, quant_stats = quantize_state_dict_int8( + base_model.state_dict(), gpu_device=device, rank=rank, world_size=world_size, + ) + torch.cuda.synchronize() + log0(f"quantization:done {quant_label} time={1000.0 * (time.perf_counter() - t_quant):.0f}ms") + quant_buf = io.BytesIO() + torch.save(quant_obj, quant_buf) + quant_raw = quant_buf.getvalue() + if _HAS_ZSTD: + compressor = zstd.ZstdCompressor(level=22, threads=-1) # -1 = use all CPU cores + quant_blob = compressor.compress(quant_raw) + else: + quant_blob = zlib.compress(quant_raw, level=9) + quant_raw_bytes = len(quant_raw) + artifact_name = str(run_dir / f"final_model.{quant_label}.ptz") + if master_process: + with open(artifact_name, "wb") as f: + f.write(quant_blob) + quant_file_bytes = os.path.getsize(artifact_name) + code_bytes = len(code.encode("utf-8")) + ratio = quant_stats["baseline_tensor_bytes"] / max(quant_stats["int8_payload_bytes"], 1) + log0( + f"Serialized model {quant_label}+{compress_label}: {quant_file_bytes} bytes " + f"(payload:{quant_stats['int8_payload_bytes']} raw_torch:{quant_raw_bytes} payload_ratio:{ratio:.2f}x)" + ) + log0(f"Total submission size {quant_label}+{compress_label}: {quant_file_bytes + code_bytes} bytes") + log0(f"artifact_path:{artifact_name}") + if distributed: + dist.barrier() + with open(artifact_name, "rb") as f: + quant_blob_disk = f.read() + if _HAS_ZSTD: + decompressor = zstd.ZstdDecompressor() + quant_state = torch.load(io.BytesIO(decompressor.decompress(quant_blob_disk)), map_location="cpu") + else: + quant_state = torch.load(io.BytesIO(zlib.decompress(quant_blob_disk)), map_location="cpu") + base_model.load_state_dict(dequantize_state_dict_int8(quant_state), strict=True) + torch.cuda.synchronize() + t_qeval = time.perf_counter() + q_val_loss, q_val_bpb = eval_val( + args, + model, + rank, + world_size, + device, + grad_accum_steps, + val_tokens, + base_bytes_lut, + has_leading_space_lut, + is_boundary_token_lut, + ) + torch.cuda.synchronize() + log0( + f"final_{quant_label}_{compress_label}_roundtrip val_loss:{q_val_loss:.4f} val_bpb:{q_val_bpb:.4f} " + f"eval_time:{1000.0 * (time.perf_counter() - t_qeval):.0f}ms" + ) + log0(f"final_{quant_label}_{compress_label}_roundtrip_exact val_loss:{q_val_loss:.8f} val_bpb:{q_val_bpb:.8f}") + if args.eval_stride > 0 and args.eval_stride < args.train_seq_len: + torch.cuda.synchronize() + t_slide = time.perf_counter() + sw_val_loss, sw_val_bpb = eval_val_sliding( + args, base_model, rank, world_size, device, + val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, + stride=args.eval_stride, + ) + torch.cuda.synchronize() + log0( + f"final_sliding_window val_loss:{sw_val_loss:.4f} val_bpb:{sw_val_bpb:.4f} " + f"stride:{args.eval_stride} eval_time:{1000.0 * (time.perf_counter() - t_slide):.0f}ms" + ) + log0(f"final_sliding_window_exact val_loss:{sw_val_loss:.8f} val_bpb:{sw_val_bpb:.8f}") + if args.ttt_enabled: + torch.cuda.synchronize() + t_ttt = time.perf_counter() + + # Choose LoRA TTT or full TTT based on TTT_LORA_ENABLED flag + if args.ttt_lora_enabled: + ttt_loss, ttt_bpb = eval_val_sliding_lora_ttt( + args, base_model, rank, world_size, device, + val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, + stride=args.eval_stride if args.eval_stride > 0 else 64, + log0=log0, + ) + ttt_label = "legal_lora_ttt" + else: + ttt_loss, ttt_bpb = eval_val_sliding_ttt( + args, base_model, rank, world_size, device, + val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, + stride=args.eval_stride if args.eval_stride > 0 else 64, + log0=log0, + ) + ttt_label = "legal_ttt" + + torch.cuda.synchronize() + log0( + f"{ttt_label} val_loss:{ttt_loss:.4f} val_bpb:{ttt_bpb:.4f} " + f"eval_time:{1000.0 * (time.perf_counter() - t_ttt):.0f}ms" + ) + log0(f"{ttt_label}_exact val_loss:{ttt_loss:.8f} val_bpb:{ttt_bpb:.8f}") + if int(os.environ.get("NGRAM_EVAL", "0")): + ngram_max_order = int(os.environ.get("NGRAM_MAX_ORDER", "9")) + log0(f"ngram_eval:starting max_order={ngram_max_order}") + torch.cuda.synchronize() + t_ngram = time.perf_counter() + try: + from ngram_eval import eval_val_ngram + ng_val_loss, ng_val_bpb = eval_val_ngram( + args, base_model, model, rank, world_size, device, grad_accum_steps, + val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, + max_order=ngram_max_order, log_fn=log0, + ) + torch.cuda.synchronize() + log0( + f"ngram_eval_result val_loss:{ng_val_loss:.4f} val_bpb:{ng_val_bpb:.4f} " + f"eval_time:{1000.0 * (time.perf_counter() - t_ngram):.0f}ms" + ) + log0(f"ngram_eval_result_exact val_loss:{ng_val_loss:.8f} val_bpb:{ng_val_bpb:.8f}") + except Exception as e: + log0(f"ngram_eval:FAILED {e}") + if distributed: + dist.destroy_process_group() +if __name__ == "__main__": + main() diff --git a/train_gpt_phase1_sp4096.py b/train_gpt_phase1_sp4096.py new file mode 100644 index 0000000000..185b7bb21c --- /dev/null +++ b/train_gpt_phase1_sp4096.py @@ -0,0 +1,1886 @@ +from __future__ import annotations +import copy +import glob +import io +import math +import os +import random +import subprocess +import sys +import time +import uuid +import zlib +from pathlib import Path +try: + import zstandard as zstd + _HAS_ZSTD = True +except ImportError: + _HAS_ZSTD = False +import numpy as np +import sentencepiece as spm +import torch +import torch.distributed as dist +import torch.nn.functional as F +from torch import Tensor, nn +from torch.nn.parallel import DistributedDataParallel as DDP +class Hyperparameters: + data_path = os.environ.get("DATA_PATH", f"./data/datasets/fineweb10B_sp{os.environ.get('VOCAB_SIZE', '4096')}") + train_files = os.path.join(data_path, "fineweb_train_*.bin") + val_files = os.path.join(data_path, "fineweb_val_*.bin") + tokenizer_path = os.environ.get("TOKENIZER_PATH", f"./data/tokenizers/fineweb_{os.environ.get('VOCAB_SIZE', '4096')}_bpe.model") + run_id = os.environ.get("RUN_ID", str(uuid.uuid4())) + seed = int(os.environ.get("SEED", 1337)) + val_batch_size = int(os.environ.get("VAL_BATCH_SIZE", 524_288)) + val_loss_every = int(os.environ.get("VAL_LOSS_EVERY", 1000)) + train_log_every = int(os.environ.get("TRAIN_LOG_EVERY", 200)) + val_max_tokens = int(os.environ.get("VAL_MAX_TOKENS", 0)) + checkpoint_every = int(os.environ.get("CHECKPOINT_EVERY", 0)) + iterations = int(os.environ.get("ITERATIONS", 20000)) + warmdown_iters = int(os.environ.get("WARMDOWN_ITERS", 3500)) + warmup_steps = int(os.environ.get("WARMUP_STEPS", 20)) + train_batch_tokens = int(os.environ.get("TRAIN_BATCH_TOKENS", 524_288)) + train_seq_len = int(os.environ.get("TRAIN_SEQ_LEN", 1024)) + max_wallclock_seconds = float(os.environ.get("MAX_WALLCLOCK_SECONDS", 600.0)) + qk_gain_init = float(os.environ.get("QK_GAIN_INIT", 1.5)) + vocab_size = int(os.environ.get("VOCAB_SIZE", 4096)) + num_layers = int(os.environ.get("NUM_LAYERS", 11)) + num_kv_heads = int(os.environ.get("NUM_KV_HEADS", 4)) + model_dim = int(os.environ.get("MODEL_DIM", 512)) + num_heads = int(os.environ.get("NUM_HEADS", 8)) + mlp_mult = float(os.environ.get("MLP_MULT", 4)) + tie_embeddings = bool(int(os.environ.get("TIE_EMBEDDINGS", "1"))) + rope_base = float(os.environ.get("ROPE_BASE", 10000.0)) + logit_softcap = float(os.environ.get("LOGIT_SOFTCAP", 30.0)) + embed_lr = float(os.environ.get("EMBED_LR", 0.6)) + head_lr = float(os.environ.get("HEAD_LR", 0.008)) + tied_embed_lr = float(os.environ.get("TIED_EMBED_LR", 0.05)) + tied_embed_init_std = float(os.environ.get("TIED_EMBED_INIT_STD", 0.005)) + matrix_lr = float(os.environ.get("MATRIX_LR", 0.04)) + scalar_lr = float(os.environ.get("SCALAR_LR", 0.04)) + muon_momentum = float(os.environ.get("MUON_MOMENTUM", 0.95)) + muon_backend_steps = int(os.environ.get("MUON_BACKEND_STEPS", 5)) + muon_momentum_warmup_start = float(os.environ.get("MUON_MOMENTUM_WARMUP_START", 0.85)) + muon_momentum_warmup_steps = int(os.environ.get("MUON_MOMENTUM_WARMUP_STEPS", 500)) + beta1 = float(os.environ.get("BETA1", 0.9)) + beta2 = float(os.environ.get("BETA2", 0.95)) + adam_eps = float(os.environ.get("ADAM_EPS", 1e-8)) + grad_clip_norm = float(os.environ.get("GRAD_CLIP_NORM", 0.0)) + bigramhash_buckets = int(os.environ.get("BIGRAMHASH_BUCKETS", 3072)) + ema_decay = float(os.environ.get("EMA_DECAY", 0.997)) + smeargate = bool(int(os.environ.get("SMEARGATE", "1"))) + unet_skips = bool(int(os.environ.get("UNET_SKIPS", "1"))) + int6_qat = bool(int(os.environ.get("INT6_QAT", "1"))) + rope_partial_dims = int(os.environ.get("ROPE_PARTIAL_DIMS", 16)) + ln_scale = bool(int(os.environ.get("LN_SCALE", "1"))) + xsa_layers = int(os.environ.get("XSA_LAYERS", 4)) + eval_stride = int(os.environ.get("EVAL_STRIDE", 0)) + ttt_enabled = bool(int(os.environ.get("TTT_ENABLED", "0"))) + ttt_lr = float(os.environ.get("TTT_LR", 0.002)) + ttt_epochs = int(os.environ.get("TTT_EPOCHS", 3)) + ttt_chunk_tokens = int(os.environ.get("TTT_CHUNK_TOKENS", 32768)) + ttt_freeze_blocks = int(os.environ.get("TTT_FREEZE_BLOCKS", 0)) + ttt_momentum = float(os.environ.get("TTT_MOMENTUM", 0.9)) + ttt_batch_seqs = int(os.environ.get("TTT_BATCH_SEQS", 32)) + ttt_grad_clip = float(os.environ.get("TTT_GRAD_CLIP", 1.0)) + ttt_lora_enabled = bool(int(os.environ.get("TTT_LORA_ENABLED", "0"))) + ttt_lora_rank = int(os.environ.get("TTT_LORA_RANK", 8)) + ttt_lora_lr = float(os.environ.get("TTT_LORA_LR", 0.001)) + recurrence_repeats = int(os.environ.get("RECURRENCE_REPEATS", 0)) + recurrence_unique_head = int(os.environ.get("RECURRENCE_UNIQUE_HEAD", 2)) + recurrence_unique_tail = int(os.environ.get("RECURRENCE_UNIQUE_TAIL", 2)) + moe_num_experts = int(os.environ.get("MOE_NUM_EXPERTS", 0)) + bigramhash_dim = int(os.environ.get("BIGRAMHASH_DIM", 112)) + encoder_layers = int(os.environ.get("ENCODER_LAYERS", 1)) +def zeropower_via_newtonschulz5(G: Tensor, steps: int = 10, eps: float = 1e-7) -> Tensor: + a, b, c = (3.4445, -4.7750, 2.0315) + X = G.bfloat16() + X /= X.norm() + eps + transposed = G.size(0) > G.size(1) + if transposed: + X = X.T + for _ in range(steps): + A = X @ X.T + B = b * A + c * A @ A + X = a * X + B @ X + return X.T if transposed else X +class Muon(torch.optim.Optimizer): + def __init__(self, params, lr: float, momentum: float, backend_steps: int, nesterov: bool = True): + super().__init__( + params, + dict(lr=lr, momentum=momentum, backend_steps=backend_steps, nesterov=nesterov), + ) + @torch.no_grad() + def step(self, closure=None): + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + distributed = dist.is_available() and dist.is_initialized() + world_size = dist.get_world_size() if distributed else 1 + rank = dist.get_rank() if distributed else 0 + for group in self.param_groups: + params = group["params"] + if not params: + continue + lr = group["lr"] + momentum = group["momentum"] + backend_steps = group["backend_steps"] + nesterov = group["nesterov"] + total_params = sum(int(p.numel()) for p in params) + updates_flat = torch.zeros(total_params, device=params[0].device, dtype=torch.bfloat16) + curr = 0 + for i, p in enumerate(params): + if i % world_size == rank and p.grad is not None: + g = p.grad + state = self.state[p] + if "momentum_buffer" not in state: + state["momentum_buffer"] = torch.zeros_like(g) + buf = state["momentum_buffer"] + buf.mul_(momentum).add_(g) + if nesterov: + g = g.add(buf, alpha=momentum) + g = zeropower_via_newtonschulz5(g, steps=backend_steps) + g *= max(1, g.size(0) / g.size(1)) ** 0.5 + updates_flat[curr : curr + p.numel()] = g.reshape(-1) + curr += p.numel() + if distributed: + dist.all_reduce(updates_flat, op=dist.ReduceOp.SUM) + curr = 0 + for p in params: + g = updates_flat[curr : curr + p.numel()].view_as(p).to(dtype=p.dtype) + p.add_(g, alpha=-lr) + curr += p.numel() + return loss +def int6_ste(w: Tensor) -> Tensor: + qmax = _QUANT_MAX_VAL # Uses QUANT_BITS override if set (e.g. 15 for int5, 7 for int4) + if w.ndim == 2: + abs_max = w.abs().amax(dim=1, keepdim=True).clamp_min(1e-8) + else: + abs_max = w.abs().amax().clamp_min(1e-8) + scale = abs_max / float(qmax) + w_q = (w / scale).round().clamp(-qmax, qmax) * scale + return w + (w_q - w).detach() +def build_sentencepiece_luts( + sp: spm.SentencePieceProcessor, vocab_size: int, device: torch.device +) -> tuple[Tensor, Tensor, Tensor]: + sp_vocab_size = int(sp.vocab_size()) + table_size = max(sp_vocab_size, vocab_size) + base_bytes_np = np.zeros((table_size,), dtype=np.int16) + has_leading_space_np = np.zeros((table_size,), dtype=np.bool_) + is_boundary_token_np = np.ones((table_size,), dtype=np.bool_) + for token_id in range(sp_vocab_size): + if sp.is_control(token_id) or sp.is_unknown(token_id) or sp.is_unused(token_id): + continue + is_boundary_token_np[token_id] = False + if sp.is_byte(token_id): + base_bytes_np[token_id] = 1 + continue + piece = sp.id_to_piece(token_id) + if piece.startswith("▁"): + has_leading_space_np[token_id] = True + piece = piece[1:] + base_bytes_np[token_id] = len(piece.encode("utf-8")) + return ( + torch.tensor(base_bytes_np, dtype=torch.int16, device=device), + torch.tensor(has_leading_space_np, dtype=torch.bool, device=device), + torch.tensor(is_boundary_token_np, dtype=torch.bool, device=device), + ) +def load_validation_tokens(pattern: str, seq_len: int) -> Tensor: + files = [Path(p) for p in sorted(glob.glob(pattern))] + if not files: + raise FileNotFoundError(f"No files found for pattern: {pattern}") + tokens = torch.cat([load_data_shard(file) for file in files]).contiguous() + max_tokens = int(os.environ.get("VAL_MAX_TOKENS", 0)) + if max_tokens > 0 and tokens.numel() > max_tokens: + tokens = tokens[:max_tokens].contiguous() + usable = ((tokens.numel() - 1) // seq_len) * seq_len + if usable <= 0: + raise ValueError(f"Validation split is too short for TRAIN_SEQ_LEN={seq_len}") + return tokens[: usable + 1] +def eval_val( + args: Hyperparameters, + model: nn.Module, + rank: int, + world_size: int, + device: torch.device, + grad_accum_steps: int, + val_tokens: Tensor, + base_bytes_lut: Tensor, + has_leading_space_lut: Tensor, + is_boundary_token_lut: Tensor, +) -> tuple[float, float]: + local_batch_tokens = args.val_batch_size // (world_size * grad_accum_steps) + if local_batch_tokens < args.train_seq_len: + raise ValueError( + "VAL_BATCH_SIZE must provide at least one sequence per rank; " + f"got VAL_BATCH_SIZE={args.val_batch_size}, WORLD_SIZE={world_size}, " + f"GRAD_ACCUM_STEPS={grad_accum_steps}, TRAIN_SEQ_LEN={args.train_seq_len}" + ) + local_batch_seqs = local_batch_tokens // args.train_seq_len + total_seqs = (val_tokens.numel() - 1) // args.train_seq_len + seq_start = (total_seqs * rank) // world_size + seq_end = (total_seqs * (rank + 1)) // world_size + val_loss_sum = torch.zeros((), device=device, dtype=torch.float64) + val_token_count = torch.zeros((), device=device, dtype=torch.float64) + val_byte_count = torch.zeros((), device=device, dtype=torch.float64) + model.eval() + with torch.inference_mode(): + for batch_seq_start in range(seq_start, seq_end, local_batch_seqs): + batch_seq_end = min(batch_seq_start + local_batch_seqs, seq_end) + raw_start = batch_seq_start * args.train_seq_len + raw_end = batch_seq_end * args.train_seq_len + 1 + local = val_tokens[raw_start:raw_end].to(device=device, dtype=torch.int64, non_blocking=True) + x = local[:-1].reshape(-1, args.train_seq_len) + y = local[1:].reshape(-1, args.train_seq_len) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + batch_loss = model(x, y).detach() + batch_token_count = float(y.numel()) + val_loss_sum += batch_loss.to(torch.float64) * batch_token_count + val_token_count += batch_token_count + prev_ids = x.reshape(-1) + tgt_ids = y.reshape(-1) + token_bytes = base_bytes_lut[tgt_ids].to(dtype=torch.int16) + token_bytes += (has_leading_space_lut[tgt_ids] & ~is_boundary_token_lut[prev_ids]).to(dtype=torch.int16) + val_byte_count += token_bytes.to(torch.float64).sum() + if dist.is_available() and dist.is_initialized(): + dist.all_reduce(val_loss_sum, op=dist.ReduceOp.SUM) + dist.all_reduce(val_token_count, op=dist.ReduceOp.SUM) + dist.all_reduce(val_byte_count, op=dist.ReduceOp.SUM) + val_loss = val_loss_sum / val_token_count + bits_per_token = val_loss.item() / math.log(2.0) + tokens_per_byte = val_token_count.item() / val_byte_count.item() + model.train() + return float(val_loss.item()), float(bits_per_token * tokens_per_byte) +def eval_val_sliding( + args: Hyperparameters, + base_model: nn.Module, + rank: int, + world_size: int, + device: torch.device, + val_tokens: Tensor, + base_bytes_lut: Tensor, + has_leading_space_lut: Tensor, + is_boundary_token_lut: Tensor, + stride: int, + batch_seqs: int = 32, +) -> tuple[float, float]: + seq_len = args.train_seq_len + total_tokens = val_tokens.numel() - 1 + window_starts = [ws for ws in range(0, total_tokens, stride) + if min(ws + seq_len, total_tokens) - ws >= 1] + total_windows = len(window_starts) + my_s = (total_windows * rank) // world_size + my_e = (total_windows * (rank + 1)) // world_size + my_windows = window_starts[my_s:my_e] + loss_sum = torch.zeros((), device=device, dtype=torch.float64) + token_count = torch.zeros((), device=device, dtype=torch.float64) + byte_count = torch.zeros((), device=device, dtype=torch.float64) + base_model.eval() + compiled_logits = torch.compile(base_model.forward_logits, dynamic=False, fullgraph=True) + with torch.inference_mode(): + for bi in range(0, len(my_windows), batch_seqs): + batch_ws = my_windows[bi:bi + batch_seqs] + bsz = len(batch_ws) + x_batch = torch.zeros(bsz, seq_len, dtype=torch.int64, device=device) + y_batch = torch.zeros(bsz, seq_len, dtype=torch.int64, device=device) + wlens: list[int] = [] + for i, ws in enumerate(batch_ws): + end = min(ws + seq_len, total_tokens) + wlen = end - ws + wlens.append(wlen) + chunk = val_tokens[ws:end + 1].to(dtype=torch.int64, device=device) + x_batch[i, :wlen] = chunk[:-1] + y_batch[i, :wlen] = chunk[1:] + with torch.autocast(device_type="cuda", dtype=torch.bfloat16): + logits = compiled_logits(x_batch) + nll = F.cross_entropy( + logits.reshape(-1, logits.size(-1)).float(), + y_batch.reshape(-1), + reduction="none", + ).reshape(bsz, seq_len) + for i, ws in enumerate(batch_ws): + wlen = wlens[i] + s = 0 if ws == 0 else max(wlen - stride, 0) + scored_nll = nll[i, s:wlen].to(torch.float64) + loss_sum += scored_nll.sum() + token_count += float(wlen - s) + tgt = y_batch[i, s:wlen] + prev = x_batch[i, s:wlen] + tb = base_bytes_lut[tgt].to(torch.float64) + tb += (has_leading_space_lut[tgt] & ~is_boundary_token_lut[prev]).to(torch.float64) + byte_count += tb.sum() + if dist.is_available() and dist.is_initialized(): + dist.all_reduce(loss_sum, op=dist.ReduceOp.SUM) + dist.all_reduce(token_count, op=dist.ReduceOp.SUM) + dist.all_reduce(byte_count, op=dist.ReduceOp.SUM) + val_loss = (loss_sum / token_count).item() + bits_per_token = val_loss / math.log(2.0) + tokens_per_byte = token_count.item() / byte_count.item() + base_model.train() + return val_loss, bits_per_token * tokens_per_byte +class LoRAAdapter(nn.Module): + """Low-Rank Adaptation adapter for a linear layer.""" + def __init__(self, base_linear: nn.Linear, rank: int = 8): + super().__init__() + dim_in = base_linear.in_features + dim_out = base_linear.out_features + self.A = nn.Parameter(torch.empty(dim_in, rank)) + self.B = nn.Parameter(torch.zeros(rank, dim_out)) + nn.init.kaiming_uniform_(self.A, a=math.sqrt(5)) + self.base_linear = base_linear + self._hook_handle = None + + def _lora_hook(self, module: nn.Module, input: tuple[Tensor, ...], output: Tensor) -> Tensor: + """Forward hook that adds LoRA contribution to base linear output.""" + x = input[0] + lora_delta = (x @ self.A.to(x.dtype)) @ self.B.to(x.dtype) + return output + lora_delta + + def register(self) -> None: + """Register the LoRA hook on the base linear layer.""" + if self._hook_handle is None: + self._hook_handle = self.base_linear.register_forward_hook(self._lora_hook) + + def remove(self) -> None: + """Remove the LoRA hook from the base linear layer.""" + if self._hook_handle is not None: + self._hook_handle.remove() + self._hook_handle = None + +def eval_val_sliding_lora_ttt( + args: Hyperparameters, + base_model: nn.Module, + rank: int, + world_size: int, + device: torch.device, + val_tokens: Tensor, + base_bytes_lut: Tensor, + has_leading_space_lut: Tensor, + is_boundary_token_lut: Tensor, + stride: int, + batch_seqs: int = 32, + log0=print, +) -> tuple[float, float]: + """Test-time training with LoRA adapters (only train adapter weights, not base model).""" + seq_len = args.train_seq_len + total_tokens = val_tokens.numel() - 1 + ttt_chunk = args.ttt_chunk_tokens + + window_starts = [ws for ws in range(0, total_tokens, stride) + if min(ws + seq_len, total_tokens) - ws >= stride or ws == 0] + num_chunks = (total_tokens + ttt_chunk - 1) // ttt_chunk + + chunk_windows: list[list[int]] = [[] for _ in range(num_chunks)] + for ws in window_starts: + end = min(ws + seq_len, total_tokens) + wlen = end - ws + s = 0 if ws == 0 else max(wlen - stride, 0) + scored_start = ws + s + ci = min(scored_start // ttt_chunk, num_chunks - 1) + chunk_windows[ci].append(ws) + + log0(f"lora_ttt_sliding:start chunks={num_chunks} chunk_tokens={ttt_chunk} " + f"total_windows={len(window_starts)} stride={stride} " + f"lora_rank={args.ttt_lora_rank} lora_lr={args.ttt_lora_lr} " + f"ttt_epochs={args.ttt_epochs}") + + loss_sum = torch.zeros((), device=device, dtype=torch.float64) + token_count = torch.zeros((), device=device, dtype=torch.float64) + byte_count = torch.zeros((), device=device, dtype=torch.float64) + + # Create LoRA adapters for attention projections and LM head + lora_adapters: list[LoRAAdapter] = [] + + # Add LoRA to attention blocks (c_q, c_v, proj) + if args.recurrence_repeats > 0: + # Recurrent model structure + all_blocks = list(base_model.head_blocks) + [base_model.shared_block] + list(base_model.tail_blocks) + else: + # Standard block structure + all_blocks = list(base_model.blocks) + + for block in all_blocks: + if hasattr(block, 'attn'): + attn = block.attn + # Add LoRA to Q, V, and output projection + for layer_name in ['c_q', 'c_v', 'proj']: + if hasattr(attn, layer_name): + base_linear = getattr(attn, layer_name) + adapter = LoRAAdapter(base_linear, rank=args.ttt_lora_rank).to(device) + lora_adapters.append(adapter) + + # Add LoRA to LM head (or tok_emb if tied) + if base_model.tie_embeddings: + # For tied embeddings, we need to patch tok_emb.weight usage in forward + # Create a pseudo-linear wrapper for the embedding weight + class EmbeddingAsLinear(nn.Module): + def __init__(self, weight: nn.Parameter): + super().__init__() + self.weight = weight + self.in_features = weight.size(1) + self.out_features = weight.size(0) + self.bias = None + + def forward(self, x: Tensor) -> Tensor: + return F.linear(x, self.weight) + + emb_linear = EmbeddingAsLinear(base_model.tok_emb.weight).to(device) + lm_adapter = LoRAAdapter(emb_linear, rank=args.ttt_lora_rank).to(device) + lora_adapters.append(lm_adapter) + elif base_model.lm_head is not None: + lm_adapter = LoRAAdapter(base_model.lm_head, rank=args.ttt_lora_rank).to(device) + lora_adapters.append(lm_adapter) + + # Collect all LoRA parameters + lora_params = [] + for adapter in lora_adapters: + lora_params.extend([adapter.A, adapter.B]) + + log0(f"lora_ttt_sliding:params lora={sum(p.numel() for p in lora_params)} " + f"base_frozen={sum(p.numel() for p in base_model.parameters())} " + f"adapters={len(lora_adapters)}") + + # Freeze all base model parameters + for p in base_model.parameters(): + p.requires_grad_(False) + + # Register all LoRA hooks + for adapter in lora_adapters: + adapter.register() + + # Use AdamW optimizer for LoRA params + optimizer = torch.optim.AdamW( + lora_params, + lr=args.ttt_lora_lr, + betas=(0.9, 0.999), + weight_decay=0.01, + ) + + # Compile forward paths + if not int(os.environ.get("TORCHDYNAMO_DISABLE", "0")): + compiled_logits_lora = torch.compile(base_model.forward_logits, dynamic=False, fullgraph=True) + compiled_forward_lora = torch.compile(base_model.forward, dynamic=False, fullgraph=True) + else: + compiled_logits_lora = base_model.forward_logits + compiled_forward_lora = base_model.forward + + t0 = time.perf_counter() + + for ci in range(num_chunks): + windows = chunk_windows[ci] + if not windows: + continue + + chunk_start = ci * ttt_chunk + chunk_end = min((ci + 1) * ttt_chunk, total_tokens) + + my_s = (len(windows) * rank) // world_size + my_e = (len(windows) * (rank + 1)) // world_size + my_windows = windows[my_s:my_e] + + # Score phase (inference only) + base_model.eval() + with torch.inference_mode(): + for bi in range(0, len(my_windows), batch_seqs): + batch_ws = my_windows[bi:bi + batch_seqs] + bsz = len(batch_ws) + x_batch = torch.zeros(bsz, seq_len, dtype=torch.int64, device=device) + y_batch = torch.zeros(bsz, seq_len, dtype=torch.int64, device=device) + wlens: list[int] = [] + + for i, ws in enumerate(batch_ws): + end = min(ws + seq_len, total_tokens) + wlen = end - ws + wlens.append(wlen) + chunk_tok = val_tokens[ws:end + 1].to(dtype=torch.int64, device=device) + x_batch[i, :wlen] = chunk_tok[:-1] + y_batch[i, :wlen] = chunk_tok[1:] + + with torch.autocast(device_type="cuda", dtype=torch.bfloat16): + logits = compiled_logits_lora(x_batch) + + nll = F.cross_entropy( + logits.reshape(-1, logits.size(-1)).float(), + y_batch.reshape(-1), reduction="none", + ).reshape(bsz, seq_len) + + for i, ws in enumerate(batch_ws): + wlen = wlens[i] + s = 0 if ws == 0 else max(wlen - stride, 0) + scored_nll = nll[i, s:wlen].to(torch.float64) + loss_sum += scored_nll.sum() + token_count += float(wlen - s) + + tgt, prev = y_batch[i, s:wlen], x_batch[i, s:wlen] + tb = base_bytes_lut[tgt].to(torch.float64) + tb += (has_leading_space_lut[tgt] & ~is_boundary_token_lut[prev]).to(torch.float64) + byte_count += tb.sum() + + # Train phase (update LoRA params only, skip last chunk) + is_last_chunk = (ci == num_chunks - 1) + if not is_last_chunk and args.ttt_epochs > 0: + base_model.train() + chunk_seqs = (chunk_end - chunk_start) // seq_len + + if chunk_seqs > 0: + # Cosine LR decay + cos_lr = args.ttt_lora_lr * 0.5 * (1.0 + math.cos(math.pi * ci / max(num_chunks - 1, 1))) + for pg in optimizer.param_groups: + pg['lr'] = cos_lr + + my_seq_s = (chunk_seqs * rank) // world_size + my_seq_e = (chunk_seqs * (rank + 1)) // world_size + my_chunk_seqs = my_seq_e - my_seq_s + + for _ep in range(args.ttt_epochs): + for bs in range(0, my_chunk_seqs, args.ttt_batch_seqs): + be = min(bs + args.ttt_batch_seqs, my_chunk_seqs) + actual_bs = my_seq_s + bs + start_tok = chunk_start + actual_bs * seq_len + end_tok = chunk_start + (my_seq_s + be) * seq_len + 1 + + if end_tok > val_tokens.numel(): + continue + + local = val_tokens[start_tok:end_tok].to(device=device, dtype=torch.int64) + x = local[:-1].reshape(-1, seq_len) + y = local[1:].reshape(-1, seq_len) + + optimizer.zero_grad(set_to_none=True) + + with torch.autocast(device_type="cuda", dtype=torch.bfloat16): + loss = compiled_forward_lora(x, y) + + loss.backward() + + # All-reduce LoRA gradients across GPUs + if world_size > 1: + for p in lora_params: + if p.grad is not None: + dist.all_reduce(p.grad, op=dist.ReduceOp.AVG) + + # Gradient clipping on LoRA params + torch.nn.utils.clip_grad_norm_(lora_params, args.ttt_grad_clip) + optimizer.step() + + if rank == 0 and (ci % 10 == 0 or ci == num_chunks - 1): + elapsed = time.perf_counter() - t0 + rl = loss_sum.item() / max(token_count.item(), 1) + rbpb = rl / math.log(2.0) * (token_count.item() / max(byte_count.item(), 1)) if token_count.item() > 0 else 0.0 + log0(f" lora_ttt_chunk [{ci+1}/{num_chunks}] bpb={rbpb:.6f} time={elapsed:.1f}s") + + if dist.is_available() and dist.is_initialized(): + dist.all_reduce(loss_sum, op=dist.ReduceOp.SUM) + dist.all_reduce(token_count, op=dist.ReduceOp.SUM) + dist.all_reduce(byte_count, op=dist.ReduceOp.SUM) + + val_loss = (loss_sum / token_count).item() + val_bpb = val_loss / math.log(2.0) * (token_count.item() / byte_count.item()) + + # Clean up: remove all LoRA hooks and restore base model + for adapter in lora_adapters: + adapter.remove() + + # Restore requires_grad on base model + for p in base_model.parameters(): + p.requires_grad_(True) + + base_model.eval() + + log0(f"lora_ttt_sliding:done val_loss={val_loss:.6f} val_bpb={val_bpb:.6f} " + f"elapsed={time.perf_counter() - t0:.1f}s") + + return val_loss, val_bpb + +def eval_val_sliding_ttt( + args: Hyperparameters, + base_model: nn.Module, + rank: int, + world_size: int, + device: torch.device, + val_tokens: Tensor, + base_bytes_lut: Tensor, + has_leading_space_lut: Tensor, + is_boundary_token_lut: Tensor, + stride: int, + batch_seqs: int = 32, + log0=print, +) -> tuple[float, float]: + seq_len = args.train_seq_len + total_tokens = val_tokens.numel() - 1 + ttt_chunk = args.ttt_chunk_tokens + window_starts = [ws for ws in range(0, total_tokens, stride) + if min(ws + seq_len, total_tokens) - ws >= stride or ws == 0] + num_chunks = (total_tokens + ttt_chunk - 1) // ttt_chunk + chunk_windows: list[list[int]] = [[] for _ in range(num_chunks)] + for ws in window_starts: + end = min(ws + seq_len, total_tokens) + wlen = end - ws + s = 0 if ws == 0 else max(wlen - stride, 0) + scored_start = ws + s + ci = min(scored_start // ttt_chunk, num_chunks - 1) + chunk_windows[ci].append(ws) + log0(f"ttt_sliding:start chunks={num_chunks} chunk_tokens={ttt_chunk} " + f"total_windows={len(window_starts)} stride={stride} " + f"ttt_lr={args.ttt_lr} ttt_epochs={args.ttt_epochs} " + f"freeze_blocks={args.ttt_freeze_blocks}") + loss_sum = torch.zeros((), device=device, dtype=torch.float64) + token_count = torch.zeros((), device=device, dtype=torch.float64) + byte_count = torch.zeros((), device=device, dtype=torch.float64) + frozen_block_ids = set(range(min(args.ttt_freeze_blocks, len(base_model.blocks)))) + ttt_params = [] + for name, p in base_model.named_parameters(): + freeze = False + for bi in frozen_block_ids: + if f"blocks.{bi}." in name: + freeze = True + break + if freeze: + p.requires_grad_(False) + else: + p.requires_grad_(True) + ttt_params.append(p) + log0(f"ttt_sliding:params unfrozen={sum(p.numel() for p in ttt_params)} " + f"frozen={sum(p.numel() for p in base_model.parameters() if not p.requires_grad)}") + optimizer = torch.optim.SGD(ttt_params, lr=args.ttt_lr, momentum=args.ttt_momentum) + # Compile forward paths for TTT (matches eval_val_sliding behavior) + if not int(os.environ.get("TORCHDYNAMO_DISABLE", "0")): + compiled_logits_ttt = torch.compile(base_model.forward_logits, dynamic=False, fullgraph=True) + compiled_forward_ttt = torch.compile(base_model.forward, dynamic=False, fullgraph=True) + else: + compiled_logits_ttt = base_model.forward_logits + compiled_forward_ttt = base_model.forward + t0 = time.perf_counter() + for ci in range(num_chunks): + windows = chunk_windows[ci] + if not windows: + continue + chunk_start = ci * ttt_chunk + chunk_end = min((ci + 1) * ttt_chunk, total_tokens) + my_s = (len(windows) * rank) // world_size + my_e = (len(windows) * (rank + 1)) // world_size + my_windows = windows[my_s:my_e] + base_model.eval() + with torch.inference_mode(): + for bi in range(0, len(my_windows), batch_seqs): + batch_ws = my_windows[bi:bi + batch_seqs] + bsz = len(batch_ws) + x_batch = torch.zeros(bsz, seq_len, dtype=torch.int64, device=device) + y_batch = torch.zeros(bsz, seq_len, dtype=torch.int64, device=device) + wlens: list[int] = [] + for i, ws in enumerate(batch_ws): + end = min(ws + seq_len, total_tokens) + wlen = end - ws + wlens.append(wlen) + chunk_tok = val_tokens[ws:end + 1].to(dtype=torch.int64, device=device) + x_batch[i, :wlen] = chunk_tok[:-1] + y_batch[i, :wlen] = chunk_tok[1:] + with torch.autocast(device_type="cuda", dtype=torch.bfloat16): + logits = compiled_logits_ttt(x_batch) + nll = F.cross_entropy( + logits.reshape(-1, logits.size(-1)).float(), + y_batch.reshape(-1), reduction="none", + ).reshape(bsz, seq_len) + for i, ws in enumerate(batch_ws): + wlen = wlens[i] + s = 0 if ws == 0 else max(wlen - stride, 0) + scored_nll = nll[i, s:wlen].to(torch.float64) + loss_sum += scored_nll.sum() + token_count += float(wlen - s) + tgt, prev = y_batch[i, s:wlen], x_batch[i, s:wlen] + tb = base_bytes_lut[tgt].to(torch.float64) + tb += (has_leading_space_lut[tgt] & ~is_boundary_token_lut[prev]).to(torch.float64) + byte_count += tb.sum() + is_last_chunk = (ci == num_chunks - 1) + if not is_last_chunk and args.ttt_epochs > 0: + base_model.train() + chunk_seqs = (chunk_end - chunk_start) // seq_len + if chunk_seqs > 0: + cos_lr = args.ttt_lr * 0.5 * (1.0 + math.cos(math.pi * ci / max(num_chunks - 1, 1))) + for pg in optimizer.param_groups: + pg['lr'] = cos_lr + my_seq_s = (chunk_seqs * rank) // world_size + my_seq_e = (chunk_seqs * (rank + 1)) // world_size + my_chunk_seqs = my_seq_e - my_seq_s + for _ep in range(args.ttt_epochs): + for bs in range(0, my_chunk_seqs, args.ttt_batch_seqs): + be = min(bs + args.ttt_batch_seqs, my_chunk_seqs) + actual_bs = my_seq_s + bs + start_tok = chunk_start + actual_bs * seq_len + end_tok = chunk_start + (my_seq_s + be) * seq_len + 1 + if end_tok > val_tokens.numel(): + continue + local = val_tokens[start_tok:end_tok].to(device=device, dtype=torch.int64) + x = local[:-1].reshape(-1, seq_len) + y = local[1:].reshape(-1, seq_len) + optimizer.zero_grad(set_to_none=True) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16): + loss = compiled_forward_ttt(x, y) + loss.backward() + if world_size > 1: + for p in ttt_params: + if p.grad is not None: + dist.all_reduce(p.grad, op=dist.ReduceOp.AVG) + torch.nn.utils.clip_grad_norm_(ttt_params, args.ttt_grad_clip) + optimizer.step() + if rank == 0 and (ci % 10 == 0 or ci == num_chunks - 1): + elapsed = time.perf_counter() - t0 + rl = loss_sum.item() / max(token_count.item(), 1) + rbpb = rl / math.log(2.0) * (token_count.item() / max(byte_count.item(), 1)) if token_count.item() > 0 else 0.0 + log0(f" ttt_chunk [{ci+1}/{num_chunks}] bpb={rbpb:.6f} time={elapsed:.1f}s") + if dist.is_available() and dist.is_initialized(): + dist.all_reduce(loss_sum, op=dist.ReduceOp.SUM) + dist.all_reduce(token_count, op=dist.ReduceOp.SUM) + dist.all_reduce(byte_count, op=dist.ReduceOp.SUM) + val_loss = (loss_sum / token_count).item() + val_bpb = val_loss / math.log(2.0) * (token_count.item() / byte_count.item()) + for p in base_model.parameters(): + p.requires_grad_(True) + base_model.eval() + log0(f"ttt_sliding:done val_loss={val_loss:.6f} val_bpb={val_bpb:.6f} " + f"elapsed={time.perf_counter() - t0:.1f}s") + return val_loss, val_bpb +_INT6_MODE = bool(int(os.environ.get("INT6_QAT", "1"))) +_INT3_MODE = bool(int(os.environ.get("INT3_QUANT", "0"))) # Ternary quantization +_QUANT_OVERRIDE = int(os.environ.get("QUANT_BITS", "0")) # Override: 3,4,5,6,8 +if _QUANT_OVERRIDE > 0: + _QUANT_BITS = _QUANT_OVERRIDE + _QUANT_MAX_VAL = (1 << (_QUANT_BITS - 1)) - 1 # e.g. int5 -> 15, int4 -> 7 +elif _INT3_MODE: + _QUANT_MAX_VAL = 3 # Ternary: -3, -2, -1, 0, 1, 2, 3 + _QUANT_BITS = 3 +elif _INT6_MODE: + _QUANT_MAX_VAL = 31 + _QUANT_BITS = 6 +else: + _QUANT_MAX_VAL = 127 + _QUANT_BITS = 8 +CONTROL_TENSOR_NAME_PATTERNS = tuple( + pattern + for pattern in os.environ.get( + "CONTROL_TENSOR_NAME_PATTERNS", + "attn_scale,attn_scales,mlp_scale,mlp_scales,resid_mix,resid_mixes,q_gain,skip_weight,skip_weights,smear", + ).split(",") + if pattern +) +INT8_KEEP_FLOAT_FP32_NAME_PATTERNS = tuple( + pattern + for pattern in os.environ.get( + "INT8_KEEP_FLOAT_FP32_NAME_PATTERNS", + ",".join(CONTROL_TENSOR_NAME_PATTERNS), + ).split(",") + if pattern +) +INT8_KEEP_FLOAT_MAX_NUMEL = 65_536 +INT8_KEEP_FLOAT_STORE_DTYPE = torch.float16 +INT8_PER_ROW_SCALE_DTYPE = torch.float16 +INT8_CLIP_PERCENTILE = 99.99984 +INT8_CLIP_Q = INT8_CLIP_PERCENTILE / 100.0 +def tensor_nbytes(t: Tensor) -> int: + return int(t.numel()) * int(t.element_size()) +def keep_float_tensor(name: str, t: Tensor, passthrough_orig_dtypes: dict[str, str]) -> Tensor: + if any(pattern in name for pattern in INT8_KEEP_FLOAT_FP32_NAME_PATTERNS): + return t.float().contiguous() + if t.dtype in {torch.float32, torch.bfloat16}: + passthrough_orig_dtypes[name] = str(t.dtype).removeprefix("torch.") + return t.to(dtype=INT8_KEEP_FLOAT_STORE_DTYPE).contiguous() + return t +_GPTQ_LITE_PERCENTILES = [0.9999, 0.99995, 0.99999, 0.999995, 1.0] +def quantize_float_tensor(t: Tensor, gpu_device: torch.device | None = None) -> tuple[Tensor, Tensor]: + # Run GPTQ-lite search on GPU when available (much faster for large 2D weight matrices) + dev = gpu_device if (gpu_device is not None and t.ndim == 2 and t.numel() > 0) else t.device + t32 = t.float().to(dev) + qmax = _QUANT_MAX_VAL + if t32.ndim == 2 and t32.numel() > 0: + abs_vals = t32.abs() + best_q = None + best_scale = None + best_mse = None + for pct in _GPTQ_LITE_PERCENTILES: + if pct >= 1.0: + clip_abs = abs_vals.amax(dim=1) + else: + clip_abs = torch.quantile(abs_vals, pct, dim=1) + clip_abs = clip_abs.clamp_min(1e-8) + scale = (clip_abs / qmax).clamp_min(1.0 / qmax) + clipped = torch.maximum(torch.minimum(t32, clip_abs[:, None]), -clip_abs[:, None]) + q = torch.clamp(torch.round(clipped / scale[:, None]), -qmax, qmax) + recon = q * scale[:, None] + mse = (t32 - recon).square().mean(dim=1) + if best_mse is None: + best_mse = mse + best_q = q + best_scale = scale + else: + improved = mse < best_mse + if improved.any(): + best_mse = torch.where(improved, mse, best_mse) + best_q = torch.where(improved[:, None], q, best_q) + best_scale = torch.where(improved, scale, best_scale) + return best_q.cpu().to(torch.int8).contiguous(), best_scale.cpu().to(dtype=INT8_PER_ROW_SCALE_DTYPE).contiguous() + if t32.ndim == 2: + return torch.empty_like(t32, dtype=torch.int8, device="cpu"), torch.empty((t32.shape[0],), dtype=INT8_PER_ROW_SCALE_DTYPE) + t32 = t32.cpu() + clip_abs = float(torch.quantile(t32.abs().flatten(), INT8_CLIP_Q).item()) if t32.numel() else 0.0 + scale = torch.tensor(clip_abs / qmax if clip_abs > 0 else 1.0, dtype=torch.float32) + q = torch.clamp(torch.round(torch.clamp(t32, -clip_abs, clip_abs) / scale), -qmax, qmax).to(torch.int8).contiguous() + return q, scale +def quantize_state_dict_int8(state_dict: dict[str, Tensor], gpu_device: torch.device | None = None, + rank: int = 0, world_size: int = 1): + """Quantize state dict. When world_size > 1, distributes large tensor quantization across ranks.""" + quantized: dict[str, Tensor] = {} + scales: dict[str, Tensor] = {} + dtypes: dict[str, str] = {} + passthrough: dict[str, Tensor] = {} + passthrough_orig_dtypes: dict[str, str] = {} + qmeta: dict[str, dict[str, object]] = {} + stats = dict.fromkeys( + ("param_count", "num_tensors", "num_float_tensors", "num_nonfloat_tensors", "baseline_tensor_bytes", "int8_payload_bytes"), + 0, + ) + # Separate tensors into passthrough vs needs-quantization + to_quantize: list[tuple[str, Tensor]] = [] + for name, tensor in state_dict.items(): + t = tensor.detach().to("cpu").contiguous() + stats["param_count"] += int(t.numel()) + stats["num_tensors"] += 1 + stats["baseline_tensor_bytes"] += tensor_nbytes(t) + if not t.is_floating_point(): + stats["num_nonfloat_tensors"] += 1 + passthrough[name] = t + stats["int8_payload_bytes"] += tensor_nbytes(t) + continue + if t.numel() <= INT8_KEEP_FLOAT_MAX_NUMEL: + kept = keep_float_tensor(name, t, passthrough_orig_dtypes) + passthrough[name] = kept + stats["int8_payload_bytes"] += tensor_nbytes(kept) + continue + stats["num_float_tensors"] += 1 + to_quantize.append((name, t)) + # Distribute quantization work across ranks (round-robin by tensor index) + for idx, (name, t) in enumerate(to_quantize): + if world_size > 1 and idx % world_size != rank: + continue # Another rank handles this tensor + q, s = quantize_float_tensor(t, gpu_device=gpu_device) + if s.ndim > 0: + qmeta[name] = {"scheme": "per_row", "axis": 0} + quantized[name] = q + scales[name] = s + dtypes[name] = str(t.dtype).removeprefix("torch.") + stats["int8_payload_bytes"] += tensor_nbytes(q) + tensor_nbytes(s) + # Gather results from all ranks + if world_size > 1 and dist.is_available() and dist.is_initialized(): + # Each rank broadcasts its quantized tensors to all others + for idx, (name, t) in enumerate(to_quantize): + owner = idx % world_size + if owner == rank: + # Broadcast shape info then tensor data + q_tensor = quantized[name] + s_tensor = scales[name] + # Send via broadcast + q_gpu = q_tensor.to(gpu_device) if gpu_device is not None else q_tensor.cuda() + s_gpu = s_tensor.to(gpu_device) if gpu_device is not None else s_tensor.cuda() + dist.broadcast(q_gpu, src=owner) + dist.broadcast(s_gpu, src=owner) + quantized[name] = q_gpu.cpu() + scales[name] = s_gpu.cpu() + else: + # Allocate matching tensors and receive + q_shape = (t.shape[0],) + t.shape[1:] if t.ndim == 2 else t.shape + q_gpu = torch.empty(q_shape, dtype=torch.int8, device=gpu_device if gpu_device is not None else "cuda") + if t.ndim == 2 and t.numel() > 0: + s_gpu = torch.empty(t.shape[0], dtype=INT8_PER_ROW_SCALE_DTYPE, device=gpu_device if gpu_device is not None else "cuda") + else: + s_gpu = torch.empty((), dtype=torch.float32, device=gpu_device if gpu_device is not None else "cuda") + dist.broadcast(q_gpu, src=owner) + dist.broadcast(s_gpu, src=owner) + quantized[name] = q_gpu.cpu() + scales[name] = s_gpu.cpu() + dtypes[name] = str(t.dtype).removeprefix("torch.") + if s_gpu.ndim > 0: + qmeta[name] = {"scheme": "per_row", "axis": 0} + stats["int8_payload_bytes"] += tensor_nbytes(quantized[name]) + tensor_nbytes(scales[name]) + obj: dict[str, object] = { + "__quant_format__": f"int{_QUANT_BITS}_clean_per_row_v1", + "quantized": quantized, + "scales": scales, + "dtypes": dtypes, + "passthrough": passthrough, + } + if qmeta: + obj["qmeta"] = qmeta + if passthrough_orig_dtypes: + obj["passthrough_orig_dtypes"] = passthrough_orig_dtypes + return obj, stats +def dequantize_state_dict_int8(obj: dict[str, object]) -> dict[str, Tensor]: + out: dict[str, Tensor] = {} + qmeta = obj.get("qmeta", {}) + passthrough_orig_dtypes = obj.get("passthrough_orig_dtypes", {}) + for name, q in obj["quantized"].items(): + dtype = getattr(torch, obj["dtypes"][name]) + s = obj["scales"][name] + if qmeta.get(name, {}).get("scheme") == "per_row" or s.ndim > 0: + s = s.to(dtype=torch.float32) + out[name] = (q.float() * s.view(q.shape[0], *([1] * (q.ndim - 1)))).to(dtype=dtype).contiguous() + else: + scale = float(s.item()) + out[name] = (q.float() * scale).to(dtype=dtype).contiguous() + for name, t in obj["passthrough"].items(): + out_t = t.detach().to("cpu").contiguous() + orig_dtype = passthrough_orig_dtypes.get(name) + if isinstance(orig_dtype, str): + out_t = out_t.to(dtype=getattr(torch, orig_dtype)).contiguous() + out[name] = out_t + return out +def load_data_shard(file: Path) -> Tensor: + header_bytes = 256 * np.dtype(" None: + self.file_idx = (self.file_idx + 1) % len(self.files) + self.tokens = load_data_shard(self.files[self.file_idx]) + self.pos = 0 + def take(self, n: int) -> Tensor: + chunks: list[Tensor] = [] + remaining = n + while remaining > 0: + avail = self.tokens.numel() - self.pos + if avail <= 0: + self._advance_file() + continue + k = min(remaining, avail) + chunks.append(self.tokens[self.pos : self.pos + k]) + self.pos += k + remaining -= k + return chunks[0] if len(chunks) == 1 else torch.cat(chunks) +class DistributedTokenLoader: + def __init__(self, pattern: str, rank: int, world_size: int, device: torch.device): + self.rank = rank + self.world_size = world_size + self.device = device + self.stream = TokenStream(pattern) + def next_batch(self, global_tokens: int, seq_len: int, grad_accum_steps: int) -> tuple[Tensor, Tensor]: + local_tokens = global_tokens // (self.world_size * grad_accum_steps) + per_rank_span = local_tokens + 1 + chunk = self.stream.take(per_rank_span * self.world_size) + start = self.rank * per_rank_span + local = chunk[start : start + per_rank_span].to(dtype=torch.int64) + x = local[:-1].reshape(-1, seq_len) + y = local[1:].reshape(-1, seq_len) + return x.to(self.device, non_blocking=True), y.to(self.device, non_blocking=True) +class RMSNorm(nn.Module): + def __init__(self, eps: float | None = None): + super().__init__() + self.eps = eps + def forward(self, x: Tensor) -> Tensor: + return F.rms_norm(x, (x.size(-1),), eps=self.eps) +class CastedLinear(nn.Linear): + _int6_qat: bool = False + def forward(self, x: Tensor) -> Tensor: + w = self.weight.to(x.dtype) + if CastedLinear._int6_qat and self.training: + w = int6_ste(w) + bias = self.bias.to(x.dtype) if self.bias is not None else None + return F.linear(x, w, bias) +def restore_low_dim_params_to_fp32(module: nn.Module) -> None: + with torch.no_grad(): + for name, param in module.named_parameters(): + if (param.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS)) and param.dtype != torch.float32: + param.data = param.data.float() +class BigramHash(nn.Module): + def __init__(self, num_buckets: int, dim: int, proj_dim: int = 0): + super().__init__() + self.num_buckets = num_buckets + embed_dim = proj_dim if proj_dim > 0 and proj_dim < dim else dim + self.emb = nn.Embedding(num_buckets, embed_dim) + self.proj = nn.Linear(embed_dim, dim, bias=False) if embed_dim < dim else None + self.scale = nn.Parameter(torch.tensor(0.05)) + nn.init.zeros_(self.emb.weight) + def forward(self, input_ids: Tensor) -> Tensor: + prev = F.pad(input_ids[:, :-1], (1, 0), value=0) + bigram_hash = (prev * 31 + input_ids) % self.num_buckets + h = self.emb(bigram_hash) + if self.proj is not None: + h = self.proj(h) + return h * self.scale +class SmearGate(nn.Module): + def __init__(self, dim: int): + super().__init__() + self.gate = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + def forward(self, x: Tensor) -> Tensor: + g = torch.sigmoid(self.gate.to(dtype=x.dtype))[None, None, :] + positions = torch.arange(1, x.shape[1] + 1, device=x.device, dtype=x.dtype).view(1, -1, 1) + smooth = torch.cumsum(x, dim=1) / positions + return g * x + (1 - g) * smooth +class Rotary(nn.Module): + def __init__(self, dim: int, base: float = 10000.0): + super().__init__() + inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.float32) / dim)) + self.register_buffer("inv_freq", inv_freq, persistent=False) + self._seq_len_cached = 0 + self._cos_cached: Tensor | None = None + self._sin_cached: Tensor | None = None + def forward(self, seq_len: int, device: torch.device, dtype: torch.dtype) -> tuple[Tensor, Tensor]: + if ( + self._cos_cached is None + or self._sin_cached is None + or self._seq_len_cached != seq_len + or self._cos_cached.device != device + ): + t = torch.arange(seq_len, device=device, dtype=self.inv_freq.dtype) + freqs = torch.outer(t, self.inv_freq.to(device)) + self._cos_cached = freqs.cos()[None, None, :, :] + self._sin_cached = freqs.sin()[None, None, :, :] + self._seq_len_cached = seq_len + return self._cos_cached.to(dtype=dtype), self._sin_cached.to(dtype=dtype) +def apply_rotary_emb(x: Tensor, cos: Tensor, sin: Tensor) -> Tensor: + half = x.size(-1) // 2 + x1, x2 = x[..., :half], x[..., half:] + return torch.cat((x1 * cos + x2 * sin, x1 * (-sin) + x2 * cos), dim=-1) +class CausalSelfAttention(nn.Module): + def __init__( + self, + dim: int, + num_heads: int, + num_kv_heads: int, + rope_base: float, + qk_gain_init: float, + rope_partial_dims: int = 0, + use_xsa: bool = False, + ): + super().__init__() + if dim % num_heads != 0: + raise ValueError("model_dim must be divisible by num_heads") + if num_heads % num_kv_heads != 0: + raise ValueError("num_heads must be divisible by num_kv_heads") + self.num_heads = num_heads + self.num_kv_heads = num_kv_heads + self.head_dim = dim // num_heads + if self.head_dim % 2 != 0: + raise ValueError("head_dim must be even for RoPE") + kv_dim = self.num_kv_heads * self.head_dim + self.c_q = CastedLinear(dim, dim, bias=False) + self.c_k = CastedLinear(dim, kv_dim, bias=False) + self.c_v = CastedLinear(dim, kv_dim, bias=False) + self.proj = CastedLinear(dim, dim, bias=False) + self.proj._zero_init = True + self.q_gain = nn.Parameter(torch.full((num_heads,), qk_gain_init, dtype=torch.float32)) + rope_dim = rope_partial_dims if 0 < rope_partial_dims < self.head_dim else self.head_dim + self.rope_dim = rope_dim + self.rotary = Rotary(rope_dim, base=rope_base) + self.use_xsa = use_xsa + def forward(self, x: Tensor) -> Tensor: + bsz, seqlen, dim = x.shape + q = self.c_q(x).reshape(bsz, seqlen, self.num_heads, self.head_dim).transpose(1, 2) + k = self.c_k(x).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim).transpose(1, 2) + v = self.c_v(x).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim).transpose(1, 2) + q = F.rms_norm(q, (q.size(-1),)) + k = F.rms_norm(k, (k.size(-1),)) + cos, sin = self.rotary(seqlen, x.device, q.dtype) + if self.rope_dim < self.head_dim: + q_rope, q_pass = q[..., :self.rope_dim], q[..., self.rope_dim:] + k_rope, k_pass = k[..., :self.rope_dim], k[..., self.rope_dim:] + q_rope = apply_rotary_emb(q_rope, cos, sin) + k_rope = apply_rotary_emb(k_rope, cos, sin) + q = torch.cat([q_rope, q_pass], dim=-1) + k = torch.cat([k_rope, k_pass], dim=-1) + else: + q = apply_rotary_emb(q, cos, sin) + k = apply_rotary_emb(k, cos, sin) + q = q * self.q_gain.to(dtype=q.dtype)[None, :, None, None] + y = F.scaled_dot_product_attention( + q, + k, + v, + attn_mask=None, + is_causal=True, + enable_gqa=(self.num_kv_heads != self.num_heads), + ) + if self.use_xsa: + if self.num_kv_heads != self.num_heads: + repeats = self.num_heads // self.num_kv_heads + v_expanded = v.repeat_interleave(repeats, dim=1) + else: + v_expanded = v + y = y - v_expanded + y = y.transpose(1, 2).contiguous().reshape(bsz, seqlen, dim) + return self.proj(y) +class MLP(nn.Module): + def __init__(self, dim: int, mlp_mult: float): + super().__init__() + hidden = int(mlp_mult * dim) + self.fc = CastedLinear(dim, hidden, bias=False) + self.proj = CastedLinear(hidden, dim, bias=False) + self.proj._zero_init = True + def forward(self, x: Tensor) -> Tensor: + x = F.leaky_relu(self.fc(x), negative_slope=0.5) + return self.proj(x.square()) +class MoEMLP(nn.Module): + def __init__(self, dim: int, mlp_mult: float, num_experts: int): + super().__init__() + self.num_experts = num_experts + expert_hidden = int(mlp_mult * dim) // num_experts + self.experts = nn.ModuleList([ + nn.ModuleDict({ + "fc": CastedLinear(dim, expert_hidden, bias=False), + "proj": CastedLinear(expert_hidden, dim, bias=False), + }) + for _ in range(num_experts) + ]) + for e in self.experts: + e["proj"]._zero_init = True + self.router = nn.Linear(dim, num_experts, bias=False) + def forward(self, x: Tensor) -> Tensor: + bsz, seqlen, dim = x.shape + logits = self.router(x.detach()) + weights = torch.softmax(logits, dim=-1) + indices = weights.argmax(dim=-1) + out = torch.zeros_like(x) + for i, expert in enumerate(self.experts): + mask = (indices == i) + if not mask.any(): + continue + tokens = x[mask] + h = F.leaky_relu(expert["fc"](tokens), negative_slope=0.5) + h = expert["proj"](h.square()) + out[mask] = h + return out +class Block(nn.Module): + def __init__( + self, + dim: int, + num_heads: int, + num_kv_heads: int, + mlp_mult: float, + rope_base: float, + qk_gain_init: float, + use_smeargate: bool = False, + rope_partial_dims: int = 0, + use_xsa: bool = False, + ln_scale_factor: float = 1.0, + moe_num_experts: int = 0, + ): + super().__init__() + self.attn_norm = RMSNorm() + self.mlp_norm = RMSNorm() + self.attn = CausalSelfAttention( + dim, num_heads, num_kv_heads, rope_base, qk_gain_init, + rope_partial_dims=rope_partial_dims, + use_xsa=use_xsa, + ) + self.mlp = MoEMLP(dim, mlp_mult, moe_num_experts) if moe_num_experts > 1 else MLP(dim, mlp_mult) + self.attn_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + self.mlp_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + self.resid_mix = nn.Parameter(torch.stack((torch.ones(dim), torch.zeros(dim))).float()) + self.smear = SmearGate(dim) if use_smeargate else None + self.ln_scale_factor = ln_scale_factor + def forward(self, x: Tensor, x0: Tensor) -> Tensor: + mix = self.resid_mix.to(dtype=x.dtype) + x = mix[0][None, None, :] * x + mix[1][None, None, :] * x0 + normed = self.attn_norm(x) + if self.ln_scale_factor != 1.0: + normed = normed * self.ln_scale_factor + attn_out = self.attn(normed) + x = x + self.attn_scale.to(dtype=x.dtype)[None, None, :] * attn_out + mlp_normed = self.mlp_norm(x) + if self.ln_scale_factor != 1.0: + mlp_normed = mlp_normed * self.ln_scale_factor + x = x + self.mlp_scale.to(dtype=x.dtype)[None, None, :] * self.mlp(mlp_normed) + if self.smear is not None: + x = self.smear(x) + return x +class GPT(nn.Module): + def __init__( + self, + vocab_size: int, + num_layers: int, + model_dim: int, + num_heads: int, + num_kv_heads: int, + mlp_mult: float, + tie_embeddings: bool, + tied_embed_init_std: float, + logit_softcap: float, + rope_base: float, + qk_gain_init: float, + bigramhash_buckets: int = 0, + bigramhash_dim: int = 0, + use_smeargate: bool = False, + unet_skips: bool = True, + rope_partial_dims: int = 0, + ln_scale: bool = False, + xsa_layers: int = 0, + recurrence_repeats: int = 0, + recurrence_unique_head: int = 2, + recurrence_unique_tail: int = 2, + moe_num_experts: int = 0, + encoder_layers: int = 0, + ): + super().__init__() + if logit_softcap <= 0.0: + raise ValueError(f"logit_softcap must be positive, got {logit_softcap}") + self.tie_embeddings = tie_embeddings + self.tied_embed_init_std = tied_embed_init_std + self.logit_softcap = logit_softcap + self.tok_emb = nn.Embedding(vocab_size, model_dim) + self.bigram_hash = BigramHash(bigramhash_buckets, model_dim, proj_dim=bigramhash_dim) if bigramhash_buckets > 0 else None + self.recurrence_repeats = recurrence_repeats + if recurrence_repeats > 0: + self.num_encoder_layers = 0 + self.num_decoder_layers = 0 + self.num_skip_weights = 0 + self.skip_weights = nn.Parameter(torch.zeros(0, dtype=torch.float32)) + n_head = recurrence_unique_head + n_tail = recurrence_unique_tail + effective_layers = n_head + recurrence_repeats + n_tail + def make_block(layer_idx: int, eff_total: int) -> Block: + xsa_start_idx = eff_total - xsa_layers if xsa_layers > 0 else eff_total + return Block( + model_dim, num_heads, num_kv_heads, mlp_mult, rope_base, qk_gain_init, + use_smeargate=use_smeargate, + rope_partial_dims=rope_partial_dims, + use_xsa=(layer_idx >= xsa_start_idx), + ln_scale_factor=(1.0 / math.sqrt(layer_idx + 1)) if ln_scale else 1.0, + moe_num_experts=moe_num_experts, + ) + self.head_blocks = nn.ModuleList([make_block(i, effective_layers) for i in range(n_head)]) + self.shared_block = make_block(n_head, effective_layers) + self.recurrence_gate = nn.Parameter(torch.full((recurrence_repeats,), 0.5, dtype=torch.float32)) + self.tail_blocks = nn.ModuleList([ + make_block(n_head + recurrence_repeats + i, effective_layers) for i in range(n_tail) + ]) + self.blocks = nn.ModuleList() + else: + if unet_skips: + if encoder_layers > 0: + self.num_encoder_layers = encoder_layers + else: + self.num_encoder_layers = num_layers // 2 + self.num_decoder_layers = num_layers - self.num_encoder_layers + self.num_skip_weights = min(self.num_encoder_layers, self.num_decoder_layers) + else: + self.num_encoder_layers = num_layers + self.num_decoder_layers = 0 + self.num_skip_weights = 0 + self.skip_weights = nn.Parameter(torch.ones(self.num_skip_weights, model_dim, dtype=torch.float32)) + xsa_start = num_layers - xsa_layers if xsa_layers > 0 else num_layers + self.blocks = nn.ModuleList( + [ + Block( + model_dim, num_heads, num_kv_heads, mlp_mult, rope_base, qk_gain_init, + use_smeargate=use_smeargate, + rope_partial_dims=rope_partial_dims, + use_xsa=(i >= xsa_start), + ln_scale_factor=(1.0 / math.sqrt(i + 1)) if ln_scale else 1.0, + moe_num_experts=moe_num_experts, + ) + for i in range(num_layers) + ] + ) + self.head_blocks = nn.ModuleList() + self.shared_block = None + self.recurrence_gate = None + self.tail_blocks = nn.ModuleList() + self.final_norm = RMSNorm() + self.lm_head = None if tie_embeddings else CastedLinear(model_dim, vocab_size, bias=False) + if self.lm_head is not None: + self.lm_head._zero_init = True + self._init_weights() + def _init_weights(self) -> None: + if self.tie_embeddings: + nn.init.normal_(self.tok_emb.weight, mean=0.0, std=self.tied_embed_init_std) + for module in self.modules(): + if isinstance(module, nn.Linear) and getattr(module, "_zero_init", False): + nn.init.zeros_(module.weight) + def forward(self, input_ids: Tensor, target_ids: Tensor) -> Tensor: + x = self.tok_emb(input_ids) + if self.bigram_hash is not None: + x = x + self.bigram_hash(input_ids) + x = F.rms_norm(x, (x.size(-1),)) + x0 = x + if self.recurrence_repeats > 0: + for block in self.head_blocks: + x = block(x, x0) + for r in range(self.recurrence_repeats): + gate = torch.sigmoid(self.recurrence_gate[r]).to(dtype=x.dtype) + residual = x + x = self.shared_block(x, x0) + x = gate * x + (1.0 - gate) * residual + for block in self.tail_blocks: + x = block(x, x0) + else: + skips: list[Tensor] = [] + for i in range(self.num_encoder_layers): + x = self.blocks[i](x, x0) + skips.append(x) + for i in range(self.num_decoder_layers): + if skips: + x = x + self.skip_weights[i].to(dtype=x.dtype)[None, None, :] * skips.pop() + x = self.blocks[self.num_encoder_layers + i](x, x0) + x = self.final_norm(x).reshape(-1, x.size(-1)) + targets = target_ids.reshape(-1) + if self.tie_embeddings: + logits_proj = F.linear(x, self.tok_emb.weight) + else: + if self.lm_head is None: + raise RuntimeError("lm_head is required when tie_embeddings=False") + logits_proj = self.lm_head(x) + logits = self.logit_softcap * torch.tanh(logits_proj / self.logit_softcap) + if getattr(self, '_return_logits', False): + return logits.float().reshape(input_ids.shape[0], input_ids.shape[1], -1) + return F.cross_entropy(logits.float(), targets, reduction="mean") + def forward_logits(self, input_ids: Tensor) -> Tensor: + x = self.tok_emb(input_ids) + if self.bigram_hash is not None: + x = x + self.bigram_hash(input_ids) + x = F.rms_norm(x, (x.size(-1),)) + x0 = x + if self.recurrence_repeats > 0: + for block in self.head_blocks: + x = block(x, x0) + for r in range(self.recurrence_repeats): + gate = torch.sigmoid(self.recurrence_gate[r]).to(dtype=x.dtype) + residual = x + x = self.shared_block(x, x0) + x = gate * x + (1.0 - gate) * residual + for block in self.tail_blocks: + x = block(x, x0) + else: + skips: list[Tensor] = [] + for i in range(self.num_encoder_layers): + x = self.blocks[i](x, x0) + skips.append(x) + for i in range(self.num_decoder_layers): + if skips: + x = x + self.skip_weights[i].to(dtype=x.dtype)[None, None, :] * skips.pop() + x = self.blocks[self.num_encoder_layers + i](x, x0) + x = self.final_norm(x) + if self.tie_embeddings: + logits_proj = F.linear(x, self.tok_emb.weight) + else: + if self.lm_head is None: + raise RuntimeError("lm_head is required when tie_embeddings=False") + logits_proj = self.lm_head(x) + return self.logit_softcap * torch.tanh(logits_proj / self.logit_softcap) +def main() -> None: + global zeropower_via_newtonschulz5 + code = Path(__file__).read_text(encoding="utf-8") + args = Hyperparameters() + if not int(os.environ.get("TORCHDYNAMO_DISABLE", "0")): + zeropower_via_newtonschulz5 = torch.compile(zeropower_via_newtonschulz5) + distributed = "RANK" in os.environ and "WORLD_SIZE" in os.environ + rank = int(os.environ.get("RANK", "0")) + world_size = int(os.environ.get("WORLD_SIZE", "1")) + local_rank = int(os.environ.get("LOCAL_RANK", "0")) + if world_size <= 0: + raise ValueError(f"WORLD_SIZE must be positive, got {world_size}") + if 8 % world_size != 0: + raise ValueError(f"WORLD_SIZE={world_size} must divide 8 so grad_accum_steps stays integral") + grad_accum_steps = 8 // world_size + grad_scale = 1.0 / grad_accum_steps + if not torch.cuda.is_available(): + raise RuntimeError("CUDA is required") + device = torch.device("cuda", local_rank) + torch.cuda.set_device(device) + if distributed: + dist.init_process_group(backend="nccl", device_id=device) + dist.barrier() + master_process = rank == 0 + torch.backends.cuda.matmul.allow_tf32 = True + torch.backends.cudnn.allow_tf32 = True + from torch.backends.cuda import enable_cudnn_sdp, enable_flash_sdp, enable_math_sdp, enable_mem_efficient_sdp + _use_flash = sys.platform != "win32" + enable_cudnn_sdp(False) + enable_flash_sdp(_use_flash) + enable_mem_efficient_sdp(False) + enable_math_sdp(not _use_flash) + # Create per-run output directory so artifacts/logs never collide + run_dir = Path(os.environ.get("RUN_OUTPUT_DIR", f"runs/{args.run_id}")) + if master_process: + run_dir.mkdir(parents=True, exist_ok=True) + # Backward-compat: also keep logs/ symlink/copy + os.makedirs("logs", exist_ok=True) + logfile = None + if master_process: + logfile = str(run_dir / "log.txt") + # Also write to legacy logs/ path for compatibility + legacy_logfile = f"logs/{args.run_id}.txt" + if not Path(legacy_logfile).exists(): + try: + os.symlink(os.path.abspath(logfile), legacy_logfile) + except OSError: + pass # Windows or cross-device; will write to both + print(logfile) + def log0(msg: str, console: bool = True) -> None: + if not master_process: + return + if console: + print(msg) + if logfile is not None: + with open(logfile, "a", encoding="utf-8") as f: + print(msg, file=f) + log0(code, console=False) + log0("=" * 100, console=False) + log0(f"Running Python {sys.version}", console=False) + log0(f"Running PyTorch {torch.__version__}", console=False) + log0( + subprocess.run(["nvidia-smi"], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True, check=False).stdout, + console=False, + ) + log0("=" * 100, console=False) + random.seed(args.seed) + np.random.seed(args.seed) + torch.manual_seed(args.seed) + torch.cuda.manual_seed_all(args.seed) + if not args.tokenizer_path.endswith(".model"): + raise ValueError(f"Script only setup for SentencePiece .model file: {args.tokenizer_path}") + sp = spm.SentencePieceProcessor(model_file=args.tokenizer_path) + if int(sp.vocab_size()) != args.vocab_size: + raise ValueError( + f"VOCAB_SIZE={args.vocab_size} does not match tokenizer vocab_size={int(sp.vocab_size())}" + ) + dataset_dir = Path(args.data_path).resolve() + actual_train_files = len(list(dataset_dir.glob("fineweb_train_*.bin"))) + val_tokens = load_validation_tokens(args.val_files, args.train_seq_len) + base_bytes_lut, has_leading_space_lut, is_boundary_token_lut = build_sentencepiece_luts( + sp, args.vocab_size, device + ) + log0(f"val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path={args.tokenizer_path}") + log0(f"train_loader:dataset:{dataset_dir.name} train_shards:{actual_train_files}") + log0(f"val_loader:shards pattern={args.val_files} tokens:{val_tokens.numel() - 1}") + CastedLinear._int6_qat = args.int6_qat + base_model = GPT( + vocab_size=args.vocab_size, + num_layers=args.num_layers, + model_dim=args.model_dim, + num_heads=args.num_heads, + num_kv_heads=args.num_kv_heads, + mlp_mult=args.mlp_mult, + tie_embeddings=args.tie_embeddings, + tied_embed_init_std=args.tied_embed_init_std, + logit_softcap=args.logit_softcap, + rope_base=args.rope_base, + qk_gain_init=args.qk_gain_init, + bigramhash_buckets=args.bigramhash_buckets, + bigramhash_dim=args.bigramhash_dim, + use_smeargate=args.smeargate, + unet_skips=args.unet_skips, + rope_partial_dims=args.rope_partial_dims, + ln_scale=args.ln_scale, + xsa_layers=args.xsa_layers, + recurrence_repeats=args.recurrence_repeats, + recurrence_unique_head=args.recurrence_unique_head, + recurrence_unique_tail=args.recurrence_unique_tail, + moe_num_experts=args.moe_num_experts, + encoder_layers=args.encoder_layers, + ).to(device).bfloat16() + for module in base_model.modules(): + if isinstance(module, CastedLinear): + module.float() + restore_low_dim_params_to_fp32(base_model) + if int(os.environ.get("TORCHDYNAMO_DISABLE", "0")): + compiled_model = base_model + else: + compiled_model = torch.compile(base_model, dynamic=False, fullgraph=True) + model: nn.Module = DDP(compiled_model, device_ids=[local_rank], broadcast_buffers=False) if distributed else compiled_model + if args.recurrence_repeats > 0: + block_named_params = ( + list(base_model.head_blocks.named_parameters()) + + list(base_model.shared_block.named_parameters()) + + list(base_model.tail_blocks.named_parameters()) + ) + if base_model.recurrence_gate is not None: + block_named_params.append(("recurrence_gate", base_model.recurrence_gate)) + else: + block_named_params = list(base_model.blocks.named_parameters()) + matrix_params = [ + p + for name, p in block_named_params + if p.ndim == 2 and not any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) + ] + scalar_params = [ + p + for name, p in block_named_params + if p.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) + ] + if base_model.skip_weights.numel() > 0: + scalar_params.append(base_model.skip_weights) + token_lr = args.tied_embed_lr if args.tie_embeddings else args.embed_lr + embed_params: list[nn.Parameter] = [base_model.tok_emb.weight] + if base_model.bigram_hash is not None: + embed_params.append(base_model.bigram_hash.emb.weight) + optimizer_tok = torch.optim.Adam( + [{"params": embed_params, "lr": token_lr, "base_lr": token_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + fused=True, + ) + optimizer_muon = Muon( + matrix_params, + lr=args.matrix_lr, + momentum=args.muon_momentum, + backend_steps=args.muon_backend_steps, + ) + for group in optimizer_muon.param_groups: + group["base_lr"] = args.matrix_lr + optimizer_scalar = torch.optim.Adam( + [{"params": scalar_params, "lr": args.scalar_lr, "base_lr": args.scalar_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + fused=True, + ) + optimizers: list[torch.optim.Optimizer] = [optimizer_tok, optimizer_muon, optimizer_scalar] + if base_model.lm_head is not None: + optimizer_head = torch.optim.Adam( + [{"params": [base_model.lm_head.weight], "lr": args.head_lr, "base_lr": args.head_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + fused=True, + ) + optimizers.insert(1, optimizer_head) + n_params = sum(p.numel() for p in base_model.parameters()) + log0(f"model_params:{n_params}") + log0(f"world_size:{world_size} grad_accum_steps:{grad_accum_steps}") + log0(f"sdp_backends:cudnn=False flash={_use_flash} mem_efficient=False math={not _use_flash}") + log0(f"attention_mode:gqa num_heads:{args.num_heads} num_kv_heads:{args.num_kv_heads}") + log0( + f"tie_embeddings:{args.tie_embeddings} embed_lr:{token_lr} " + f"head_lr:{args.head_lr if base_model.lm_head is not None else 0.0} " + f"matrix_lr:{args.matrix_lr} scalar_lr:{args.scalar_lr}" + ) + log0( + f"train_batch_tokens:{args.train_batch_tokens} train_seq_len:{args.train_seq_len} " + f"iterations:{args.iterations} warmup_steps:{args.warmup_steps} " + f"max_wallclock_seconds:{args.max_wallclock_seconds:.3f}" + ) + log0(f"seed:{args.seed}") + log0( + f"meta_baseline: int6_qat:{args.int6_qat} bigramhash:{args.bigramhash_buckets} " + f"smeargate:{args.smeargate} ema_decay:{args.ema_decay} unet_skips:{args.unet_skips} " + f"warmdown_iters:{args.warmdown_iters} compression:{'zstd-22' if _HAS_ZSTD else 'zlib-9'}" + ) + log0( + f"strong_gains: leaky_relu_sq:True rope_partial_dims:{args.rope_partial_dims} " + f"ln_scale:{args.ln_scale} xsa_layers:{args.xsa_layers}" + ) + if args.recurrence_repeats > 0: + eff = args.recurrence_unique_head + args.recurrence_repeats + args.recurrence_unique_tail + log0( + f"depth_recurrence: head={args.recurrence_unique_head} repeats={args.recurrence_repeats} " + f"tail={args.recurrence_unique_tail} effective_layers={eff}" + ) + train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) + def zero_grad_all() -> None: + for opt in optimizers: + opt.zero_grad(set_to_none=True) + max_wallclock_ms = 1000.0 * args.max_wallclock_seconds if args.max_wallclock_seconds > 0 else None + def lr_mul(step: int, elapsed_ms: float) -> float: + if args.warmdown_iters <= 0: + return 1.0 + if max_wallclock_ms is None: + warmdown_start = max(args.iterations - args.warmdown_iters, 0) + return max((args.iterations - step) / max(args.warmdown_iters, 1), 0.0) if warmdown_start <= step < args.iterations else 1.0 + step_ms = elapsed_ms / max(step, 1) + warmdown_ms = args.warmdown_iters * step_ms + remaining_ms = max(max_wallclock_ms - elapsed_ms, 0.0) + return remaining_ms / max(warmdown_ms, 1e-9) if remaining_ms <= warmdown_ms else 1.0 + if args.warmup_steps > 0: + initial_model_state = {name: tensor.detach().cpu().clone() for name, tensor in base_model.state_dict().items()} + initial_optimizer_states = [copy.deepcopy(opt.state_dict()) for opt in optimizers] + model.train() + for warmup_step in range(args.warmup_steps): + zero_grad_all() + for micro_step in range(grad_accum_steps): + if distributed: + model.require_backward_grad_sync = micro_step == grad_accum_steps - 1 + x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + warmup_loss = model(x, y) + (warmup_loss * grad_scale).backward() + for opt in optimizers: + opt.step() + zero_grad_all() + if args.warmup_steps <= 20 or (warmup_step + 1) % 10 == 0 or warmup_step + 1 == args.warmup_steps: + log0(f"warmup_step:{warmup_step + 1}/{args.warmup_steps}") + base_model.load_state_dict(initial_model_state, strict=True) + for opt, state in zip(optimizers, initial_optimizer_states, strict=True): + opt.load_state_dict(state) + zero_grad_all() + if distributed: + model.require_backward_grad_sync = True + train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) + ema_state: dict[str, Tensor] | None = None + if args.ema_decay > 0: + ema_state = {name: param.data.clone() for name, param in base_model.named_parameters()} + log0(f"ema:enabled decay={args.ema_decay}") + training_time_ms = 0.0 + stop_after_step: int | None = None + torch.cuda.synchronize() + t0 = time.perf_counter() + step = 0 + while True: + last_step = step == args.iterations or (stop_after_step is not None and step >= stop_after_step) + should_validate = last_step or (args.val_loss_every > 0 and step % args.val_loss_every == 0) + if should_validate: + torch.cuda.synchronize() + training_time_ms += 1000.0 * (time.perf_counter() - t0) + ema_backup: dict[str, Tensor] | None = None + if ema_state is not None: + ema_backup = {} + with torch.no_grad(): + for name, p in base_model.named_parameters(): + ema_backup[name] = p.data.clone() + p.data.copy_(ema_state[name]) + val_loss, val_bpb = eval_val( + args, + model, + rank, + world_size, + device, + grad_accum_steps, + val_tokens, + base_bytes_lut, + has_leading_space_lut, + is_boundary_token_lut, + ) + if ema_backup is not None: + with torch.no_grad(): + for name, p in base_model.named_parameters(): + p.data.copy_(ema_backup[name]) + del ema_backup + log0( + f"step:{step}/{args.iterations} val_loss:{val_loss:.4f} val_bpb:{val_bpb:.4f} " + f"train_time:{training_time_ms:.0f}ms step_avg:{training_time_ms / max(step, 1):.2f}ms" + ) + torch.cuda.synchronize() + t0 = time.perf_counter() + if last_step: + if stop_after_step is not None and step < args.iterations: + log0( + f"stopping_early: wallclock_cap train_time:{training_time_ms:.0f}ms " + f"step:{step}/{args.iterations}" + ) + break + elapsed_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) + scale = lr_mul(step, elapsed_ms) + zero_grad_all() + train_loss = torch.zeros((), device=device) + for micro_step in range(grad_accum_steps): + if distributed: + model.require_backward_grad_sync = micro_step == grad_accum_steps - 1 + x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + loss = model(x, y) + train_loss += loss.detach() + (loss * grad_scale).backward() + train_loss /= grad_accum_steps + frac = min(step / args.muon_momentum_warmup_steps, 1.0) if args.muon_momentum_warmup_steps > 0 else 1.0 + muon_momentum = (1 - frac) * args.muon_momentum_warmup_start + frac * args.muon_momentum + for group in optimizer_muon.param_groups: + group["momentum"] = muon_momentum + for opt in optimizers: + for group in opt.param_groups: + group["lr"] = group["base_lr"] * scale + if args.grad_clip_norm > 0: + torch.nn.utils.clip_grad_norm_(base_model.parameters(), args.grad_clip_norm) + for opt in optimizers: + opt.step() + zero_grad_all() + if ema_state is not None: + with torch.no_grad(): + decay = args.ema_decay + for name, p in base_model.named_parameters(): + ema_state[name].mul_(decay).add_(p.data, alpha=1.0 - decay) + step += 1 + if args.checkpoint_every > 0 and step % args.checkpoint_every == 0 and master_process: + ckpt_dir = run_dir / "checkpoints" + ckpt_dir.mkdir(exist_ok=True) + ckpt_path = ckpt_dir / f"ckpt_step{step}.pt" + ckpt_data = {"step": step, "model": base_model.state_dict()} + if ema_state is not None: + ckpt_data["ema"] = {k: v.cpu() for k, v in ema_state.items()} + torch.save(ckpt_data, ckpt_path) + ckpts = sorted(ckpt_dir.glob("ckpt_step*.pt"), key=lambda p: p.stat().st_mtime) + for old in ckpts[:-2]: + old.unlink(missing_ok=True) + log0(f"checkpoint:saved {ckpt_path} (kept {min(len(ckpts), 2)} latest)") + approx_training_time_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) + should_log_train = ( + args.train_log_every > 0 + and (step <= 10 or step % args.train_log_every == 0 or stop_after_step is not None) + ) + if should_log_train: + log0( + f"step:{step}/{args.iterations} train_loss:{train_loss.item():.4f} " + f"train_time:{approx_training_time_ms:.0f}ms step_avg:{approx_training_time_ms / step:.2f}ms" + ) + reached_cap = max_wallclock_ms is not None and approx_training_time_ms >= max_wallclock_ms + if distributed and max_wallclock_ms is not None: + reached_cap_tensor = torch.tensor(int(reached_cap), device=device) + dist.all_reduce(reached_cap_tensor, op=dist.ReduceOp.MAX) + reached_cap = bool(reached_cap_tensor.item()) + if stop_after_step is None and reached_cap: + stop_after_step = step + log0( + f"peak memory allocated: {torch.cuda.max_memory_allocated() // 1024 // 1024} MiB " + f"reserved: {torch.cuda.max_memory_reserved() // 1024 // 1024} MiB" + ) + if ema_state is not None: + with torch.no_grad(): + for name, p in base_model.named_parameters(): + p.data.copy_(ema_state[name]) + log0("ema:applied EMA weights for export") + fp32_path = str(run_dir / "final_model.pt") + if master_process: + torch.save(base_model.state_dict(), fp32_path) + model_bytes = os.path.getsize(fp32_path) + code_bytes = len(code.encode("utf-8")) + log0(f"Serialized model: {model_bytes} bytes → {fp32_path}") + log0(f"Code size: {code_bytes} bytes") + log0(f"Total submission size: {model_bytes + code_bytes} bytes") + quant_label = f"int{_QUANT_BITS}" + compress_label = "zstd" if _HAS_ZSTD else "zlib" + log0(f"quantization:start {quant_label} distributing across {world_size} GPUs") + torch.cuda.synchronize() + t_quant = time.perf_counter() + quant_obj, quant_stats = quantize_state_dict_int8( + base_model.state_dict(), gpu_device=device, rank=rank, world_size=world_size, + ) + torch.cuda.synchronize() + log0(f"quantization:done {quant_label} time={1000.0 * (time.perf_counter() - t_quant):.0f}ms") + quant_buf = io.BytesIO() + torch.save(quant_obj, quant_buf) + quant_raw = quant_buf.getvalue() + if _HAS_ZSTD: + compressor = zstd.ZstdCompressor(level=22, threads=-1) # -1 = use all CPU cores + quant_blob = compressor.compress(quant_raw) + else: + quant_blob = zlib.compress(quant_raw, level=9) + quant_raw_bytes = len(quant_raw) + artifact_name = str(run_dir / f"final_model.{quant_label}.ptz") + if master_process: + with open(artifact_name, "wb") as f: + f.write(quant_blob) + quant_file_bytes = os.path.getsize(artifact_name) + code_bytes = len(code.encode("utf-8")) + ratio = quant_stats["baseline_tensor_bytes"] / max(quant_stats["int8_payload_bytes"], 1) + log0( + f"Serialized model {quant_label}+{compress_label}: {quant_file_bytes} bytes " + f"(payload:{quant_stats['int8_payload_bytes']} raw_torch:{quant_raw_bytes} payload_ratio:{ratio:.2f}x)" + ) + log0(f"Total submission size {quant_label}+{compress_label}: {quant_file_bytes + code_bytes} bytes") + log0(f"artifact_path:{artifact_name}") + if distributed: + dist.barrier() + with open(artifact_name, "rb") as f: + quant_blob_disk = f.read() + if _HAS_ZSTD: + decompressor = zstd.ZstdDecompressor() + quant_state = torch.load(io.BytesIO(decompressor.decompress(quant_blob_disk)), map_location="cpu") + else: + quant_state = torch.load(io.BytesIO(zlib.decompress(quant_blob_disk)), map_location="cpu") + base_model.load_state_dict(dequantize_state_dict_int8(quant_state), strict=True) + torch.cuda.synchronize() + t_qeval = time.perf_counter() + q_val_loss, q_val_bpb = eval_val( + args, + model, + rank, + world_size, + device, + grad_accum_steps, + val_tokens, + base_bytes_lut, + has_leading_space_lut, + is_boundary_token_lut, + ) + torch.cuda.synchronize() + log0( + f"final_{quant_label}_{compress_label}_roundtrip val_loss:{q_val_loss:.4f} val_bpb:{q_val_bpb:.4f} " + f"eval_time:{1000.0 * (time.perf_counter() - t_qeval):.0f}ms" + ) + log0(f"final_{quant_label}_{compress_label}_roundtrip_exact val_loss:{q_val_loss:.8f} val_bpb:{q_val_bpb:.8f}") + if args.eval_stride > 0 and args.eval_stride < args.train_seq_len: + torch.cuda.synchronize() + t_slide = time.perf_counter() + sw_val_loss, sw_val_bpb = eval_val_sliding( + args, base_model, rank, world_size, device, + val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, + stride=args.eval_stride, + ) + torch.cuda.synchronize() + log0( + f"final_sliding_window val_loss:{sw_val_loss:.4f} val_bpb:{sw_val_bpb:.4f} " + f"stride:{args.eval_stride} eval_time:{1000.0 * (time.perf_counter() - t_slide):.0f}ms" + ) + log0(f"final_sliding_window_exact val_loss:{sw_val_loss:.8f} val_bpb:{sw_val_bpb:.8f}") + if args.ttt_enabled: + torch.cuda.synchronize() + t_ttt = time.perf_counter() + + # Choose LoRA TTT or full TTT based on TTT_LORA_ENABLED flag + if args.ttt_lora_enabled: + ttt_loss, ttt_bpb = eval_val_sliding_lora_ttt( + args, base_model, rank, world_size, device, + val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, + stride=args.eval_stride if args.eval_stride > 0 else 64, + log0=log0, + ) + ttt_label = "legal_lora_ttt" + else: + ttt_loss, ttt_bpb = eval_val_sliding_ttt( + args, base_model, rank, world_size, device, + val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, + stride=args.eval_stride if args.eval_stride > 0 else 64, + log0=log0, + ) + ttt_label = "legal_ttt" + + torch.cuda.synchronize() + log0( + f"{ttt_label} val_loss:{ttt_loss:.4f} val_bpb:{ttt_bpb:.4f} " + f"eval_time:{1000.0 * (time.perf_counter() - t_ttt):.0f}ms" + ) + log0(f"{ttt_label}_exact val_loss:{ttt_loss:.8f} val_bpb:{ttt_bpb:.8f}") + if int(os.environ.get("NGRAM_EVAL", "0")): + ngram_max_order = int(os.environ.get("NGRAM_MAX_ORDER", "9")) + log0(f"ngram_eval:starting max_order={ngram_max_order}") + torch.cuda.synchronize() + t_ngram = time.perf_counter() + try: + from ngram_eval import eval_val_ngram + ng_val_loss, ng_val_bpb = eval_val_ngram( + args, base_model, model, rank, world_size, device, grad_accum_steps, + val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, + max_order=ngram_max_order, log_fn=log0, + ) + torch.cuda.synchronize() + log0( + f"ngram_eval_result val_loss:{ng_val_loss:.4f} val_bpb:{ng_val_bpb:.4f} " + f"eval_time:{1000.0 * (time.perf_counter() - t_ngram):.0f}ms" + ) + log0(f"ngram_eval_result_exact val_loss:{ng_val_loss:.8f} val_bpb:{ng_val_bpb:.8f}") + except Exception as e: + log0(f"ngram_eval:FAILED {e}") + if distributed: + dist.destroy_process_group() +if __name__ == "__main__": + main() diff --git a/train_gpt_phase2_depthrecur.py b/train_gpt_phase2_depthrecur.py new file mode 100644 index 0000000000..d8b5d3a27c --- /dev/null +++ b/train_gpt_phase2_depthrecur.py @@ -0,0 +1,1975 @@ +from __future__ import annotations +import copy +import glob +import io +import math +import os +import random +import subprocess +import sys +import time +import uuid +import zlib +from pathlib import Path +try: + import zstandard as zstd + _HAS_ZSTD = True +except ImportError: + _HAS_ZSTD = False +import numpy as np +import sentencepiece as spm +import torch +import torch.distributed as dist +import torch.nn.functional as F +from torch import Tensor, nn +from torch.nn.parallel import DistributedDataParallel as DDP +class Hyperparameters: + data_path = os.environ.get("DATA_PATH", "./data/datasets/fineweb10B_sp1024") + train_files = os.path.join(data_path, "fineweb_train_*.bin") + val_files = os.path.join(data_path, "fineweb_val_*.bin") + tokenizer_path = os.environ.get("TOKENIZER_PATH", "./data/tokenizers/fineweb_1024_bpe.model") + run_id = os.environ.get("RUN_ID", str(uuid.uuid4())) + seed = int(os.environ.get("SEED", 1337)) + val_batch_size = int(os.environ.get("VAL_BATCH_SIZE", 524_288)) + val_loss_every = int(os.environ.get("VAL_LOSS_EVERY", 1000)) + train_log_every = int(os.environ.get("TRAIN_LOG_EVERY", 200)) + val_max_tokens = int(os.environ.get("VAL_MAX_TOKENS", 0)) + checkpoint_every = int(os.environ.get("CHECKPOINT_EVERY", 0)) + iterations = int(os.environ.get("ITERATIONS", 20000)) + warmdown_iters = int(os.environ.get("WARMDOWN_ITERS", 3500)) + warmup_steps = int(os.environ.get("WARMUP_STEPS", 20)) + train_batch_tokens = int(os.environ.get("TRAIN_BATCH_TOKENS", 524_288)) + train_seq_len = int(os.environ.get("TRAIN_SEQ_LEN", 1024)) + max_wallclock_seconds = float(os.environ.get("MAX_WALLCLOCK_SECONDS", 600.0)) + qk_gain_init = float(os.environ.get("QK_GAIN_INIT", 1.5)) + vocab_size = int(os.environ.get("VOCAB_SIZE", 1024)) + num_layers = int(os.environ.get("NUM_LAYERS", 11)) + num_kv_heads = int(os.environ.get("NUM_KV_HEADS", 4)) + model_dim = int(os.environ.get("MODEL_DIM", 512)) + num_heads = int(os.environ.get("NUM_HEADS", 8)) + mlp_mult = int(os.environ.get("MLP_MULT", 3)) + tie_embeddings = bool(int(os.environ.get("TIE_EMBEDDINGS", "1"))) + rope_base = float(os.environ.get("ROPE_BASE", 10000.0)) + logit_softcap = float(os.environ.get("LOGIT_SOFTCAP", 30.0)) + embed_lr = float(os.environ.get("EMBED_LR", 0.6)) + head_lr = float(os.environ.get("HEAD_LR", 0.008)) + tied_embed_lr = float(os.environ.get("TIED_EMBED_LR", 0.05)) + tied_embed_init_std = float(os.environ.get("TIED_EMBED_INIT_STD", 0.005)) + matrix_lr = float(os.environ.get("MATRIX_LR", 0.04)) + scalar_lr = float(os.environ.get("SCALAR_LR", 0.04)) + muon_momentum = float(os.environ.get("MUON_MOMENTUM", 0.95)) + muon_backend_steps = int(os.environ.get("MUON_BACKEND_STEPS", 5)) + muon_momentum_warmup_start = float(os.environ.get("MUON_MOMENTUM_WARMUP_START", 0.85)) + muon_momentum_warmup_steps = int(os.environ.get("MUON_MOMENTUM_WARMUP_STEPS", 500)) + beta1 = float(os.environ.get("BETA1", 0.9)) + beta2 = float(os.environ.get("BETA2", 0.95)) + adam_eps = float(os.environ.get("ADAM_EPS", 1e-8)) + grad_clip_norm = float(os.environ.get("GRAD_CLIP_NORM", 0.0)) + bigramhash_buckets = int(os.environ.get("BIGRAMHASH_BUCKETS", 4096)) + ema_decay = float(os.environ.get("EMA_DECAY", 0.997)) + smeargate = bool(int(os.environ.get("SMEARGATE", "1"))) + unet_skips = bool(int(os.environ.get("UNET_SKIPS", "1"))) + int6_qat = bool(int(os.environ.get("INT6_QAT", "1"))) + rope_partial_dims = int(os.environ.get("ROPE_PARTIAL_DIMS", 16)) + ln_scale = bool(int(os.environ.get("LN_SCALE", "1"))) + xsa_layers = int(os.environ.get("XSA_LAYERS", 4)) + eval_stride = int(os.environ.get("EVAL_STRIDE", 0)) + ttt_enabled = bool(int(os.environ.get("TTT_ENABLED", "0"))) + ttt_lr = float(os.environ.get("TTT_LR", 0.002)) + ttt_epochs = int(os.environ.get("TTT_EPOCHS", 3)) + ttt_chunk_tokens = int(os.environ.get("TTT_CHUNK_TOKENS", 32768)) + ttt_freeze_blocks = int(os.environ.get("TTT_FREEZE_BLOCKS", 0)) + ttt_momentum = float(os.environ.get("TTT_MOMENTUM", 0.9)) + ttt_batch_seqs = int(os.environ.get("TTT_BATCH_SEQS", 32)) + ttt_grad_clip = float(os.environ.get("TTT_GRAD_CLIP", 1.0)) + ttt_lora_enabled = bool(int(os.environ.get("TTT_LORA_ENABLED", "0"))) + ttt_lora_rank = int(os.environ.get("TTT_LORA_RANK", 8)) + ttt_lora_lr = float(os.environ.get("TTT_LORA_LR", 0.001)) + recurrence_repeats = int(os.environ.get("RECURRENCE_REPEATS", 0)) + recurrence_unique_head = int(os.environ.get("RECURRENCE_UNIQUE_HEAD", 2)) + recurrence_unique_tail = int(os.environ.get("RECURRENCE_UNIQUE_TAIL", 2)) + moe_num_experts = int(os.environ.get("MOE_NUM_EXPERTS", 0)) + recur_layers = os.environ.get("RECUR_LAYERS", "") # e.g. "3,4,5" — empty = disabled + recur_start_step = int(os.environ.get("RECUR_START_STEP", 3000)) +def zeropower_via_newtonschulz5(G: Tensor, steps: int = 10, eps: float = 1e-7) -> Tensor: + a, b, c = (3.4445, -4.7750, 2.0315) + X = G.bfloat16() + X /= X.norm() + eps + transposed = G.size(0) > G.size(1) + if transposed: + X = X.T + for _ in range(steps): + A = X @ X.T + B = b * A + c * A @ A + X = a * X + B @ X + return X.T if transposed else X +class Muon(torch.optim.Optimizer): + def __init__(self, params, lr: float, momentum: float, backend_steps: int, nesterov: bool = True): + super().__init__( + params, + dict(lr=lr, momentum=momentum, backend_steps=backend_steps, nesterov=nesterov), + ) + @torch.no_grad() + def step(self, closure=None): + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + distributed = dist.is_available() and dist.is_initialized() + world_size = dist.get_world_size() if distributed else 1 + rank = dist.get_rank() if distributed else 0 + for group in self.param_groups: + params = group["params"] + if not params: + continue + lr = group["lr"] + momentum = group["momentum"] + backend_steps = group["backend_steps"] + nesterov = group["nesterov"] + total_params = sum(int(p.numel()) for p in params) + updates_flat = torch.zeros(total_params, device=params[0].device, dtype=torch.bfloat16) + curr = 0 + for i, p in enumerate(params): + if i % world_size == rank and p.grad is not None: + g = p.grad + state = self.state[p] + if "momentum_buffer" not in state: + state["momentum_buffer"] = torch.zeros_like(g) + buf = state["momentum_buffer"] + buf.mul_(momentum).add_(g) + if nesterov: + g = g.add(buf, alpha=momentum) + g = zeropower_via_newtonschulz5(g, steps=backend_steps) + g *= max(1, g.size(0) / g.size(1)) ** 0.5 + updates_flat[curr : curr + p.numel()] = g.reshape(-1) + curr += p.numel() + if distributed: + dist.all_reduce(updates_flat, op=dist.ReduceOp.SUM) + curr = 0 + for p in params: + g = updates_flat[curr : curr + p.numel()].view_as(p).to(dtype=p.dtype) + p.add_(g, alpha=-lr) + curr += p.numel() + return loss +def int6_ste(w: Tensor) -> Tensor: + qmax = _QUANT_MAX_VAL # Uses QUANT_BITS override if set (e.g. 15 for int5, 7 for int4) + if w.ndim == 2: + abs_max = w.abs().amax(dim=1, keepdim=True).clamp_min(1e-8) + else: + abs_max = w.abs().amax().clamp_min(1e-8) + scale = abs_max / float(qmax) + w_q = (w / scale).round().clamp(-qmax, qmax) * scale + return w + (w_q - w).detach() +def build_sentencepiece_luts( + sp: spm.SentencePieceProcessor, vocab_size: int, device: torch.device +) -> tuple[Tensor, Tensor, Tensor]: + sp_vocab_size = int(sp.vocab_size()) + table_size = max(sp_vocab_size, vocab_size) + base_bytes_np = np.zeros((table_size,), dtype=np.int16) + has_leading_space_np = np.zeros((table_size,), dtype=np.bool_) + is_boundary_token_np = np.ones((table_size,), dtype=np.bool_) + for token_id in range(sp_vocab_size): + if sp.is_control(token_id) or sp.is_unknown(token_id) or sp.is_unused(token_id): + continue + is_boundary_token_np[token_id] = False + if sp.is_byte(token_id): + base_bytes_np[token_id] = 1 + continue + piece = sp.id_to_piece(token_id) + if piece.startswith("▁"): + has_leading_space_np[token_id] = True + piece = piece[1:] + base_bytes_np[token_id] = len(piece.encode("utf-8")) + return ( + torch.tensor(base_bytes_np, dtype=torch.int16, device=device), + torch.tensor(has_leading_space_np, dtype=torch.bool, device=device), + torch.tensor(is_boundary_token_np, dtype=torch.bool, device=device), + ) +def load_validation_tokens(pattern: str, seq_len: int) -> Tensor: + files = [Path(p) for p in sorted(glob.glob(pattern))] + if not files: + raise FileNotFoundError(f"No files found for pattern: {pattern}") + tokens = torch.cat([load_data_shard(file) for file in files]).contiguous() + max_tokens = int(os.environ.get("VAL_MAX_TOKENS", 0)) + if max_tokens > 0 and tokens.numel() > max_tokens: + tokens = tokens[:max_tokens].contiguous() + usable = ((tokens.numel() - 1) // seq_len) * seq_len + if usable <= 0: + raise ValueError(f"Validation split is too short for TRAIN_SEQ_LEN={seq_len}") + return tokens[: usable + 1] +def eval_val( + args: Hyperparameters, + model: nn.Module, + rank: int, + world_size: int, + device: torch.device, + grad_accum_steps: int, + val_tokens: Tensor, + base_bytes_lut: Tensor, + has_leading_space_lut: Tensor, + is_boundary_token_lut: Tensor, +) -> tuple[float, float]: + local_batch_tokens = args.val_batch_size // (world_size * grad_accum_steps) + if local_batch_tokens < args.train_seq_len: + raise ValueError( + "VAL_BATCH_SIZE must provide at least one sequence per rank; " + f"got VAL_BATCH_SIZE={args.val_batch_size}, WORLD_SIZE={world_size}, " + f"GRAD_ACCUM_STEPS={grad_accum_steps}, TRAIN_SEQ_LEN={args.train_seq_len}" + ) + local_batch_seqs = local_batch_tokens // args.train_seq_len + total_seqs = (val_tokens.numel() - 1) // args.train_seq_len + seq_start = (total_seqs * rank) // world_size + seq_end = (total_seqs * (rank + 1)) // world_size + val_loss_sum = torch.zeros((), device=device, dtype=torch.float64) + val_token_count = torch.zeros((), device=device, dtype=torch.float64) + val_byte_count = torch.zeros((), device=device, dtype=torch.float64) + model.eval() + with torch.inference_mode(): + for batch_seq_start in range(seq_start, seq_end, local_batch_seqs): + batch_seq_end = min(batch_seq_start + local_batch_seqs, seq_end) + raw_start = batch_seq_start * args.train_seq_len + raw_end = batch_seq_end * args.train_seq_len + 1 + local = val_tokens[raw_start:raw_end].to(device=device, dtype=torch.int64, non_blocking=True) + x = local[:-1].reshape(-1, args.train_seq_len) + y = local[1:].reshape(-1, args.train_seq_len) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + batch_loss = model(x, y).detach() + batch_token_count = float(y.numel()) + val_loss_sum += batch_loss.to(torch.float64) * batch_token_count + val_token_count += batch_token_count + prev_ids = x.reshape(-1) + tgt_ids = y.reshape(-1) + token_bytes = base_bytes_lut[tgt_ids].to(dtype=torch.int16) + token_bytes += (has_leading_space_lut[tgt_ids] & ~is_boundary_token_lut[prev_ids]).to(dtype=torch.int16) + val_byte_count += token_bytes.to(torch.float64).sum() + if dist.is_available() and dist.is_initialized(): + dist.all_reduce(val_loss_sum, op=dist.ReduceOp.SUM) + dist.all_reduce(val_token_count, op=dist.ReduceOp.SUM) + dist.all_reduce(val_byte_count, op=dist.ReduceOp.SUM) + val_loss = val_loss_sum / val_token_count + bits_per_token = val_loss.item() / math.log(2.0) + tokens_per_byte = val_token_count.item() / val_byte_count.item() + model.train() + return float(val_loss.item()), float(bits_per_token * tokens_per_byte) +def eval_val_sliding( + args: Hyperparameters, + base_model: nn.Module, + rank: int, + world_size: int, + device: torch.device, + val_tokens: Tensor, + base_bytes_lut: Tensor, + has_leading_space_lut: Tensor, + is_boundary_token_lut: Tensor, + stride: int, + batch_seqs: int = 32, +) -> tuple[float, float]: + seq_len = args.train_seq_len + total_tokens = val_tokens.numel() - 1 + window_starts = [ws for ws in range(0, total_tokens, stride) + if min(ws + seq_len, total_tokens) - ws >= 1] + total_windows = len(window_starts) + my_s = (total_windows * rank) // world_size + my_e = (total_windows * (rank + 1)) // world_size + my_windows = window_starts[my_s:my_e] + loss_sum = torch.zeros((), device=device, dtype=torch.float64) + token_count = torch.zeros((), device=device, dtype=torch.float64) + byte_count = torch.zeros((), device=device, dtype=torch.float64) + base_model.eval() + compiled_logits = torch.compile(base_model.forward_logits, dynamic=False, fullgraph=True) + with torch.inference_mode(): + for bi in range(0, len(my_windows), batch_seqs): + batch_ws = my_windows[bi:bi + batch_seqs] + bsz = len(batch_ws) + x_batch = torch.zeros(bsz, seq_len, dtype=torch.int64, device=device) + y_batch = torch.zeros(bsz, seq_len, dtype=torch.int64, device=device) + wlens: list[int] = [] + for i, ws in enumerate(batch_ws): + end = min(ws + seq_len, total_tokens) + wlen = end - ws + wlens.append(wlen) + chunk = val_tokens[ws:end + 1].to(dtype=torch.int64, device=device) + x_batch[i, :wlen] = chunk[:-1] + y_batch[i, :wlen] = chunk[1:] + with torch.autocast(device_type="cuda", dtype=torch.bfloat16): + logits = compiled_logits(x_batch) + nll = F.cross_entropy( + logits.reshape(-1, logits.size(-1)).float(), + y_batch.reshape(-1), + reduction="none", + ).reshape(bsz, seq_len) + for i, ws in enumerate(batch_ws): + wlen = wlens[i] + s = 0 if ws == 0 else max(wlen - stride, 0) + scored_nll = nll[i, s:wlen].to(torch.float64) + loss_sum += scored_nll.sum() + token_count += float(wlen - s) + tgt = y_batch[i, s:wlen] + prev = x_batch[i, s:wlen] + tb = base_bytes_lut[tgt].to(torch.float64) + tb += (has_leading_space_lut[tgt] & ~is_boundary_token_lut[prev]).to(torch.float64) + byte_count += tb.sum() + if dist.is_available() and dist.is_initialized(): + dist.all_reduce(loss_sum, op=dist.ReduceOp.SUM) + dist.all_reduce(token_count, op=dist.ReduceOp.SUM) + dist.all_reduce(byte_count, op=dist.ReduceOp.SUM) + val_loss = (loss_sum / token_count).item() + bits_per_token = val_loss / math.log(2.0) + tokens_per_byte = token_count.item() / byte_count.item() + base_model.train() + return val_loss, bits_per_token * tokens_per_byte +class LoRAAdapter(nn.Module): + """Low-Rank Adaptation adapter for a linear layer.""" + def __init__(self, base_linear: nn.Linear, rank: int = 8): + super().__init__() + dim_in = base_linear.in_features + dim_out = base_linear.out_features + self.A = nn.Parameter(torch.empty(dim_in, rank)) + self.B = nn.Parameter(torch.zeros(rank, dim_out)) + nn.init.kaiming_uniform_(self.A, a=math.sqrt(5)) + self.base_linear = base_linear + self._hook_handle = None + + def _lora_hook(self, module: nn.Module, input: tuple[Tensor, ...], output: Tensor) -> Tensor: + """Forward hook that adds LoRA contribution to base linear output.""" + x = input[0] + lora_delta = (x @ self.A.to(x.dtype)) @ self.B.to(x.dtype) + return output + lora_delta + + def register(self) -> None: + """Register the LoRA hook on the base linear layer.""" + if self._hook_handle is None: + self._hook_handle = self.base_linear.register_forward_hook(self._lora_hook) + + def remove(self) -> None: + """Remove the LoRA hook from the base linear layer.""" + if self._hook_handle is not None: + self._hook_handle.remove() + self._hook_handle = None + +def eval_val_sliding_lora_ttt( + args: Hyperparameters, + base_model: nn.Module, + rank: int, + world_size: int, + device: torch.device, + val_tokens: Tensor, + base_bytes_lut: Tensor, + has_leading_space_lut: Tensor, + is_boundary_token_lut: Tensor, + stride: int, + batch_seqs: int = 32, + log0=print, +) -> tuple[float, float]: + """Test-time training with LoRA adapters (only train adapter weights, not base model).""" + seq_len = args.train_seq_len + total_tokens = val_tokens.numel() - 1 + ttt_chunk = args.ttt_chunk_tokens + + window_starts = [ws for ws in range(0, total_tokens, stride) + if min(ws + seq_len, total_tokens) - ws >= stride or ws == 0] + num_chunks = (total_tokens + ttt_chunk - 1) // ttt_chunk + + chunk_windows: list[list[int]] = [[] for _ in range(num_chunks)] + for ws in window_starts: + end = min(ws + seq_len, total_tokens) + wlen = end - ws + s = 0 if ws == 0 else max(wlen - stride, 0) + scored_start = ws + s + ci = min(scored_start // ttt_chunk, num_chunks - 1) + chunk_windows[ci].append(ws) + + log0(f"lora_ttt_sliding:start chunks={num_chunks} chunk_tokens={ttt_chunk} " + f"total_windows={len(window_starts)} stride={stride} " + f"lora_rank={args.ttt_lora_rank} lora_lr={args.ttt_lora_lr} " + f"ttt_epochs={args.ttt_epochs}") + + loss_sum = torch.zeros((), device=device, dtype=torch.float64) + token_count = torch.zeros((), device=device, dtype=torch.float64) + byte_count = torch.zeros((), device=device, dtype=torch.float64) + + # Create LoRA adapters for attention projections and LM head + lora_adapters: list[LoRAAdapter] = [] + + # Add LoRA to attention blocks (c_q, c_v, proj) + if args.recurrence_repeats > 0: + # Recurrent model structure + all_blocks = list(base_model.head_blocks) + [base_model.shared_block] + list(base_model.tail_blocks) + else: + # Standard block structure + all_blocks = list(base_model.blocks) + + for block in all_blocks: + if hasattr(block, 'attn'): + attn = block.attn + # Add LoRA to Q, V, and output projection + for layer_name in ['c_q', 'c_v', 'proj']: + if hasattr(attn, layer_name): + base_linear = getattr(attn, layer_name) + adapter = LoRAAdapter(base_linear, rank=args.ttt_lora_rank).to(device) + lora_adapters.append(adapter) + + # Add LoRA to LM head (or tok_emb if tied) + if base_model.tie_embeddings: + # For tied embeddings, we need to patch tok_emb.weight usage in forward + # Create a pseudo-linear wrapper for the embedding weight + class EmbeddingAsLinear(nn.Module): + def __init__(self, weight: nn.Parameter): + super().__init__() + self.weight = weight + self.in_features = weight.size(1) + self.out_features = weight.size(0) + self.bias = None + + def forward(self, x: Tensor) -> Tensor: + return F.linear(x, self.weight) + + emb_linear = EmbeddingAsLinear(base_model.tok_emb.weight).to(device) + lm_adapter = LoRAAdapter(emb_linear, rank=args.ttt_lora_rank).to(device) + lora_adapters.append(lm_adapter) + elif base_model.lm_head is not None: + lm_adapter = LoRAAdapter(base_model.lm_head, rank=args.ttt_lora_rank).to(device) + lora_adapters.append(lm_adapter) + + # Collect all LoRA parameters + lora_params = [] + for adapter in lora_adapters: + lora_params.extend([adapter.A, adapter.B]) + + log0(f"lora_ttt_sliding:params lora={sum(p.numel() for p in lora_params)} " + f"base_frozen={sum(p.numel() for p in base_model.parameters())} " + f"adapters={len(lora_adapters)}") + + # Freeze all base model parameters + for p in base_model.parameters(): + p.requires_grad_(False) + + # Register all LoRA hooks + for adapter in lora_adapters: + adapter.register() + + # Use AdamW optimizer for LoRA params + optimizer = torch.optim.AdamW( + lora_params, + lr=args.ttt_lora_lr, + betas=(0.9, 0.999), + weight_decay=0.01, + ) + + # Compile forward paths + if not int(os.environ.get("TORCHDYNAMO_DISABLE", "0")): + compiled_logits_lora = torch.compile(base_model.forward_logits, dynamic=False, fullgraph=True) + compiled_forward_lora = torch.compile(base_model.forward, dynamic=False, fullgraph=True) + else: + compiled_logits_lora = base_model.forward_logits + compiled_forward_lora = base_model.forward + + t0 = time.perf_counter() + + for ci in range(num_chunks): + windows = chunk_windows[ci] + if not windows: + continue + + chunk_start = ci * ttt_chunk + chunk_end = min((ci + 1) * ttt_chunk, total_tokens) + + my_s = (len(windows) * rank) // world_size + my_e = (len(windows) * (rank + 1)) // world_size + my_windows = windows[my_s:my_e] + + # Score phase (inference only) + base_model.eval() + with torch.inference_mode(): + for bi in range(0, len(my_windows), batch_seqs): + batch_ws = my_windows[bi:bi + batch_seqs] + bsz = len(batch_ws) + x_batch = torch.zeros(bsz, seq_len, dtype=torch.int64, device=device) + y_batch = torch.zeros(bsz, seq_len, dtype=torch.int64, device=device) + wlens: list[int] = [] + + for i, ws in enumerate(batch_ws): + end = min(ws + seq_len, total_tokens) + wlen = end - ws + wlens.append(wlen) + chunk_tok = val_tokens[ws:end + 1].to(dtype=torch.int64, device=device) + x_batch[i, :wlen] = chunk_tok[:-1] + y_batch[i, :wlen] = chunk_tok[1:] + + with torch.autocast(device_type="cuda", dtype=torch.bfloat16): + logits = compiled_logits_lora(x_batch) + + nll = F.cross_entropy( + logits.reshape(-1, logits.size(-1)).float(), + y_batch.reshape(-1), reduction="none", + ).reshape(bsz, seq_len) + + for i, ws in enumerate(batch_ws): + wlen = wlens[i] + s = 0 if ws == 0 else max(wlen - stride, 0) + scored_nll = nll[i, s:wlen].to(torch.float64) + loss_sum += scored_nll.sum() + token_count += float(wlen - s) + + tgt, prev = y_batch[i, s:wlen], x_batch[i, s:wlen] + tb = base_bytes_lut[tgt].to(torch.float64) + tb += (has_leading_space_lut[tgt] & ~is_boundary_token_lut[prev]).to(torch.float64) + byte_count += tb.sum() + + # Train phase (update LoRA params only, skip last chunk) + is_last_chunk = (ci == num_chunks - 1) + if not is_last_chunk and args.ttt_epochs > 0: + base_model.train() + chunk_seqs = (chunk_end - chunk_start) // seq_len + + if chunk_seqs > 0: + # Cosine LR decay + cos_lr = args.ttt_lora_lr * 0.5 * (1.0 + math.cos(math.pi * ci / max(num_chunks - 1, 1))) + for pg in optimizer.param_groups: + pg['lr'] = cos_lr + + my_seq_s = (chunk_seqs * rank) // world_size + my_seq_e = (chunk_seqs * (rank + 1)) // world_size + my_chunk_seqs = my_seq_e - my_seq_s + + for _ep in range(args.ttt_epochs): + for bs in range(0, my_chunk_seqs, args.ttt_batch_seqs): + be = min(bs + args.ttt_batch_seqs, my_chunk_seqs) + actual_bs = my_seq_s + bs + start_tok = chunk_start + actual_bs * seq_len + end_tok = chunk_start + (my_seq_s + be) * seq_len + 1 + + if end_tok > val_tokens.numel(): + continue + + local = val_tokens[start_tok:end_tok].to(device=device, dtype=torch.int64) + x = local[:-1].reshape(-1, seq_len) + y = local[1:].reshape(-1, seq_len) + + optimizer.zero_grad(set_to_none=True) + + with torch.autocast(device_type="cuda", dtype=torch.bfloat16): + loss = compiled_forward_lora(x, y) + + loss.backward() + + # All-reduce LoRA gradients across GPUs + if world_size > 1: + for p in lora_params: + if p.grad is not None: + dist.all_reduce(p.grad, op=dist.ReduceOp.AVG) + + # Gradient clipping on LoRA params + torch.nn.utils.clip_grad_norm_(lora_params, args.ttt_grad_clip) + optimizer.step() + + if rank == 0 and (ci % 10 == 0 or ci == num_chunks - 1): + elapsed = time.perf_counter() - t0 + rl = loss_sum.item() / max(token_count.item(), 1) + rbpb = rl / math.log(2.0) * (token_count.item() / max(byte_count.item(), 1)) if token_count.item() > 0 else 0.0 + log0(f" lora_ttt_chunk [{ci+1}/{num_chunks}] bpb={rbpb:.6f} time={elapsed:.1f}s") + + if dist.is_available() and dist.is_initialized(): + dist.all_reduce(loss_sum, op=dist.ReduceOp.SUM) + dist.all_reduce(token_count, op=dist.ReduceOp.SUM) + dist.all_reduce(byte_count, op=dist.ReduceOp.SUM) + + val_loss = (loss_sum / token_count).item() + val_bpb = val_loss / math.log(2.0) * (token_count.item() / byte_count.item()) + + # Clean up: remove all LoRA hooks and restore base model + for adapter in lora_adapters: + adapter.remove() + + # Restore requires_grad on base model + for p in base_model.parameters(): + p.requires_grad_(True) + + base_model.eval() + + log0(f"lora_ttt_sliding:done val_loss={val_loss:.6f} val_bpb={val_bpb:.6f} " + f"elapsed={time.perf_counter() - t0:.1f}s") + + return val_loss, val_bpb + +def eval_val_sliding_ttt( + args: Hyperparameters, + base_model: nn.Module, + rank: int, + world_size: int, + device: torch.device, + val_tokens: Tensor, + base_bytes_lut: Tensor, + has_leading_space_lut: Tensor, + is_boundary_token_lut: Tensor, + stride: int, + batch_seqs: int = 32, + log0=print, +) -> tuple[float, float]: + seq_len = args.train_seq_len + total_tokens = val_tokens.numel() - 1 + ttt_chunk = args.ttt_chunk_tokens + window_starts = [ws for ws in range(0, total_tokens, stride) + if min(ws + seq_len, total_tokens) - ws >= stride or ws == 0] + num_chunks = (total_tokens + ttt_chunk - 1) // ttt_chunk + chunk_windows: list[list[int]] = [[] for _ in range(num_chunks)] + for ws in window_starts: + end = min(ws + seq_len, total_tokens) + wlen = end - ws + s = 0 if ws == 0 else max(wlen - stride, 0) + scored_start = ws + s + ci = min(scored_start // ttt_chunk, num_chunks - 1) + chunk_windows[ci].append(ws) + log0(f"ttt_sliding:start chunks={num_chunks} chunk_tokens={ttt_chunk} " + f"total_windows={len(window_starts)} stride={stride} " + f"ttt_lr={args.ttt_lr} ttt_epochs={args.ttt_epochs} " + f"freeze_blocks={args.ttt_freeze_blocks}") + loss_sum = torch.zeros((), device=device, dtype=torch.float64) + token_count = torch.zeros((), device=device, dtype=torch.float64) + byte_count = torch.zeros((), device=device, dtype=torch.float64) + frozen_block_ids = set(range(min(args.ttt_freeze_blocks, len(base_model.blocks)))) + ttt_params = [] + for name, p in base_model.named_parameters(): + freeze = False + for bi in frozen_block_ids: + if f"blocks.{bi}." in name: + freeze = True + break + if freeze: + p.requires_grad_(False) + else: + p.requires_grad_(True) + ttt_params.append(p) + log0(f"ttt_sliding:params unfrozen={sum(p.numel() for p in ttt_params)} " + f"frozen={sum(p.numel() for p in base_model.parameters() if not p.requires_grad)}") + optimizer = torch.optim.SGD(ttt_params, lr=args.ttt_lr, momentum=args.ttt_momentum) + # Compile forward paths for TTT (matches eval_val_sliding behavior) + if not int(os.environ.get("TORCHDYNAMO_DISABLE", "0")): + compiled_logits_ttt = torch.compile(base_model.forward_logits, dynamic=False, fullgraph=True) + compiled_forward_ttt = torch.compile(base_model.forward, dynamic=False, fullgraph=True) + else: + compiled_logits_ttt = base_model.forward_logits + compiled_forward_ttt = base_model.forward + t0 = time.perf_counter() + for ci in range(num_chunks): + windows = chunk_windows[ci] + if not windows: + continue + chunk_start = ci * ttt_chunk + chunk_end = min((ci + 1) * ttt_chunk, total_tokens) + my_s = (len(windows) * rank) // world_size + my_e = (len(windows) * (rank + 1)) // world_size + my_windows = windows[my_s:my_e] + base_model.eval() + with torch.inference_mode(): + for bi in range(0, len(my_windows), batch_seqs): + batch_ws = my_windows[bi:bi + batch_seqs] + bsz = len(batch_ws) + x_batch = torch.zeros(bsz, seq_len, dtype=torch.int64, device=device) + y_batch = torch.zeros(bsz, seq_len, dtype=torch.int64, device=device) + wlens: list[int] = [] + for i, ws in enumerate(batch_ws): + end = min(ws + seq_len, total_tokens) + wlen = end - ws + wlens.append(wlen) + chunk_tok = val_tokens[ws:end + 1].to(dtype=torch.int64, device=device) + x_batch[i, :wlen] = chunk_tok[:-1] + y_batch[i, :wlen] = chunk_tok[1:] + with torch.autocast(device_type="cuda", dtype=torch.bfloat16): + logits = compiled_logits_ttt(x_batch) + nll = F.cross_entropy( + logits.reshape(-1, logits.size(-1)).float(), + y_batch.reshape(-1), reduction="none", + ).reshape(bsz, seq_len) + for i, ws in enumerate(batch_ws): + wlen = wlens[i] + s = 0 if ws == 0 else max(wlen - stride, 0) + scored_nll = nll[i, s:wlen].to(torch.float64) + loss_sum += scored_nll.sum() + token_count += float(wlen - s) + tgt, prev = y_batch[i, s:wlen], x_batch[i, s:wlen] + tb = base_bytes_lut[tgt].to(torch.float64) + tb += (has_leading_space_lut[tgt] & ~is_boundary_token_lut[prev]).to(torch.float64) + byte_count += tb.sum() + is_last_chunk = (ci == num_chunks - 1) + if not is_last_chunk and args.ttt_epochs > 0: + base_model.train() + chunk_seqs = (chunk_end - chunk_start) // seq_len + if chunk_seqs > 0: + cos_lr = args.ttt_lr * 0.5 * (1.0 + math.cos(math.pi * ci / max(num_chunks - 1, 1))) + for pg in optimizer.param_groups: + pg['lr'] = cos_lr + my_seq_s = (chunk_seqs * rank) // world_size + my_seq_e = (chunk_seqs * (rank + 1)) // world_size + my_chunk_seqs = my_seq_e - my_seq_s + for _ep in range(args.ttt_epochs): + for bs in range(0, my_chunk_seqs, args.ttt_batch_seqs): + be = min(bs + args.ttt_batch_seqs, my_chunk_seqs) + actual_bs = my_seq_s + bs + start_tok = chunk_start + actual_bs * seq_len + end_tok = chunk_start + (my_seq_s + be) * seq_len + 1 + if end_tok > val_tokens.numel(): + continue + local = val_tokens[start_tok:end_tok].to(device=device, dtype=torch.int64) + x = local[:-1].reshape(-1, seq_len) + y = local[1:].reshape(-1, seq_len) + optimizer.zero_grad(set_to_none=True) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16): + loss = compiled_forward_ttt(x, y) + loss.backward() + if world_size > 1: + for p in ttt_params: + if p.grad is not None: + dist.all_reduce(p.grad, op=dist.ReduceOp.AVG) + torch.nn.utils.clip_grad_norm_(ttt_params, args.ttt_grad_clip) + optimizer.step() + if rank == 0 and (ci % 10 == 0 or ci == num_chunks - 1): + elapsed = time.perf_counter() - t0 + rl = loss_sum.item() / max(token_count.item(), 1) + rbpb = rl / math.log(2.0) * (token_count.item() / max(byte_count.item(), 1)) if token_count.item() > 0 else 0.0 + log0(f" ttt_chunk [{ci+1}/{num_chunks}] bpb={rbpb:.6f} time={elapsed:.1f}s") + if dist.is_available() and dist.is_initialized(): + dist.all_reduce(loss_sum, op=dist.ReduceOp.SUM) + dist.all_reduce(token_count, op=dist.ReduceOp.SUM) + dist.all_reduce(byte_count, op=dist.ReduceOp.SUM) + val_loss = (loss_sum / token_count).item() + val_bpb = val_loss / math.log(2.0) * (token_count.item() / byte_count.item()) + for p in base_model.parameters(): + p.requires_grad_(True) + base_model.eval() + log0(f"ttt_sliding:done val_loss={val_loss:.6f} val_bpb={val_bpb:.6f} " + f"elapsed={time.perf_counter() - t0:.1f}s") + return val_loss, val_bpb +_INT6_MODE = bool(int(os.environ.get("INT6_QAT", "1"))) +_INT3_MODE = bool(int(os.environ.get("INT3_QUANT", "0"))) # Ternary quantization +_QUANT_OVERRIDE = int(os.environ.get("QUANT_BITS", "0")) # Override: 3,4,5,6,8 +if _QUANT_OVERRIDE > 0: + _QUANT_BITS = _QUANT_OVERRIDE + _QUANT_MAX_VAL = (1 << (_QUANT_BITS - 1)) - 1 # e.g. int5 -> 15, int4 -> 7 +elif _INT3_MODE: + _QUANT_MAX_VAL = 3 # Ternary: -3, -2, -1, 0, 1, 2, 3 + _QUANT_BITS = 3 +elif _INT6_MODE: + _QUANT_MAX_VAL = 31 + _QUANT_BITS = 6 +else: + _QUANT_MAX_VAL = 127 + _QUANT_BITS = 8 +CONTROL_TENSOR_NAME_PATTERNS = tuple( + pattern + for pattern in os.environ.get( + "CONTROL_TENSOR_NAME_PATTERNS", + "attn_scale,attn_scales,mlp_scale,mlp_scales,resid_mix,resid_mixes,q_gain,skip_weight,skip_weights,smear", + ).split(",") + if pattern +) +INT8_KEEP_FLOAT_FP32_NAME_PATTERNS = tuple( + pattern + for pattern in os.environ.get( + "INT8_KEEP_FLOAT_FP32_NAME_PATTERNS", + ",".join(CONTROL_TENSOR_NAME_PATTERNS), + ).split(",") + if pattern +) +INT8_KEEP_FLOAT_MAX_NUMEL = 65_536 +INT8_KEEP_FLOAT_STORE_DTYPE = torch.float16 +INT8_PER_ROW_SCALE_DTYPE = torch.float16 +INT8_CLIP_PERCENTILE = 99.99984 +INT8_CLIP_Q = INT8_CLIP_PERCENTILE / 100.0 +def tensor_nbytes(t: Tensor) -> int: + return int(t.numel()) * int(t.element_size()) +def keep_float_tensor(name: str, t: Tensor, passthrough_orig_dtypes: dict[str, str]) -> Tensor: + if any(pattern in name for pattern in INT8_KEEP_FLOAT_FP32_NAME_PATTERNS): + return t.float().contiguous() + if t.dtype in {torch.float32, torch.bfloat16}: + passthrough_orig_dtypes[name] = str(t.dtype).removeprefix("torch.") + return t.to(dtype=INT8_KEEP_FLOAT_STORE_DTYPE).contiguous() + return t +_GPTQ_LITE_PERCENTILES = [0.9999, 0.99995, 0.99999, 0.999995, 1.0] +def quantize_float_tensor(t: Tensor, gpu_device: torch.device | None = None) -> tuple[Tensor, Tensor]: + # Run GPTQ-lite search on GPU when available (much faster for large 2D weight matrices) + dev = gpu_device if (gpu_device is not None and t.ndim == 2 and t.numel() > 0) else t.device + t32 = t.float().to(dev) + qmax = _QUANT_MAX_VAL + if t32.ndim == 2 and t32.numel() > 0: + abs_vals = t32.abs() + best_q = None + best_scale = None + best_mse = None + for pct in _GPTQ_LITE_PERCENTILES: + if pct >= 1.0: + clip_abs = abs_vals.amax(dim=1) + else: + clip_abs = torch.quantile(abs_vals, pct, dim=1) + clip_abs = clip_abs.clamp_min(1e-8) + scale = (clip_abs / qmax).clamp_min(1.0 / qmax) + clipped = torch.maximum(torch.minimum(t32, clip_abs[:, None]), -clip_abs[:, None]) + q = torch.clamp(torch.round(clipped / scale[:, None]), -qmax, qmax) + recon = q * scale[:, None] + mse = (t32 - recon).square().mean(dim=1) + if best_mse is None: + best_mse = mse + best_q = q + best_scale = scale + else: + improved = mse < best_mse + if improved.any(): + best_mse = torch.where(improved, mse, best_mse) + best_q = torch.where(improved[:, None], q, best_q) + best_scale = torch.where(improved, scale, best_scale) + return best_q.cpu().to(torch.int8).contiguous(), best_scale.cpu().to(dtype=INT8_PER_ROW_SCALE_DTYPE).contiguous() + if t32.ndim == 2: + return torch.empty_like(t32, dtype=torch.int8, device="cpu"), torch.empty((t32.shape[0],), dtype=INT8_PER_ROW_SCALE_DTYPE) + t32 = t32.cpu() + clip_abs = float(torch.quantile(t32.abs().flatten(), INT8_CLIP_Q).item()) if t32.numel() else 0.0 + scale = torch.tensor(clip_abs / qmax if clip_abs > 0 else 1.0, dtype=torch.float32) + q = torch.clamp(torch.round(torch.clamp(t32, -clip_abs, clip_abs) / scale), -qmax, qmax).to(torch.int8).contiguous() + return q, scale +def quantize_state_dict_int8(state_dict: dict[str, Tensor], gpu_device: torch.device | None = None, + rank: int = 0, world_size: int = 1): + """Quantize state dict. When world_size > 1, distributes large tensor quantization across ranks.""" + quantized: dict[str, Tensor] = {} + scales: dict[str, Tensor] = {} + dtypes: dict[str, str] = {} + passthrough: dict[str, Tensor] = {} + passthrough_orig_dtypes: dict[str, str] = {} + qmeta: dict[str, dict[str, object]] = {} + stats = dict.fromkeys( + ("param_count", "num_tensors", "num_float_tensors", "num_nonfloat_tensors", "baseline_tensor_bytes", "int8_payload_bytes"), + 0, + ) + # Separate tensors into passthrough vs needs-quantization + to_quantize: list[tuple[str, Tensor]] = [] + for name, tensor in state_dict.items(): + t = tensor.detach().to("cpu").contiguous() + stats["param_count"] += int(t.numel()) + stats["num_tensors"] += 1 + stats["baseline_tensor_bytes"] += tensor_nbytes(t) + if not t.is_floating_point(): + stats["num_nonfloat_tensors"] += 1 + passthrough[name] = t + stats["int8_payload_bytes"] += tensor_nbytes(t) + continue + if t.numel() <= INT8_KEEP_FLOAT_MAX_NUMEL: + kept = keep_float_tensor(name, t, passthrough_orig_dtypes) + passthrough[name] = kept + stats["int8_payload_bytes"] += tensor_nbytes(kept) + continue + stats["num_float_tensors"] += 1 + to_quantize.append((name, t)) + # Distribute quantization work across ranks (round-robin by tensor index) + for idx, (name, t) in enumerate(to_quantize): + if world_size > 1 and idx % world_size != rank: + continue # Another rank handles this tensor + q, s = quantize_float_tensor(t, gpu_device=gpu_device) + if s.ndim > 0: + qmeta[name] = {"scheme": "per_row", "axis": 0} + quantized[name] = q + scales[name] = s + dtypes[name] = str(t.dtype).removeprefix("torch.") + stats["int8_payload_bytes"] += tensor_nbytes(q) + tensor_nbytes(s) + # Gather results from all ranks + if world_size > 1 and dist.is_available() and dist.is_initialized(): + # Each rank broadcasts its quantized tensors to all others + for idx, (name, t) in enumerate(to_quantize): + owner = idx % world_size + if owner == rank: + # Broadcast shape info then tensor data + q_tensor = quantized[name] + s_tensor = scales[name] + # Send via broadcast + q_gpu = q_tensor.to(gpu_device) if gpu_device is not None else q_tensor.cuda() + s_gpu = s_tensor.to(gpu_device) if gpu_device is not None else s_tensor.cuda() + dist.broadcast(q_gpu, src=owner) + dist.broadcast(s_gpu, src=owner) + quantized[name] = q_gpu.cpu() + scales[name] = s_gpu.cpu() + else: + # Allocate matching tensors and receive + q_shape = (t.shape[0],) + t.shape[1:] if t.ndim == 2 else t.shape + q_gpu = torch.empty(q_shape, dtype=torch.int8, device=gpu_device if gpu_device is not None else "cuda") + if t.ndim == 2 and t.numel() > 0: + s_gpu = torch.empty(t.shape[0], dtype=INT8_PER_ROW_SCALE_DTYPE, device=gpu_device if gpu_device is not None else "cuda") + else: + s_gpu = torch.empty((), dtype=torch.float32, device=gpu_device if gpu_device is not None else "cuda") + dist.broadcast(q_gpu, src=owner) + dist.broadcast(s_gpu, src=owner) + quantized[name] = q_gpu.cpu() + scales[name] = s_gpu.cpu() + dtypes[name] = str(t.dtype).removeprefix("torch.") + if s_gpu.ndim > 0: + qmeta[name] = {"scheme": "per_row", "axis": 0} + stats["int8_payload_bytes"] += tensor_nbytes(quantized[name]) + tensor_nbytes(scales[name]) + obj: dict[str, object] = { + "__quant_format__": f"int{_QUANT_BITS}_clean_per_row_v1", + "quantized": quantized, + "scales": scales, + "dtypes": dtypes, + "passthrough": passthrough, + } + if qmeta: + obj["qmeta"] = qmeta + if passthrough_orig_dtypes: + obj["passthrough_orig_dtypes"] = passthrough_orig_dtypes + return obj, stats +def dequantize_state_dict_int8(obj: dict[str, object]) -> dict[str, Tensor]: + out: dict[str, Tensor] = {} + qmeta = obj.get("qmeta", {}) + passthrough_orig_dtypes = obj.get("passthrough_orig_dtypes", {}) + for name, q in obj["quantized"].items(): + dtype = getattr(torch, obj["dtypes"][name]) + s = obj["scales"][name] + if qmeta.get(name, {}).get("scheme") == "per_row" or s.ndim > 0: + s = s.to(dtype=torch.float32) + out[name] = (q.float() * s.view(q.shape[0], *([1] * (q.ndim - 1)))).to(dtype=dtype).contiguous() + else: + scale = float(s.item()) + out[name] = (q.float() * scale).to(dtype=dtype).contiguous() + for name, t in obj["passthrough"].items(): + out_t = t.detach().to("cpu").contiguous() + orig_dtype = passthrough_orig_dtypes.get(name) + if isinstance(orig_dtype, str): + out_t = out_t.to(dtype=getattr(torch, orig_dtype)).contiguous() + out[name] = out_t + return out +def load_data_shard(file: Path) -> Tensor: + header_bytes = 256 * np.dtype(" None: + self.file_idx = (self.file_idx + 1) % len(self.files) + self.tokens = load_data_shard(self.files[self.file_idx]) + self.pos = 0 + def take(self, n: int) -> Tensor: + chunks: list[Tensor] = [] + remaining = n + while remaining > 0: + avail = self.tokens.numel() - self.pos + if avail <= 0: + self._advance_file() + continue + k = min(remaining, avail) + chunks.append(self.tokens[self.pos : self.pos + k]) + self.pos += k + remaining -= k + return chunks[0] if len(chunks) == 1 else torch.cat(chunks) +class DistributedTokenLoader: + def __init__(self, pattern: str, rank: int, world_size: int, device: torch.device): + self.rank = rank + self.world_size = world_size + self.device = device + self.stream = TokenStream(pattern) + def next_batch(self, global_tokens: int, seq_len: int, grad_accum_steps: int) -> tuple[Tensor, Tensor]: + local_tokens = global_tokens // (self.world_size * grad_accum_steps) + per_rank_span = local_tokens + 1 + chunk = self.stream.take(per_rank_span * self.world_size) + start = self.rank * per_rank_span + local = chunk[start : start + per_rank_span].to(dtype=torch.int64) + x = local[:-1].reshape(-1, seq_len) + y = local[1:].reshape(-1, seq_len) + return x.to(self.device, non_blocking=True), y.to(self.device, non_blocking=True) +class RMSNorm(nn.Module): + def __init__(self, eps: float | None = None): + super().__init__() + self.eps = eps + def forward(self, x: Tensor) -> Tensor: + return F.rms_norm(x, (x.size(-1),), eps=self.eps) +class CastedLinear(nn.Linear): + _int6_qat: bool = False + def forward(self, x: Tensor) -> Tensor: + w = self.weight.to(x.dtype) + if CastedLinear._int6_qat and self.training: + w = int6_ste(w) + bias = self.bias.to(x.dtype) if self.bias is not None else None + return F.linear(x, w, bias) +def restore_low_dim_params_to_fp32(module: nn.Module) -> None: + with torch.no_grad(): + for name, param in module.named_parameters(): + if (param.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS)) and param.dtype != torch.float32: + param.data = param.data.float() +class BigramHash(nn.Module): + def __init__(self, num_buckets: int, dim: int): + super().__init__() + self.num_buckets = num_buckets + self.emb = nn.Embedding(num_buckets, dim) + nn.init.zeros_(self.emb.weight) + def forward(self, input_ids: Tensor) -> Tensor: + prev = F.pad(input_ids[:, :-1], (1, 0), value=0) + bigram_hash = (prev * 31 + input_ids) % self.num_buckets + return self.emb(bigram_hash) +class SmearGate(nn.Module): + def __init__(self, dim: int): + super().__init__() + self.gate = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + def forward(self, x: Tensor) -> Tensor: + g = torch.sigmoid(self.gate.to(dtype=x.dtype))[None, None, :] + positions = torch.arange(1, x.shape[1] + 1, device=x.device, dtype=x.dtype).view(1, -1, 1) + smooth = torch.cumsum(x, dim=1) / positions + return g * x + (1 - g) * smooth +class Rotary(nn.Module): + def __init__(self, dim: int, base: float = 10000.0): + super().__init__() + inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.float32) / dim)) + self.register_buffer("inv_freq", inv_freq, persistent=False) + self._seq_len_cached = 0 + self._cos_cached: Tensor | None = None + self._sin_cached: Tensor | None = None + def forward(self, seq_len: int, device: torch.device, dtype: torch.dtype) -> tuple[Tensor, Tensor]: + if ( + self._cos_cached is None + or self._sin_cached is None + or self._seq_len_cached != seq_len + or self._cos_cached.device != device + ): + t = torch.arange(seq_len, device=device, dtype=self.inv_freq.dtype) + freqs = torch.outer(t, self.inv_freq.to(device)) + self._cos_cached = freqs.cos()[None, None, :, :] + self._sin_cached = freqs.sin()[None, None, :, :] + self._seq_len_cached = seq_len + return self._cos_cached.to(dtype=dtype), self._sin_cached.to(dtype=dtype) +def apply_rotary_emb(x: Tensor, cos: Tensor, sin: Tensor) -> Tensor: + half = x.size(-1) // 2 + x1, x2 = x[..., :half], x[..., half:] + return torch.cat((x1 * cos + x2 * sin, x1 * (-sin) + x2 * cos), dim=-1) +class CausalSelfAttention(nn.Module): + def __init__( + self, + dim: int, + num_heads: int, + num_kv_heads: int, + rope_base: float, + qk_gain_init: float, + rope_partial_dims: int = 0, + use_xsa: bool = False, + ): + super().__init__() + if dim % num_heads != 0: + raise ValueError("model_dim must be divisible by num_heads") + if num_heads % num_kv_heads != 0: + raise ValueError("num_heads must be divisible by num_kv_heads") + self.num_heads = num_heads + self.num_kv_heads = num_kv_heads + self.head_dim = dim // num_heads + if self.head_dim % 2 != 0: + raise ValueError("head_dim must be even for RoPE") + kv_dim = self.num_kv_heads * self.head_dim + self.c_q = CastedLinear(dim, dim, bias=False) + self.c_k = CastedLinear(dim, kv_dim, bias=False) + self.c_v = CastedLinear(dim, kv_dim, bias=False) + self.proj = CastedLinear(dim, dim, bias=False) + self.proj._zero_init = True + self.q_gain = nn.Parameter(torch.full((num_heads,), qk_gain_init, dtype=torch.float32)) + rope_dim = rope_partial_dims if 0 < rope_partial_dims < self.head_dim else self.head_dim + self.rope_dim = rope_dim + self.rotary = Rotary(rope_dim, base=rope_base) + self.use_xsa = use_xsa + def forward(self, x: Tensor) -> Tensor: + bsz, seqlen, dim = x.shape + q = self.c_q(x).reshape(bsz, seqlen, self.num_heads, self.head_dim).transpose(1, 2) + k = self.c_k(x).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim).transpose(1, 2) + v = self.c_v(x).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim).transpose(1, 2) + q = F.rms_norm(q, (q.size(-1),)) + k = F.rms_norm(k, (k.size(-1),)) + cos, sin = self.rotary(seqlen, x.device, q.dtype) + if self.rope_dim < self.head_dim: + q_rope, q_pass = q[..., :self.rope_dim], q[..., self.rope_dim:] + k_rope, k_pass = k[..., :self.rope_dim], k[..., self.rope_dim:] + q_rope = apply_rotary_emb(q_rope, cos, sin) + k_rope = apply_rotary_emb(k_rope, cos, sin) + q = torch.cat([q_rope, q_pass], dim=-1) + k = torch.cat([k_rope, k_pass], dim=-1) + else: + q = apply_rotary_emb(q, cos, sin) + k = apply_rotary_emb(k, cos, sin) + q = q * self.q_gain.to(dtype=q.dtype)[None, :, None, None] + y = F.scaled_dot_product_attention( + q, + k, + v, + attn_mask=None, + is_causal=True, + enable_gqa=(self.num_kv_heads != self.num_heads), + ) + if self.use_xsa: + if self.num_kv_heads != self.num_heads: + repeats = self.num_heads // self.num_kv_heads + v_expanded = v.repeat_interleave(repeats, dim=1) + else: + v_expanded = v + y = y - v_expanded + y = y.transpose(1, 2).contiguous().reshape(bsz, seqlen, dim) + return self.proj(y) +class MLP(nn.Module): + def __init__(self, dim: int, mlp_mult: int): + super().__init__() + hidden = mlp_mult * dim + self.fc = CastedLinear(dim, hidden, bias=False) + self.proj = CastedLinear(hidden, dim, bias=False) + self.proj._zero_init = True + def forward(self, x: Tensor) -> Tensor: + x = F.leaky_relu(self.fc(x), negative_slope=0.5) + return self.proj(x.square()) +class MoEMLP(nn.Module): + def __init__(self, dim: int, mlp_mult: int, num_experts: int): + super().__init__() + self.num_experts = num_experts + expert_hidden = (mlp_mult * dim) // num_experts + self.experts = nn.ModuleList([ + nn.ModuleDict({ + "fc": CastedLinear(dim, expert_hidden, bias=False), + "proj": CastedLinear(expert_hidden, dim, bias=False), + }) + for _ in range(num_experts) + ]) + for e in self.experts: + e["proj"]._zero_init = True + self.router = nn.Linear(dim, num_experts, bias=False) + def forward(self, x: Tensor) -> Tensor: + bsz, seqlen, dim = x.shape + logits = self.router(x.detach()) + weights = torch.softmax(logits, dim=-1) + indices = weights.argmax(dim=-1) + out = torch.zeros_like(x) + for i, expert in enumerate(self.experts): + mask = (indices == i) + if not mask.any(): + continue + tokens = x[mask] + h = F.leaky_relu(expert["fc"](tokens), negative_slope=0.5) + h = expert["proj"](h.square()) + out[mask] = h + return out +class RepeatMLP(nn.Module): + """Separate MLP weights for depth recurrence repeat pass. Copy-initialized from base MLP at activation step.""" + def __init__(self, dim: int, hidden: int): + super().__init__() + self.fc = nn.Linear(dim, hidden, bias=False) + self.proj = nn.Linear(hidden, dim, bias=False) + + def forward(self, x: Tensor) -> Tensor: + return self.proj(F.leaky_relu(self.fc(x), 0.5).square()) +class Block(nn.Module): + def __init__( + self, + dim: int, + num_heads: int, + num_kv_heads: int, + mlp_mult: int, + rope_base: float, + qk_gain_init: float, + use_smeargate: bool = False, + rope_partial_dims: int = 0, + use_xsa: bool = False, + ln_scale_factor: float = 1.0, + moe_num_experts: int = 0, + has_repeat_mlp: bool = False, + ): + super().__init__() + self.attn_norm = RMSNorm() + self.mlp_norm = RMSNorm() + self.attn = CausalSelfAttention( + dim, num_heads, num_kv_heads, rope_base, qk_gain_init, + rope_partial_dims=rope_partial_dims, + use_xsa=use_xsa, + ) + self.mlp = MoEMLP(dim, mlp_mult, moe_num_experts) if moe_num_experts > 1 else MLP(dim, mlp_mult) + self.attn_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + self.mlp_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + self.resid_mix = nn.Parameter(torch.stack((torch.ones(dim), torch.zeros(dim))).float()) + self.smear = SmearGate(dim) if use_smeargate else None + self.ln_scale_factor = ln_scale_factor + if has_repeat_mlp: + mlp_hidden = int(dim * mlp_mult) + self.repeat_mlp = RepeatMLP(dim, mlp_hidden) + else: + self.repeat_mlp = None + def forward(self, x: Tensor, x0: Tensor) -> Tensor: + mix = self.resid_mix.to(dtype=x.dtype) + x = mix[0][None, None, :] * x + mix[1][None, None, :] * x0 + normed = self.attn_norm(x) + if self.ln_scale_factor != 1.0: + normed = normed * self.ln_scale_factor + attn_out = self.attn(normed) + x = x + self.attn_scale.to(dtype=x.dtype)[None, None, :] * attn_out + mlp_normed = self.mlp_norm(x) + if self.ln_scale_factor != 1.0: + mlp_normed = mlp_normed * self.ln_scale_factor + x = x + self.mlp_scale.to(dtype=x.dtype)[None, None, :] * self.mlp(mlp_normed) + if self.smear is not None: + x = self.smear(x) + return x + def forward_repeat(self, x: Tensor, x0: Tensor) -> Tensor: + """Same as forward but using repeat_mlp instead of mlp for the MLP step.""" + mix = self.resid_mix.to(dtype=x.dtype) + x_ln = self.attn_norm(mix[0] * x + mix[1] * x0) + attn_out = self.attn(x_ln) + x = x + self.attn_scale.to(dtype=x.dtype) * attn_out + mlp_in = self.mlp_norm(x) + mlp_out = self.repeat_mlp(mlp_in) # uses repeat_mlp, not self.mlp + x = x + self.mlp_scale.to(dtype=x.dtype) * mlp_out + return x +class GPT(nn.Module): + def __init__( + self, + vocab_size: int, + num_layers: int, + model_dim: int, + num_heads: int, + num_kv_heads: int, + mlp_mult: int, + tie_embeddings: bool, + tied_embed_init_std: float, + logit_softcap: float, + rope_base: float, + qk_gain_init: float, + bigramhash_buckets: int = 0, + use_smeargate: bool = False, + unet_skips: bool = True, + rope_partial_dims: int = 0, + ln_scale: bool = False, + xsa_layers: int = 0, + recurrence_repeats: int = 0, + recurrence_unique_head: int = 2, + recurrence_unique_tail: int = 2, + moe_num_experts: int = 0, + recur_layer_indices: list[int] | None = None, + recur_start_step: int = 3000, + ): + super().__init__() + if logit_softcap <= 0.0: + raise ValueError(f"logit_softcap must be positive, got {logit_softcap}") + self.tie_embeddings = tie_embeddings + self.tied_embed_init_std = tied_embed_init_std + self.logit_softcap = logit_softcap + self.tok_emb = nn.Embedding(vocab_size, model_dim) + self.bigram_hash = BigramHash(bigramhash_buckets, model_dim) if bigramhash_buckets > 0 else None + self.recurrence_repeats = recurrence_repeats + # Depth recurrence (new, untied MLP weights) + self.recur_layer_indices = recur_layer_indices if recur_layer_indices is not None else [] + self.recur_start_step = recur_start_step + self._recur_active = False + if recurrence_repeats > 0: + self.num_encoder_layers = 0 + self.num_decoder_layers = 0 + self.num_skip_weights = 0 + self.skip_weights = nn.Parameter(torch.zeros(0, dtype=torch.float32)) + n_head = recurrence_unique_head + n_tail = recurrence_unique_tail + effective_layers = n_head + recurrence_repeats + n_tail + def make_block(layer_idx: int, eff_total: int) -> Block: + xsa_start_idx = eff_total - xsa_layers if xsa_layers > 0 else eff_total + return Block( + model_dim, num_heads, num_kv_heads, mlp_mult, rope_base, qk_gain_init, + use_smeargate=use_smeargate, + rope_partial_dims=rope_partial_dims, + use_xsa=(layer_idx >= xsa_start_idx), + ln_scale_factor=(1.0 / math.sqrt(layer_idx + 1)) if ln_scale else 1.0, + moe_num_experts=moe_num_experts, + ) + self.head_blocks = nn.ModuleList([make_block(i, effective_layers) for i in range(n_head)]) + self.shared_block = make_block(n_head, effective_layers) + self.recurrence_gate = nn.Parameter(torch.full((recurrence_repeats,), 0.5, dtype=torch.float32)) + self.tail_blocks = nn.ModuleList([ + make_block(n_head + recurrence_repeats + i, effective_layers) for i in range(n_tail) + ]) + self.blocks = nn.ModuleList() + else: + if unet_skips: + # If depth recurrence is configured, adjust U-Net split based on virtual layer count + if self.recur_layer_indices: + virtual_num_layers = num_layers + len(self.recur_layer_indices) + self.num_encoder_layers_virtual = virtual_num_layers // 2 + self.num_decoder_layers_virtual = virtual_num_layers - self.num_encoder_layers_virtual + # Map virtual encoder count back to physical layers (best-effort split) + self.num_encoder_layers = min(self.num_encoder_layers_virtual, num_layers) + self.num_decoder_layers = num_layers - self.num_encoder_layers + else: + self.num_encoder_layers = num_layers // 2 + self.num_decoder_layers = num_layers - self.num_encoder_layers + self.num_skip_weights = min(self.num_encoder_layers, self.num_decoder_layers) + else: + self.num_encoder_layers = num_layers + self.num_decoder_layers = 0 + self.num_skip_weights = 0 + self.skip_weights = nn.Parameter(torch.ones(self.num_skip_weights, model_dim, dtype=torch.float32)) + xsa_start = num_layers - xsa_layers if xsa_layers > 0 else num_layers + self.blocks = nn.ModuleList( + [ + Block( + model_dim, num_heads, num_kv_heads, mlp_mult, rope_base, qk_gain_init, + use_smeargate=use_smeargate, + rope_partial_dims=rope_partial_dims, + use_xsa=(i >= xsa_start), + ln_scale_factor=(1.0 / math.sqrt(i + 1)) if ln_scale else 1.0, + moe_num_experts=moe_num_experts, + has_repeat_mlp=(i in self.recur_layer_indices), + ) + for i in range(num_layers) + ] + ) + self.head_blocks = nn.ModuleList() + self.shared_block = None + self.recurrence_gate = None + self.tail_blocks = nn.ModuleList() + self.final_norm = RMSNorm() + self.lm_head = None if tie_embeddings else CastedLinear(model_dim, vocab_size, bias=False) + if self.lm_head is not None: + self.lm_head._zero_init = True + self._init_weights() + def _init_weights(self) -> None: + if self.tie_embeddings: + nn.init.normal_(self.tok_emb.weight, mean=0.0, std=self.tied_embed_init_std) + for module in self.modules(): + if isinstance(module, nn.Linear) and getattr(module, "_zero_init", False): + nn.init.zeros_(module.weight) + def _sync_repeat_mlp_from_base(self) -> None: + """Copy MLP weights to repeat_mlp weights at recurrence activation step.""" + for li in self.recur_layer_indices: + block = self.blocks[li] + if block.repeat_mlp is not None: + block.repeat_mlp.fc.weight.data.copy_(block.mlp.fc.weight.data) + block.repeat_mlp.proj.weight.data.copy_(block.mlp.proj.weight.data) + def forward(self, input_ids: Tensor, target_ids: Tensor) -> Tensor: + x = self.tok_emb(input_ids) + if self.bigram_hash is not None: + x = x + self.bigram_hash(input_ids) + x = F.rms_norm(x, (x.size(-1),)) + x0 = x + if self.recurrence_repeats > 0: + for block in self.head_blocks: + x = block(x, x0) + for r in range(self.recurrence_repeats): + gate = torch.sigmoid(self.recurrence_gate[r]).to(dtype=x.dtype) + residual = x + x = self.shared_block(x, x0) + x = gate * x + (1.0 - gate) * residual + for block in self.tail_blocks: + x = block(x, x0) + else: + skips: list[Tensor] = [] + # Standard encoder pass (physical layers 0..num_encoder_layers-1) + for i in range(self.num_encoder_layers): + x = self.blocks[i](x, x0) + if self.skip_weights.numel() > 0: + skips.append(x) + # Depth recurrence pass (if active) — uses separate repeat_mlp weights + if self._recur_active and self.recur_layer_indices: + for li in self.recur_layer_indices: + x = self.blocks[li].forward_repeat(x, x0) + # Standard decoder pass + for i in range(self.num_decoder_layers): + if skips: + skip_idx = self.num_encoder_layers - 1 - i + if 0 <= skip_idx < len(skips): + x = x + self.skip_weights[min(i, len(self.skip_weights) - 1)].to(dtype=x.dtype)[None, None, :] * skips[skip_idx] + x = self.blocks[self.num_encoder_layers + i](x, x0) + x = self.final_norm(x).reshape(-1, x.size(-1)) + targets = target_ids.reshape(-1) + if self.tie_embeddings: + logits_proj = F.linear(x, self.tok_emb.weight) + else: + if self.lm_head is None: + raise RuntimeError("lm_head is required when tie_embeddings=False") + logits_proj = self.lm_head(x) + logits = self.logit_softcap * torch.tanh(logits_proj / self.logit_softcap) + if getattr(self, '_return_logits', False): + return logits.float().reshape(input_ids.shape[0], input_ids.shape[1], -1) + return F.cross_entropy(logits.float(), targets, reduction="mean") + def forward_logits(self, input_ids: Tensor) -> Tensor: + x = self.tok_emb(input_ids) + if self.bigram_hash is not None: + x = x + self.bigram_hash(input_ids) + x = F.rms_norm(x, (x.size(-1),)) + x0 = x + if self.recurrence_repeats > 0: + for block in self.head_blocks: + x = block(x, x0) + for r in range(self.recurrence_repeats): + gate = torch.sigmoid(self.recurrence_gate[r]).to(dtype=x.dtype) + residual = x + x = self.shared_block(x, x0) + x = gate * x + (1.0 - gate) * residual + for block in self.tail_blocks: + x = block(x, x0) + else: + skips: list[Tensor] = [] + # Standard encoder pass + for i in range(self.num_encoder_layers): + x = self.blocks[i](x, x0) + if self.skip_weights.numel() > 0: + skips.append(x) + # Depth recurrence pass (if active) + if self._recur_active and self.recur_layer_indices: + for li in self.recur_layer_indices: + x = self.blocks[li].forward_repeat(x, x0) + # Standard decoder pass + for i in range(self.num_decoder_layers): + if skips: + skip_idx = self.num_encoder_layers - 1 - i + if 0 <= skip_idx < len(skips): + x = x + self.skip_weights[min(i, len(self.skip_weights) - 1)].to(dtype=x.dtype)[None, None, :] * skips[skip_idx] + x = self.blocks[self.num_encoder_layers + i](x, x0) + x = self.final_norm(x) + if self.tie_embeddings: + logits_proj = F.linear(x, self.tok_emb.weight) + else: + if self.lm_head is None: + raise RuntimeError("lm_head is required when tie_embeddings=False") + logits_proj = self.lm_head(x) + return self.logit_softcap * torch.tanh(logits_proj / self.logit_softcap) +def main() -> None: + global zeropower_via_newtonschulz5 + code = Path(__file__).read_text(encoding="utf-8") + args = Hyperparameters() + if not int(os.environ.get("TORCHDYNAMO_DISABLE", "0")): + zeropower_via_newtonschulz5 = torch.compile(zeropower_via_newtonschulz5) + distributed = "RANK" in os.environ and "WORLD_SIZE" in os.environ + rank = int(os.environ.get("RANK", "0")) + world_size = int(os.environ.get("WORLD_SIZE", "1")) + local_rank = int(os.environ.get("LOCAL_RANK", "0")) + if world_size <= 0: + raise ValueError(f"WORLD_SIZE must be positive, got {world_size}") + if 8 % world_size != 0: + raise ValueError(f"WORLD_SIZE={world_size} must divide 8 so grad_accum_steps stays integral") + grad_accum_steps = 8 // world_size + grad_scale = 1.0 / grad_accum_steps + if not torch.cuda.is_available(): + raise RuntimeError("CUDA is required") + device = torch.device("cuda", local_rank) + torch.cuda.set_device(device) + if distributed: + dist.init_process_group(backend="nccl", device_id=device) + dist.barrier() + master_process = rank == 0 + torch.backends.cuda.matmul.allow_tf32 = True + torch.backends.cudnn.allow_tf32 = True + from torch.backends.cuda import enable_cudnn_sdp, enable_flash_sdp, enable_math_sdp, enable_mem_efficient_sdp + _use_flash = sys.platform != "win32" + enable_cudnn_sdp(False) + enable_flash_sdp(_use_flash) + enable_mem_efficient_sdp(False) + enable_math_sdp(not _use_flash) + # Create per-run output directory so artifacts/logs never collide + run_dir = Path(os.environ.get("RUN_OUTPUT_DIR", f"runs/{args.run_id}")) + if master_process: + run_dir.mkdir(parents=True, exist_ok=True) + # Backward-compat: also keep logs/ symlink/copy + os.makedirs("logs", exist_ok=True) + logfile = None + if master_process: + logfile = str(run_dir / "log.txt") + # Also write to legacy logs/ path for compatibility + legacy_logfile = f"logs/{args.run_id}.txt" + if not Path(legacy_logfile).exists(): + try: + os.symlink(os.path.abspath(logfile), legacy_logfile) + except OSError: + pass # Windows or cross-device; will write to both + print(logfile) + def log0(msg: str, console: bool = True) -> None: + if not master_process: + return + if console: + print(msg) + if logfile is not None: + with open(logfile, "a", encoding="utf-8") as f: + print(msg, file=f) + log0(code, console=False) + log0("=" * 100, console=False) + log0(f"Running Python {sys.version}", console=False) + log0(f"Running PyTorch {torch.__version__}", console=False) + log0( + subprocess.run(["nvidia-smi"], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True, check=False).stdout, + console=False, + ) + log0("=" * 100, console=False) + random.seed(args.seed) + np.random.seed(args.seed) + torch.manual_seed(args.seed) + torch.cuda.manual_seed_all(args.seed) + if not args.tokenizer_path.endswith(".model"): + raise ValueError(f"Script only setup for SentencePiece .model file: {args.tokenizer_path}") + sp = spm.SentencePieceProcessor(model_file=args.tokenizer_path) + if int(sp.vocab_size()) != args.vocab_size: + raise ValueError( + f"VOCAB_SIZE={args.vocab_size} does not match tokenizer vocab_size={int(sp.vocab_size())}" + ) + dataset_dir = Path(args.data_path).resolve() + actual_train_files = len(list(dataset_dir.glob("fineweb_train_*.bin"))) + val_tokens = load_validation_tokens(args.val_files, args.train_seq_len) + base_bytes_lut, has_leading_space_lut, is_boundary_token_lut = build_sentencepiece_luts( + sp, args.vocab_size, device + ) + log0(f"val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path={args.tokenizer_path}") + log0(f"train_loader:dataset:{dataset_dir.name} train_shards:{actual_train_files}") + log0(f"val_loader:shards pattern={args.val_files} tokens:{val_tokens.numel() - 1}") + CastedLinear._int6_qat = args.int6_qat + # Parse recur_layer_indices from args + recur_layer_indices = [int(x) for x in args.recur_layers.split(",") if x.strip()] if args.recur_layers else [] + base_model = GPT( + vocab_size=args.vocab_size, + num_layers=args.num_layers, + model_dim=args.model_dim, + num_heads=args.num_heads, + num_kv_heads=args.num_kv_heads, + mlp_mult=args.mlp_mult, + tie_embeddings=args.tie_embeddings, + tied_embed_init_std=args.tied_embed_init_std, + logit_softcap=args.logit_softcap, + rope_base=args.rope_base, + qk_gain_init=args.qk_gain_init, + bigramhash_buckets=args.bigramhash_buckets, + use_smeargate=args.smeargate, + unet_skips=args.unet_skips, + rope_partial_dims=args.rope_partial_dims, + ln_scale=args.ln_scale, + xsa_layers=args.xsa_layers, + recurrence_repeats=args.recurrence_repeats, + recurrence_unique_head=args.recurrence_unique_head, + recurrence_unique_tail=args.recurrence_unique_tail, + moe_num_experts=args.moe_num_experts, + recur_layer_indices=recur_layer_indices, + recur_start_step=args.recur_start_step, + ).to(device).bfloat16() + for module in base_model.modules(): + if isinstance(module, CastedLinear): + module.float() + restore_low_dim_params_to_fp32(base_model) + if int(os.environ.get("TORCHDYNAMO_DISABLE", "0")): + compiled_model = base_model + else: + compiled_model = torch.compile(base_model, dynamic=False, fullgraph=True) + model: nn.Module = DDP(compiled_model, device_ids=[local_rank], broadcast_buffers=False) if distributed else compiled_model + if args.recurrence_repeats > 0: + block_named_params = ( + list(base_model.head_blocks.named_parameters()) + + list(base_model.shared_block.named_parameters()) + + list(base_model.tail_blocks.named_parameters()) + ) + if base_model.recurrence_gate is not None: + block_named_params.append(("recurrence_gate", base_model.recurrence_gate)) + else: + block_named_params = list(base_model.blocks.named_parameters()) + matrix_params = [ + p + for name, p in block_named_params + if p.ndim == 2 and not any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) + ] + scalar_params = [ + p + for name, p in block_named_params + if p.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) + ] + if base_model.skip_weights.numel() > 0: + scalar_params.append(base_model.skip_weights) + token_lr = args.tied_embed_lr if args.tie_embeddings else args.embed_lr + embed_params: list[nn.Parameter] = [base_model.tok_emb.weight] + if base_model.bigram_hash is not None: + embed_params.append(base_model.bigram_hash.emb.weight) + optimizer_tok = torch.optim.Adam( + [{"params": embed_params, "lr": token_lr, "base_lr": token_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + fused=True, + ) + optimizer_muon = Muon( + matrix_params, + lr=args.matrix_lr, + momentum=args.muon_momentum, + backend_steps=args.muon_backend_steps, + ) + for group in optimizer_muon.param_groups: + group["base_lr"] = args.matrix_lr + optimizer_scalar = torch.optim.Adam( + [{"params": scalar_params, "lr": args.scalar_lr, "base_lr": args.scalar_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + fused=True, + ) + optimizers: list[torch.optim.Optimizer] = [optimizer_tok, optimizer_muon, optimizer_scalar] + if base_model.lm_head is not None: + optimizer_head = torch.optim.Adam( + [{"params": [base_model.lm_head.weight], "lr": args.head_lr, "base_lr": args.head_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + fused=True, + ) + optimizers.insert(1, optimizer_head) + # Add repeat_mlp weights to the Muon optimizer param group + # These are the separate MLP weights used during the depth recurrence repeat pass + if base_model.recur_layer_indices: + repeat_mlp_matrix_params = [] + for li in base_model.recur_layer_indices: + block = base_model.blocks[li] + if block.repeat_mlp is not None: + # repeat_mlp.fc and repeat_mlp.proj are plain nn.Linear (2D weight matrices) + repeat_mlp_matrix_params.append(block.repeat_mlp.fc.weight) + repeat_mlp_matrix_params.append(block.repeat_mlp.proj.weight) + if repeat_mlp_matrix_params: + optimizer_muon.add_param_group({ + "params": repeat_mlp_matrix_params, + "lr": args.matrix_lr, + "base_lr": args.matrix_lr, + "momentum": args.muon_momentum, + "backend_steps": args.muon_backend_steps, + "nesterov": True, + }) + n_params = sum(p.numel() for p in base_model.parameters()) + log0(f"model_params:{n_params}") + log0(f"world_size:{world_size} grad_accum_steps:{grad_accum_steps}") + log0(f"sdp_backends:cudnn=False flash={_use_flash} mem_efficient=False math={not _use_flash}") + log0(f"attention_mode:gqa num_heads:{args.num_heads} num_kv_heads:{args.num_kv_heads}") + log0( + f"tie_embeddings:{args.tie_embeddings} embed_lr:{token_lr} " + f"head_lr:{args.head_lr if base_model.lm_head is not None else 0.0} " + f"matrix_lr:{args.matrix_lr} scalar_lr:{args.scalar_lr}" + ) + log0( + f"train_batch_tokens:{args.train_batch_tokens} train_seq_len:{args.train_seq_len} " + f"iterations:{args.iterations} warmup_steps:{args.warmup_steps} " + f"max_wallclock_seconds:{args.max_wallclock_seconds:.3f}" + ) + log0(f"seed:{args.seed}") + log0( + f"meta_baseline: int6_qat:{args.int6_qat} bigramhash:{args.bigramhash_buckets} " + f"smeargate:{args.smeargate} ema_decay:{args.ema_decay} unet_skips:{args.unet_skips} " + f"warmdown_iters:{args.warmdown_iters} compression:{'zstd-22' if _HAS_ZSTD else 'zlib-9'}" + ) + log0( + f"strong_gains: leaky_relu_sq:True rope_partial_dims:{args.rope_partial_dims} " + f"ln_scale:{args.ln_scale} xsa_layers:{args.xsa_layers}" + ) + if args.recurrence_repeats > 0: + eff = args.recurrence_unique_head + args.recurrence_repeats + args.recurrence_unique_tail + log0( + f"depth_recurrence: head={args.recurrence_unique_head} repeats={args.recurrence_repeats} " + f"tail={args.recurrence_unique_tail} effective_layers={eff}" + ) + if base_model.recur_layer_indices: + log0( + f"depth_recurrence_v2: recur_layers={base_model.recur_layer_indices} " + f"recur_start_step={base_model.recur_start_step} " + f"(inactive until step {base_model.recur_start_step})" + ) + train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) + def zero_grad_all() -> None: + for opt in optimizers: + opt.zero_grad(set_to_none=True) + max_wallclock_ms = 1000.0 * args.max_wallclock_seconds if args.max_wallclock_seconds > 0 else None + def lr_mul(step: int, elapsed_ms: float) -> float: + if args.warmdown_iters <= 0: + return 1.0 + if max_wallclock_ms is None: + warmdown_start = max(args.iterations - args.warmdown_iters, 0) + return max((args.iterations - step) / max(args.warmdown_iters, 1), 0.0) if warmdown_start <= step < args.iterations else 1.0 + step_ms = elapsed_ms / max(step, 1) + warmdown_ms = args.warmdown_iters * step_ms + remaining_ms = max(max_wallclock_ms - elapsed_ms, 0.0) + return remaining_ms / max(warmdown_ms, 1e-9) if remaining_ms <= warmdown_ms else 1.0 + if args.warmup_steps > 0: + initial_model_state = {name: tensor.detach().cpu().clone() for name, tensor in base_model.state_dict().items()} + initial_optimizer_states = [copy.deepcopy(opt.state_dict()) for opt in optimizers] + model.train() + for warmup_step in range(args.warmup_steps): + zero_grad_all() + for micro_step in range(grad_accum_steps): + if distributed: + model.require_backward_grad_sync = micro_step == grad_accum_steps - 1 + x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + warmup_loss = model(x, y) + (warmup_loss * grad_scale).backward() + for opt in optimizers: + opt.step() + zero_grad_all() + if args.warmup_steps <= 20 or (warmup_step + 1) % 10 == 0 or warmup_step + 1 == args.warmup_steps: + log0(f"warmup_step:{warmup_step + 1}/{args.warmup_steps}") + base_model.load_state_dict(initial_model_state, strict=True) + for opt, state in zip(optimizers, initial_optimizer_states, strict=True): + opt.load_state_dict(state) + zero_grad_all() + if distributed: + model.require_backward_grad_sync = True + train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) + ema_state: dict[str, Tensor] | None = None + if args.ema_decay > 0: + ema_state = {name: param.data.clone() for name, param in base_model.named_parameters()} + log0(f"ema:enabled decay={args.ema_decay}") + training_time_ms = 0.0 + stop_after_step: int | None = None + torch.cuda.synchronize() + t0 = time.perf_counter() + step = 0 + while True: + last_step = step == args.iterations or (stop_after_step is not None and step >= stop_after_step) + should_validate = last_step or (args.val_loss_every > 0 and step % args.val_loss_every == 0) + if should_validate: + torch.cuda.synchronize() + training_time_ms += 1000.0 * (time.perf_counter() - t0) + ema_backup: dict[str, Tensor] | None = None + if ema_state is not None: + ema_backup = {} + with torch.no_grad(): + for name, p in base_model.named_parameters(): + ema_backup[name] = p.data.clone() + p.data.copy_(ema_state[name]) + val_loss, val_bpb = eval_val( + args, + model, + rank, + world_size, + device, + grad_accum_steps, + val_tokens, + base_bytes_lut, + has_leading_space_lut, + is_boundary_token_lut, + ) + if ema_backup is not None: + with torch.no_grad(): + for name, p in base_model.named_parameters(): + p.data.copy_(ema_backup[name]) + del ema_backup + log0( + f"step:{step}/{args.iterations} val_loss:{val_loss:.4f} val_bpb:{val_bpb:.4f} " + f"train_time:{training_time_ms:.0f}ms step_avg:{training_time_ms / max(step, 1):.2f}ms" + ) + torch.cuda.synchronize() + t0 = time.perf_counter() + if last_step: + if stop_after_step is not None and step < args.iterations: + log0( + f"stopping_early: wallclock_cap train_time:{training_time_ms:.0f}ms " + f"step:{step}/{args.iterations}" + ) + break + elapsed_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) + scale = lr_mul(step, elapsed_ms) + zero_grad_all() + train_loss = torch.zeros((), device=device) + for micro_step in range(grad_accum_steps): + if distributed: + model.require_backward_grad_sync = micro_step == grad_accum_steps - 1 + x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + loss = model(x, y) + train_loss += loss.detach() + (loss * grad_scale).backward() + train_loss /= grad_accum_steps + frac = min(step / args.muon_momentum_warmup_steps, 1.0) if args.muon_momentum_warmup_steps > 0 else 1.0 + muon_momentum = (1 - frac) * args.muon_momentum_warmup_start + frac * args.muon_momentum + for group in optimizer_muon.param_groups: + group["momentum"] = muon_momentum + for opt in optimizers: + for group in opt.param_groups: + group["lr"] = group["base_lr"] * scale + if args.grad_clip_norm > 0: + torch.nn.utils.clip_grad_norm_(base_model.parameters(), args.grad_clip_norm) + for opt in optimizers: + opt.step() + zero_grad_all() + if ema_state is not None: + with torch.no_grad(): + decay = args.ema_decay + for name, p in base_model.named_parameters(): + ema_state[name].mul_(decay).add_(p.data, alpha=1.0 - decay) + step += 1 + # Activate depth recurrence at the designated step + if base_model.recur_layer_indices: + if step == base_model.recur_start_step: + base_model._sync_repeat_mlp_from_base() + base_model._recur_active = True + if master_process: + log0(f"Depth recurrence activated at step {step}, layers {base_model.recur_layer_indices}") + if args.checkpoint_every > 0 and step % args.checkpoint_every == 0 and master_process: + ckpt_dir = run_dir / "checkpoints" + ckpt_dir.mkdir(exist_ok=True) + ckpt_path = ckpt_dir / f"ckpt_step{step}.pt" + ckpt_data = {"step": step, "model": base_model.state_dict()} + if ema_state is not None: + ckpt_data["ema"] = {k: v.cpu() for k, v in ema_state.items()} + torch.save(ckpt_data, ckpt_path) + ckpts = sorted(ckpt_dir.glob("ckpt_step*.pt"), key=lambda p: p.stat().st_mtime) + for old in ckpts[:-2]: + old.unlink(missing_ok=True) + log0(f"checkpoint:saved {ckpt_path} (kept {min(len(ckpts), 2)} latest)") + approx_training_time_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) + should_log_train = ( + args.train_log_every > 0 + and (step <= 10 or step % args.train_log_every == 0 or stop_after_step is not None) + ) + if should_log_train: + log0( + f"step:{step}/{args.iterations} train_loss:{train_loss.item():.4f} " + f"train_time:{approx_training_time_ms:.0f}ms step_avg:{approx_training_time_ms / step:.2f}ms" + ) + reached_cap = max_wallclock_ms is not None and approx_training_time_ms >= max_wallclock_ms + if distributed and max_wallclock_ms is not None: + reached_cap_tensor = torch.tensor(int(reached_cap), device=device) + dist.all_reduce(reached_cap_tensor, op=dist.ReduceOp.MAX) + reached_cap = bool(reached_cap_tensor.item()) + if stop_after_step is None and reached_cap: + stop_after_step = step + log0( + f"peak memory allocated: {torch.cuda.max_memory_allocated() // 1024 // 1024} MiB " + f"reserved: {torch.cuda.max_memory_reserved() // 1024 // 1024} MiB" + ) + if ema_state is not None: + with torch.no_grad(): + for name, p in base_model.named_parameters(): + p.data.copy_(ema_state[name]) + log0("ema:applied EMA weights for export") + fp32_path = str(run_dir / "final_model.pt") + if master_process: + torch.save(base_model.state_dict(), fp32_path) + model_bytes = os.path.getsize(fp32_path) + code_bytes = len(code.encode("utf-8")) + log0(f"Serialized model: {model_bytes} bytes → {fp32_path}") + log0(f"Code size: {code_bytes} bytes") + log0(f"Total submission size: {model_bytes + code_bytes} bytes") + quant_label = f"int{_QUANT_BITS}" + compress_label = "zstd" if _HAS_ZSTD else "zlib" + log0(f"quantization:start {quant_label} distributing across {world_size} GPUs") + torch.cuda.synchronize() + t_quant = time.perf_counter() + quant_obj, quant_stats = quantize_state_dict_int8( + base_model.state_dict(), gpu_device=device, rank=rank, world_size=world_size, + ) + torch.cuda.synchronize() + log0(f"quantization:done {quant_label} time={1000.0 * (time.perf_counter() - t_quant):.0f}ms") + quant_buf = io.BytesIO() + torch.save(quant_obj, quant_buf) + quant_raw = quant_buf.getvalue() + if _HAS_ZSTD: + compressor = zstd.ZstdCompressor(level=22, threads=-1) # -1 = use all CPU cores + quant_blob = compressor.compress(quant_raw) + else: + quant_blob = zlib.compress(quant_raw, level=9) + quant_raw_bytes = len(quant_raw) + artifact_name = str(run_dir / f"final_model.{quant_label}.ptz") + if master_process: + with open(artifact_name, "wb") as f: + f.write(quant_blob) + quant_file_bytes = os.path.getsize(artifact_name) + code_bytes = len(code.encode("utf-8")) + ratio = quant_stats["baseline_tensor_bytes"] / max(quant_stats["int8_payload_bytes"], 1) + log0( + f"Serialized model {quant_label}+{compress_label}: {quant_file_bytes} bytes " + f"(payload:{quant_stats['int8_payload_bytes']} raw_torch:{quant_raw_bytes} payload_ratio:{ratio:.2f}x)" + ) + log0(f"Total submission size {quant_label}+{compress_label}: {quant_file_bytes + code_bytes} bytes") + log0(f"artifact_path:{artifact_name}") + if distributed: + dist.barrier() + with open(artifact_name, "rb") as f: + quant_blob_disk = f.read() + if _HAS_ZSTD: + decompressor = zstd.ZstdDecompressor() + quant_state = torch.load(io.BytesIO(decompressor.decompress(quant_blob_disk)), map_location="cpu") + else: + quant_state = torch.load(io.BytesIO(zlib.decompress(quant_blob_disk)), map_location="cpu") + base_model.load_state_dict(dequantize_state_dict_int8(quant_state), strict=True) + torch.cuda.synchronize() + t_qeval = time.perf_counter() + q_val_loss, q_val_bpb = eval_val( + args, + model, + rank, + world_size, + device, + grad_accum_steps, + val_tokens, + base_bytes_lut, + has_leading_space_lut, + is_boundary_token_lut, + ) + torch.cuda.synchronize() + log0( + f"final_{quant_label}_{compress_label}_roundtrip val_loss:{q_val_loss:.4f} val_bpb:{q_val_bpb:.4f} " + f"eval_time:{1000.0 * (time.perf_counter() - t_qeval):.0f}ms" + ) + log0(f"final_{quant_label}_{compress_label}_roundtrip_exact val_loss:{q_val_loss:.8f} val_bpb:{q_val_bpb:.8f}") + if args.eval_stride > 0 and args.eval_stride < args.train_seq_len: + torch.cuda.synchronize() + t_slide = time.perf_counter() + sw_val_loss, sw_val_bpb = eval_val_sliding( + args, base_model, rank, world_size, device, + val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, + stride=args.eval_stride, + ) + torch.cuda.synchronize() + log0( + f"final_sliding_window val_loss:{sw_val_loss:.4f} val_bpb:{sw_val_bpb:.4f} " + f"stride:{args.eval_stride} eval_time:{1000.0 * (time.perf_counter() - t_slide):.0f}ms" + ) + log0(f"final_sliding_window_exact val_loss:{sw_val_loss:.8f} val_bpb:{sw_val_bpb:.8f}") + if args.ttt_enabled: + torch.cuda.synchronize() + t_ttt = time.perf_counter() + + # Choose LoRA TTT or full TTT based on TTT_LORA_ENABLED flag + if args.ttt_lora_enabled: + ttt_loss, ttt_bpb = eval_val_sliding_lora_ttt( + args, base_model, rank, world_size, device, + val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, + stride=args.eval_stride if args.eval_stride > 0 else 64, + log0=log0, + ) + ttt_label = "legal_lora_ttt" + else: + ttt_loss, ttt_bpb = eval_val_sliding_ttt( + args, base_model, rank, world_size, device, + val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, + stride=args.eval_stride if args.eval_stride > 0 else 64, + log0=log0, + ) + ttt_label = "legal_ttt" + + torch.cuda.synchronize() + log0( + f"{ttt_label} val_loss:{ttt_loss:.4f} val_bpb:{ttt_bpb:.4f} " + f"eval_time:{1000.0 * (time.perf_counter() - t_ttt):.0f}ms" + ) + log0(f"{ttt_label}_exact val_loss:{ttt_loss:.8f} val_bpb:{ttt_bpb:.8f}") + if int(os.environ.get("NGRAM_EVAL", "0")): + ngram_max_order = int(os.environ.get("NGRAM_MAX_ORDER", "9")) + log0(f"ngram_eval:starting max_order={ngram_max_order}") + torch.cuda.synchronize() + t_ngram = time.perf_counter() + try: + from ngram_eval import eval_val_ngram + ng_val_loss, ng_val_bpb = eval_val_ngram( + args, base_model, model, rank, world_size, device, grad_accum_steps, + val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, + max_order=ngram_max_order, log_fn=log0, + ) + torch.cuda.synchronize() + log0( + f"ngram_eval_result val_loss:{ng_val_loss:.4f} val_bpb:{ng_val_bpb:.4f} " + f"eval_time:{1000.0 * (time.perf_counter() - t_ngram):.0f}ms" + ) + log0(f"ngram_eval_result_exact val_loss:{ng_val_loss:.8f} val_bpb:{ng_val_bpb:.8f}") + except Exception as e: + log0(f"ngram_eval:FAILED {e}") + if distributed: + dist.destroy_process_group() +if __name__ == "__main__": + main() diff --git a/train_gpt_phase2_parallel_twolane.py b/train_gpt_phase2_parallel_twolane.py new file mode 100644 index 0000000000..86094a737a --- /dev/null +++ b/train_gpt_phase2_parallel_twolane.py @@ -0,0 +1,2041 @@ +from __future__ import annotations +import copy +import glob +import io +import math +import os +import random +import subprocess +import sys +import time +import uuid +import zlib +from pathlib import Path +try: + import zstandard as zstd + _HAS_ZSTD = True +except ImportError: + _HAS_ZSTD = False +import numpy as np +import sentencepiece as spm +import torch +import torch.distributed as dist +import torch.nn.functional as F +from torch import Tensor, nn +from torch.nn.parallel import DistributedDataParallel as DDP +class Hyperparameters: + data_path = os.environ.get("DATA_PATH", "./data/datasets/fineweb10B_sp1024") + train_files = os.path.join(data_path, "fineweb_train_*.bin") + val_files = os.path.join(data_path, "fineweb_val_*.bin") + tokenizer_path = os.environ.get("TOKENIZER_PATH", "./data/tokenizers/fineweb_1024_bpe.model") + run_id = os.environ.get("RUN_ID", str(uuid.uuid4())) + seed = int(os.environ.get("SEED", 1337)) + val_batch_size = int(os.environ.get("VAL_BATCH_SIZE", 524_288)) + val_loss_every = int(os.environ.get("VAL_LOSS_EVERY", 1000)) + train_log_every = int(os.environ.get("TRAIN_LOG_EVERY", 200)) + val_max_tokens = int(os.environ.get("VAL_MAX_TOKENS", 0)) + checkpoint_every = int(os.environ.get("CHECKPOINT_EVERY", 0)) + iterations = int(os.environ.get("ITERATIONS", 20000)) + warmdown_iters = int(os.environ.get("WARMDOWN_ITERS", 3500)) + warmup_steps = int(os.environ.get("WARMUP_STEPS", 20)) + train_batch_tokens = int(os.environ.get("TRAIN_BATCH_TOKENS", 524_288)) + train_seq_len = int(os.environ.get("TRAIN_SEQ_LEN", 1024)) + max_wallclock_seconds = float(os.environ.get("MAX_WALLCLOCK_SECONDS", 600.0)) + qk_gain_init = float(os.environ.get("QK_GAIN_INIT", 1.5)) + vocab_size = int(os.environ.get("VOCAB_SIZE", 1024)) + num_layers = int(os.environ.get("NUM_LAYERS", 11)) + num_kv_heads = int(os.environ.get("NUM_KV_HEADS", 4)) + model_dim = int(os.environ.get("MODEL_DIM", 512)) + num_heads = int(os.environ.get("NUM_HEADS", 8)) + mlp_mult = int(os.environ.get("MLP_MULT", 3)) + tie_embeddings = bool(int(os.environ.get("TIE_EMBEDDINGS", "1"))) + rope_base = float(os.environ.get("ROPE_BASE", 10000.0)) + logit_softcap = float(os.environ.get("LOGIT_SOFTCAP", 30.0)) + embed_lr = float(os.environ.get("EMBED_LR", 0.6)) + head_lr = float(os.environ.get("HEAD_LR", 0.008)) + tied_embed_lr = float(os.environ.get("TIED_EMBED_LR", 0.05)) + tied_embed_init_std = float(os.environ.get("TIED_EMBED_INIT_STD", 0.005)) + matrix_lr = float(os.environ.get("MATRIX_LR", 0.04)) + scalar_lr = float(os.environ.get("SCALAR_LR", 0.04)) + muon_momentum = float(os.environ.get("MUON_MOMENTUM", 0.95)) + muon_backend_steps = int(os.environ.get("MUON_BACKEND_STEPS", 5)) + muon_momentum_warmup_start = float(os.environ.get("MUON_MOMENTUM_WARMUP_START", 0.85)) + muon_momentum_warmup_steps = int(os.environ.get("MUON_MOMENTUM_WARMUP_STEPS", 500)) + beta1 = float(os.environ.get("BETA1", 0.9)) + beta2 = float(os.environ.get("BETA2", 0.95)) + adam_eps = float(os.environ.get("ADAM_EPS", 1e-8)) + grad_clip_norm = float(os.environ.get("GRAD_CLIP_NORM", 0.0)) + bigramhash_buckets = int(os.environ.get("BIGRAMHASH_BUCKETS", 4096)) + ema_decay = float(os.environ.get("EMA_DECAY", 0.997)) + smeargate = bool(int(os.environ.get("SMEARGATE", "1"))) + unet_skips = bool(int(os.environ.get("UNET_SKIPS", "1"))) + int6_qat = bool(int(os.environ.get("INT6_QAT", "1"))) + rope_partial_dims = int(os.environ.get("ROPE_PARTIAL_DIMS", 16)) + ln_scale = bool(int(os.environ.get("LN_SCALE", "1"))) + xsa_layers = int(os.environ.get("XSA_LAYERS", 4)) + eval_stride = int(os.environ.get("EVAL_STRIDE", 0)) + ttt_enabled = bool(int(os.environ.get("TTT_ENABLED", "0"))) + ttt_lr = float(os.environ.get("TTT_LR", 0.002)) + ttt_epochs = int(os.environ.get("TTT_EPOCHS", 3)) + ttt_chunk_tokens = int(os.environ.get("TTT_CHUNK_TOKENS", 32768)) + ttt_freeze_blocks = int(os.environ.get("TTT_FREEZE_BLOCKS", 0)) + ttt_momentum = float(os.environ.get("TTT_MOMENTUM", 0.9)) + ttt_batch_seqs = int(os.environ.get("TTT_BATCH_SEQS", 32)) + ttt_grad_clip = float(os.environ.get("TTT_GRAD_CLIP", 1.0)) + ttt_lora_enabled = bool(int(os.environ.get("TTT_LORA_ENABLED", "0"))) + ttt_lora_rank = int(os.environ.get("TTT_LORA_RANK", 8)) + ttt_lora_lr = float(os.environ.get("TTT_LORA_LR", 0.001)) + recurrence_repeats = int(os.environ.get("RECURRENCE_REPEATS", 0)) + recurrence_unique_head = int(os.environ.get("RECURRENCE_UNIQUE_HEAD", 2)) + recurrence_unique_tail = int(os.environ.get("RECURRENCE_UNIQUE_TAIL", 2)) + moe_num_experts = int(os.environ.get("MOE_NUM_EXPERTS", 0)) + parallel_start_layer = int(os.environ.get("PARALLEL_START_LAYER", -1)) # -1 = disabled +def zeropower_via_newtonschulz5(G: Tensor, steps: int = 10, eps: float = 1e-7) -> Tensor: + a, b, c = (3.4445, -4.7750, 2.0315) + X = G.bfloat16() + X /= X.norm() + eps + transposed = G.size(0) > G.size(1) + if transposed: + X = X.T + for _ in range(steps): + A = X @ X.T + B = b * A + c * A @ A + X = a * X + B @ X + return X.T if transposed else X +class Muon(torch.optim.Optimizer): + def __init__(self, params, lr: float, momentum: float, backend_steps: int, nesterov: bool = True): + super().__init__( + params, + dict(lr=lr, momentum=momentum, backend_steps=backend_steps, nesterov=nesterov), + ) + @torch.no_grad() + def step(self, closure=None): + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + distributed = dist.is_available() and dist.is_initialized() + world_size = dist.get_world_size() if distributed else 1 + rank = dist.get_rank() if distributed else 0 + for group in self.param_groups: + params = group["params"] + if not params: + continue + lr = group["lr"] + momentum = group["momentum"] + backend_steps = group["backend_steps"] + nesterov = group["nesterov"] + total_params = sum(int(p.numel()) for p in params) + updates_flat = torch.zeros(total_params, device=params[0].device, dtype=torch.bfloat16) + curr = 0 + for i, p in enumerate(params): + if i % world_size == rank and p.grad is not None: + g = p.grad + state = self.state[p] + if "momentum_buffer" not in state: + state["momentum_buffer"] = torch.zeros_like(g) + buf = state["momentum_buffer"] + buf.mul_(momentum).add_(g) + if nesterov: + g = g.add(buf, alpha=momentum) + g = zeropower_via_newtonschulz5(g, steps=backend_steps) + g *= max(1, g.size(0) / g.size(1)) ** 0.5 + updates_flat[curr : curr + p.numel()] = g.reshape(-1) + curr += p.numel() + if distributed: + dist.all_reduce(updates_flat, op=dist.ReduceOp.SUM) + curr = 0 + for p in params: + g = updates_flat[curr : curr + p.numel()].view_as(p).to(dtype=p.dtype) + p.add_(g, alpha=-lr) + curr += p.numel() + return loss +def int6_ste(w: Tensor) -> Tensor: + qmax = _QUANT_MAX_VAL # Uses QUANT_BITS override if set (e.g. 15 for int5, 7 for int4) + if w.ndim == 2: + abs_max = w.abs().amax(dim=1, keepdim=True).clamp_min(1e-8) + else: + abs_max = w.abs().amax().clamp_min(1e-8) + scale = abs_max / float(qmax) + w_q = (w / scale).round().clamp(-qmax, qmax) * scale + return w + (w_q - w).detach() +def build_sentencepiece_luts( + sp: spm.SentencePieceProcessor, vocab_size: int, device: torch.device +) -> tuple[Tensor, Tensor, Tensor]: + sp_vocab_size = int(sp.vocab_size()) + table_size = max(sp_vocab_size, vocab_size) + base_bytes_np = np.zeros((table_size,), dtype=np.int16) + has_leading_space_np = np.zeros((table_size,), dtype=np.bool_) + is_boundary_token_np = np.ones((table_size,), dtype=np.bool_) + for token_id in range(sp_vocab_size): + if sp.is_control(token_id) or sp.is_unknown(token_id) or sp.is_unused(token_id): + continue + is_boundary_token_np[token_id] = False + if sp.is_byte(token_id): + base_bytes_np[token_id] = 1 + continue + piece = sp.id_to_piece(token_id) + if piece.startswith("▁"): + has_leading_space_np[token_id] = True + piece = piece[1:] + base_bytes_np[token_id] = len(piece.encode("utf-8")) + return ( + torch.tensor(base_bytes_np, dtype=torch.int16, device=device), + torch.tensor(has_leading_space_np, dtype=torch.bool, device=device), + torch.tensor(is_boundary_token_np, dtype=torch.bool, device=device), + ) +def load_validation_tokens(pattern: str, seq_len: int) -> Tensor: + files = [Path(p) for p in sorted(glob.glob(pattern))] + if not files: + raise FileNotFoundError(f"No files found for pattern: {pattern}") + tokens = torch.cat([load_data_shard(file) for file in files]).contiguous() + max_tokens = int(os.environ.get("VAL_MAX_TOKENS", 0)) + if max_tokens > 0 and tokens.numel() > max_tokens: + tokens = tokens[:max_tokens].contiguous() + usable = ((tokens.numel() - 1) // seq_len) * seq_len + if usable <= 0: + raise ValueError(f"Validation split is too short for TRAIN_SEQ_LEN={seq_len}") + return tokens[: usable + 1] +def eval_val( + args: Hyperparameters, + model: nn.Module, + rank: int, + world_size: int, + device: torch.device, + grad_accum_steps: int, + val_tokens: Tensor, + base_bytes_lut: Tensor, + has_leading_space_lut: Tensor, + is_boundary_token_lut: Tensor, +) -> tuple[float, float]: + local_batch_tokens = args.val_batch_size // (world_size * grad_accum_steps) + if local_batch_tokens < args.train_seq_len: + raise ValueError( + "VAL_BATCH_SIZE must provide at least one sequence per rank; " + f"got VAL_BATCH_SIZE={args.val_batch_size}, WORLD_SIZE={world_size}, " + f"GRAD_ACCUM_STEPS={grad_accum_steps}, TRAIN_SEQ_LEN={args.train_seq_len}" + ) + local_batch_seqs = local_batch_tokens // args.train_seq_len + total_seqs = (val_tokens.numel() - 1) // args.train_seq_len + seq_start = (total_seqs * rank) // world_size + seq_end = (total_seqs * (rank + 1)) // world_size + val_loss_sum = torch.zeros((), device=device, dtype=torch.float64) + val_token_count = torch.zeros((), device=device, dtype=torch.float64) + val_byte_count = torch.zeros((), device=device, dtype=torch.float64) + model.eval() + with torch.inference_mode(): + for batch_seq_start in range(seq_start, seq_end, local_batch_seqs): + batch_seq_end = min(batch_seq_start + local_batch_seqs, seq_end) + raw_start = batch_seq_start * args.train_seq_len + raw_end = batch_seq_end * args.train_seq_len + 1 + local = val_tokens[raw_start:raw_end].to(device=device, dtype=torch.int64, non_blocking=True) + x = local[:-1].reshape(-1, args.train_seq_len) + y = local[1:].reshape(-1, args.train_seq_len) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + batch_loss = model(x, y).detach() + batch_token_count = float(y.numel()) + val_loss_sum += batch_loss.to(torch.float64) * batch_token_count + val_token_count += batch_token_count + prev_ids = x.reshape(-1) + tgt_ids = y.reshape(-1) + token_bytes = base_bytes_lut[tgt_ids].to(dtype=torch.int16) + token_bytes += (has_leading_space_lut[tgt_ids] & ~is_boundary_token_lut[prev_ids]).to(dtype=torch.int16) + val_byte_count += token_bytes.to(torch.float64).sum() + if dist.is_available() and dist.is_initialized(): + dist.all_reduce(val_loss_sum, op=dist.ReduceOp.SUM) + dist.all_reduce(val_token_count, op=dist.ReduceOp.SUM) + dist.all_reduce(val_byte_count, op=dist.ReduceOp.SUM) + val_loss = val_loss_sum / val_token_count + bits_per_token = val_loss.item() / math.log(2.0) + tokens_per_byte = val_token_count.item() / val_byte_count.item() + model.train() + return float(val_loss.item()), float(bits_per_token * tokens_per_byte) +def eval_val_sliding( + args: Hyperparameters, + base_model: nn.Module, + rank: int, + world_size: int, + device: torch.device, + val_tokens: Tensor, + base_bytes_lut: Tensor, + has_leading_space_lut: Tensor, + is_boundary_token_lut: Tensor, + stride: int, + batch_seqs: int = 32, +) -> tuple[float, float]: + seq_len = args.train_seq_len + total_tokens = val_tokens.numel() - 1 + window_starts = [ws for ws in range(0, total_tokens, stride) + if min(ws + seq_len, total_tokens) - ws >= 1] + total_windows = len(window_starts) + my_s = (total_windows * rank) // world_size + my_e = (total_windows * (rank + 1)) // world_size + my_windows = window_starts[my_s:my_e] + loss_sum = torch.zeros((), device=device, dtype=torch.float64) + token_count = torch.zeros((), device=device, dtype=torch.float64) + byte_count = torch.zeros((), device=device, dtype=torch.float64) + base_model.eval() + compiled_logits = torch.compile(base_model.forward_logits, dynamic=False, fullgraph=True) + with torch.inference_mode(): + for bi in range(0, len(my_windows), batch_seqs): + batch_ws = my_windows[bi:bi + batch_seqs] + bsz = len(batch_ws) + x_batch = torch.zeros(bsz, seq_len, dtype=torch.int64, device=device) + y_batch = torch.zeros(bsz, seq_len, dtype=torch.int64, device=device) + wlens: list[int] = [] + for i, ws in enumerate(batch_ws): + end = min(ws + seq_len, total_tokens) + wlen = end - ws + wlens.append(wlen) + chunk = val_tokens[ws:end + 1].to(dtype=torch.int64, device=device) + x_batch[i, :wlen] = chunk[:-1] + y_batch[i, :wlen] = chunk[1:] + with torch.autocast(device_type="cuda", dtype=torch.bfloat16): + logits = compiled_logits(x_batch) + nll = F.cross_entropy( + logits.reshape(-1, logits.size(-1)).float(), + y_batch.reshape(-1), + reduction="none", + ).reshape(bsz, seq_len) + for i, ws in enumerate(batch_ws): + wlen = wlens[i] + s = 0 if ws == 0 else max(wlen - stride, 0) + scored_nll = nll[i, s:wlen].to(torch.float64) + loss_sum += scored_nll.sum() + token_count += float(wlen - s) + tgt = y_batch[i, s:wlen] + prev = x_batch[i, s:wlen] + tb = base_bytes_lut[tgt].to(torch.float64) + tb += (has_leading_space_lut[tgt] & ~is_boundary_token_lut[prev]).to(torch.float64) + byte_count += tb.sum() + if dist.is_available() and dist.is_initialized(): + dist.all_reduce(loss_sum, op=dist.ReduceOp.SUM) + dist.all_reduce(token_count, op=dist.ReduceOp.SUM) + dist.all_reduce(byte_count, op=dist.ReduceOp.SUM) + val_loss = (loss_sum / token_count).item() + bits_per_token = val_loss / math.log(2.0) + tokens_per_byte = token_count.item() / byte_count.item() + base_model.train() + return val_loss, bits_per_token * tokens_per_byte +class LoRAAdapter(nn.Module): + """Low-Rank Adaptation adapter for a linear layer.""" + def __init__(self, base_linear: nn.Linear, rank: int = 8): + super().__init__() + dim_in = base_linear.in_features + dim_out = base_linear.out_features + self.A = nn.Parameter(torch.empty(dim_in, rank)) + self.B = nn.Parameter(torch.zeros(rank, dim_out)) + nn.init.kaiming_uniform_(self.A, a=math.sqrt(5)) + self.base_linear = base_linear + self._hook_handle = None + + def _lora_hook(self, module: nn.Module, input: tuple[Tensor, ...], output: Tensor) -> Tensor: + """Forward hook that adds LoRA contribution to base linear output.""" + x = input[0] + lora_delta = (x @ self.A.to(x.dtype)) @ self.B.to(x.dtype) + return output + lora_delta + + def register(self) -> None: + """Register the LoRA hook on the base linear layer.""" + if self._hook_handle is None: + self._hook_handle = self.base_linear.register_forward_hook(self._lora_hook) + + def remove(self) -> None: + """Remove the LoRA hook from the base linear layer.""" + if self._hook_handle is not None: + self._hook_handle.remove() + self._hook_handle = None + +def eval_val_sliding_lora_ttt( + args: Hyperparameters, + base_model: nn.Module, + rank: int, + world_size: int, + device: torch.device, + val_tokens: Tensor, + base_bytes_lut: Tensor, + has_leading_space_lut: Tensor, + is_boundary_token_lut: Tensor, + stride: int, + batch_seqs: int = 32, + log0=print, +) -> tuple[float, float]: + """Test-time training with LoRA adapters (only train adapter weights, not base model).""" + seq_len = args.train_seq_len + total_tokens = val_tokens.numel() - 1 + ttt_chunk = args.ttt_chunk_tokens + + window_starts = [ws for ws in range(0, total_tokens, stride) + if min(ws + seq_len, total_tokens) - ws >= stride or ws == 0] + num_chunks = (total_tokens + ttt_chunk - 1) // ttt_chunk + + chunk_windows: list[list[int]] = [[] for _ in range(num_chunks)] + for ws in window_starts: + end = min(ws + seq_len, total_tokens) + wlen = end - ws + s = 0 if ws == 0 else max(wlen - stride, 0) + scored_start = ws + s + ci = min(scored_start // ttt_chunk, num_chunks - 1) + chunk_windows[ci].append(ws) + + log0(f"lora_ttt_sliding:start chunks={num_chunks} chunk_tokens={ttt_chunk} " + f"total_windows={len(window_starts)} stride={stride} " + f"lora_rank={args.ttt_lora_rank} lora_lr={args.ttt_lora_lr} " + f"ttt_epochs={args.ttt_epochs}") + + loss_sum = torch.zeros((), device=device, dtype=torch.float64) + token_count = torch.zeros((), device=device, dtype=torch.float64) + byte_count = torch.zeros((), device=device, dtype=torch.float64) + + # Create LoRA adapters for attention projections and LM head + lora_adapters: list[LoRAAdapter] = [] + + # Add LoRA to attention blocks (c_q, c_v, proj) + if args.recurrence_repeats > 0: + # Recurrent model structure + all_blocks = list(base_model.head_blocks) + [base_model.shared_block] + list(base_model.tail_blocks) + else: + # Standard block structure + all_blocks = list(base_model.blocks) + + for block in all_blocks: + if hasattr(block, 'attn'): + attn = block.attn + # Add LoRA to Q, V, and output projection + for layer_name in ['c_q', 'c_v', 'proj']: + if hasattr(attn, layer_name): + base_linear = getattr(attn, layer_name) + adapter = LoRAAdapter(base_linear, rank=args.ttt_lora_rank).to(device) + lora_adapters.append(adapter) + + # Add LoRA to LM head (or tok_emb if tied) + if base_model.tie_embeddings: + # For tied embeddings, we need to patch tok_emb.weight usage in forward + # Create a pseudo-linear wrapper for the embedding weight + class EmbeddingAsLinear(nn.Module): + def __init__(self, weight: nn.Parameter): + super().__init__() + self.weight = weight + self.in_features = weight.size(1) + self.out_features = weight.size(0) + self.bias = None + + def forward(self, x: Tensor) -> Tensor: + return F.linear(x, self.weight) + + emb_linear = EmbeddingAsLinear(base_model.tok_emb.weight).to(device) + lm_adapter = LoRAAdapter(emb_linear, rank=args.ttt_lora_rank).to(device) + lora_adapters.append(lm_adapter) + elif base_model.lm_head is not None: + lm_adapter = LoRAAdapter(base_model.lm_head, rank=args.ttt_lora_rank).to(device) + lora_adapters.append(lm_adapter) + + # Collect all LoRA parameters + lora_params = [] + for adapter in lora_adapters: + lora_params.extend([adapter.A, adapter.B]) + + log0(f"lora_ttt_sliding:params lora={sum(p.numel() for p in lora_params)} " + f"base_frozen={sum(p.numel() for p in base_model.parameters())} " + f"adapters={len(lora_adapters)}") + + # Freeze all base model parameters + for p in base_model.parameters(): + p.requires_grad_(False) + + # Register all LoRA hooks + for adapter in lora_adapters: + adapter.register() + + # Use AdamW optimizer for LoRA params + optimizer = torch.optim.AdamW( + lora_params, + lr=args.ttt_lora_lr, + betas=(0.9, 0.999), + weight_decay=0.01, + ) + + # Compile forward paths + if not int(os.environ.get("TORCHDYNAMO_DISABLE", "0")): + compiled_logits_lora = torch.compile(base_model.forward_logits, dynamic=False, fullgraph=True) + compiled_forward_lora = torch.compile(base_model.forward, dynamic=False, fullgraph=True) + else: + compiled_logits_lora = base_model.forward_logits + compiled_forward_lora = base_model.forward + + t0 = time.perf_counter() + + for ci in range(num_chunks): + windows = chunk_windows[ci] + if not windows: + continue + + chunk_start = ci * ttt_chunk + chunk_end = min((ci + 1) * ttt_chunk, total_tokens) + + my_s = (len(windows) * rank) // world_size + my_e = (len(windows) * (rank + 1)) // world_size + my_windows = windows[my_s:my_e] + + # Score phase (inference only) + base_model.eval() + with torch.inference_mode(): + for bi in range(0, len(my_windows), batch_seqs): + batch_ws = my_windows[bi:bi + batch_seqs] + bsz = len(batch_ws) + x_batch = torch.zeros(bsz, seq_len, dtype=torch.int64, device=device) + y_batch = torch.zeros(bsz, seq_len, dtype=torch.int64, device=device) + wlens: list[int] = [] + + for i, ws in enumerate(batch_ws): + end = min(ws + seq_len, total_tokens) + wlen = end - ws + wlens.append(wlen) + chunk_tok = val_tokens[ws:end + 1].to(dtype=torch.int64, device=device) + x_batch[i, :wlen] = chunk_tok[:-1] + y_batch[i, :wlen] = chunk_tok[1:] + + with torch.autocast(device_type="cuda", dtype=torch.bfloat16): + logits = compiled_logits_lora(x_batch) + + nll = F.cross_entropy( + logits.reshape(-1, logits.size(-1)).float(), + y_batch.reshape(-1), reduction="none", + ).reshape(bsz, seq_len) + + for i, ws in enumerate(batch_ws): + wlen = wlens[i] + s = 0 if ws == 0 else max(wlen - stride, 0) + scored_nll = nll[i, s:wlen].to(torch.float64) + loss_sum += scored_nll.sum() + token_count += float(wlen - s) + + tgt, prev = y_batch[i, s:wlen], x_batch[i, s:wlen] + tb = base_bytes_lut[tgt].to(torch.float64) + tb += (has_leading_space_lut[tgt] & ~is_boundary_token_lut[prev]).to(torch.float64) + byte_count += tb.sum() + + # Train phase (update LoRA params only, skip last chunk) + is_last_chunk = (ci == num_chunks - 1) + if not is_last_chunk and args.ttt_epochs > 0: + base_model.train() + chunk_seqs = (chunk_end - chunk_start) // seq_len + + if chunk_seqs > 0: + # Cosine LR decay + cos_lr = args.ttt_lora_lr * 0.5 * (1.0 + math.cos(math.pi * ci / max(num_chunks - 1, 1))) + for pg in optimizer.param_groups: + pg['lr'] = cos_lr + + my_seq_s = (chunk_seqs * rank) // world_size + my_seq_e = (chunk_seqs * (rank + 1)) // world_size + my_chunk_seqs = my_seq_e - my_seq_s + + for _ep in range(args.ttt_epochs): + for bs in range(0, my_chunk_seqs, args.ttt_batch_seqs): + be = min(bs + args.ttt_batch_seqs, my_chunk_seqs) + actual_bs = my_seq_s + bs + start_tok = chunk_start + actual_bs * seq_len + end_tok = chunk_start + (my_seq_s + be) * seq_len + 1 + + if end_tok > val_tokens.numel(): + continue + + local = val_tokens[start_tok:end_tok].to(device=device, dtype=torch.int64) + x = local[:-1].reshape(-1, seq_len) + y = local[1:].reshape(-1, seq_len) + + optimizer.zero_grad(set_to_none=True) + + with torch.autocast(device_type="cuda", dtype=torch.bfloat16): + loss = compiled_forward_lora(x, y) + + loss.backward() + + # All-reduce LoRA gradients across GPUs + if world_size > 1: + for p in lora_params: + if p.grad is not None: + dist.all_reduce(p.grad, op=dist.ReduceOp.AVG) + + # Gradient clipping on LoRA params + torch.nn.utils.clip_grad_norm_(lora_params, args.ttt_grad_clip) + optimizer.step() + + if rank == 0 and (ci % 10 == 0 or ci == num_chunks - 1): + elapsed = time.perf_counter() - t0 + rl = loss_sum.item() / max(token_count.item(), 1) + rbpb = rl / math.log(2.0) * (token_count.item() / max(byte_count.item(), 1)) if token_count.item() > 0 else 0.0 + log0(f" lora_ttt_chunk [{ci+1}/{num_chunks}] bpb={rbpb:.6f} time={elapsed:.1f}s") + + if dist.is_available() and dist.is_initialized(): + dist.all_reduce(loss_sum, op=dist.ReduceOp.SUM) + dist.all_reduce(token_count, op=dist.ReduceOp.SUM) + dist.all_reduce(byte_count, op=dist.ReduceOp.SUM) + + val_loss = (loss_sum / token_count).item() + val_bpb = val_loss / math.log(2.0) * (token_count.item() / byte_count.item()) + + # Clean up: remove all LoRA hooks and restore base model + for adapter in lora_adapters: + adapter.remove() + + # Restore requires_grad on base model + for p in base_model.parameters(): + p.requires_grad_(True) + + base_model.eval() + + log0(f"lora_ttt_sliding:done val_loss={val_loss:.6f} val_bpb={val_bpb:.6f} " + f"elapsed={time.perf_counter() - t0:.1f}s") + + return val_loss, val_bpb + +def eval_val_sliding_ttt( + args: Hyperparameters, + base_model: nn.Module, + rank: int, + world_size: int, + device: torch.device, + val_tokens: Tensor, + base_bytes_lut: Tensor, + has_leading_space_lut: Tensor, + is_boundary_token_lut: Tensor, + stride: int, + batch_seqs: int = 32, + log0=print, +) -> tuple[float, float]: + seq_len = args.train_seq_len + total_tokens = val_tokens.numel() - 1 + ttt_chunk = args.ttt_chunk_tokens + window_starts = [ws for ws in range(0, total_tokens, stride) + if min(ws + seq_len, total_tokens) - ws >= stride or ws == 0] + num_chunks = (total_tokens + ttt_chunk - 1) // ttt_chunk + chunk_windows: list[list[int]] = [[] for _ in range(num_chunks)] + for ws in window_starts: + end = min(ws + seq_len, total_tokens) + wlen = end - ws + s = 0 if ws == 0 else max(wlen - stride, 0) + scored_start = ws + s + ci = min(scored_start // ttt_chunk, num_chunks - 1) + chunk_windows[ci].append(ws) + log0(f"ttt_sliding:start chunks={num_chunks} chunk_tokens={ttt_chunk} " + f"total_windows={len(window_starts)} stride={stride} " + f"ttt_lr={args.ttt_lr} ttt_epochs={args.ttt_epochs} " + f"freeze_blocks={args.ttt_freeze_blocks}") + loss_sum = torch.zeros((), device=device, dtype=torch.float64) + token_count = torch.zeros((), device=device, dtype=torch.float64) + byte_count = torch.zeros((), device=device, dtype=torch.float64) + frozen_block_ids = set(range(min(args.ttt_freeze_blocks, len(base_model.blocks)))) + ttt_params = [] + for name, p in base_model.named_parameters(): + freeze = False + for bi in frozen_block_ids: + if f"blocks.{bi}." in name: + freeze = True + break + if freeze: + p.requires_grad_(False) + else: + p.requires_grad_(True) + ttt_params.append(p) + log0(f"ttt_sliding:params unfrozen={sum(p.numel() for p in ttt_params)} " + f"frozen={sum(p.numel() for p in base_model.parameters() if not p.requires_grad)}") + optimizer = torch.optim.SGD(ttt_params, lr=args.ttt_lr, momentum=args.ttt_momentum) + # Compile forward paths for TTT (matches eval_val_sliding behavior) + if not int(os.environ.get("TORCHDYNAMO_DISABLE", "0")): + compiled_logits_ttt = torch.compile(base_model.forward_logits, dynamic=False, fullgraph=True) + compiled_forward_ttt = torch.compile(base_model.forward, dynamic=False, fullgraph=True) + else: + compiled_logits_ttt = base_model.forward_logits + compiled_forward_ttt = base_model.forward + t0 = time.perf_counter() + for ci in range(num_chunks): + windows = chunk_windows[ci] + if not windows: + continue + chunk_start = ci * ttt_chunk + chunk_end = min((ci + 1) * ttt_chunk, total_tokens) + my_s = (len(windows) * rank) // world_size + my_e = (len(windows) * (rank + 1)) // world_size + my_windows = windows[my_s:my_e] + base_model.eval() + with torch.inference_mode(): + for bi in range(0, len(my_windows), batch_seqs): + batch_ws = my_windows[bi:bi + batch_seqs] + bsz = len(batch_ws) + x_batch = torch.zeros(bsz, seq_len, dtype=torch.int64, device=device) + y_batch = torch.zeros(bsz, seq_len, dtype=torch.int64, device=device) + wlens: list[int] = [] + for i, ws in enumerate(batch_ws): + end = min(ws + seq_len, total_tokens) + wlen = end - ws + wlens.append(wlen) + chunk_tok = val_tokens[ws:end + 1].to(dtype=torch.int64, device=device) + x_batch[i, :wlen] = chunk_tok[:-1] + y_batch[i, :wlen] = chunk_tok[1:] + with torch.autocast(device_type="cuda", dtype=torch.bfloat16): + logits = compiled_logits_ttt(x_batch) + nll = F.cross_entropy( + logits.reshape(-1, logits.size(-1)).float(), + y_batch.reshape(-1), reduction="none", + ).reshape(bsz, seq_len) + for i, ws in enumerate(batch_ws): + wlen = wlens[i] + s = 0 if ws == 0 else max(wlen - stride, 0) + scored_nll = nll[i, s:wlen].to(torch.float64) + loss_sum += scored_nll.sum() + token_count += float(wlen - s) + tgt, prev = y_batch[i, s:wlen], x_batch[i, s:wlen] + tb = base_bytes_lut[tgt].to(torch.float64) + tb += (has_leading_space_lut[tgt] & ~is_boundary_token_lut[prev]).to(torch.float64) + byte_count += tb.sum() + is_last_chunk = (ci == num_chunks - 1) + if not is_last_chunk and args.ttt_epochs > 0: + base_model.train() + chunk_seqs = (chunk_end - chunk_start) // seq_len + if chunk_seqs > 0: + cos_lr = args.ttt_lr * 0.5 * (1.0 + math.cos(math.pi * ci / max(num_chunks - 1, 1))) + for pg in optimizer.param_groups: + pg['lr'] = cos_lr + my_seq_s = (chunk_seqs * rank) // world_size + my_seq_e = (chunk_seqs * (rank + 1)) // world_size + my_chunk_seqs = my_seq_e - my_seq_s + for _ep in range(args.ttt_epochs): + for bs in range(0, my_chunk_seqs, args.ttt_batch_seqs): + be = min(bs + args.ttt_batch_seqs, my_chunk_seqs) + actual_bs = my_seq_s + bs + start_tok = chunk_start + actual_bs * seq_len + end_tok = chunk_start + (my_seq_s + be) * seq_len + 1 + if end_tok > val_tokens.numel(): + continue + local = val_tokens[start_tok:end_tok].to(device=device, dtype=torch.int64) + x = local[:-1].reshape(-1, seq_len) + y = local[1:].reshape(-1, seq_len) + optimizer.zero_grad(set_to_none=True) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16): + loss = compiled_forward_ttt(x, y) + loss.backward() + if world_size > 1: + for p in ttt_params: + if p.grad is not None: + dist.all_reduce(p.grad, op=dist.ReduceOp.AVG) + torch.nn.utils.clip_grad_norm_(ttt_params, args.ttt_grad_clip) + optimizer.step() + if rank == 0 and (ci % 10 == 0 or ci == num_chunks - 1): + elapsed = time.perf_counter() - t0 + rl = loss_sum.item() / max(token_count.item(), 1) + rbpb = rl / math.log(2.0) * (token_count.item() / max(byte_count.item(), 1)) if token_count.item() > 0 else 0.0 + log0(f" ttt_chunk [{ci+1}/{num_chunks}] bpb={rbpb:.6f} time={elapsed:.1f}s") + if dist.is_available() and dist.is_initialized(): + dist.all_reduce(loss_sum, op=dist.ReduceOp.SUM) + dist.all_reduce(token_count, op=dist.ReduceOp.SUM) + dist.all_reduce(byte_count, op=dist.ReduceOp.SUM) + val_loss = (loss_sum / token_count).item() + val_bpb = val_loss / math.log(2.0) * (token_count.item() / byte_count.item()) + for p in base_model.parameters(): + p.requires_grad_(True) + base_model.eval() + log0(f"ttt_sliding:done val_loss={val_loss:.6f} val_bpb={val_bpb:.6f} " + f"elapsed={time.perf_counter() - t0:.1f}s") + return val_loss, val_bpb +_INT6_MODE = bool(int(os.environ.get("INT6_QAT", "1"))) +_INT3_MODE = bool(int(os.environ.get("INT3_QUANT", "0"))) # Ternary quantization +_QUANT_OVERRIDE = int(os.environ.get("QUANT_BITS", "0")) # Override: 3,4,5,6,8 +if _QUANT_OVERRIDE > 0: + _QUANT_BITS = _QUANT_OVERRIDE + _QUANT_MAX_VAL = (1 << (_QUANT_BITS - 1)) - 1 # e.g. int5 -> 15, int4 -> 7 +elif _INT3_MODE: + _QUANT_MAX_VAL = 3 # Ternary: -3, -2, -1, 0, 1, 2, 3 + _QUANT_BITS = 3 +elif _INT6_MODE: + _QUANT_MAX_VAL = 31 + _QUANT_BITS = 6 +else: + _QUANT_MAX_VAL = 127 + _QUANT_BITS = 8 +CONTROL_TENSOR_NAME_PATTERNS = tuple( + pattern + for pattern in os.environ.get( + "CONTROL_TENSOR_NAME_PATTERNS", + "attn_scale,attn_scales,mlp_scale,mlp_scales,resid_mix,resid_mixes,q_gain,skip_weight,skip_weights,smear", + ).split(",") + if pattern +) +INT8_KEEP_FLOAT_FP32_NAME_PATTERNS = tuple( + pattern + for pattern in os.environ.get( + "INT8_KEEP_FLOAT_FP32_NAME_PATTERNS", + ",".join(CONTROL_TENSOR_NAME_PATTERNS), + ).split(",") + if pattern +) +INT8_KEEP_FLOAT_MAX_NUMEL = 65_536 +INT8_KEEP_FLOAT_STORE_DTYPE = torch.float16 +INT8_PER_ROW_SCALE_DTYPE = torch.float16 +INT8_CLIP_PERCENTILE = 99.99984 +INT8_CLIP_Q = INT8_CLIP_PERCENTILE / 100.0 +def tensor_nbytes(t: Tensor) -> int: + return int(t.numel()) * int(t.element_size()) +def keep_float_tensor(name: str, t: Tensor, passthrough_orig_dtypes: dict[str, str]) -> Tensor: + if any(pattern in name for pattern in INT8_KEEP_FLOAT_FP32_NAME_PATTERNS): + return t.float().contiguous() + if t.dtype in {torch.float32, torch.bfloat16}: + passthrough_orig_dtypes[name] = str(t.dtype).removeprefix("torch.") + return t.to(dtype=INT8_KEEP_FLOAT_STORE_DTYPE).contiguous() + return t +_GPTQ_LITE_PERCENTILES = [0.9999, 0.99995, 0.99999, 0.999995, 1.0] +def quantize_float_tensor(t: Tensor, gpu_device: torch.device | None = None) -> tuple[Tensor, Tensor]: + # Run GPTQ-lite search on GPU when available (much faster for large 2D weight matrices) + dev = gpu_device if (gpu_device is not None and t.ndim == 2 and t.numel() > 0) else t.device + t32 = t.float().to(dev) + qmax = _QUANT_MAX_VAL + if t32.ndim == 2 and t32.numel() > 0: + abs_vals = t32.abs() + best_q = None + best_scale = None + best_mse = None + for pct in _GPTQ_LITE_PERCENTILES: + if pct >= 1.0: + clip_abs = abs_vals.amax(dim=1) + else: + clip_abs = torch.quantile(abs_vals, pct, dim=1) + clip_abs = clip_abs.clamp_min(1e-8) + scale = (clip_abs / qmax).clamp_min(1.0 / qmax) + clipped = torch.maximum(torch.minimum(t32, clip_abs[:, None]), -clip_abs[:, None]) + q = torch.clamp(torch.round(clipped / scale[:, None]), -qmax, qmax) + recon = q * scale[:, None] + mse = (t32 - recon).square().mean(dim=1) + if best_mse is None: + best_mse = mse + best_q = q + best_scale = scale + else: + improved = mse < best_mse + if improved.any(): + best_mse = torch.where(improved, mse, best_mse) + best_q = torch.where(improved[:, None], q, best_q) + best_scale = torch.where(improved, scale, best_scale) + return best_q.cpu().to(torch.int8).contiguous(), best_scale.cpu().to(dtype=INT8_PER_ROW_SCALE_DTYPE).contiguous() + if t32.ndim == 2: + return torch.empty_like(t32, dtype=torch.int8, device="cpu"), torch.empty((t32.shape[0],), dtype=INT8_PER_ROW_SCALE_DTYPE) + t32 = t32.cpu() + clip_abs = float(torch.quantile(t32.abs().flatten(), INT8_CLIP_Q).item()) if t32.numel() else 0.0 + scale = torch.tensor(clip_abs / qmax if clip_abs > 0 else 1.0, dtype=torch.float32) + q = torch.clamp(torch.round(torch.clamp(t32, -clip_abs, clip_abs) / scale), -qmax, qmax).to(torch.int8).contiguous() + return q, scale +def quantize_state_dict_int8(state_dict: dict[str, Tensor], gpu_device: torch.device | None = None, + rank: int = 0, world_size: int = 1): + """Quantize state dict. When world_size > 1, distributes large tensor quantization across ranks.""" + quantized: dict[str, Tensor] = {} + scales: dict[str, Tensor] = {} + dtypes: dict[str, str] = {} + passthrough: dict[str, Tensor] = {} + passthrough_orig_dtypes: dict[str, str] = {} + qmeta: dict[str, dict[str, object]] = {} + stats = dict.fromkeys( + ("param_count", "num_tensors", "num_float_tensors", "num_nonfloat_tensors", "baseline_tensor_bytes", "int8_payload_bytes"), + 0, + ) + # Separate tensors into passthrough vs needs-quantization + to_quantize: list[tuple[str, Tensor]] = [] + for name, tensor in state_dict.items(): + t = tensor.detach().to("cpu").contiguous() + stats["param_count"] += int(t.numel()) + stats["num_tensors"] += 1 + stats["baseline_tensor_bytes"] += tensor_nbytes(t) + if not t.is_floating_point(): + stats["num_nonfloat_tensors"] += 1 + passthrough[name] = t + stats["int8_payload_bytes"] += tensor_nbytes(t) + continue + if t.numel() <= INT8_KEEP_FLOAT_MAX_NUMEL: + kept = keep_float_tensor(name, t, passthrough_orig_dtypes) + passthrough[name] = kept + stats["int8_payload_bytes"] += tensor_nbytes(kept) + continue + stats["num_float_tensors"] += 1 + to_quantize.append((name, t)) + # Distribute quantization work across ranks (round-robin by tensor index) + for idx, (name, t) in enumerate(to_quantize): + if world_size > 1 and idx % world_size != rank: + continue # Another rank handles this tensor + q, s = quantize_float_tensor(t, gpu_device=gpu_device) + if s.ndim > 0: + qmeta[name] = {"scheme": "per_row", "axis": 0} + quantized[name] = q + scales[name] = s + dtypes[name] = str(t.dtype).removeprefix("torch.") + stats["int8_payload_bytes"] += tensor_nbytes(q) + tensor_nbytes(s) + # Gather results from all ranks + if world_size > 1 and dist.is_available() and dist.is_initialized(): + # Each rank broadcasts its quantized tensors to all others + for idx, (name, t) in enumerate(to_quantize): + owner = idx % world_size + if owner == rank: + # Broadcast shape info then tensor data + q_tensor = quantized[name] + s_tensor = scales[name] + # Send via broadcast + q_gpu = q_tensor.to(gpu_device) if gpu_device is not None else q_tensor.cuda() + s_gpu = s_tensor.to(gpu_device) if gpu_device is not None else s_tensor.cuda() + dist.broadcast(q_gpu, src=owner) + dist.broadcast(s_gpu, src=owner) + quantized[name] = q_gpu.cpu() + scales[name] = s_gpu.cpu() + else: + # Allocate matching tensors and receive + q_shape = (t.shape[0],) + t.shape[1:] if t.ndim == 2 else t.shape + q_gpu = torch.empty(q_shape, dtype=torch.int8, device=gpu_device if gpu_device is not None else "cuda") + if t.ndim == 2 and t.numel() > 0: + s_gpu = torch.empty(t.shape[0], dtype=INT8_PER_ROW_SCALE_DTYPE, device=gpu_device if gpu_device is not None else "cuda") + else: + s_gpu = torch.empty((), dtype=torch.float32, device=gpu_device if gpu_device is not None else "cuda") + dist.broadcast(q_gpu, src=owner) + dist.broadcast(s_gpu, src=owner) + quantized[name] = q_gpu.cpu() + scales[name] = s_gpu.cpu() + dtypes[name] = str(t.dtype).removeprefix("torch.") + if s_gpu.ndim > 0: + qmeta[name] = {"scheme": "per_row", "axis": 0} + stats["int8_payload_bytes"] += tensor_nbytes(quantized[name]) + tensor_nbytes(scales[name]) + obj: dict[str, object] = { + "__quant_format__": f"int{_QUANT_BITS}_clean_per_row_v1", + "quantized": quantized, + "scales": scales, + "dtypes": dtypes, + "passthrough": passthrough, + } + if qmeta: + obj["qmeta"] = qmeta + if passthrough_orig_dtypes: + obj["passthrough_orig_dtypes"] = passthrough_orig_dtypes + return obj, stats +def dequantize_state_dict_int8(obj: dict[str, object]) -> dict[str, Tensor]: + out: dict[str, Tensor] = {} + qmeta = obj.get("qmeta", {}) + passthrough_orig_dtypes = obj.get("passthrough_orig_dtypes", {}) + for name, q in obj["quantized"].items(): + dtype = getattr(torch, obj["dtypes"][name]) + s = obj["scales"][name] + if qmeta.get(name, {}).get("scheme") == "per_row" or s.ndim > 0: + s = s.to(dtype=torch.float32) + out[name] = (q.float() * s.view(q.shape[0], *([1] * (q.ndim - 1)))).to(dtype=dtype).contiguous() + else: + scale = float(s.item()) + out[name] = (q.float() * scale).to(dtype=dtype).contiguous() + for name, t in obj["passthrough"].items(): + out_t = t.detach().to("cpu").contiguous() + orig_dtype = passthrough_orig_dtypes.get(name) + if isinstance(orig_dtype, str): + out_t = out_t.to(dtype=getattr(torch, orig_dtype)).contiguous() + out[name] = out_t + return out +def load_data_shard(file: Path) -> Tensor: + header_bytes = 256 * np.dtype(" None: + self.file_idx = (self.file_idx + 1) % len(self.files) + self.tokens = load_data_shard(self.files[self.file_idx]) + self.pos = 0 + def take(self, n: int) -> Tensor: + chunks: list[Tensor] = [] + remaining = n + while remaining > 0: + avail = self.tokens.numel() - self.pos + if avail <= 0: + self._advance_file() + continue + k = min(remaining, avail) + chunks.append(self.tokens[self.pos : self.pos + k]) + self.pos += k + remaining -= k + return chunks[0] if len(chunks) == 1 else torch.cat(chunks) +class DistributedTokenLoader: + def __init__(self, pattern: str, rank: int, world_size: int, device: torch.device): + self.rank = rank + self.world_size = world_size + self.device = device + self.stream = TokenStream(pattern) + def next_batch(self, global_tokens: int, seq_len: int, grad_accum_steps: int) -> tuple[Tensor, Tensor]: + local_tokens = global_tokens // (self.world_size * grad_accum_steps) + per_rank_span = local_tokens + 1 + chunk = self.stream.take(per_rank_span * self.world_size) + start = self.rank * per_rank_span + local = chunk[start : start + per_rank_span].to(dtype=torch.int64) + x = local[:-1].reshape(-1, seq_len) + y = local[1:].reshape(-1, seq_len) + return x.to(self.device, non_blocking=True), y.to(self.device, non_blocking=True) +class RMSNorm(nn.Module): + def __init__(self, eps: float | None = None): + super().__init__() + self.eps = eps + def forward(self, x: Tensor) -> Tensor: + return F.rms_norm(x, (x.size(-1),), eps=self.eps) +class CastedLinear(nn.Linear): + _int6_qat: bool = False + def forward(self, x: Tensor) -> Tensor: + w = self.weight.to(x.dtype) + if CastedLinear._int6_qat and self.training: + w = int6_ste(w) + bias = self.bias.to(x.dtype) if self.bias is not None else None + return F.linear(x, w, bias) +def restore_low_dim_params_to_fp32(module: nn.Module) -> None: + with torch.no_grad(): + for name, param in module.named_parameters(): + if (param.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS)) and param.dtype != torch.float32: + param.data = param.data.float() +class BigramHash(nn.Module): + def __init__(self, num_buckets: int, dim: int): + super().__init__() + self.num_buckets = num_buckets + self.emb = nn.Embedding(num_buckets, dim) + nn.init.zeros_(self.emb.weight) + def forward(self, input_ids: Tensor) -> Tensor: + prev = F.pad(input_ids[:, :-1], (1, 0), value=0) + bigram_hash = (prev * 31 + input_ids) % self.num_buckets + return self.emb(bigram_hash) +class SmearGate(nn.Module): + def __init__(self, dim: int): + super().__init__() + self.gate = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + def forward(self, x: Tensor) -> Tensor: + g = torch.sigmoid(self.gate.to(dtype=x.dtype))[None, None, :] + positions = torch.arange(1, x.shape[1] + 1, device=x.device, dtype=x.dtype).view(1, -1, 1) + smooth = torch.cumsum(x, dim=1) / positions + return g * x + (1 - g) * smooth +class Rotary(nn.Module): + def __init__(self, dim: int, base: float = 10000.0): + super().__init__() + inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.float32) / dim)) + self.register_buffer("inv_freq", inv_freq, persistent=False) + self._seq_len_cached = 0 + self._cos_cached: Tensor | None = None + self._sin_cached: Tensor | None = None + def forward(self, seq_len: int, device: torch.device, dtype: torch.dtype) -> tuple[Tensor, Tensor]: + if ( + self._cos_cached is None + or self._sin_cached is None + or self._seq_len_cached != seq_len + or self._cos_cached.device != device + ): + t = torch.arange(seq_len, device=device, dtype=self.inv_freq.dtype) + freqs = torch.outer(t, self.inv_freq.to(device)) + self._cos_cached = freqs.cos()[None, None, :, :] + self._sin_cached = freqs.sin()[None, None, :, :] + self._seq_len_cached = seq_len + return self._cos_cached.to(dtype=dtype), self._sin_cached.to(dtype=dtype) +def apply_rotary_emb(x: Tensor, cos: Tensor, sin: Tensor) -> Tensor: + half = x.size(-1) // 2 + x1, x2 = x[..., :half], x[..., half:] + return torch.cat((x1 * cos + x2 * sin, x1 * (-sin) + x2 * cos), dim=-1) +class CausalSelfAttention(nn.Module): + def __init__( + self, + dim: int, + num_heads: int, + num_kv_heads: int, + rope_base: float, + qk_gain_init: float, + rope_partial_dims: int = 0, + use_xsa: bool = False, + ): + super().__init__() + if dim % num_heads != 0: + raise ValueError("model_dim must be divisible by num_heads") + if num_heads % num_kv_heads != 0: + raise ValueError("num_heads must be divisible by num_kv_heads") + self.num_heads = num_heads + self.num_kv_heads = num_kv_heads + self.head_dim = dim // num_heads + if self.head_dim % 2 != 0: + raise ValueError("head_dim must be even for RoPE") + kv_dim = self.num_kv_heads * self.head_dim + self.c_q = CastedLinear(dim, dim, bias=False) + self.c_k = CastedLinear(dim, kv_dim, bias=False) + self.c_v = CastedLinear(dim, kv_dim, bias=False) + self.proj = CastedLinear(dim, dim, bias=False) + self.proj._zero_init = True + self.q_gain = nn.Parameter(torch.full((num_heads,), qk_gain_init, dtype=torch.float32)) + rope_dim = rope_partial_dims if 0 < rope_partial_dims < self.head_dim else self.head_dim + self.rope_dim = rope_dim + self.rotary = Rotary(rope_dim, base=rope_base) + self.use_xsa = use_xsa + def forward(self, x: Tensor) -> Tensor: + bsz, seqlen, dim = x.shape + q = self.c_q(x).reshape(bsz, seqlen, self.num_heads, self.head_dim).transpose(1, 2) + k = self.c_k(x).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim).transpose(1, 2) + v = self.c_v(x).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim).transpose(1, 2) + q = F.rms_norm(q, (q.size(-1),)) + k = F.rms_norm(k, (k.size(-1),)) + cos, sin = self.rotary(seqlen, x.device, q.dtype) + if self.rope_dim < self.head_dim: + q_rope, q_pass = q[..., :self.rope_dim], q[..., self.rope_dim:] + k_rope, k_pass = k[..., :self.rope_dim], k[..., self.rope_dim:] + q_rope = apply_rotary_emb(q_rope, cos, sin) + k_rope = apply_rotary_emb(k_rope, cos, sin) + q = torch.cat([q_rope, q_pass], dim=-1) + k = torch.cat([k_rope, k_pass], dim=-1) + else: + q = apply_rotary_emb(q, cos, sin) + k = apply_rotary_emb(k, cos, sin) + q = q * self.q_gain.to(dtype=q.dtype)[None, :, None, None] + y = F.scaled_dot_product_attention( + q, + k, + v, + attn_mask=None, + is_causal=True, + enable_gqa=(self.num_kv_heads != self.num_heads), + ) + if self.use_xsa: + if self.num_kv_heads != self.num_heads: + repeats = self.num_heads // self.num_kv_heads + v_expanded = v.repeat_interleave(repeats, dim=1) + else: + v_expanded = v + y = y - v_expanded + y = y.transpose(1, 2).contiguous().reshape(bsz, seqlen, dim) + return self.proj(y) +class MLP(nn.Module): + def __init__(self, dim: int, mlp_mult: int): + super().__init__() + hidden = mlp_mult * dim + self.fc = CastedLinear(dim, hidden, bias=False) + self.proj = CastedLinear(hidden, dim, bias=False) + self.proj._zero_init = True + def forward(self, x: Tensor) -> Tensor: + x = F.leaky_relu(self.fc(x), negative_slope=0.5) + return self.proj(x.square()) +class MoEMLP(nn.Module): + def __init__(self, dim: int, mlp_mult: int, num_experts: int): + super().__init__() + self.num_experts = num_experts + expert_hidden = (mlp_mult * dim) // num_experts + self.experts = nn.ModuleList([ + nn.ModuleDict({ + "fc": CastedLinear(dim, expert_hidden, bias=False), + "proj": CastedLinear(expert_hidden, dim, bias=False), + }) + for _ in range(num_experts) + ]) + for e in self.experts: + e["proj"]._zero_init = True + self.router = nn.Linear(dim, num_experts, bias=False) + def forward(self, x: Tensor) -> Tensor: + bsz, seqlen, dim = x.shape + logits = self.router(x.detach()) + weights = torch.softmax(logits, dim=-1) + indices = weights.argmax(dim=-1) + out = torch.zeros_like(x) + for i, expert in enumerate(self.experts): + mask = (indices == i) + if not mask.any(): + continue + tokens = x[mask] + h = F.leaky_relu(expert["fc"](tokens), negative_slope=0.5) + h = expert["proj"](h.square()) + out[mask] = h + return out +class Block(nn.Module): + def __init__( + self, + dim: int, + num_heads: int, + num_kv_heads: int, + mlp_mult: int, + rope_base: float, + qk_gain_init: float, + use_smeargate: bool = False, + rope_partial_dims: int = 0, + use_xsa: bool = False, + ln_scale_factor: float = 1.0, + moe_num_experts: int = 0, + ): + super().__init__() + self.attn_norm = RMSNorm() + self.mlp_norm = RMSNorm() + self.attn = CausalSelfAttention( + dim, num_heads, num_kv_heads, rope_base, qk_gain_init, + rope_partial_dims=rope_partial_dims, + use_xsa=use_xsa, + ) + self.mlp = MoEMLP(dim, mlp_mult, moe_num_experts) if moe_num_experts > 1 else MLP(dim, mlp_mult) + self.attn_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + self.mlp_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + self.resid_mix = nn.Parameter(torch.stack((torch.ones(dim), torch.zeros(dim))).float()) + self.smear = SmearGate(dim) if use_smeargate else None + self.ln_scale_factor = ln_scale_factor + def forward(self, x: Tensor, x0: Tensor) -> Tensor: + mix = self.resid_mix.to(dtype=x.dtype) + x = mix[0][None, None, :] * x + mix[1][None, None, :] * x0 + normed = self.attn_norm(x) + if self.ln_scale_factor != 1.0: + normed = normed * self.ln_scale_factor + attn_out = self.attn(normed) + x = x + self.attn_scale.to(dtype=x.dtype)[None, None, :] * attn_out + mlp_normed = self.mlp_norm(x) + if self.ln_scale_factor != 1.0: + mlp_normed = mlp_normed * self.ln_scale_factor + x = x + self.mlp_scale.to(dtype=x.dtype)[None, None, :] * self.mlp(mlp_normed) + if self.smear is not None: + x = self.smear(x) + return x +class GPT(nn.Module): + def __init__( + self, + vocab_size: int, + num_layers: int, + model_dim: int, + num_heads: int, + num_kv_heads: int, + mlp_mult: int, + tie_embeddings: bool, + tied_embed_init_std: float, + logit_softcap: float, + rope_base: float, + qk_gain_init: float, + bigramhash_buckets: int = 0, + use_smeargate: bool = False, + unet_skips: bool = True, + rope_partial_dims: int = 0, + ln_scale: bool = False, + xsa_layers: int = 0, + recurrence_repeats: int = 0, + recurrence_unique_head: int = 2, + recurrence_unique_tail: int = 2, + moe_num_experts: int = 0, + parallel_start_layer: int = -1, + ): + super().__init__() + if logit_softcap <= 0.0: + raise ValueError(f"logit_softcap must be positive, got {logit_softcap}") + self.tie_embeddings = tie_embeddings + self.tied_embed_init_std = tied_embed_init_std + self.logit_softcap = logit_softcap + self.tok_emb = nn.Embedding(vocab_size, model_dim) + self.bigram_hash = BigramHash(bigramhash_buckets, model_dim) if bigramhash_buckets > 0 else None + self.recurrence_repeats = recurrence_repeats + if recurrence_repeats > 0: + self.num_encoder_layers = 0 + self.num_decoder_layers = 0 + self.num_skip_weights = 0 + self.skip_weights = nn.Parameter(torch.zeros(0, dtype=torch.float32)) + n_head = recurrence_unique_head + n_tail = recurrence_unique_tail + effective_layers = n_head + recurrence_repeats + n_tail + def make_block(layer_idx: int, eff_total: int) -> Block: + xsa_start_idx = eff_total - xsa_layers if xsa_layers > 0 else eff_total + return Block( + model_dim, num_heads, num_kv_heads, mlp_mult, rope_base, qk_gain_init, + use_smeargate=use_smeargate, + rope_partial_dims=rope_partial_dims, + use_xsa=(layer_idx >= xsa_start_idx), + ln_scale_factor=(1.0 / math.sqrt(layer_idx + 1)) if ln_scale else 1.0, + moe_num_experts=moe_num_experts, + ) + self.head_blocks = nn.ModuleList([make_block(i, effective_layers) for i in range(n_head)]) + self.shared_block = make_block(n_head, effective_layers) + self.recurrence_gate = nn.Parameter(torch.full((recurrence_repeats,), 0.5, dtype=torch.float32)) + self.tail_blocks = nn.ModuleList([ + make_block(n_head + recurrence_repeats + i, effective_layers) for i in range(n_tail) + ]) + self.blocks = nn.ModuleList() + else: + if unet_skips: + self.num_encoder_layers = num_layers // 2 + self.num_decoder_layers = num_layers - self.num_encoder_layers + self.num_skip_weights = min(self.num_encoder_layers, self.num_decoder_layers) + else: + self.num_encoder_layers = num_layers + self.num_decoder_layers = 0 + self.num_skip_weights = 0 + self.skip_weights = nn.Parameter(torch.ones(self.num_skip_weights, model_dim, dtype=torch.float32)) + xsa_start = num_layers - xsa_layers if xsa_layers > 0 else num_layers + self.blocks = nn.ModuleList( + [ + Block( + model_dim, num_heads, num_kv_heads, mlp_mult, rope_base, qk_gain_init, + use_smeargate=use_smeargate, + rope_partial_dims=rope_partial_dims, + use_xsa=(i >= xsa_start), + ln_scale_factor=(1.0 / math.sqrt(i + 1)) if ln_scale else 1.0, + moe_num_experts=moe_num_experts, + ) + for i in range(num_layers) + ] + ) + self.head_blocks = nn.ModuleList() + self.shared_block = None + self.recurrence_gate = None + self.tail_blocks = nn.ModuleList() + self.final_norm = RMSNorm() + self.lm_head = None if tie_embeddings else CastedLinear(model_dim, vocab_size, bias=False) + if self.lm_head is not None: + self.lm_head._zero_init = True + self.unet_skips = unet_skips + self.parallel_start_layer = parallel_start_layer + if self.parallel_start_layer >= 0: + num_parallel = num_layers - self.parallel_start_layer + # Cross-lane routing: [num_parallel_layers, 2, 2] — how each sublayer writes to each lane + self.parallel_post_lambdas = nn.Parameter(torch.ones(num_parallel, 2, 2, dtype=torch.float32)) + # Residual carry: [num_parallel_layers, 2] — scaling for residual stream + self.parallel_resid_lambdas = nn.Parameter( + torch.full((num_parallel, 2), math.sqrt(1.1), dtype=torch.float32) + ) + else: + self.parallel_post_lambdas = None + self.parallel_resid_lambdas = None + self._init_weights() + def _init_weights(self) -> None: + if self.tie_embeddings: + nn.init.normal_(self.tok_emb.weight, mean=0.0, std=self.tied_embed_init_std) + for module in self.modules(): + if isinstance(module, nn.Linear) and getattr(module, "_zero_init", False): + nn.init.zeros_(module.weight) + def forward(self, input_ids: Tensor, target_ids: Tensor) -> Tensor: + x = self.tok_emb(input_ids) + if self.bigram_hash is not None: + x = x + self.bigram_hash(input_ids) + x = F.rms_norm(x, (x.size(-1),)) + x0 = x + if self.recurrence_repeats > 0: + for block in self.head_blocks: + x = block(x, x0) + for r in range(self.recurrence_repeats): + gate = torch.sigmoid(self.recurrence_gate[r]).to(dtype=x.dtype) + residual = x + x = self.shared_block(x, x0) + x = gate * x + (1.0 - gate) * residual + for block in self.tail_blocks: + x = block(x, x0) + else: + skips: list[Tensor] = [] + lane0: Tensor | None = None + lane1: Tensor | None = None + + for i in range(self.num_encoder_layers): + layer_idx = i + if self.parallel_start_layer >= 0 and layer_idx >= self.parallel_start_layer: + # Two-lane mode + if lane0 is None: + lane0 = x.clone() + lane1 = x.clone() + + pi = layer_idx - self.parallel_start_layer # parallel index + block = self.blocks[layer_idx] + + mix = block.resid_mix.to(dtype=lane0.dtype) + + # Attn reads from lane0 + attn_in = block.attn_norm(mix[0] * lane0 + mix[1] * x0) + attn_out = block.attn(attn_in) * block.attn_scale.to(dtype=lane0.dtype) + + # MLP reads from lane1 + mlp_in = block.mlp_norm(mix[0] * lane1 + mix[1] * x0) + mlp_out = block.mlp(mlp_in) * block.mlp_scale.to(dtype=lane1.dtype) + + # Cross-lane routing + post = self.parallel_post_lambdas[pi].to(dtype=lane0.dtype) # [2, 2] + resid_l = self.parallel_resid_lambdas[pi].to(dtype=lane0.dtype) # [2] + + # Update lanes: each sublayer writes to both lanes + new_lane0 = resid_l[0] * lane0 + post[0, 0] * attn_out + post[1, 0] * mlp_out + new_lane1 = resid_l[1] * lane1 + post[0, 1] * attn_out + post[1, 1] * mlp_out + lane0, lane1 = new_lane0, new_lane1 + else: + # Standard sequential mode + x = self.blocks[layer_idx](x, x0) + + if self.unet_skips: + skips.append(lane0 if lane0 is not None else x) + + for i in range(self.num_decoder_layers): + layer_idx = self.num_encoder_layers + i + # Apply skip connections (same LIFO logic as original) + if skips: + cur_dtype = lane0.dtype if lane0 is not None else x.dtype + skip_val = self.skip_weights[i].to(dtype=cur_dtype)[None, None, :] * skips.pop() + if lane0 is not None: + lane0 = lane0 + skip_val + lane1 = lane1 + skip_val + else: + x = x + skip_val + + if self.parallel_start_layer >= 0 and layer_idx >= self.parallel_start_layer: + if lane0 is None: + lane0 = x.clone() + lane1 = x.clone() + + pi = layer_idx - self.parallel_start_layer + block = self.blocks[layer_idx] + mix = block.resid_mix.to(dtype=lane0.dtype) + + attn_in = block.attn_norm(mix[0] * lane0 + mix[1] * x0) + attn_out = block.attn(attn_in) * block.attn_scale.to(dtype=lane0.dtype) + + mlp_in = block.mlp_norm(mix[0] * lane1 + mix[1] * x0) + mlp_out = block.mlp(mlp_in) * block.mlp_scale.to(dtype=lane1.dtype) + + post = self.parallel_post_lambdas[pi].to(dtype=lane0.dtype) + resid_l = self.parallel_resid_lambdas[pi].to(dtype=lane0.dtype) + + new_lane0 = resid_l[0] * lane0 + post[0, 0] * attn_out + post[1, 0] * mlp_out + new_lane1 = resid_l[1] * lane1 + post[0, 1] * attn_out + post[1, 1] * mlp_out + lane0, lane1 = new_lane0, new_lane1 + else: + x = self.blocks[layer_idx](x, x0) + + # Final merge + if lane0 is not None: + x = (lane0 + lane1) * 0.5 + + x = self.final_norm(x).reshape(-1, x.size(-1)) + targets = target_ids.reshape(-1) + if self.tie_embeddings: + logits_proj = F.linear(x, self.tok_emb.weight) + else: + if self.lm_head is None: + raise RuntimeError("lm_head is required when tie_embeddings=False") + logits_proj = self.lm_head(x) + logits = self.logit_softcap * torch.tanh(logits_proj / self.logit_softcap) + if getattr(self, '_return_logits', False): + return logits.float().reshape(input_ids.shape[0], input_ids.shape[1], -1) + return F.cross_entropy(logits.float(), targets, reduction="mean") + def forward_logits(self, input_ids: Tensor) -> Tensor: + x = self.tok_emb(input_ids) + if self.bigram_hash is not None: + x = x + self.bigram_hash(input_ids) + x = F.rms_norm(x, (x.size(-1),)) + x0 = x + if self.recurrence_repeats > 0: + for block in self.head_blocks: + x = block(x, x0) + for r in range(self.recurrence_repeats): + gate = torch.sigmoid(self.recurrence_gate[r]).to(dtype=x.dtype) + residual = x + x = self.shared_block(x, x0) + x = gate * x + (1.0 - gate) * residual + for block in self.tail_blocks: + x = block(x, x0) + else: + skips: list[Tensor] = [] + lane0: Tensor | None = None + lane1: Tensor | None = None + + for i in range(self.num_encoder_layers): + layer_idx = i + if self.parallel_start_layer >= 0 and layer_idx >= self.parallel_start_layer: + # Two-lane mode + if lane0 is None: + lane0 = x.clone() + lane1 = x.clone() + + pi = layer_idx - self.parallel_start_layer + block = self.blocks[layer_idx] + + mix = block.resid_mix.to(dtype=lane0.dtype) + + # Attn reads from lane0 + attn_in = block.attn_norm(mix[0] * lane0 + mix[1] * x0) + attn_out = block.attn(attn_in) * block.attn_scale.to(dtype=lane0.dtype) + + # MLP reads from lane1 + mlp_in = block.mlp_norm(mix[0] * lane1 + mix[1] * x0) + mlp_out = block.mlp(mlp_in) * block.mlp_scale.to(dtype=lane1.dtype) + + # Cross-lane routing + post = self.parallel_post_lambdas[pi].to(dtype=lane0.dtype) # [2, 2] + resid_l = self.parallel_resid_lambdas[pi].to(dtype=lane0.dtype) # [2] + + # Update lanes: each sublayer writes to both lanes + new_lane0 = resid_l[0] * lane0 + post[0, 0] * attn_out + post[1, 0] * mlp_out + new_lane1 = resid_l[1] * lane1 + post[0, 1] * attn_out + post[1, 1] * mlp_out + lane0, lane1 = new_lane0, new_lane1 + else: + # Standard sequential mode + x = self.blocks[layer_idx](x, x0) + + if self.unet_skips: + skips.append(lane0 if lane0 is not None else x) + + for i in range(self.num_decoder_layers): + layer_idx = self.num_encoder_layers + i + # Apply skip connections (same LIFO logic as original) + if skips: + cur_dtype = lane0.dtype if lane0 is not None else x.dtype + skip_val = self.skip_weights[i].to(dtype=cur_dtype)[None, None, :] * skips.pop() + if lane0 is not None: + lane0 = lane0 + skip_val + lane1 = lane1 + skip_val + else: + x = x + skip_val + + if self.parallel_start_layer >= 0 and layer_idx >= self.parallel_start_layer: + if lane0 is None: + lane0 = x.clone() + lane1 = x.clone() + + pi = layer_idx - self.parallel_start_layer + block = self.blocks[layer_idx] + mix = block.resid_mix.to(dtype=lane0.dtype) + + attn_in = block.attn_norm(mix[0] * lane0 + mix[1] * x0) + attn_out = block.attn(attn_in) * block.attn_scale.to(dtype=lane0.dtype) + + mlp_in = block.mlp_norm(mix[0] * lane1 + mix[1] * x0) + mlp_out = block.mlp(mlp_in) * block.mlp_scale.to(dtype=lane1.dtype) + + post = self.parallel_post_lambdas[pi].to(dtype=lane0.dtype) + resid_l = self.parallel_resid_lambdas[pi].to(dtype=lane0.dtype) + + new_lane0 = resid_l[0] * lane0 + post[0, 0] * attn_out + post[1, 0] * mlp_out + new_lane1 = resid_l[1] * lane1 + post[0, 1] * attn_out + post[1, 1] * mlp_out + lane0, lane1 = new_lane0, new_lane1 + else: + x = self.blocks[layer_idx](x, x0) + + # Final merge + if lane0 is not None: + x = (lane0 + lane1) * 0.5 + + x = self.final_norm(x) + if self.tie_embeddings: + logits_proj = F.linear(x, self.tok_emb.weight) + else: + if self.lm_head is None: + raise RuntimeError("lm_head is required when tie_embeddings=False") + logits_proj = self.lm_head(x) + return self.logit_softcap * torch.tanh(logits_proj / self.logit_softcap) +def main() -> None: + global zeropower_via_newtonschulz5 + code = Path(__file__).read_text(encoding="utf-8") + args = Hyperparameters() + if not int(os.environ.get("TORCHDYNAMO_DISABLE", "0")): + zeropower_via_newtonschulz5 = torch.compile(zeropower_via_newtonschulz5) + distributed = "RANK" in os.environ and "WORLD_SIZE" in os.environ + rank = int(os.environ.get("RANK", "0")) + world_size = int(os.environ.get("WORLD_SIZE", "1")) + local_rank = int(os.environ.get("LOCAL_RANK", "0")) + if world_size <= 0: + raise ValueError(f"WORLD_SIZE must be positive, got {world_size}") + if 8 % world_size != 0: + raise ValueError(f"WORLD_SIZE={world_size} must divide 8 so grad_accum_steps stays integral") + grad_accum_steps = 8 // world_size + grad_scale = 1.0 / grad_accum_steps + if not torch.cuda.is_available(): + raise RuntimeError("CUDA is required") + device = torch.device("cuda", local_rank) + torch.cuda.set_device(device) + if distributed: + dist.init_process_group(backend="nccl", device_id=device) + dist.barrier() + master_process = rank == 0 + torch.backends.cuda.matmul.allow_tf32 = True + torch.backends.cudnn.allow_tf32 = True + from torch.backends.cuda import enable_cudnn_sdp, enable_flash_sdp, enable_math_sdp, enable_mem_efficient_sdp + _use_flash = sys.platform != "win32" + enable_cudnn_sdp(False) + enable_flash_sdp(_use_flash) + enable_mem_efficient_sdp(False) + enable_math_sdp(not _use_flash) + # Create per-run output directory so artifacts/logs never collide + run_dir = Path(os.environ.get("RUN_OUTPUT_DIR", f"runs/{args.run_id}")) + if master_process: + run_dir.mkdir(parents=True, exist_ok=True) + # Backward-compat: also keep logs/ symlink/copy + os.makedirs("logs", exist_ok=True) + logfile = None + if master_process: + logfile = str(run_dir / "log.txt") + # Also write to legacy logs/ path for compatibility + legacy_logfile = f"logs/{args.run_id}.txt" + if not Path(legacy_logfile).exists(): + try: + os.symlink(os.path.abspath(logfile), legacy_logfile) + except OSError: + pass # Windows or cross-device; will write to both + print(logfile) + def log0(msg: str, console: bool = True) -> None: + if not master_process: + return + if console: + print(msg) + if logfile is not None: + with open(logfile, "a", encoding="utf-8") as f: + print(msg, file=f) + log0(code, console=False) + log0("=" * 100, console=False) + log0(f"Running Python {sys.version}", console=False) + log0(f"Running PyTorch {torch.__version__}", console=False) + log0( + subprocess.run(["nvidia-smi"], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True, check=False).stdout, + console=False, + ) + log0("=" * 100, console=False) + random.seed(args.seed) + np.random.seed(args.seed) + torch.manual_seed(args.seed) + torch.cuda.manual_seed_all(args.seed) + if not args.tokenizer_path.endswith(".model"): + raise ValueError(f"Script only setup for SentencePiece .model file: {args.tokenizer_path}") + sp = spm.SentencePieceProcessor(model_file=args.tokenizer_path) + if int(sp.vocab_size()) != args.vocab_size: + raise ValueError( + f"VOCAB_SIZE={args.vocab_size} does not match tokenizer vocab_size={int(sp.vocab_size())}" + ) + dataset_dir = Path(args.data_path).resolve() + actual_train_files = len(list(dataset_dir.glob("fineweb_train_*.bin"))) + val_tokens = load_validation_tokens(args.val_files, args.train_seq_len) + base_bytes_lut, has_leading_space_lut, is_boundary_token_lut = build_sentencepiece_luts( + sp, args.vocab_size, device + ) + log0(f"val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path={args.tokenizer_path}") + log0(f"train_loader:dataset:{dataset_dir.name} train_shards:{actual_train_files}") + log0(f"val_loader:shards pattern={args.val_files} tokens:{val_tokens.numel() - 1}") + CastedLinear._int6_qat = args.int6_qat + base_model = GPT( + vocab_size=args.vocab_size, + num_layers=args.num_layers, + model_dim=args.model_dim, + num_heads=args.num_heads, + num_kv_heads=args.num_kv_heads, + mlp_mult=args.mlp_mult, + tie_embeddings=args.tie_embeddings, + tied_embed_init_std=args.tied_embed_init_std, + logit_softcap=args.logit_softcap, + rope_base=args.rope_base, + qk_gain_init=args.qk_gain_init, + bigramhash_buckets=args.bigramhash_buckets, + use_smeargate=args.smeargate, + unet_skips=args.unet_skips, + rope_partial_dims=args.rope_partial_dims, + ln_scale=args.ln_scale, + xsa_layers=args.xsa_layers, + recurrence_repeats=args.recurrence_repeats, + recurrence_unique_head=args.recurrence_unique_head, + recurrence_unique_tail=args.recurrence_unique_tail, + moe_num_experts=args.moe_num_experts, + parallel_start_layer=args.parallel_start_layer, + ).to(device).bfloat16() + for module in base_model.modules(): + if isinstance(module, CastedLinear): + module.float() + restore_low_dim_params_to_fp32(base_model) + # Keep parallel routing params in float32 (they are not covered by restore_low_dim_params_to_fp32) + with torch.no_grad(): + if base_model.parallel_post_lambdas is not None: + base_model.parallel_post_lambdas.data = base_model.parallel_post_lambdas.data.float() + if base_model.parallel_resid_lambdas is not None: + base_model.parallel_resid_lambdas.data = base_model.parallel_resid_lambdas.data.float() + if int(os.environ.get("TORCHDYNAMO_DISABLE", "0")): + compiled_model = base_model + else: + compiled_model = torch.compile(base_model, dynamic=False, fullgraph=True) + model: nn.Module = DDP(compiled_model, device_ids=[local_rank], broadcast_buffers=False) if distributed else compiled_model + if args.recurrence_repeats > 0: + block_named_params = ( + list(base_model.head_blocks.named_parameters()) + + list(base_model.shared_block.named_parameters()) + + list(base_model.tail_blocks.named_parameters()) + ) + if base_model.recurrence_gate is not None: + block_named_params.append(("recurrence_gate", base_model.recurrence_gate)) + else: + block_named_params = list(base_model.blocks.named_parameters()) + matrix_params = [ + p + for name, p in block_named_params + if p.ndim == 2 and not any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) + ] + scalar_params = [ + p + for name, p in block_named_params + if p.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) + ] + if base_model.skip_weights.numel() > 0: + scalar_params.append(base_model.skip_weights) + if base_model.parallel_post_lambdas is not None: + scalar_params.append(base_model.parallel_post_lambdas) + if base_model.parallel_resid_lambdas is not None: + scalar_params.append(base_model.parallel_resid_lambdas) + token_lr = args.tied_embed_lr if args.tie_embeddings else args.embed_lr + embed_params: list[nn.Parameter] = [base_model.tok_emb.weight] + if base_model.bigram_hash is not None: + embed_params.append(base_model.bigram_hash.emb.weight) + optimizer_tok = torch.optim.Adam( + [{"params": embed_params, "lr": token_lr, "base_lr": token_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + fused=True, + ) + optimizer_muon = Muon( + matrix_params, + lr=args.matrix_lr, + momentum=args.muon_momentum, + backend_steps=args.muon_backend_steps, + ) + for group in optimizer_muon.param_groups: + group["base_lr"] = args.matrix_lr + optimizer_scalar = torch.optim.Adam( + [{"params": scalar_params, "lr": args.scalar_lr, "base_lr": args.scalar_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + fused=True, + ) + optimizers: list[torch.optim.Optimizer] = [optimizer_tok, optimizer_muon, optimizer_scalar] + if base_model.lm_head is not None: + optimizer_head = torch.optim.Adam( + [{"params": [base_model.lm_head.weight], "lr": args.head_lr, "base_lr": args.head_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + fused=True, + ) + optimizers.insert(1, optimizer_head) + n_params = sum(p.numel() for p in base_model.parameters()) + log0(f"model_params:{n_params}") + log0(f"world_size:{world_size} grad_accum_steps:{grad_accum_steps}") + log0(f"sdp_backends:cudnn=False flash={_use_flash} mem_efficient=False math={not _use_flash}") + log0(f"attention_mode:gqa num_heads:{args.num_heads} num_kv_heads:{args.num_kv_heads}") + log0( + f"tie_embeddings:{args.tie_embeddings} embed_lr:{token_lr} " + f"head_lr:{args.head_lr if base_model.lm_head is not None else 0.0} " + f"matrix_lr:{args.matrix_lr} scalar_lr:{args.scalar_lr}" + ) + log0( + f"train_batch_tokens:{args.train_batch_tokens} train_seq_len:{args.train_seq_len} " + f"iterations:{args.iterations} warmup_steps:{args.warmup_steps} " + f"max_wallclock_seconds:{args.max_wallclock_seconds:.3f}" + ) + log0(f"seed:{args.seed}") + log0( + f"meta_baseline: int6_qat:{args.int6_qat} bigramhash:{args.bigramhash_buckets} " + f"smeargate:{args.smeargate} ema_decay:{args.ema_decay} unet_skips:{args.unet_skips} " + f"warmdown_iters:{args.warmdown_iters} compression:{'zstd-22' if _HAS_ZSTD else 'zlib-9'}" + ) + log0( + f"strong_gains: leaky_relu_sq:True rope_partial_dims:{args.rope_partial_dims} " + f"ln_scale:{args.ln_scale} xsa_layers:{args.xsa_layers}" + ) + if args.recurrence_repeats > 0: + eff = args.recurrence_unique_head + args.recurrence_repeats + args.recurrence_unique_tail + log0( + f"depth_recurrence: head={args.recurrence_unique_head} repeats={args.recurrence_repeats} " + f"tail={args.recurrence_unique_tail} effective_layers={eff}" + ) + train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) + def zero_grad_all() -> None: + for opt in optimizers: + opt.zero_grad(set_to_none=True) + max_wallclock_ms = 1000.0 * args.max_wallclock_seconds if args.max_wallclock_seconds > 0 else None + def lr_mul(step: int, elapsed_ms: float) -> float: + if args.warmdown_iters <= 0: + return 1.0 + if max_wallclock_ms is None: + warmdown_start = max(args.iterations - args.warmdown_iters, 0) + return max((args.iterations - step) / max(args.warmdown_iters, 1), 0.0) if warmdown_start <= step < args.iterations else 1.0 + step_ms = elapsed_ms / max(step, 1) + warmdown_ms = args.warmdown_iters * step_ms + remaining_ms = max(max_wallclock_ms - elapsed_ms, 0.0) + return remaining_ms / max(warmdown_ms, 1e-9) if remaining_ms <= warmdown_ms else 1.0 + if args.warmup_steps > 0: + initial_model_state = {name: tensor.detach().cpu().clone() for name, tensor in base_model.state_dict().items()} + initial_optimizer_states = [copy.deepcopy(opt.state_dict()) for opt in optimizers] + model.train() + for warmup_step in range(args.warmup_steps): + zero_grad_all() + for micro_step in range(grad_accum_steps): + if distributed: + model.require_backward_grad_sync = micro_step == grad_accum_steps - 1 + x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + warmup_loss = model(x, y) + (warmup_loss * grad_scale).backward() + for opt in optimizers: + opt.step() + zero_grad_all() + if args.warmup_steps <= 20 or (warmup_step + 1) % 10 == 0 or warmup_step + 1 == args.warmup_steps: + log0(f"warmup_step:{warmup_step + 1}/{args.warmup_steps}") + base_model.load_state_dict(initial_model_state, strict=True) + for opt, state in zip(optimizers, initial_optimizer_states, strict=True): + opt.load_state_dict(state) + zero_grad_all() + if distributed: + model.require_backward_grad_sync = True + train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) + ema_state: dict[str, Tensor] | None = None + if args.ema_decay > 0: + ema_state = {name: param.data.clone() for name, param in base_model.named_parameters()} + log0(f"ema:enabled decay={args.ema_decay}") + training_time_ms = 0.0 + stop_after_step: int | None = None + torch.cuda.synchronize() + t0 = time.perf_counter() + step = 0 + while True: + last_step = step == args.iterations or (stop_after_step is not None and step >= stop_after_step) + should_validate = last_step or (args.val_loss_every > 0 and step % args.val_loss_every == 0) + if should_validate: + torch.cuda.synchronize() + training_time_ms += 1000.0 * (time.perf_counter() - t0) + ema_backup: dict[str, Tensor] | None = None + if ema_state is not None: + ema_backup = {} + with torch.no_grad(): + for name, p in base_model.named_parameters(): + ema_backup[name] = p.data.clone() + p.data.copy_(ema_state[name]) + val_loss, val_bpb = eval_val( + args, + model, + rank, + world_size, + device, + grad_accum_steps, + val_tokens, + base_bytes_lut, + has_leading_space_lut, + is_boundary_token_lut, + ) + if ema_backup is not None: + with torch.no_grad(): + for name, p in base_model.named_parameters(): + p.data.copy_(ema_backup[name]) + del ema_backup + log0( + f"step:{step}/{args.iterations} val_loss:{val_loss:.4f} val_bpb:{val_bpb:.4f} " + f"train_time:{training_time_ms:.0f}ms step_avg:{training_time_ms / max(step, 1):.2f}ms" + ) + torch.cuda.synchronize() + t0 = time.perf_counter() + if last_step: + if stop_after_step is not None and step < args.iterations: + log0( + f"stopping_early: wallclock_cap train_time:{training_time_ms:.0f}ms " + f"step:{step}/{args.iterations}" + ) + break + elapsed_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) + scale = lr_mul(step, elapsed_ms) + zero_grad_all() + train_loss = torch.zeros((), device=device) + for micro_step in range(grad_accum_steps): + if distributed: + model.require_backward_grad_sync = micro_step == grad_accum_steps - 1 + x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + loss = model(x, y) + train_loss += loss.detach() + (loss * grad_scale).backward() + train_loss /= grad_accum_steps + frac = min(step / args.muon_momentum_warmup_steps, 1.0) if args.muon_momentum_warmup_steps > 0 else 1.0 + muon_momentum = (1 - frac) * args.muon_momentum_warmup_start + frac * args.muon_momentum + for group in optimizer_muon.param_groups: + group["momentum"] = muon_momentum + for opt in optimizers: + for group in opt.param_groups: + group["lr"] = group["base_lr"] * scale + if args.grad_clip_norm > 0: + torch.nn.utils.clip_grad_norm_(base_model.parameters(), args.grad_clip_norm) + for opt in optimizers: + opt.step() + zero_grad_all() + if ema_state is not None: + with torch.no_grad(): + decay = args.ema_decay + for name, p in base_model.named_parameters(): + ema_state[name].mul_(decay).add_(p.data, alpha=1.0 - decay) + step += 1 + if args.checkpoint_every > 0 and step % args.checkpoint_every == 0 and master_process: + ckpt_dir = run_dir / "checkpoints" + ckpt_dir.mkdir(exist_ok=True) + ckpt_path = ckpt_dir / f"ckpt_step{step}.pt" + ckpt_data = {"step": step, "model": base_model.state_dict()} + if ema_state is not None: + ckpt_data["ema"] = {k: v.cpu() for k, v in ema_state.items()} + torch.save(ckpt_data, ckpt_path) + ckpts = sorted(ckpt_dir.glob("ckpt_step*.pt"), key=lambda p: p.stat().st_mtime) + for old in ckpts[:-2]: + old.unlink(missing_ok=True) + log0(f"checkpoint:saved {ckpt_path} (kept {min(len(ckpts), 2)} latest)") + approx_training_time_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) + should_log_train = ( + args.train_log_every > 0 + and (step <= 10 or step % args.train_log_every == 0 or stop_after_step is not None) + ) + if should_log_train: + log0( + f"step:{step}/{args.iterations} train_loss:{train_loss.item():.4f} " + f"train_time:{approx_training_time_ms:.0f}ms step_avg:{approx_training_time_ms / step:.2f}ms" + ) + reached_cap = max_wallclock_ms is not None and approx_training_time_ms >= max_wallclock_ms + if distributed and max_wallclock_ms is not None: + reached_cap_tensor = torch.tensor(int(reached_cap), device=device) + dist.all_reduce(reached_cap_tensor, op=dist.ReduceOp.MAX) + reached_cap = bool(reached_cap_tensor.item()) + if stop_after_step is None and reached_cap: + stop_after_step = step + log0( + f"peak memory allocated: {torch.cuda.max_memory_allocated() // 1024 // 1024} MiB " + f"reserved: {torch.cuda.max_memory_reserved() // 1024 // 1024} MiB" + ) + if ema_state is not None: + with torch.no_grad(): + for name, p in base_model.named_parameters(): + p.data.copy_(ema_state[name]) + log0("ema:applied EMA weights for export") + fp32_path = str(run_dir / "final_model.pt") + if master_process: + torch.save(base_model.state_dict(), fp32_path) + model_bytes = os.path.getsize(fp32_path) + code_bytes = len(code.encode("utf-8")) + log0(f"Serialized model: {model_bytes} bytes → {fp32_path}") + log0(f"Code size: {code_bytes} bytes") + log0(f"Total submission size: {model_bytes + code_bytes} bytes") + quant_label = f"int{_QUANT_BITS}" + compress_label = "zstd" if _HAS_ZSTD else "zlib" + log0(f"quantization:start {quant_label} distributing across {world_size} GPUs") + torch.cuda.synchronize() + t_quant = time.perf_counter() + quant_obj, quant_stats = quantize_state_dict_int8( + base_model.state_dict(), gpu_device=device, rank=rank, world_size=world_size, + ) + torch.cuda.synchronize() + log0(f"quantization:done {quant_label} time={1000.0 * (time.perf_counter() - t_quant):.0f}ms") + quant_buf = io.BytesIO() + torch.save(quant_obj, quant_buf) + quant_raw = quant_buf.getvalue() + if _HAS_ZSTD: + compressor = zstd.ZstdCompressor(level=22, threads=-1) # -1 = use all CPU cores + quant_blob = compressor.compress(quant_raw) + else: + quant_blob = zlib.compress(quant_raw, level=9) + quant_raw_bytes = len(quant_raw) + artifact_name = str(run_dir / f"final_model.{quant_label}.ptz") + if master_process: + with open(artifact_name, "wb") as f: + f.write(quant_blob) + quant_file_bytes = os.path.getsize(artifact_name) + code_bytes = len(code.encode("utf-8")) + ratio = quant_stats["baseline_tensor_bytes"] / max(quant_stats["int8_payload_bytes"], 1) + log0( + f"Serialized model {quant_label}+{compress_label}: {quant_file_bytes} bytes " + f"(payload:{quant_stats['int8_payload_bytes']} raw_torch:{quant_raw_bytes} payload_ratio:{ratio:.2f}x)" + ) + log0(f"Total submission size {quant_label}+{compress_label}: {quant_file_bytes + code_bytes} bytes") + log0(f"artifact_path:{artifact_name}") + if distributed: + dist.barrier() + with open(artifact_name, "rb") as f: + quant_blob_disk = f.read() + if _HAS_ZSTD: + decompressor = zstd.ZstdDecompressor() + quant_state = torch.load(io.BytesIO(decompressor.decompress(quant_blob_disk)), map_location="cpu") + else: + quant_state = torch.load(io.BytesIO(zlib.decompress(quant_blob_disk)), map_location="cpu") + base_model.load_state_dict(dequantize_state_dict_int8(quant_state), strict=True) + torch.cuda.synchronize() + t_qeval = time.perf_counter() + q_val_loss, q_val_bpb = eval_val( + args, + model, + rank, + world_size, + device, + grad_accum_steps, + val_tokens, + base_bytes_lut, + has_leading_space_lut, + is_boundary_token_lut, + ) + torch.cuda.synchronize() + log0( + f"final_{quant_label}_{compress_label}_roundtrip val_loss:{q_val_loss:.4f} val_bpb:{q_val_bpb:.4f} " + f"eval_time:{1000.0 * (time.perf_counter() - t_qeval):.0f}ms" + ) + log0(f"final_{quant_label}_{compress_label}_roundtrip_exact val_loss:{q_val_loss:.8f} val_bpb:{q_val_bpb:.8f}") + if args.eval_stride > 0 and args.eval_stride < args.train_seq_len: + torch.cuda.synchronize() + t_slide = time.perf_counter() + sw_val_loss, sw_val_bpb = eval_val_sliding( + args, base_model, rank, world_size, device, + val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, + stride=args.eval_stride, + ) + torch.cuda.synchronize() + log0( + f"final_sliding_window val_loss:{sw_val_loss:.4f} val_bpb:{sw_val_bpb:.4f} " + f"stride:{args.eval_stride} eval_time:{1000.0 * (time.perf_counter() - t_slide):.0f}ms" + ) + log0(f"final_sliding_window_exact val_loss:{sw_val_loss:.8f} val_bpb:{sw_val_bpb:.8f}") + if args.ttt_enabled: + torch.cuda.synchronize() + t_ttt = time.perf_counter() + + # Choose LoRA TTT or full TTT based on TTT_LORA_ENABLED flag + if args.ttt_lora_enabled: + ttt_loss, ttt_bpb = eval_val_sliding_lora_ttt( + args, base_model, rank, world_size, device, + val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, + stride=args.eval_stride if args.eval_stride > 0 else 64, + log0=log0, + ) + ttt_label = "legal_lora_ttt" + else: + ttt_loss, ttt_bpb = eval_val_sliding_ttt( + args, base_model, rank, world_size, device, + val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, + stride=args.eval_stride if args.eval_stride > 0 else 64, + log0=log0, + ) + ttt_label = "legal_ttt" + + torch.cuda.synchronize() + log0( + f"{ttt_label} val_loss:{ttt_loss:.4f} val_bpb:{ttt_bpb:.4f} " + f"eval_time:{1000.0 * (time.perf_counter() - t_ttt):.0f}ms" + ) + log0(f"{ttt_label}_exact val_loss:{ttt_loss:.8f} val_bpb:{ttt_bpb:.8f}") + if int(os.environ.get("NGRAM_EVAL", "0")): + ngram_max_order = int(os.environ.get("NGRAM_MAX_ORDER", "9")) + log0(f"ngram_eval:starting max_order={ngram_max_order}") + torch.cuda.synchronize() + t_ngram = time.perf_counter() + try: + from ngram_eval import eval_val_ngram + ng_val_loss, ng_val_bpb = eval_val_ngram( + args, base_model, model, rank, world_size, device, grad_accum_steps, + val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, + max_order=ngram_max_order, log_fn=log0, + ) + torch.cuda.synchronize() + log0( + f"ngram_eval_result val_loss:{ng_val_loss:.4f} val_bpb:{ng_val_bpb:.4f} " + f"eval_time:{1000.0 * (time.perf_counter() - t_ngram):.0f}ms" + ) + log0(f"ngram_eval_result_exact val_loss:{ng_val_loss:.8f} val_bpb:{ng_val_bpb:.8f}") + except Exception as e: + log0(f"ngram_eval:FAILED {e}") + if distributed: + dist.destroy_process_group() +if __name__ == "__main__": + main() diff --git a/train_gpt_phase3_discriminative_ttt.py b/train_gpt_phase3_discriminative_ttt.py new file mode 100644 index 0000000000..d6de425060 --- /dev/null +++ b/train_gpt_phase3_discriminative_ttt.py @@ -0,0 +1,1896 @@ +from __future__ import annotations +import copy +import glob +import io +import math +import os +import random +import subprocess +import sys +import time +import uuid +import zlib +from pathlib import Path +try: + import zstandard as zstd + _HAS_ZSTD = True +except ImportError: + _HAS_ZSTD = False +import numpy as np +import sentencepiece as spm +import torch +import torch.distributed as dist +import torch.nn.functional as F +from torch import Tensor, nn +from torch.nn.parallel import DistributedDataParallel as DDP +class Hyperparameters: + data_path = os.environ.get("DATA_PATH", "./data/datasets/fineweb10B_sp1024") + train_files = os.path.join(data_path, "fineweb_train_*.bin") + val_files = os.path.join(data_path, "fineweb_val_*.bin") + tokenizer_path = os.environ.get("TOKENIZER_PATH", "./data/tokenizers/fineweb_1024_bpe.model") + run_id = os.environ.get("RUN_ID", str(uuid.uuid4())) + seed = int(os.environ.get("SEED", 1337)) + val_batch_size = int(os.environ.get("VAL_BATCH_SIZE", 524_288)) + val_loss_every = int(os.environ.get("VAL_LOSS_EVERY", 1000)) + train_log_every = int(os.environ.get("TRAIN_LOG_EVERY", 200)) + val_max_tokens = int(os.environ.get("VAL_MAX_TOKENS", 0)) + checkpoint_every = int(os.environ.get("CHECKPOINT_EVERY", 0)) + iterations = int(os.environ.get("ITERATIONS", 20000)) + warmdown_iters = int(os.environ.get("WARMDOWN_ITERS", 3500)) + warmup_steps = int(os.environ.get("WARMUP_STEPS", 20)) + train_batch_tokens = int(os.environ.get("TRAIN_BATCH_TOKENS", 524_288)) + train_seq_len = int(os.environ.get("TRAIN_SEQ_LEN", 1024)) + max_wallclock_seconds = float(os.environ.get("MAX_WALLCLOCK_SECONDS", 600.0)) + qk_gain_init = float(os.environ.get("QK_GAIN_INIT", 1.5)) + vocab_size = int(os.environ.get("VOCAB_SIZE", 1024)) + num_layers = int(os.environ.get("NUM_LAYERS", 11)) + num_kv_heads = int(os.environ.get("NUM_KV_HEADS", 4)) + model_dim = int(os.environ.get("MODEL_DIM", 512)) + num_heads = int(os.environ.get("NUM_HEADS", 8)) + mlp_mult = int(os.environ.get("MLP_MULT", 3)) + tie_embeddings = bool(int(os.environ.get("TIE_EMBEDDINGS", "1"))) + rope_base = float(os.environ.get("ROPE_BASE", 10000.0)) + logit_softcap = float(os.environ.get("LOGIT_SOFTCAP", 30.0)) + embed_lr = float(os.environ.get("EMBED_LR", 0.6)) + head_lr = float(os.environ.get("HEAD_LR", 0.008)) + tied_embed_lr = float(os.environ.get("TIED_EMBED_LR", 0.05)) + tied_embed_init_std = float(os.environ.get("TIED_EMBED_INIT_STD", 0.005)) + matrix_lr = float(os.environ.get("MATRIX_LR", 0.04)) + scalar_lr = float(os.environ.get("SCALAR_LR", 0.04)) + muon_momentum = float(os.environ.get("MUON_MOMENTUM", 0.95)) + muon_backend_steps = int(os.environ.get("MUON_BACKEND_STEPS", 5)) + muon_momentum_warmup_start = float(os.environ.get("MUON_MOMENTUM_WARMUP_START", 0.85)) + muon_momentum_warmup_steps = int(os.environ.get("MUON_MOMENTUM_WARMUP_STEPS", 500)) + beta1 = float(os.environ.get("BETA1", 0.9)) + beta2 = float(os.environ.get("BETA2", 0.95)) + adam_eps = float(os.environ.get("ADAM_EPS", 1e-8)) + grad_clip_norm = float(os.environ.get("GRAD_CLIP_NORM", 0.0)) + bigramhash_buckets = int(os.environ.get("BIGRAMHASH_BUCKETS", 4096)) + ema_decay = float(os.environ.get("EMA_DECAY", 0.997)) + smeargate = bool(int(os.environ.get("SMEARGATE", "1"))) + unet_skips = bool(int(os.environ.get("UNET_SKIPS", "1"))) + int6_qat = bool(int(os.environ.get("INT6_QAT", "1"))) + rope_partial_dims = int(os.environ.get("ROPE_PARTIAL_DIMS", 16)) + ln_scale = bool(int(os.environ.get("LN_SCALE", "1"))) + xsa_layers = int(os.environ.get("XSA_LAYERS", 4)) + eval_stride = int(os.environ.get("EVAL_STRIDE", 0)) + ttt_enabled = bool(int(os.environ.get("TTT_ENABLED", "0"))) + ttt_lr = float(os.environ.get("TTT_LR", 0.002)) + ttt_epochs = int(os.environ.get("TTT_EPOCHS", 3)) + ttt_chunk_tokens = int(os.environ.get("TTT_CHUNK_TOKENS", 32768)) + ttt_freeze_blocks = int(os.environ.get("TTT_FREEZE_BLOCKS", 0)) + ttt_momentum = float(os.environ.get("TTT_MOMENTUM", 0.9)) + ttt_batch_seqs = int(os.environ.get("TTT_BATCH_SEQS", 32)) + ttt_grad_clip = float(os.environ.get("TTT_GRAD_CLIP", 1.0)) + ttt_lora_enabled = bool(int(os.environ.get("TTT_LORA_ENABLED", "0"))) + ttt_lora_rank = int(os.environ.get("TTT_LORA_RANK", 8)) + ttt_lora_lr = float(os.environ.get("TTT_LORA_LR", 0.001)) + ttt_discriminative = bool(int(os.environ.get("TTT_DISCRIMINATIVE", "0"))) # default OFF + recurrence_repeats = int(os.environ.get("RECURRENCE_REPEATS", 0)) + recurrence_unique_head = int(os.environ.get("RECURRENCE_UNIQUE_HEAD", 2)) + recurrence_unique_tail = int(os.environ.get("RECURRENCE_UNIQUE_TAIL", 2)) + moe_num_experts = int(os.environ.get("MOE_NUM_EXPERTS", 0)) +def zeropower_via_newtonschulz5(G: Tensor, steps: int = 10, eps: float = 1e-7) -> Tensor: + a, b, c = (3.4445, -4.7750, 2.0315) + X = G.bfloat16() + X /= X.norm() + eps + transposed = G.size(0) > G.size(1) + if transposed: + X = X.T + for _ in range(steps): + A = X @ X.T + B = b * A + c * A @ A + X = a * X + B @ X + return X.T if transposed else X +class Muon(torch.optim.Optimizer): + def __init__(self, params, lr: float, momentum: float, backend_steps: int, nesterov: bool = True): + super().__init__( + params, + dict(lr=lr, momentum=momentum, backend_steps=backend_steps, nesterov=nesterov), + ) + @torch.no_grad() + def step(self, closure=None): + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + distributed = dist.is_available() and dist.is_initialized() + world_size = dist.get_world_size() if distributed else 1 + rank = dist.get_rank() if distributed else 0 + for group in self.param_groups: + params = group["params"] + if not params: + continue + lr = group["lr"] + momentum = group["momentum"] + backend_steps = group["backend_steps"] + nesterov = group["nesterov"] + total_params = sum(int(p.numel()) for p in params) + updates_flat = torch.zeros(total_params, device=params[0].device, dtype=torch.bfloat16) + curr = 0 + for i, p in enumerate(params): + if i % world_size == rank and p.grad is not None: + g = p.grad + state = self.state[p] + if "momentum_buffer" not in state: + state["momentum_buffer"] = torch.zeros_like(g) + buf = state["momentum_buffer"] + buf.mul_(momentum).add_(g) + if nesterov: + g = g.add(buf, alpha=momentum) + g = zeropower_via_newtonschulz5(g, steps=backend_steps) + g *= max(1, g.size(0) / g.size(1)) ** 0.5 + updates_flat[curr : curr + p.numel()] = g.reshape(-1) + curr += p.numel() + if distributed: + dist.all_reduce(updates_flat, op=dist.ReduceOp.SUM) + curr = 0 + for p in params: + g = updates_flat[curr : curr + p.numel()].view_as(p).to(dtype=p.dtype) + p.add_(g, alpha=-lr) + curr += p.numel() + return loss +def int6_ste(w: Tensor) -> Tensor: + qmax = _QUANT_MAX_VAL # Uses QUANT_BITS override if set (e.g. 15 for int5, 7 for int4) + if w.ndim == 2: + abs_max = w.abs().amax(dim=1, keepdim=True).clamp_min(1e-8) + else: + abs_max = w.abs().amax().clamp_min(1e-8) + scale = abs_max / float(qmax) + w_q = (w / scale).round().clamp(-qmax, qmax) * scale + return w + (w_q - w).detach() +def build_sentencepiece_luts( + sp: spm.SentencePieceProcessor, vocab_size: int, device: torch.device +) -> tuple[Tensor, Tensor, Tensor]: + sp_vocab_size = int(sp.vocab_size()) + table_size = max(sp_vocab_size, vocab_size) + base_bytes_np = np.zeros((table_size,), dtype=np.int16) + has_leading_space_np = np.zeros((table_size,), dtype=np.bool_) + is_boundary_token_np = np.ones((table_size,), dtype=np.bool_) + for token_id in range(sp_vocab_size): + if sp.is_control(token_id) or sp.is_unknown(token_id) or sp.is_unused(token_id): + continue + is_boundary_token_np[token_id] = False + if sp.is_byte(token_id): + base_bytes_np[token_id] = 1 + continue + piece = sp.id_to_piece(token_id) + if piece.startswith("▁"): + has_leading_space_np[token_id] = True + piece = piece[1:] + base_bytes_np[token_id] = len(piece.encode("utf-8")) + return ( + torch.tensor(base_bytes_np, dtype=torch.int16, device=device), + torch.tensor(has_leading_space_np, dtype=torch.bool, device=device), + torch.tensor(is_boundary_token_np, dtype=torch.bool, device=device), + ) +def load_validation_tokens(pattern: str, seq_len: int) -> Tensor: + files = [Path(p) for p in sorted(glob.glob(pattern))] + if not files: + raise FileNotFoundError(f"No files found for pattern: {pattern}") + tokens = torch.cat([load_data_shard(file) for file in files]).contiguous() + max_tokens = int(os.environ.get("VAL_MAX_TOKENS", 0)) + if max_tokens > 0 and tokens.numel() > max_tokens: + tokens = tokens[:max_tokens].contiguous() + usable = ((tokens.numel() - 1) // seq_len) * seq_len + if usable <= 0: + raise ValueError(f"Validation split is too short for TRAIN_SEQ_LEN={seq_len}") + return tokens[: usable + 1] +def eval_val( + args: Hyperparameters, + model: nn.Module, + rank: int, + world_size: int, + device: torch.device, + grad_accum_steps: int, + val_tokens: Tensor, + base_bytes_lut: Tensor, + has_leading_space_lut: Tensor, + is_boundary_token_lut: Tensor, +) -> tuple[float, float]: + local_batch_tokens = args.val_batch_size // (world_size * grad_accum_steps) + if local_batch_tokens < args.train_seq_len: + raise ValueError( + "VAL_BATCH_SIZE must provide at least one sequence per rank; " + f"got VAL_BATCH_SIZE={args.val_batch_size}, WORLD_SIZE={world_size}, " + f"GRAD_ACCUM_STEPS={grad_accum_steps}, TRAIN_SEQ_LEN={args.train_seq_len}" + ) + local_batch_seqs = local_batch_tokens // args.train_seq_len + total_seqs = (val_tokens.numel() - 1) // args.train_seq_len + seq_start = (total_seqs * rank) // world_size + seq_end = (total_seqs * (rank + 1)) // world_size + val_loss_sum = torch.zeros((), device=device, dtype=torch.float64) + val_token_count = torch.zeros((), device=device, dtype=torch.float64) + val_byte_count = torch.zeros((), device=device, dtype=torch.float64) + model.eval() + with torch.inference_mode(): + for batch_seq_start in range(seq_start, seq_end, local_batch_seqs): + batch_seq_end = min(batch_seq_start + local_batch_seqs, seq_end) + raw_start = batch_seq_start * args.train_seq_len + raw_end = batch_seq_end * args.train_seq_len + 1 + local = val_tokens[raw_start:raw_end].to(device=device, dtype=torch.int64, non_blocking=True) + x = local[:-1].reshape(-1, args.train_seq_len) + y = local[1:].reshape(-1, args.train_seq_len) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + batch_loss = model(x, y).detach() + batch_token_count = float(y.numel()) + val_loss_sum += batch_loss.to(torch.float64) * batch_token_count + val_token_count += batch_token_count + prev_ids = x.reshape(-1) + tgt_ids = y.reshape(-1) + token_bytes = base_bytes_lut[tgt_ids].to(dtype=torch.int16) + token_bytes += (has_leading_space_lut[tgt_ids] & ~is_boundary_token_lut[prev_ids]).to(dtype=torch.int16) + val_byte_count += token_bytes.to(torch.float64).sum() + if dist.is_available() and dist.is_initialized(): + dist.all_reduce(val_loss_sum, op=dist.ReduceOp.SUM) + dist.all_reduce(val_token_count, op=dist.ReduceOp.SUM) + dist.all_reduce(val_byte_count, op=dist.ReduceOp.SUM) + val_loss = val_loss_sum / val_token_count + bits_per_token = val_loss.item() / math.log(2.0) + tokens_per_byte = val_token_count.item() / val_byte_count.item() + model.train() + return float(val_loss.item()), float(bits_per_token * tokens_per_byte) +def eval_val_sliding( + args: Hyperparameters, + base_model: nn.Module, + rank: int, + world_size: int, + device: torch.device, + val_tokens: Tensor, + base_bytes_lut: Tensor, + has_leading_space_lut: Tensor, + is_boundary_token_lut: Tensor, + stride: int, + batch_seqs: int = 32, +) -> tuple[float, float]: + seq_len = args.train_seq_len + total_tokens = val_tokens.numel() - 1 + window_starts = [ws for ws in range(0, total_tokens, stride) + if min(ws + seq_len, total_tokens) - ws >= 1] + total_windows = len(window_starts) + my_s = (total_windows * rank) // world_size + my_e = (total_windows * (rank + 1)) // world_size + my_windows = window_starts[my_s:my_e] + loss_sum = torch.zeros((), device=device, dtype=torch.float64) + token_count = torch.zeros((), device=device, dtype=torch.float64) + byte_count = torch.zeros((), device=device, dtype=torch.float64) + base_model.eval() + compiled_logits = torch.compile(base_model.forward_logits, dynamic=False, fullgraph=True) + with torch.inference_mode(): + for bi in range(0, len(my_windows), batch_seqs): + batch_ws = my_windows[bi:bi + batch_seqs] + bsz = len(batch_ws) + x_batch = torch.zeros(bsz, seq_len, dtype=torch.int64, device=device) + y_batch = torch.zeros(bsz, seq_len, dtype=torch.int64, device=device) + wlens: list[int] = [] + for i, ws in enumerate(batch_ws): + end = min(ws + seq_len, total_tokens) + wlen = end - ws + wlens.append(wlen) + chunk = val_tokens[ws:end + 1].to(dtype=torch.int64, device=device) + x_batch[i, :wlen] = chunk[:-1] + y_batch[i, :wlen] = chunk[1:] + with torch.autocast(device_type="cuda", dtype=torch.bfloat16): + logits = compiled_logits(x_batch) + nll = F.cross_entropy( + logits.reshape(-1, logits.size(-1)).float(), + y_batch.reshape(-1), + reduction="none", + ).reshape(bsz, seq_len) + for i, ws in enumerate(batch_ws): + wlen = wlens[i] + s = 0 if ws == 0 else max(wlen - stride, 0) + scored_nll = nll[i, s:wlen].to(torch.float64) + loss_sum += scored_nll.sum() + token_count += float(wlen - s) + tgt = y_batch[i, s:wlen] + prev = x_batch[i, s:wlen] + tb = base_bytes_lut[tgt].to(torch.float64) + tb += (has_leading_space_lut[tgt] & ~is_boundary_token_lut[prev]).to(torch.float64) + byte_count += tb.sum() + if dist.is_available() and dist.is_initialized(): + dist.all_reduce(loss_sum, op=dist.ReduceOp.SUM) + dist.all_reduce(token_count, op=dist.ReduceOp.SUM) + dist.all_reduce(byte_count, op=dist.ReduceOp.SUM) + val_loss = (loss_sum / token_count).item() + bits_per_token = val_loss / math.log(2.0) + tokens_per_byte = token_count.item() / byte_count.item() + base_model.train() + return val_loss, bits_per_token * tokens_per_byte +class LoRAAdapter(nn.Module): + """Low-Rank Adaptation adapter for a linear layer.""" + def __init__(self, base_linear: nn.Linear, rank: int = 8): + super().__init__() + dim_in = base_linear.in_features + dim_out = base_linear.out_features + self.A = nn.Parameter(torch.empty(dim_in, rank)) + self.B = nn.Parameter(torch.zeros(rank, dim_out)) + nn.init.kaiming_uniform_(self.A, a=math.sqrt(5)) + self.base_linear = base_linear + self._hook_handle = None + + def _lora_hook(self, module: nn.Module, input: tuple[Tensor, ...], output: Tensor) -> Tensor: + """Forward hook that adds LoRA contribution to base linear output.""" + x = input[0] + lora_delta = (x @ self.A.to(x.dtype)) @ self.B.to(x.dtype) + return output + lora_delta + + def register(self) -> None: + """Register the LoRA hook on the base linear layer.""" + if self._hook_handle is None: + self._hook_handle = self.base_linear.register_forward_hook(self._lora_hook) + + def remove(self) -> None: + """Remove the LoRA hook from the base linear layer.""" + if self._hook_handle is not None: + self._hook_handle.remove() + self._hook_handle = None + +def eval_val_sliding_lora_ttt( + args: Hyperparameters, + base_model: nn.Module, + rank: int, + world_size: int, + device: torch.device, + val_tokens: Tensor, + base_bytes_lut: Tensor, + has_leading_space_lut: Tensor, + is_boundary_token_lut: Tensor, + stride: int, + batch_seqs: int = 32, + log0=print, +) -> tuple[float, float]: + """Test-time training with LoRA adapters (only train adapter weights, not base model).""" + seq_len = args.train_seq_len + total_tokens = val_tokens.numel() - 1 + ttt_chunk = args.ttt_chunk_tokens + + window_starts = [ws for ws in range(0, total_tokens, stride) + if min(ws + seq_len, total_tokens) - ws >= stride or ws == 0] + num_chunks = (total_tokens + ttt_chunk - 1) // ttt_chunk + + chunk_windows: list[list[int]] = [[] for _ in range(num_chunks)] + for ws in window_starts: + end = min(ws + seq_len, total_tokens) + wlen = end - ws + s = 0 if ws == 0 else max(wlen - stride, 0) + scored_start = ws + s + ci = min(scored_start // ttt_chunk, num_chunks - 1) + chunk_windows[ci].append(ws) + + log0(f"lora_ttt_sliding:start chunks={num_chunks} chunk_tokens={ttt_chunk} " + f"total_windows={len(window_starts)} stride={stride} " + f"lora_rank={args.ttt_lora_rank} lora_lr={args.ttt_lora_lr} " + f"ttt_epochs={args.ttt_epochs}") + + loss_sum = torch.zeros((), device=device, dtype=torch.float64) + token_count = torch.zeros((), device=device, dtype=torch.float64) + byte_count = torch.zeros((), device=device, dtype=torch.float64) + + # Create LoRA adapters for attention projections and LM head + lora_adapters: list[LoRAAdapter] = [] + + # Add LoRA to attention blocks (c_q, c_v, proj) + if args.recurrence_repeats > 0: + # Recurrent model structure + all_blocks = list(base_model.head_blocks) + [base_model.shared_block] + list(base_model.tail_blocks) + else: + # Standard block structure + all_blocks = list(base_model.blocks) + + for block in all_blocks: + if hasattr(block, 'attn'): + attn = block.attn + # Add LoRA to Q, V, and output projection + for layer_name in ['c_q', 'c_v', 'proj']: + if hasattr(attn, layer_name): + base_linear = getattr(attn, layer_name) + adapter = LoRAAdapter(base_linear, rank=args.ttt_lora_rank).to(device) + lora_adapters.append(adapter) + + # Add LoRA to LM head (or tok_emb if tied) + if base_model.tie_embeddings: + # For tied embeddings, we need to patch tok_emb.weight usage in forward + # Create a pseudo-linear wrapper for the embedding weight + class EmbeddingAsLinear(nn.Module): + def __init__(self, weight: nn.Parameter): + super().__init__() + self.weight = weight + self.in_features = weight.size(1) + self.out_features = weight.size(0) + self.bias = None + + def forward(self, x: Tensor) -> Tensor: + return F.linear(x, self.weight) + + emb_linear = EmbeddingAsLinear(base_model.tok_emb.weight).to(device) + lm_adapter = LoRAAdapter(emb_linear, rank=args.ttt_lora_rank).to(device) + lora_adapters.append(lm_adapter) + elif base_model.lm_head is not None: + lm_adapter = LoRAAdapter(base_model.lm_head, rank=args.ttt_lora_rank).to(device) + lora_adapters.append(lm_adapter) + + # Collect all LoRA parameters + lora_params = [] + for adapter in lora_adapters: + lora_params.extend([adapter.A, adapter.B]) + + log0(f"lora_ttt_sliding:params lora={sum(p.numel() for p in lora_params)} " + f"base_frozen={sum(p.numel() for p in base_model.parameters())} " + f"adapters={len(lora_adapters)}") + + # Freeze all base model parameters + for p in base_model.parameters(): + p.requires_grad_(False) + + # Register all LoRA hooks + for adapter in lora_adapters: + adapter.register() + + # Use AdamW optimizer for LoRA params + optimizer = torch.optim.AdamW( + lora_params, + lr=args.ttt_lora_lr, + betas=(0.9, 0.999), + weight_decay=0.01, + ) + + # Compile forward paths + if not int(os.environ.get("TORCHDYNAMO_DISABLE", "0")): + compiled_logits_lora = torch.compile(base_model.forward_logits, dynamic=False, fullgraph=True) + compiled_forward_lora = torch.compile(base_model.forward, dynamic=False, fullgraph=True) + else: + compiled_logits_lora = base_model.forward_logits + compiled_forward_lora = base_model.forward + + t0 = time.perf_counter() + + for ci in range(num_chunks): + windows = chunk_windows[ci] + if not windows: + continue + + chunk_start = ci * ttt_chunk + chunk_end = min((ci + 1) * ttt_chunk, total_tokens) + + my_s = (len(windows) * rank) // world_size + my_e = (len(windows) * (rank + 1)) // world_size + my_windows = windows[my_s:my_e] + + # Score phase (inference only) + base_model.eval() + with torch.inference_mode(): + for bi in range(0, len(my_windows), batch_seqs): + batch_ws = my_windows[bi:bi + batch_seqs] + bsz = len(batch_ws) + x_batch = torch.zeros(bsz, seq_len, dtype=torch.int64, device=device) + y_batch = torch.zeros(bsz, seq_len, dtype=torch.int64, device=device) + wlens: list[int] = [] + + for i, ws in enumerate(batch_ws): + end = min(ws + seq_len, total_tokens) + wlen = end - ws + wlens.append(wlen) + chunk_tok = val_tokens[ws:end + 1].to(dtype=torch.int64, device=device) + x_batch[i, :wlen] = chunk_tok[:-1] + y_batch[i, :wlen] = chunk_tok[1:] + + with torch.autocast(device_type="cuda", dtype=torch.bfloat16): + logits = compiled_logits_lora(x_batch) + + nll = F.cross_entropy( + logits.reshape(-1, logits.size(-1)).float(), + y_batch.reshape(-1), reduction="none", + ).reshape(bsz, seq_len) + + for i, ws in enumerate(batch_ws): + wlen = wlens[i] + s = 0 if ws == 0 else max(wlen - stride, 0) + scored_nll = nll[i, s:wlen].to(torch.float64) + loss_sum += scored_nll.sum() + token_count += float(wlen - s) + + tgt, prev = y_batch[i, s:wlen], x_batch[i, s:wlen] + tb = base_bytes_lut[tgt].to(torch.float64) + tb += (has_leading_space_lut[tgt] & ~is_boundary_token_lut[prev]).to(torch.float64) + byte_count += tb.sum() + + # Train phase (update LoRA params only, skip last chunk) + is_last_chunk = (ci == num_chunks - 1) + if not is_last_chunk and args.ttt_epochs > 0: + base_model.train() + chunk_seqs = (chunk_end - chunk_start) // seq_len + + if chunk_seqs > 0: + # Cosine LR decay + cos_lr = args.ttt_lora_lr * 0.5 * (1.0 + math.cos(math.pi * ci / max(num_chunks - 1, 1))) + for pg in optimizer.param_groups: + pg['lr'] = cos_lr + + my_seq_s = (chunk_seqs * rank) // world_size + my_seq_e = (chunk_seqs * (rank + 1)) // world_size + my_chunk_seqs = my_seq_e - my_seq_s + + for _ep in range(args.ttt_epochs): + for bs in range(0, my_chunk_seqs, args.ttt_batch_seqs): + be = min(bs + args.ttt_batch_seqs, my_chunk_seqs) + actual_bs = my_seq_s + bs + start_tok = chunk_start + actual_bs * seq_len + end_tok = chunk_start + (my_seq_s + be) * seq_len + 1 + + if end_tok > val_tokens.numel(): + continue + + local = val_tokens[start_tok:end_tok].to(device=device, dtype=torch.int64) + x = local[:-1].reshape(-1, seq_len) + y = local[1:].reshape(-1, seq_len) + + optimizer.zero_grad(set_to_none=True) + + with torch.autocast(device_type="cuda", dtype=torch.bfloat16): + loss = compiled_forward_lora(x, y) + + loss.backward() + + # All-reduce LoRA gradients across GPUs + if world_size > 1: + for p in lora_params: + if p.grad is not None: + dist.all_reduce(p.grad, op=dist.ReduceOp.AVG) + + # Gradient clipping on LoRA params + torch.nn.utils.clip_grad_norm_(lora_params, args.ttt_grad_clip) + optimizer.step() + + if rank == 0 and (ci % 10 == 0 or ci == num_chunks - 1): + elapsed = time.perf_counter() - t0 + rl = loss_sum.item() / max(token_count.item(), 1) + rbpb = rl / math.log(2.0) * (token_count.item() / max(byte_count.item(), 1)) if token_count.item() > 0 else 0.0 + log0(f" lora_ttt_chunk [{ci+1}/{num_chunks}] bpb={rbpb:.6f} time={elapsed:.1f}s") + + if dist.is_available() and dist.is_initialized(): + dist.all_reduce(loss_sum, op=dist.ReduceOp.SUM) + dist.all_reduce(token_count, op=dist.ReduceOp.SUM) + dist.all_reduce(byte_count, op=dist.ReduceOp.SUM) + + val_loss = (loss_sum / token_count).item() + val_bpb = val_loss / math.log(2.0) * (token_count.item() / byte_count.item()) + + # Clean up: remove all LoRA hooks and restore base model + for adapter in lora_adapters: + adapter.remove() + + # Restore requires_grad on base model + for p in base_model.parameters(): + p.requires_grad_(True) + + base_model.eval() + + log0(f"lora_ttt_sliding:done val_loss={val_loss:.6f} val_bpb={val_bpb:.6f} " + f"elapsed={time.perf_counter() - t0:.1f}s") + + return val_loss, val_bpb + +def eval_val_sliding_ttt( + args: Hyperparameters, + base_model: nn.Module, + rank: int, + world_size: int, + device: torch.device, + val_tokens: Tensor, + base_bytes_lut: Tensor, + has_leading_space_lut: Tensor, + is_boundary_token_lut: Tensor, + stride: int, + batch_seqs: int = 32, + log0=print, +) -> tuple[float, float]: + seq_len = args.train_seq_len + total_tokens = val_tokens.numel() - 1 + ttt_chunk = args.ttt_chunk_tokens + window_starts = [ws for ws in range(0, total_tokens, stride) + if min(ws + seq_len, total_tokens) - ws >= stride or ws == 0] + num_chunks = (total_tokens + ttt_chunk - 1) // ttt_chunk + chunk_windows: list[list[int]] = [[] for _ in range(num_chunks)] + for ws in window_starts: + end = min(ws + seq_len, total_tokens) + wlen = end - ws + s = 0 if ws == 0 else max(wlen - stride, 0) + scored_start = ws + s + ci = min(scored_start // ttt_chunk, num_chunks - 1) + chunk_windows[ci].append(ws) + log0(f"ttt_sliding:start chunks={num_chunks} chunk_tokens={ttt_chunk} " + f"total_windows={len(window_starts)} stride={stride} " + f"ttt_lr={args.ttt_lr} ttt_epochs={args.ttt_epochs} " + f"freeze_blocks={args.ttt_freeze_blocks}") + loss_sum = torch.zeros((), device=device, dtype=torch.float64) + token_count = torch.zeros((), device=device, dtype=torch.float64) + byte_count = torch.zeros((), device=device, dtype=torch.float64) + frozen_block_ids = set(range(min(args.ttt_freeze_blocks, len(base_model.blocks)))) + ttt_params = [] + for name, p in base_model.named_parameters(): + freeze = False + for bi in frozen_block_ids: + if f"blocks.{bi}." in name: + freeze = True + break + if freeze: + p.requires_grad_(False) + else: + p.requires_grad_(True) + ttt_params.append(p) + log0(f"ttt_sliding:params unfrozen={sum(p.numel() for p in ttt_params)} " + f"frozen={sum(p.numel() for p in base_model.parameters() if not p.requires_grad)}") + if args.ttt_discriminative: + # Per-block LR scaling: early=0.3x, late=1.0x, linear interpolation + param_groups = [] + blocks = [b for b in base_model.blocks] if hasattr(base_model, 'blocks') else [] + num_blocks = len(blocks) + for i, block in enumerate(blocks): + if i < args.ttt_freeze_blocks: + continue # skip frozen blocks + lr_scale = 0.3 + 0.7 * (i / max(num_blocks - 1, 1)) + block_params = [p for p in block.parameters() if p.requires_grad] + if block_params: + param_groups.append({"params": block_params, "lr": args.ttt_lr * lr_scale}) + # Add non-block params (embeddings, final norm, lm_head) at full LR + block_param_ids = set() + for block in blocks: + block_param_ids.update(id(p) for p in block.parameters()) + other_params = [p for p in base_model.parameters() if p.requires_grad and id(p) not in block_param_ids] + if other_params: + param_groups.append({"params": other_params, "lr": args.ttt_lr}) + ttt_opt = torch.optim.SGD(param_groups, momentum=args.ttt_momentum) + else: + # Standard flat LR for all params + ttt_opt = torch.optim.SGD(ttt_params, lr=args.ttt_lr, momentum=args.ttt_momentum) + optimizer = ttt_opt + # Compile forward paths for TTT (matches eval_val_sliding behavior) + if not int(os.environ.get("TORCHDYNAMO_DISABLE", "0")): + compiled_logits_ttt = torch.compile(base_model.forward_logits, dynamic=False, fullgraph=True) + compiled_forward_ttt = torch.compile(base_model.forward, dynamic=False, fullgraph=True) + else: + compiled_logits_ttt = base_model.forward_logits + compiled_forward_ttt = base_model.forward + effective_ttt_epochs = 10 if args.ttt_discriminative else args.ttt_epochs + t0 = time.perf_counter() + for ci in range(num_chunks): + windows = chunk_windows[ci] + if not windows: + continue + chunk_start = ci * ttt_chunk + chunk_end = min((ci + 1) * ttt_chunk, total_tokens) + my_s = (len(windows) * rank) // world_size + my_e = (len(windows) * (rank + 1)) // world_size + my_windows = windows[my_s:my_e] + base_model.eval() + with torch.inference_mode(): + for bi in range(0, len(my_windows), batch_seqs): + batch_ws = my_windows[bi:bi + batch_seqs] + bsz = len(batch_ws) + x_batch = torch.zeros(bsz, seq_len, dtype=torch.int64, device=device) + y_batch = torch.zeros(bsz, seq_len, dtype=torch.int64, device=device) + wlens: list[int] = [] + for i, ws in enumerate(batch_ws): + end = min(ws + seq_len, total_tokens) + wlen = end - ws + wlens.append(wlen) + chunk_tok = val_tokens[ws:end + 1].to(dtype=torch.int64, device=device) + x_batch[i, :wlen] = chunk_tok[:-1] + y_batch[i, :wlen] = chunk_tok[1:] + with torch.autocast(device_type="cuda", dtype=torch.bfloat16): + logits = compiled_logits_ttt(x_batch) + nll = F.cross_entropy( + logits.reshape(-1, logits.size(-1)).float(), + y_batch.reshape(-1), reduction="none", + ).reshape(bsz, seq_len) + for i, ws in enumerate(batch_ws): + wlen = wlens[i] + s = 0 if ws == 0 else max(wlen - stride, 0) + scored_nll = nll[i, s:wlen].to(torch.float64) + loss_sum += scored_nll.sum() + token_count += float(wlen - s) + tgt, prev = y_batch[i, s:wlen], x_batch[i, s:wlen] + tb = base_bytes_lut[tgt].to(torch.float64) + tb += (has_leading_space_lut[tgt] & ~is_boundary_token_lut[prev]).to(torch.float64) + byte_count += tb.sum() + is_last_chunk = (ci == num_chunks - 1) + if not is_last_chunk and effective_ttt_epochs > 0: + base_model.train() + chunk_seqs = (chunk_end - chunk_start) // seq_len + if chunk_seqs > 0: + cos_lr = args.ttt_lr * 0.5 * (1.0 + math.cos(math.pi * ci / max(num_chunks - 1, 1))) + for pg in optimizer.param_groups: + pg['lr'] = cos_lr + my_seq_s = (chunk_seqs * rank) // world_size + my_seq_e = (chunk_seqs * (rank + 1)) // world_size + my_chunk_seqs = my_seq_e - my_seq_s + for _ep in range(effective_ttt_epochs): + for bs in range(0, my_chunk_seqs, args.ttt_batch_seqs): + be = min(bs + args.ttt_batch_seqs, my_chunk_seqs) + actual_bs = my_seq_s + bs + start_tok = chunk_start + actual_bs * seq_len + end_tok = chunk_start + (my_seq_s + be) * seq_len + 1 + if end_tok > val_tokens.numel(): + continue + local = val_tokens[start_tok:end_tok].to(device=device, dtype=torch.int64) + x = local[:-1].reshape(-1, seq_len) + y = local[1:].reshape(-1, seq_len) + optimizer.zero_grad(set_to_none=True) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16): + loss = compiled_forward_ttt(x, y) + loss.backward() + if world_size > 1: + for p in ttt_params: + if p.grad is not None: + dist.all_reduce(p.grad, op=dist.ReduceOp.AVG) + torch.nn.utils.clip_grad_norm_(ttt_params, args.ttt_grad_clip) + optimizer.step() + if rank == 0 and (ci % 10 == 0 or ci == num_chunks - 1): + elapsed = time.perf_counter() - t0 + rl = loss_sum.item() / max(token_count.item(), 1) + rbpb = rl / math.log(2.0) * (token_count.item() / max(byte_count.item(), 1)) if token_count.item() > 0 else 0.0 + log0(f" ttt_chunk [{ci+1}/{num_chunks}] bpb={rbpb:.6f} time={elapsed:.1f}s") + if dist.is_available() and dist.is_initialized(): + dist.all_reduce(loss_sum, op=dist.ReduceOp.SUM) + dist.all_reduce(token_count, op=dist.ReduceOp.SUM) + dist.all_reduce(byte_count, op=dist.ReduceOp.SUM) + val_loss = (loss_sum / token_count).item() + val_bpb = val_loss / math.log(2.0) * (token_count.item() / byte_count.item()) + for p in base_model.parameters(): + p.requires_grad_(True) + base_model.eval() + log0(f"ttt_sliding:done val_loss={val_loss:.6f} val_bpb={val_bpb:.6f} " + f"elapsed={time.perf_counter() - t0:.1f}s") + return val_loss, val_bpb +_INT6_MODE = bool(int(os.environ.get("INT6_QAT", "1"))) +_INT3_MODE = bool(int(os.environ.get("INT3_QUANT", "0"))) # Ternary quantization +_QUANT_OVERRIDE = int(os.environ.get("QUANT_BITS", "0")) # Override: 3,4,5,6,8 +if _QUANT_OVERRIDE > 0: + _QUANT_BITS = _QUANT_OVERRIDE + _QUANT_MAX_VAL = (1 << (_QUANT_BITS - 1)) - 1 # e.g. int5 -> 15, int4 -> 7 +elif _INT3_MODE: + _QUANT_MAX_VAL = 3 # Ternary: -3, -2, -1, 0, 1, 2, 3 + _QUANT_BITS = 3 +elif _INT6_MODE: + _QUANT_MAX_VAL = 31 + _QUANT_BITS = 6 +else: + _QUANT_MAX_VAL = 127 + _QUANT_BITS = 8 +CONTROL_TENSOR_NAME_PATTERNS = tuple( + pattern + for pattern in os.environ.get( + "CONTROL_TENSOR_NAME_PATTERNS", + "attn_scale,attn_scales,mlp_scale,mlp_scales,resid_mix,resid_mixes,q_gain,skip_weight,skip_weights,smear", + ).split(",") + if pattern +) +INT8_KEEP_FLOAT_FP32_NAME_PATTERNS = tuple( + pattern + for pattern in os.environ.get( + "INT8_KEEP_FLOAT_FP32_NAME_PATTERNS", + ",".join(CONTROL_TENSOR_NAME_PATTERNS), + ).split(",") + if pattern +) +INT8_KEEP_FLOAT_MAX_NUMEL = 65_536 +INT8_KEEP_FLOAT_STORE_DTYPE = torch.float16 +INT8_PER_ROW_SCALE_DTYPE = torch.float16 +INT8_CLIP_PERCENTILE = 99.99984 +INT8_CLIP_Q = INT8_CLIP_PERCENTILE / 100.0 +def tensor_nbytes(t: Tensor) -> int: + return int(t.numel()) * int(t.element_size()) +def keep_float_tensor(name: str, t: Tensor, passthrough_orig_dtypes: dict[str, str]) -> Tensor: + if any(pattern in name for pattern in INT8_KEEP_FLOAT_FP32_NAME_PATTERNS): + return t.float().contiguous() + if t.dtype in {torch.float32, torch.bfloat16}: + passthrough_orig_dtypes[name] = str(t.dtype).removeprefix("torch.") + return t.to(dtype=INT8_KEEP_FLOAT_STORE_DTYPE).contiguous() + return t +_GPTQ_LITE_PERCENTILES = [0.9999, 0.99995, 0.99999, 0.999995, 1.0] +def quantize_float_tensor(t: Tensor, gpu_device: torch.device | None = None) -> tuple[Tensor, Tensor]: + # Run GPTQ-lite search on GPU when available (much faster for large 2D weight matrices) + dev = gpu_device if (gpu_device is not None and t.ndim == 2 and t.numel() > 0) else t.device + t32 = t.float().to(dev) + qmax = _QUANT_MAX_VAL + if t32.ndim == 2 and t32.numel() > 0: + abs_vals = t32.abs() + best_q = None + best_scale = None + best_mse = None + for pct in _GPTQ_LITE_PERCENTILES: + if pct >= 1.0: + clip_abs = abs_vals.amax(dim=1) + else: + clip_abs = torch.quantile(abs_vals, pct, dim=1) + clip_abs = clip_abs.clamp_min(1e-8) + scale = (clip_abs / qmax).clamp_min(1.0 / qmax) + clipped = torch.maximum(torch.minimum(t32, clip_abs[:, None]), -clip_abs[:, None]) + q = torch.clamp(torch.round(clipped / scale[:, None]), -qmax, qmax) + recon = q * scale[:, None] + mse = (t32 - recon).square().mean(dim=1) + if best_mse is None: + best_mse = mse + best_q = q + best_scale = scale + else: + improved = mse < best_mse + if improved.any(): + best_mse = torch.where(improved, mse, best_mse) + best_q = torch.where(improved[:, None], q, best_q) + best_scale = torch.where(improved, scale, best_scale) + return best_q.cpu().to(torch.int8).contiguous(), best_scale.cpu().to(dtype=INT8_PER_ROW_SCALE_DTYPE).contiguous() + if t32.ndim == 2: + return torch.empty_like(t32, dtype=torch.int8, device="cpu"), torch.empty((t32.shape[0],), dtype=INT8_PER_ROW_SCALE_DTYPE) + t32 = t32.cpu() + clip_abs = float(torch.quantile(t32.abs().flatten(), INT8_CLIP_Q).item()) if t32.numel() else 0.0 + scale = torch.tensor(clip_abs / qmax if clip_abs > 0 else 1.0, dtype=torch.float32) + q = torch.clamp(torch.round(torch.clamp(t32, -clip_abs, clip_abs) / scale), -qmax, qmax).to(torch.int8).contiguous() + return q, scale +def quantize_state_dict_int8(state_dict: dict[str, Tensor], gpu_device: torch.device | None = None, + rank: int = 0, world_size: int = 1): + """Quantize state dict. When world_size > 1, distributes large tensor quantization across ranks.""" + quantized: dict[str, Tensor] = {} + scales: dict[str, Tensor] = {} + dtypes: dict[str, str] = {} + passthrough: dict[str, Tensor] = {} + passthrough_orig_dtypes: dict[str, str] = {} + qmeta: dict[str, dict[str, object]] = {} + stats = dict.fromkeys( + ("param_count", "num_tensors", "num_float_tensors", "num_nonfloat_tensors", "baseline_tensor_bytes", "int8_payload_bytes"), + 0, + ) + # Separate tensors into passthrough vs needs-quantization + to_quantize: list[tuple[str, Tensor]] = [] + for name, tensor in state_dict.items(): + t = tensor.detach().to("cpu").contiguous() + stats["param_count"] += int(t.numel()) + stats["num_tensors"] += 1 + stats["baseline_tensor_bytes"] += tensor_nbytes(t) + if not t.is_floating_point(): + stats["num_nonfloat_tensors"] += 1 + passthrough[name] = t + stats["int8_payload_bytes"] += tensor_nbytes(t) + continue + if t.numel() <= INT8_KEEP_FLOAT_MAX_NUMEL: + kept = keep_float_tensor(name, t, passthrough_orig_dtypes) + passthrough[name] = kept + stats["int8_payload_bytes"] += tensor_nbytes(kept) + continue + stats["num_float_tensors"] += 1 + to_quantize.append((name, t)) + # Distribute quantization work across ranks (round-robin by tensor index) + for idx, (name, t) in enumerate(to_quantize): + if world_size > 1 and idx % world_size != rank: + continue # Another rank handles this tensor + q, s = quantize_float_tensor(t, gpu_device=gpu_device) + if s.ndim > 0: + qmeta[name] = {"scheme": "per_row", "axis": 0} + quantized[name] = q + scales[name] = s + dtypes[name] = str(t.dtype).removeprefix("torch.") + stats["int8_payload_bytes"] += tensor_nbytes(q) + tensor_nbytes(s) + # Gather results from all ranks + if world_size > 1 and dist.is_available() and dist.is_initialized(): + # Each rank broadcasts its quantized tensors to all others + for idx, (name, t) in enumerate(to_quantize): + owner = idx % world_size + if owner == rank: + # Broadcast shape info then tensor data + q_tensor = quantized[name] + s_tensor = scales[name] + # Send via broadcast + q_gpu = q_tensor.to(gpu_device) if gpu_device is not None else q_tensor.cuda() + s_gpu = s_tensor.to(gpu_device) if gpu_device is not None else s_tensor.cuda() + dist.broadcast(q_gpu, src=owner) + dist.broadcast(s_gpu, src=owner) + quantized[name] = q_gpu.cpu() + scales[name] = s_gpu.cpu() + else: + # Allocate matching tensors and receive + q_shape = (t.shape[0],) + t.shape[1:] if t.ndim == 2 else t.shape + q_gpu = torch.empty(q_shape, dtype=torch.int8, device=gpu_device if gpu_device is not None else "cuda") + if t.ndim == 2 and t.numel() > 0: + s_gpu = torch.empty(t.shape[0], dtype=INT8_PER_ROW_SCALE_DTYPE, device=gpu_device if gpu_device is not None else "cuda") + else: + s_gpu = torch.empty((), dtype=torch.float32, device=gpu_device if gpu_device is not None else "cuda") + dist.broadcast(q_gpu, src=owner) + dist.broadcast(s_gpu, src=owner) + quantized[name] = q_gpu.cpu() + scales[name] = s_gpu.cpu() + dtypes[name] = str(t.dtype).removeprefix("torch.") + if s_gpu.ndim > 0: + qmeta[name] = {"scheme": "per_row", "axis": 0} + stats["int8_payload_bytes"] += tensor_nbytes(quantized[name]) + tensor_nbytes(scales[name]) + obj: dict[str, object] = { + "__quant_format__": f"int{_QUANT_BITS}_clean_per_row_v1", + "quantized": quantized, + "scales": scales, + "dtypes": dtypes, + "passthrough": passthrough, + } + if qmeta: + obj["qmeta"] = qmeta + if passthrough_orig_dtypes: + obj["passthrough_orig_dtypes"] = passthrough_orig_dtypes + return obj, stats +def dequantize_state_dict_int8(obj: dict[str, object]) -> dict[str, Tensor]: + out: dict[str, Tensor] = {} + qmeta = obj.get("qmeta", {}) + passthrough_orig_dtypes = obj.get("passthrough_orig_dtypes", {}) + for name, q in obj["quantized"].items(): + dtype = getattr(torch, obj["dtypes"][name]) + s = obj["scales"][name] + if qmeta.get(name, {}).get("scheme") == "per_row" or s.ndim > 0: + s = s.to(dtype=torch.float32) + out[name] = (q.float() * s.view(q.shape[0], *([1] * (q.ndim - 1)))).to(dtype=dtype).contiguous() + else: + scale = float(s.item()) + out[name] = (q.float() * scale).to(dtype=dtype).contiguous() + for name, t in obj["passthrough"].items(): + out_t = t.detach().to("cpu").contiguous() + orig_dtype = passthrough_orig_dtypes.get(name) + if isinstance(orig_dtype, str): + out_t = out_t.to(dtype=getattr(torch, orig_dtype)).contiguous() + out[name] = out_t + return out +def load_data_shard(file: Path) -> Tensor: + header_bytes = 256 * np.dtype(" None: + self.file_idx = (self.file_idx + 1) % len(self.files) + self.tokens = load_data_shard(self.files[self.file_idx]) + self.pos = 0 + def take(self, n: int) -> Tensor: + chunks: list[Tensor] = [] + remaining = n + while remaining > 0: + avail = self.tokens.numel() - self.pos + if avail <= 0: + self._advance_file() + continue + k = min(remaining, avail) + chunks.append(self.tokens[self.pos : self.pos + k]) + self.pos += k + remaining -= k + return chunks[0] if len(chunks) == 1 else torch.cat(chunks) +class DistributedTokenLoader: + def __init__(self, pattern: str, rank: int, world_size: int, device: torch.device): + self.rank = rank + self.world_size = world_size + self.device = device + self.stream = TokenStream(pattern) + def next_batch(self, global_tokens: int, seq_len: int, grad_accum_steps: int) -> tuple[Tensor, Tensor]: + local_tokens = global_tokens // (self.world_size * grad_accum_steps) + per_rank_span = local_tokens + 1 + chunk = self.stream.take(per_rank_span * self.world_size) + start = self.rank * per_rank_span + local = chunk[start : start + per_rank_span].to(dtype=torch.int64) + x = local[:-1].reshape(-1, seq_len) + y = local[1:].reshape(-1, seq_len) + return x.to(self.device, non_blocking=True), y.to(self.device, non_blocking=True) +class RMSNorm(nn.Module): + def __init__(self, eps: float | None = None): + super().__init__() + self.eps = eps + def forward(self, x: Tensor) -> Tensor: + return F.rms_norm(x, (x.size(-1),), eps=self.eps) +class CastedLinear(nn.Linear): + _int6_qat: bool = False + def forward(self, x: Tensor) -> Tensor: + w = self.weight.to(x.dtype) + if CastedLinear._int6_qat and self.training: + w = int6_ste(w) + bias = self.bias.to(x.dtype) if self.bias is not None else None + return F.linear(x, w, bias) +def restore_low_dim_params_to_fp32(module: nn.Module) -> None: + with torch.no_grad(): + for name, param in module.named_parameters(): + if (param.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS)) and param.dtype != torch.float32: + param.data = param.data.float() +class BigramHash(nn.Module): + def __init__(self, num_buckets: int, dim: int): + super().__init__() + self.num_buckets = num_buckets + self.emb = nn.Embedding(num_buckets, dim) + nn.init.zeros_(self.emb.weight) + def forward(self, input_ids: Tensor) -> Tensor: + prev = F.pad(input_ids[:, :-1], (1, 0), value=0) + bigram_hash = (prev * 31 + input_ids) % self.num_buckets + return self.emb(bigram_hash) +class SmearGate(nn.Module): + def __init__(self, dim: int): + super().__init__() + self.gate = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + def forward(self, x: Tensor) -> Tensor: + g = torch.sigmoid(self.gate.to(dtype=x.dtype))[None, None, :] + positions = torch.arange(1, x.shape[1] + 1, device=x.device, dtype=x.dtype).view(1, -1, 1) + smooth = torch.cumsum(x, dim=1) / positions + return g * x + (1 - g) * smooth +class Rotary(nn.Module): + def __init__(self, dim: int, base: float = 10000.0): + super().__init__() + inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.float32) / dim)) + self.register_buffer("inv_freq", inv_freq, persistent=False) + self._seq_len_cached = 0 + self._cos_cached: Tensor | None = None + self._sin_cached: Tensor | None = None + def forward(self, seq_len: int, device: torch.device, dtype: torch.dtype) -> tuple[Tensor, Tensor]: + if ( + self._cos_cached is None + or self._sin_cached is None + or self._seq_len_cached != seq_len + or self._cos_cached.device != device + ): + t = torch.arange(seq_len, device=device, dtype=self.inv_freq.dtype) + freqs = torch.outer(t, self.inv_freq.to(device)) + self._cos_cached = freqs.cos()[None, None, :, :] + self._sin_cached = freqs.sin()[None, None, :, :] + self._seq_len_cached = seq_len + return self._cos_cached.to(dtype=dtype), self._sin_cached.to(dtype=dtype) +def apply_rotary_emb(x: Tensor, cos: Tensor, sin: Tensor) -> Tensor: + half = x.size(-1) // 2 + x1, x2 = x[..., :half], x[..., half:] + return torch.cat((x1 * cos + x2 * sin, x1 * (-sin) + x2 * cos), dim=-1) +class CausalSelfAttention(nn.Module): + def __init__( + self, + dim: int, + num_heads: int, + num_kv_heads: int, + rope_base: float, + qk_gain_init: float, + rope_partial_dims: int = 0, + use_xsa: bool = False, + ): + super().__init__() + if dim % num_heads != 0: + raise ValueError("model_dim must be divisible by num_heads") + if num_heads % num_kv_heads != 0: + raise ValueError("num_heads must be divisible by num_kv_heads") + self.num_heads = num_heads + self.num_kv_heads = num_kv_heads + self.head_dim = dim // num_heads + if self.head_dim % 2 != 0: + raise ValueError("head_dim must be even for RoPE") + kv_dim = self.num_kv_heads * self.head_dim + self.c_q = CastedLinear(dim, dim, bias=False) + self.c_k = CastedLinear(dim, kv_dim, bias=False) + self.c_v = CastedLinear(dim, kv_dim, bias=False) + self.proj = CastedLinear(dim, dim, bias=False) + self.proj._zero_init = True + self.q_gain = nn.Parameter(torch.full((num_heads,), qk_gain_init, dtype=torch.float32)) + rope_dim = rope_partial_dims if 0 < rope_partial_dims < self.head_dim else self.head_dim + self.rope_dim = rope_dim + self.rotary = Rotary(rope_dim, base=rope_base) + self.use_xsa = use_xsa + def forward(self, x: Tensor) -> Tensor: + bsz, seqlen, dim = x.shape + q = self.c_q(x).reshape(bsz, seqlen, self.num_heads, self.head_dim).transpose(1, 2) + k = self.c_k(x).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim).transpose(1, 2) + v = self.c_v(x).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim).transpose(1, 2) + q = F.rms_norm(q, (q.size(-1),)) + k = F.rms_norm(k, (k.size(-1),)) + cos, sin = self.rotary(seqlen, x.device, q.dtype) + if self.rope_dim < self.head_dim: + q_rope, q_pass = q[..., :self.rope_dim], q[..., self.rope_dim:] + k_rope, k_pass = k[..., :self.rope_dim], k[..., self.rope_dim:] + q_rope = apply_rotary_emb(q_rope, cos, sin) + k_rope = apply_rotary_emb(k_rope, cos, sin) + q = torch.cat([q_rope, q_pass], dim=-1) + k = torch.cat([k_rope, k_pass], dim=-1) + else: + q = apply_rotary_emb(q, cos, sin) + k = apply_rotary_emb(k, cos, sin) + q = q * self.q_gain.to(dtype=q.dtype)[None, :, None, None] + y = F.scaled_dot_product_attention( + q, + k, + v, + attn_mask=None, + is_causal=True, + enable_gqa=(self.num_kv_heads != self.num_heads), + ) + if self.use_xsa: + if self.num_kv_heads != self.num_heads: + repeats = self.num_heads // self.num_kv_heads + v_expanded = v.repeat_interleave(repeats, dim=1) + else: + v_expanded = v + y = y - v_expanded + y = y.transpose(1, 2).contiguous().reshape(bsz, seqlen, dim) + return self.proj(y) +class MLP(nn.Module): + def __init__(self, dim: int, mlp_mult: int): + super().__init__() + hidden = mlp_mult * dim + self.fc = CastedLinear(dim, hidden, bias=False) + self.proj = CastedLinear(hidden, dim, bias=False) + self.proj._zero_init = True + def forward(self, x: Tensor) -> Tensor: + x = F.leaky_relu(self.fc(x), negative_slope=0.5) + return self.proj(x.square()) +class MoEMLP(nn.Module): + def __init__(self, dim: int, mlp_mult: int, num_experts: int): + super().__init__() + self.num_experts = num_experts + expert_hidden = (mlp_mult * dim) // num_experts + self.experts = nn.ModuleList([ + nn.ModuleDict({ + "fc": CastedLinear(dim, expert_hidden, bias=False), + "proj": CastedLinear(expert_hidden, dim, bias=False), + }) + for _ in range(num_experts) + ]) + for e in self.experts: + e["proj"]._zero_init = True + self.router = nn.Linear(dim, num_experts, bias=False) + def forward(self, x: Tensor) -> Tensor: + bsz, seqlen, dim = x.shape + logits = self.router(x.detach()) + weights = torch.softmax(logits, dim=-1) + indices = weights.argmax(dim=-1) + out = torch.zeros_like(x) + for i, expert in enumerate(self.experts): + mask = (indices == i) + if not mask.any(): + continue + tokens = x[mask] + h = F.leaky_relu(expert["fc"](tokens), negative_slope=0.5) + h = expert["proj"](h.square()) + out[mask] = h + return out +class Block(nn.Module): + def __init__( + self, + dim: int, + num_heads: int, + num_kv_heads: int, + mlp_mult: int, + rope_base: float, + qk_gain_init: float, + use_smeargate: bool = False, + rope_partial_dims: int = 0, + use_xsa: bool = False, + ln_scale_factor: float = 1.0, + moe_num_experts: int = 0, + ): + super().__init__() + self.attn_norm = RMSNorm() + self.mlp_norm = RMSNorm() + self.attn = CausalSelfAttention( + dim, num_heads, num_kv_heads, rope_base, qk_gain_init, + rope_partial_dims=rope_partial_dims, + use_xsa=use_xsa, + ) + self.mlp = MoEMLP(dim, mlp_mult, moe_num_experts) if moe_num_experts > 1 else MLP(dim, mlp_mult) + self.attn_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + self.mlp_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + self.resid_mix = nn.Parameter(torch.stack((torch.ones(dim), torch.zeros(dim))).float()) + self.smear = SmearGate(dim) if use_smeargate else None + self.ln_scale_factor = ln_scale_factor + def forward(self, x: Tensor, x0: Tensor) -> Tensor: + mix = self.resid_mix.to(dtype=x.dtype) + x = mix[0][None, None, :] * x + mix[1][None, None, :] * x0 + normed = self.attn_norm(x) + if self.ln_scale_factor != 1.0: + normed = normed * self.ln_scale_factor + attn_out = self.attn(normed) + x = x + self.attn_scale.to(dtype=x.dtype)[None, None, :] * attn_out + mlp_normed = self.mlp_norm(x) + if self.ln_scale_factor != 1.0: + mlp_normed = mlp_normed * self.ln_scale_factor + x = x + self.mlp_scale.to(dtype=x.dtype)[None, None, :] * self.mlp(mlp_normed) + if self.smear is not None: + x = self.smear(x) + return x +class GPT(nn.Module): + def __init__( + self, + vocab_size: int, + num_layers: int, + model_dim: int, + num_heads: int, + num_kv_heads: int, + mlp_mult: int, + tie_embeddings: bool, + tied_embed_init_std: float, + logit_softcap: float, + rope_base: float, + qk_gain_init: float, + bigramhash_buckets: int = 0, + use_smeargate: bool = False, + unet_skips: bool = True, + rope_partial_dims: int = 0, + ln_scale: bool = False, + xsa_layers: int = 0, + recurrence_repeats: int = 0, + recurrence_unique_head: int = 2, + recurrence_unique_tail: int = 2, + moe_num_experts: int = 0, + ): + super().__init__() + if logit_softcap <= 0.0: + raise ValueError(f"logit_softcap must be positive, got {logit_softcap}") + self.tie_embeddings = tie_embeddings + self.tied_embed_init_std = tied_embed_init_std + self.logit_softcap = logit_softcap + self.tok_emb = nn.Embedding(vocab_size, model_dim) + self.bigram_hash = BigramHash(bigramhash_buckets, model_dim) if bigramhash_buckets > 0 else None + self.recurrence_repeats = recurrence_repeats + if recurrence_repeats > 0: + self.num_encoder_layers = 0 + self.num_decoder_layers = 0 + self.num_skip_weights = 0 + self.skip_weights = nn.Parameter(torch.zeros(0, dtype=torch.float32)) + n_head = recurrence_unique_head + n_tail = recurrence_unique_tail + effective_layers = n_head + recurrence_repeats + n_tail + def make_block(layer_idx: int, eff_total: int) -> Block: + xsa_start_idx = eff_total - xsa_layers if xsa_layers > 0 else eff_total + return Block( + model_dim, num_heads, num_kv_heads, mlp_mult, rope_base, qk_gain_init, + use_smeargate=use_smeargate, + rope_partial_dims=rope_partial_dims, + use_xsa=(layer_idx >= xsa_start_idx), + ln_scale_factor=(1.0 / math.sqrt(layer_idx + 1)) if ln_scale else 1.0, + moe_num_experts=moe_num_experts, + ) + self.head_blocks = nn.ModuleList([make_block(i, effective_layers) for i in range(n_head)]) + self.shared_block = make_block(n_head, effective_layers) + self.recurrence_gate = nn.Parameter(torch.full((recurrence_repeats,), 0.5, dtype=torch.float32)) + self.tail_blocks = nn.ModuleList([ + make_block(n_head + recurrence_repeats + i, effective_layers) for i in range(n_tail) + ]) + self.blocks = nn.ModuleList() + else: + if unet_skips: + self.num_encoder_layers = num_layers // 2 + self.num_decoder_layers = num_layers - self.num_encoder_layers + self.num_skip_weights = min(self.num_encoder_layers, self.num_decoder_layers) + else: + self.num_encoder_layers = num_layers + self.num_decoder_layers = 0 + self.num_skip_weights = 0 + self.skip_weights = nn.Parameter(torch.ones(self.num_skip_weights, model_dim, dtype=torch.float32)) + xsa_start = num_layers - xsa_layers if xsa_layers > 0 else num_layers + self.blocks = nn.ModuleList( + [ + Block( + model_dim, num_heads, num_kv_heads, mlp_mult, rope_base, qk_gain_init, + use_smeargate=use_smeargate, + rope_partial_dims=rope_partial_dims, + use_xsa=(i >= xsa_start), + ln_scale_factor=(1.0 / math.sqrt(i + 1)) if ln_scale else 1.0, + moe_num_experts=moe_num_experts, + ) + for i in range(num_layers) + ] + ) + self.head_blocks = nn.ModuleList() + self.shared_block = None + self.recurrence_gate = None + self.tail_blocks = nn.ModuleList() + self.final_norm = RMSNorm() + self.lm_head = None if tie_embeddings else CastedLinear(model_dim, vocab_size, bias=False) + if self.lm_head is not None: + self.lm_head._zero_init = True + self._init_weights() + def _init_weights(self) -> None: + if self.tie_embeddings: + nn.init.normal_(self.tok_emb.weight, mean=0.0, std=self.tied_embed_init_std) + for module in self.modules(): + if isinstance(module, nn.Linear) and getattr(module, "_zero_init", False): + nn.init.zeros_(module.weight) + def forward(self, input_ids: Tensor, target_ids: Tensor) -> Tensor: + x = self.tok_emb(input_ids) + if self.bigram_hash is not None: + x = x + self.bigram_hash(input_ids) + x = F.rms_norm(x, (x.size(-1),)) + x0 = x + if self.recurrence_repeats > 0: + for block in self.head_blocks: + x = block(x, x0) + for r in range(self.recurrence_repeats): + gate = torch.sigmoid(self.recurrence_gate[r]).to(dtype=x.dtype) + residual = x + x = self.shared_block(x, x0) + x = gate * x + (1.0 - gate) * residual + for block in self.tail_blocks: + x = block(x, x0) + else: + skips: list[Tensor] = [] + for i in range(self.num_encoder_layers): + x = self.blocks[i](x, x0) + skips.append(x) + for i in range(self.num_decoder_layers): + if skips: + x = x + self.skip_weights[i].to(dtype=x.dtype)[None, None, :] * skips.pop() + x = self.blocks[self.num_encoder_layers + i](x, x0) + x = self.final_norm(x).reshape(-1, x.size(-1)) + targets = target_ids.reshape(-1) + if self.tie_embeddings: + logits_proj = F.linear(x, self.tok_emb.weight) + else: + if self.lm_head is None: + raise RuntimeError("lm_head is required when tie_embeddings=False") + logits_proj = self.lm_head(x) + logits = self.logit_softcap * torch.tanh(logits_proj / self.logit_softcap) + if getattr(self, '_return_logits', False): + return logits.float().reshape(input_ids.shape[0], input_ids.shape[1], -1) + return F.cross_entropy(logits.float(), targets, reduction="mean") + def forward_logits(self, input_ids: Tensor) -> Tensor: + x = self.tok_emb(input_ids) + if self.bigram_hash is not None: + x = x + self.bigram_hash(input_ids) + x = F.rms_norm(x, (x.size(-1),)) + x0 = x + if self.recurrence_repeats > 0: + for block in self.head_blocks: + x = block(x, x0) + for r in range(self.recurrence_repeats): + gate = torch.sigmoid(self.recurrence_gate[r]).to(dtype=x.dtype) + residual = x + x = self.shared_block(x, x0) + x = gate * x + (1.0 - gate) * residual + for block in self.tail_blocks: + x = block(x, x0) + else: + skips: list[Tensor] = [] + for i in range(self.num_encoder_layers): + x = self.blocks[i](x, x0) + skips.append(x) + for i in range(self.num_decoder_layers): + if skips: + x = x + self.skip_weights[i].to(dtype=x.dtype)[None, None, :] * skips.pop() + x = self.blocks[self.num_encoder_layers + i](x, x0) + x = self.final_norm(x) + if self.tie_embeddings: + logits_proj = F.linear(x, self.tok_emb.weight) + else: + if self.lm_head is None: + raise RuntimeError("lm_head is required when tie_embeddings=False") + logits_proj = self.lm_head(x) + return self.logit_softcap * torch.tanh(logits_proj / self.logit_softcap) +def main() -> None: + global zeropower_via_newtonschulz5 + code = Path(__file__).read_text(encoding="utf-8") + args = Hyperparameters() + if not int(os.environ.get("TORCHDYNAMO_DISABLE", "0")): + zeropower_via_newtonschulz5 = torch.compile(zeropower_via_newtonschulz5) + distributed = "RANK" in os.environ and "WORLD_SIZE" in os.environ + rank = int(os.environ.get("RANK", "0")) + world_size = int(os.environ.get("WORLD_SIZE", "1")) + local_rank = int(os.environ.get("LOCAL_RANK", "0")) + if world_size <= 0: + raise ValueError(f"WORLD_SIZE must be positive, got {world_size}") + if 8 % world_size != 0: + raise ValueError(f"WORLD_SIZE={world_size} must divide 8 so grad_accum_steps stays integral") + grad_accum_steps = 8 // world_size + grad_scale = 1.0 / grad_accum_steps + if not torch.cuda.is_available(): + raise RuntimeError("CUDA is required") + device = torch.device("cuda", local_rank) + torch.cuda.set_device(device) + if distributed: + dist.init_process_group(backend="nccl", device_id=device) + dist.barrier() + master_process = rank == 0 + torch.backends.cuda.matmul.allow_tf32 = True + torch.backends.cudnn.allow_tf32 = True + from torch.backends.cuda import enable_cudnn_sdp, enable_flash_sdp, enable_math_sdp, enable_mem_efficient_sdp + _use_flash = sys.platform != "win32" + enable_cudnn_sdp(False) + enable_flash_sdp(_use_flash) + enable_mem_efficient_sdp(False) + enable_math_sdp(not _use_flash) + # Create per-run output directory so artifacts/logs never collide + run_dir = Path(os.environ.get("RUN_OUTPUT_DIR", f"runs/{args.run_id}")) + if master_process: + run_dir.mkdir(parents=True, exist_ok=True) + # Backward-compat: also keep logs/ symlink/copy + os.makedirs("logs", exist_ok=True) + logfile = None + if master_process: + logfile = str(run_dir / "log.txt") + # Also write to legacy logs/ path for compatibility + legacy_logfile = f"logs/{args.run_id}.txt" + if not Path(legacy_logfile).exists(): + try: + os.symlink(os.path.abspath(logfile), legacy_logfile) + except OSError: + pass # Windows or cross-device; will write to both + print(logfile) + def log0(msg: str, console: bool = True) -> None: + if not master_process: + return + if console: + print(msg) + if logfile is not None: + with open(logfile, "a", encoding="utf-8") as f: + print(msg, file=f) + log0(code, console=False) + log0("=" * 100, console=False) + log0(f"Running Python {sys.version}", console=False) + log0(f"Running PyTorch {torch.__version__}", console=False) + log0( + subprocess.run(["nvidia-smi"], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True, check=False).stdout, + console=False, + ) + log0("=" * 100, console=False) + random.seed(args.seed) + np.random.seed(args.seed) + torch.manual_seed(args.seed) + torch.cuda.manual_seed_all(args.seed) + if not args.tokenizer_path.endswith(".model"): + raise ValueError(f"Script only setup for SentencePiece .model file: {args.tokenizer_path}") + sp = spm.SentencePieceProcessor(model_file=args.tokenizer_path) + if int(sp.vocab_size()) != args.vocab_size: + raise ValueError( + f"VOCAB_SIZE={args.vocab_size} does not match tokenizer vocab_size={int(sp.vocab_size())}" + ) + dataset_dir = Path(args.data_path).resolve() + actual_train_files = len(list(dataset_dir.glob("fineweb_train_*.bin"))) + val_tokens = load_validation_tokens(args.val_files, args.train_seq_len) + base_bytes_lut, has_leading_space_lut, is_boundary_token_lut = build_sentencepiece_luts( + sp, args.vocab_size, device + ) + log0(f"val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path={args.tokenizer_path}") + log0(f"train_loader:dataset:{dataset_dir.name} train_shards:{actual_train_files}") + log0(f"val_loader:shards pattern={args.val_files} tokens:{val_tokens.numel() - 1}") + CastedLinear._int6_qat = args.int6_qat + base_model = GPT( + vocab_size=args.vocab_size, + num_layers=args.num_layers, + model_dim=args.model_dim, + num_heads=args.num_heads, + num_kv_heads=args.num_kv_heads, + mlp_mult=args.mlp_mult, + tie_embeddings=args.tie_embeddings, + tied_embed_init_std=args.tied_embed_init_std, + logit_softcap=args.logit_softcap, + rope_base=args.rope_base, + qk_gain_init=args.qk_gain_init, + bigramhash_buckets=args.bigramhash_buckets, + use_smeargate=args.smeargate, + unet_skips=args.unet_skips, + rope_partial_dims=args.rope_partial_dims, + ln_scale=args.ln_scale, + xsa_layers=args.xsa_layers, + recurrence_repeats=args.recurrence_repeats, + recurrence_unique_head=args.recurrence_unique_head, + recurrence_unique_tail=args.recurrence_unique_tail, + moe_num_experts=args.moe_num_experts, + ).to(device).bfloat16() + for module in base_model.modules(): + if isinstance(module, CastedLinear): + module.float() + restore_low_dim_params_to_fp32(base_model) + if int(os.environ.get("TORCHDYNAMO_DISABLE", "0")): + compiled_model = base_model + else: + compiled_model = torch.compile(base_model, dynamic=False, fullgraph=True) + model: nn.Module = DDP(compiled_model, device_ids=[local_rank], broadcast_buffers=False) if distributed else compiled_model + if args.recurrence_repeats > 0: + block_named_params = ( + list(base_model.head_blocks.named_parameters()) + + list(base_model.shared_block.named_parameters()) + + list(base_model.tail_blocks.named_parameters()) + ) + if base_model.recurrence_gate is not None: + block_named_params.append(("recurrence_gate", base_model.recurrence_gate)) + else: + block_named_params = list(base_model.blocks.named_parameters()) + matrix_params = [ + p + for name, p in block_named_params + if p.ndim == 2 and not any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) + ] + scalar_params = [ + p + for name, p in block_named_params + if p.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) + ] + if base_model.skip_weights.numel() > 0: + scalar_params.append(base_model.skip_weights) + token_lr = args.tied_embed_lr if args.tie_embeddings else args.embed_lr + embed_params: list[nn.Parameter] = [base_model.tok_emb.weight] + if base_model.bigram_hash is not None: + embed_params.append(base_model.bigram_hash.emb.weight) + optimizer_tok = torch.optim.Adam( + [{"params": embed_params, "lr": token_lr, "base_lr": token_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + fused=True, + ) + optimizer_muon = Muon( + matrix_params, + lr=args.matrix_lr, + momentum=args.muon_momentum, + backend_steps=args.muon_backend_steps, + ) + for group in optimizer_muon.param_groups: + group["base_lr"] = args.matrix_lr + optimizer_scalar = torch.optim.Adam( + [{"params": scalar_params, "lr": args.scalar_lr, "base_lr": args.scalar_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + fused=True, + ) + optimizers: list[torch.optim.Optimizer] = [optimizer_tok, optimizer_muon, optimizer_scalar] + if base_model.lm_head is not None: + optimizer_head = torch.optim.Adam( + [{"params": [base_model.lm_head.weight], "lr": args.head_lr, "base_lr": args.head_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + fused=True, + ) + optimizers.insert(1, optimizer_head) + n_params = sum(p.numel() for p in base_model.parameters()) + log0(f"model_params:{n_params}") + log0(f"world_size:{world_size} grad_accum_steps:{grad_accum_steps}") + log0(f"sdp_backends:cudnn=False flash={_use_flash} mem_efficient=False math={not _use_flash}") + log0(f"attention_mode:gqa num_heads:{args.num_heads} num_kv_heads:{args.num_kv_heads}") + log0( + f"tie_embeddings:{args.tie_embeddings} embed_lr:{token_lr} " + f"head_lr:{args.head_lr if base_model.lm_head is not None else 0.0} " + f"matrix_lr:{args.matrix_lr} scalar_lr:{args.scalar_lr}" + ) + log0( + f"train_batch_tokens:{args.train_batch_tokens} train_seq_len:{args.train_seq_len} " + f"iterations:{args.iterations} warmup_steps:{args.warmup_steps} " + f"max_wallclock_seconds:{args.max_wallclock_seconds:.3f}" + ) + log0(f"seed:{args.seed}") + log0( + f"meta_baseline: int6_qat:{args.int6_qat} bigramhash:{args.bigramhash_buckets} " + f"smeargate:{args.smeargate} ema_decay:{args.ema_decay} unet_skips:{args.unet_skips} " + f"warmdown_iters:{args.warmdown_iters} compression:{'zstd-22' if _HAS_ZSTD else 'zlib-9'}" + ) + log0( + f"strong_gains: leaky_relu_sq:True rope_partial_dims:{args.rope_partial_dims} " + f"ln_scale:{args.ln_scale} xsa_layers:{args.xsa_layers}" + ) + if args.recurrence_repeats > 0: + eff = args.recurrence_unique_head + args.recurrence_repeats + args.recurrence_unique_tail + log0( + f"depth_recurrence: head={args.recurrence_unique_head} repeats={args.recurrence_repeats} " + f"tail={args.recurrence_unique_tail} effective_layers={eff}" + ) + train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) + def zero_grad_all() -> None: + for opt in optimizers: + opt.zero_grad(set_to_none=True) + max_wallclock_ms = 1000.0 * args.max_wallclock_seconds if args.max_wallclock_seconds > 0 else None + def lr_mul(step: int, elapsed_ms: float) -> float: + if args.warmdown_iters <= 0: + return 1.0 + if max_wallclock_ms is None: + warmdown_start = max(args.iterations - args.warmdown_iters, 0) + return max((args.iterations - step) / max(args.warmdown_iters, 1), 0.0) if warmdown_start <= step < args.iterations else 1.0 + step_ms = elapsed_ms / max(step, 1) + warmdown_ms = args.warmdown_iters * step_ms + remaining_ms = max(max_wallclock_ms - elapsed_ms, 0.0) + return remaining_ms / max(warmdown_ms, 1e-9) if remaining_ms <= warmdown_ms else 1.0 + if args.warmup_steps > 0: + initial_model_state = {name: tensor.detach().cpu().clone() for name, tensor in base_model.state_dict().items()} + initial_optimizer_states = [copy.deepcopy(opt.state_dict()) for opt in optimizers] + model.train() + for warmup_step in range(args.warmup_steps): + zero_grad_all() + for micro_step in range(grad_accum_steps): + if distributed: + model.require_backward_grad_sync = micro_step == grad_accum_steps - 1 + x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + warmup_loss = model(x, y) + (warmup_loss * grad_scale).backward() + for opt in optimizers: + opt.step() + zero_grad_all() + if args.warmup_steps <= 20 or (warmup_step + 1) % 10 == 0 or warmup_step + 1 == args.warmup_steps: + log0(f"warmup_step:{warmup_step + 1}/{args.warmup_steps}") + base_model.load_state_dict(initial_model_state, strict=True) + for opt, state in zip(optimizers, initial_optimizer_states, strict=True): + opt.load_state_dict(state) + zero_grad_all() + if distributed: + model.require_backward_grad_sync = True + train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) + ema_state: dict[str, Tensor] | None = None + if args.ema_decay > 0: + ema_state = {name: param.data.clone() for name, param in base_model.named_parameters()} + log0(f"ema:enabled decay={args.ema_decay}") + training_time_ms = 0.0 + stop_after_step: int | None = None + torch.cuda.synchronize() + t0 = time.perf_counter() + step = 0 + while True: + last_step = step == args.iterations or (stop_after_step is not None and step >= stop_after_step) + should_validate = last_step or (args.val_loss_every > 0 and step % args.val_loss_every == 0) + if should_validate: + torch.cuda.synchronize() + training_time_ms += 1000.0 * (time.perf_counter() - t0) + ema_backup: dict[str, Tensor] | None = None + if ema_state is not None: + ema_backup = {} + with torch.no_grad(): + for name, p in base_model.named_parameters(): + ema_backup[name] = p.data.clone() + p.data.copy_(ema_state[name]) + val_loss, val_bpb = eval_val( + args, + model, + rank, + world_size, + device, + grad_accum_steps, + val_tokens, + base_bytes_lut, + has_leading_space_lut, + is_boundary_token_lut, + ) + if ema_backup is not None: + with torch.no_grad(): + for name, p in base_model.named_parameters(): + p.data.copy_(ema_backup[name]) + del ema_backup + log0( + f"step:{step}/{args.iterations} val_loss:{val_loss:.4f} val_bpb:{val_bpb:.4f} " + f"train_time:{training_time_ms:.0f}ms step_avg:{training_time_ms / max(step, 1):.2f}ms" + ) + torch.cuda.synchronize() + t0 = time.perf_counter() + if last_step: + if stop_after_step is not None and step < args.iterations: + log0( + f"stopping_early: wallclock_cap train_time:{training_time_ms:.0f}ms " + f"step:{step}/{args.iterations}" + ) + break + elapsed_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) + scale = lr_mul(step, elapsed_ms) + zero_grad_all() + train_loss = torch.zeros((), device=device) + for micro_step in range(grad_accum_steps): + if distributed: + model.require_backward_grad_sync = micro_step == grad_accum_steps - 1 + x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + loss = model(x, y) + train_loss += loss.detach() + (loss * grad_scale).backward() + train_loss /= grad_accum_steps + frac = min(step / args.muon_momentum_warmup_steps, 1.0) if args.muon_momentum_warmup_steps > 0 else 1.0 + muon_momentum = (1 - frac) * args.muon_momentum_warmup_start + frac * args.muon_momentum + for group in optimizer_muon.param_groups: + group["momentum"] = muon_momentum + for opt in optimizers: + for group in opt.param_groups: + group["lr"] = group["base_lr"] * scale + if args.grad_clip_norm > 0: + torch.nn.utils.clip_grad_norm_(base_model.parameters(), args.grad_clip_norm) + for opt in optimizers: + opt.step() + zero_grad_all() + if ema_state is not None: + with torch.no_grad(): + decay = args.ema_decay + for name, p in base_model.named_parameters(): + ema_state[name].mul_(decay).add_(p.data, alpha=1.0 - decay) + step += 1 + if args.checkpoint_every > 0 and step % args.checkpoint_every == 0 and master_process: + ckpt_dir = run_dir / "checkpoints" + ckpt_dir.mkdir(exist_ok=True) + ckpt_path = ckpt_dir / f"ckpt_step{step}.pt" + ckpt_data = {"step": step, "model": base_model.state_dict()} + if ema_state is not None: + ckpt_data["ema"] = {k: v.cpu() for k, v in ema_state.items()} + torch.save(ckpt_data, ckpt_path) + ckpts = sorted(ckpt_dir.glob("ckpt_step*.pt"), key=lambda p: p.stat().st_mtime) + for old in ckpts[:-2]: + old.unlink(missing_ok=True) + log0(f"checkpoint:saved {ckpt_path} (kept {min(len(ckpts), 2)} latest)") + approx_training_time_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) + should_log_train = ( + args.train_log_every > 0 + and (step <= 10 or step % args.train_log_every == 0 or stop_after_step is not None) + ) + if should_log_train: + log0( + f"step:{step}/{args.iterations} train_loss:{train_loss.item():.4f} " + f"train_time:{approx_training_time_ms:.0f}ms step_avg:{approx_training_time_ms / step:.2f}ms" + ) + reached_cap = max_wallclock_ms is not None and approx_training_time_ms >= max_wallclock_ms + if distributed and max_wallclock_ms is not None: + reached_cap_tensor = torch.tensor(int(reached_cap), device=device) + dist.all_reduce(reached_cap_tensor, op=dist.ReduceOp.MAX) + reached_cap = bool(reached_cap_tensor.item()) + if stop_after_step is None and reached_cap: + stop_after_step = step + log0( + f"peak memory allocated: {torch.cuda.max_memory_allocated() // 1024 // 1024} MiB " + f"reserved: {torch.cuda.max_memory_reserved() // 1024 // 1024} MiB" + ) + if ema_state is not None: + with torch.no_grad(): + for name, p in base_model.named_parameters(): + p.data.copy_(ema_state[name]) + log0("ema:applied EMA weights for export") + fp32_path = str(run_dir / "final_model.pt") + if master_process: + torch.save(base_model.state_dict(), fp32_path) + model_bytes = os.path.getsize(fp32_path) + code_bytes = len(code.encode("utf-8")) + log0(f"Serialized model: {model_bytes} bytes → {fp32_path}") + log0(f"Code size: {code_bytes} bytes") + log0(f"Total submission size: {model_bytes + code_bytes} bytes") + quant_label = f"int{_QUANT_BITS}" + compress_label = "zstd" if _HAS_ZSTD else "zlib" + log0(f"quantization:start {quant_label} distributing across {world_size} GPUs") + torch.cuda.synchronize() + t_quant = time.perf_counter() + quant_obj, quant_stats = quantize_state_dict_int8( + base_model.state_dict(), gpu_device=device, rank=rank, world_size=world_size, + ) + torch.cuda.synchronize() + log0(f"quantization:done {quant_label} time={1000.0 * (time.perf_counter() - t_quant):.0f}ms") + quant_buf = io.BytesIO() + torch.save(quant_obj, quant_buf) + quant_raw = quant_buf.getvalue() + if _HAS_ZSTD: + compressor = zstd.ZstdCompressor(level=22, threads=-1) # -1 = use all CPU cores + quant_blob = compressor.compress(quant_raw) + else: + quant_blob = zlib.compress(quant_raw, level=9) + quant_raw_bytes = len(quant_raw) + artifact_name = str(run_dir / f"final_model.{quant_label}.ptz") + if master_process: + with open(artifact_name, "wb") as f: + f.write(quant_blob) + quant_file_bytes = os.path.getsize(artifact_name) + code_bytes = len(code.encode("utf-8")) + ratio = quant_stats["baseline_tensor_bytes"] / max(quant_stats["int8_payload_bytes"], 1) + log0( + f"Serialized model {quant_label}+{compress_label}: {quant_file_bytes} bytes " + f"(payload:{quant_stats['int8_payload_bytes']} raw_torch:{quant_raw_bytes} payload_ratio:{ratio:.2f}x)" + ) + log0(f"Total submission size {quant_label}+{compress_label}: {quant_file_bytes + code_bytes} bytes") + log0(f"artifact_path:{artifact_name}") + if distributed: + dist.barrier() + with open(artifact_name, "rb") as f: + quant_blob_disk = f.read() + if _HAS_ZSTD: + decompressor = zstd.ZstdDecompressor() + quant_state = torch.load(io.BytesIO(decompressor.decompress(quant_blob_disk)), map_location="cpu") + else: + quant_state = torch.load(io.BytesIO(zlib.decompress(quant_blob_disk)), map_location="cpu") + base_model.load_state_dict(dequantize_state_dict_int8(quant_state), strict=True) + torch.cuda.synchronize() + t_qeval = time.perf_counter() + q_val_loss, q_val_bpb = eval_val( + args, + model, + rank, + world_size, + device, + grad_accum_steps, + val_tokens, + base_bytes_lut, + has_leading_space_lut, + is_boundary_token_lut, + ) + torch.cuda.synchronize() + log0( + f"final_{quant_label}_{compress_label}_roundtrip val_loss:{q_val_loss:.4f} val_bpb:{q_val_bpb:.4f} " + f"eval_time:{1000.0 * (time.perf_counter() - t_qeval):.0f}ms" + ) + log0(f"final_{quant_label}_{compress_label}_roundtrip_exact val_loss:{q_val_loss:.8f} val_bpb:{q_val_bpb:.8f}") + if args.eval_stride > 0 and args.eval_stride < args.train_seq_len: + torch.cuda.synchronize() + t_slide = time.perf_counter() + sw_val_loss, sw_val_bpb = eval_val_sliding( + args, base_model, rank, world_size, device, + val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, + stride=args.eval_stride, + ) + torch.cuda.synchronize() + log0( + f"final_sliding_window val_loss:{sw_val_loss:.4f} val_bpb:{sw_val_bpb:.4f} " + f"stride:{args.eval_stride} eval_time:{1000.0 * (time.perf_counter() - t_slide):.0f}ms" + ) + log0(f"final_sliding_window_exact val_loss:{sw_val_loss:.8f} val_bpb:{sw_val_bpb:.8f}") + if args.ttt_enabled: + torch.cuda.synchronize() + t_ttt = time.perf_counter() + + # Choose LoRA TTT or full TTT based on TTT_LORA_ENABLED flag + if args.ttt_lora_enabled: + ttt_loss, ttt_bpb = eval_val_sliding_lora_ttt( + args, base_model, rank, world_size, device, + val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, + stride=args.eval_stride if args.eval_stride > 0 else 64, + log0=log0, + ) + ttt_label = "legal_lora_ttt" + else: + ttt_loss, ttt_bpb = eval_val_sliding_ttt( + args, base_model, rank, world_size, device, + val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, + stride=args.eval_stride if args.eval_stride > 0 else 64, + log0=log0, + ) + ttt_label = "legal_ttt" + + torch.cuda.synchronize() + log0( + f"{ttt_label} val_loss:{ttt_loss:.4f} val_bpb:{ttt_bpb:.4f} " + f"eval_time:{1000.0 * (time.perf_counter() - t_ttt):.0f}ms" + ) + log0(f"{ttt_label}_exact val_loss:{ttt_loss:.8f} val_bpb:{ttt_bpb:.8f}") + if int(os.environ.get("NGRAM_EVAL", "0")): + ngram_max_order = int(os.environ.get("NGRAM_MAX_ORDER", "9")) + log0(f"ngram_eval:starting max_order={ngram_max_order}") + torch.cuda.synchronize() + t_ngram = time.perf_counter() + try: + from ngram_eval import eval_val_ngram + ng_val_loss, ng_val_bpb = eval_val_ngram( + args, base_model, model, rank, world_size, device, grad_accum_steps, + val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, + max_order=ngram_max_order, log_fn=log0, + ) + torch.cuda.synchronize() + log0( + f"ngram_eval_result val_loss:{ng_val_loss:.4f} val_bpb:{ng_val_bpb:.4f} " + f"eval_time:{1000.0 * (time.perf_counter() - t_ngram):.0f}ms" + ) + log0(f"ngram_eval_result_exact val_loss:{ng_val_loss:.8f} val_bpb:{ng_val_bpb:.8f}") + except Exception as e: + log0(f"ngram_eval:FAILED {e}") + if distributed: + dist.destroy_process_group() +if __name__ == "__main__": + main() diff --git a/train_gpt_phase3_prequant_ttt.py b/train_gpt_phase3_prequant_ttt.py new file mode 100644 index 0000000000..2bf21f3f03 --- /dev/null +++ b/train_gpt_phase3_prequant_ttt.py @@ -0,0 +1,1906 @@ +from __future__ import annotations +import copy +import glob +import io +import math +import os +import random +import subprocess +import sys +import time +import uuid +import zlib +from pathlib import Path +try: + import zstandard as zstd + _HAS_ZSTD = True +except ImportError: + _HAS_ZSTD = False +import numpy as np +import sentencepiece as spm +import torch +import torch.distributed as dist +import torch.nn.functional as F +from torch import Tensor, nn +from torch.nn.parallel import DistributedDataParallel as DDP +class Hyperparameters: + data_path = os.environ.get("DATA_PATH", "./data/datasets/fineweb10B_sp1024") + train_files = os.path.join(data_path, "fineweb_train_*.bin") + val_files = os.path.join(data_path, "fineweb_val_*.bin") + tokenizer_path = os.environ.get("TOKENIZER_PATH", "./data/tokenizers/fineweb_1024_bpe.model") + run_id = os.environ.get("RUN_ID", str(uuid.uuid4())) + seed = int(os.environ.get("SEED", 1337)) + val_batch_size = int(os.environ.get("VAL_BATCH_SIZE", 524_288)) + val_loss_every = int(os.environ.get("VAL_LOSS_EVERY", 1000)) + train_log_every = int(os.environ.get("TRAIN_LOG_EVERY", 200)) + val_max_tokens = int(os.environ.get("VAL_MAX_TOKENS", 0)) + checkpoint_every = int(os.environ.get("CHECKPOINT_EVERY", 0)) + iterations = int(os.environ.get("ITERATIONS", 20000)) + warmdown_iters = int(os.environ.get("WARMDOWN_ITERS", 3500)) + warmup_steps = int(os.environ.get("WARMUP_STEPS", 20)) + train_batch_tokens = int(os.environ.get("TRAIN_BATCH_TOKENS", 524_288)) + train_seq_len = int(os.environ.get("TRAIN_SEQ_LEN", 1024)) + max_wallclock_seconds = float(os.environ.get("MAX_WALLCLOCK_SECONDS", 600.0)) + qk_gain_init = float(os.environ.get("QK_GAIN_INIT", 1.5)) + vocab_size = int(os.environ.get("VOCAB_SIZE", 1024)) + num_layers = int(os.environ.get("NUM_LAYERS", 11)) + num_kv_heads = int(os.environ.get("NUM_KV_HEADS", 4)) + model_dim = int(os.environ.get("MODEL_DIM", 512)) + num_heads = int(os.environ.get("NUM_HEADS", 8)) + mlp_mult = int(os.environ.get("MLP_MULT", 3)) + tie_embeddings = bool(int(os.environ.get("TIE_EMBEDDINGS", "1"))) + rope_base = float(os.environ.get("ROPE_BASE", 10000.0)) + logit_softcap = float(os.environ.get("LOGIT_SOFTCAP", 30.0)) + embed_lr = float(os.environ.get("EMBED_LR", 0.6)) + head_lr = float(os.environ.get("HEAD_LR", 0.008)) + tied_embed_lr = float(os.environ.get("TIED_EMBED_LR", 0.05)) + tied_embed_init_std = float(os.environ.get("TIED_EMBED_INIT_STD", 0.005)) + matrix_lr = float(os.environ.get("MATRIX_LR", 0.04)) + scalar_lr = float(os.environ.get("SCALAR_LR", 0.04)) + muon_momentum = float(os.environ.get("MUON_MOMENTUM", 0.95)) + muon_backend_steps = int(os.environ.get("MUON_BACKEND_STEPS", 5)) + muon_momentum_warmup_start = float(os.environ.get("MUON_MOMENTUM_WARMUP_START", 0.85)) + muon_momentum_warmup_steps = int(os.environ.get("MUON_MOMENTUM_WARMUP_STEPS", 500)) + beta1 = float(os.environ.get("BETA1", 0.9)) + beta2 = float(os.environ.get("BETA2", 0.95)) + adam_eps = float(os.environ.get("ADAM_EPS", 1e-8)) + grad_clip_norm = float(os.environ.get("GRAD_CLIP_NORM", 0.0)) + bigramhash_buckets = int(os.environ.get("BIGRAMHASH_BUCKETS", 4096)) + ema_decay = float(os.environ.get("EMA_DECAY", 0.997)) + smeargate = bool(int(os.environ.get("SMEARGATE", "1"))) + unet_skips = bool(int(os.environ.get("UNET_SKIPS", "1"))) + int6_qat = bool(int(os.environ.get("INT6_QAT", "1"))) + rope_partial_dims = int(os.environ.get("ROPE_PARTIAL_DIMS", 16)) + ln_scale = bool(int(os.environ.get("LN_SCALE", "1"))) + xsa_layers = int(os.environ.get("XSA_LAYERS", 4)) + eval_stride = int(os.environ.get("EVAL_STRIDE", 0)) + ttt_enabled = bool(int(os.environ.get("TTT_ENABLED", "0"))) + ttt_lr = float(os.environ.get("TTT_LR", 0.005)) + ttt_epochs = int(os.environ.get("TTT_EPOCHS", 6)) + ttt_prequant = bool(int(os.environ.get("TTT_PREQUANT", "1"))) # default ON + ttt_optimizer = os.environ.get("TTT_OPTIMIZER", "adamw") # adamw or sgd + ttt_chunk_tokens = int(os.environ.get("TTT_CHUNK_TOKENS", 32768)) + ttt_freeze_blocks = int(os.environ.get("TTT_FREEZE_BLOCKS", 0)) + ttt_momentum = float(os.environ.get("TTT_MOMENTUM", 0.9)) + ttt_batch_seqs = int(os.environ.get("TTT_BATCH_SEQS", 32)) + ttt_grad_clip = float(os.environ.get("TTT_GRAD_CLIP", 1.0)) + ttt_lora_enabled = bool(int(os.environ.get("TTT_LORA_ENABLED", "0"))) + ttt_lora_rank = int(os.environ.get("TTT_LORA_RANK", 8)) + ttt_lora_lr = float(os.environ.get("TTT_LORA_LR", 0.001)) + recurrence_repeats = int(os.environ.get("RECURRENCE_REPEATS", 0)) + recurrence_unique_head = int(os.environ.get("RECURRENCE_UNIQUE_HEAD", 2)) + recurrence_unique_tail = int(os.environ.get("RECURRENCE_UNIQUE_TAIL", 2)) + moe_num_experts = int(os.environ.get("MOE_NUM_EXPERTS", 0)) +def zeropower_via_newtonschulz5(G: Tensor, steps: int = 10, eps: float = 1e-7) -> Tensor: + a, b, c = (3.4445, -4.7750, 2.0315) + X = G.bfloat16() + X /= X.norm() + eps + transposed = G.size(0) > G.size(1) + if transposed: + X = X.T + for _ in range(steps): + A = X @ X.T + B = b * A + c * A @ A + X = a * X + B @ X + return X.T if transposed else X +class Muon(torch.optim.Optimizer): + def __init__(self, params, lr: float, momentum: float, backend_steps: int, nesterov: bool = True): + super().__init__( + params, + dict(lr=lr, momentum=momentum, backend_steps=backend_steps, nesterov=nesterov), + ) + @torch.no_grad() + def step(self, closure=None): + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + distributed = dist.is_available() and dist.is_initialized() + world_size = dist.get_world_size() if distributed else 1 + rank = dist.get_rank() if distributed else 0 + for group in self.param_groups: + params = group["params"] + if not params: + continue + lr = group["lr"] + momentum = group["momentum"] + backend_steps = group["backend_steps"] + nesterov = group["nesterov"] + total_params = sum(int(p.numel()) for p in params) + updates_flat = torch.zeros(total_params, device=params[0].device, dtype=torch.bfloat16) + curr = 0 + for i, p in enumerate(params): + if i % world_size == rank and p.grad is not None: + g = p.grad + state = self.state[p] + if "momentum_buffer" not in state: + state["momentum_buffer"] = torch.zeros_like(g) + buf = state["momentum_buffer"] + buf.mul_(momentum).add_(g) + if nesterov: + g = g.add(buf, alpha=momentum) + g = zeropower_via_newtonschulz5(g, steps=backend_steps) + g *= max(1, g.size(0) / g.size(1)) ** 0.5 + updates_flat[curr : curr + p.numel()] = g.reshape(-1) + curr += p.numel() + if distributed: + dist.all_reduce(updates_flat, op=dist.ReduceOp.SUM) + curr = 0 + for p in params: + g = updates_flat[curr : curr + p.numel()].view_as(p).to(dtype=p.dtype) + p.add_(g, alpha=-lr) + curr += p.numel() + return loss +def int6_ste(w: Tensor) -> Tensor: + qmax = _QUANT_MAX_VAL # Uses QUANT_BITS override if set (e.g. 15 for int5, 7 for int4) + if w.ndim == 2: + abs_max = w.abs().amax(dim=1, keepdim=True).clamp_min(1e-8) + else: + abs_max = w.abs().amax().clamp_min(1e-8) + scale = abs_max / float(qmax) + w_q = (w / scale).round().clamp(-qmax, qmax) * scale + return w + (w_q - w).detach() +def build_sentencepiece_luts( + sp: spm.SentencePieceProcessor, vocab_size: int, device: torch.device +) -> tuple[Tensor, Tensor, Tensor]: + sp_vocab_size = int(sp.vocab_size()) + table_size = max(sp_vocab_size, vocab_size) + base_bytes_np = np.zeros((table_size,), dtype=np.int16) + has_leading_space_np = np.zeros((table_size,), dtype=np.bool_) + is_boundary_token_np = np.ones((table_size,), dtype=np.bool_) + for token_id in range(sp_vocab_size): + if sp.is_control(token_id) or sp.is_unknown(token_id) or sp.is_unused(token_id): + continue + is_boundary_token_np[token_id] = False + if sp.is_byte(token_id): + base_bytes_np[token_id] = 1 + continue + piece = sp.id_to_piece(token_id) + if piece.startswith("▁"): + has_leading_space_np[token_id] = True + piece = piece[1:] + base_bytes_np[token_id] = len(piece.encode("utf-8")) + return ( + torch.tensor(base_bytes_np, dtype=torch.int16, device=device), + torch.tensor(has_leading_space_np, dtype=torch.bool, device=device), + torch.tensor(is_boundary_token_np, dtype=torch.bool, device=device), + ) +def load_validation_tokens(pattern: str, seq_len: int) -> Tensor: + files = [Path(p) for p in sorted(glob.glob(pattern))] + if not files: + raise FileNotFoundError(f"No files found for pattern: {pattern}") + tokens = torch.cat([load_data_shard(file) for file in files]).contiguous() + max_tokens = int(os.environ.get("VAL_MAX_TOKENS", 0)) + if max_tokens > 0 and tokens.numel() > max_tokens: + tokens = tokens[:max_tokens].contiguous() + usable = ((tokens.numel() - 1) // seq_len) * seq_len + if usable <= 0: + raise ValueError(f"Validation split is too short for TRAIN_SEQ_LEN={seq_len}") + return tokens[: usable + 1] +def eval_val( + args: Hyperparameters, + model: nn.Module, + rank: int, + world_size: int, + device: torch.device, + grad_accum_steps: int, + val_tokens: Tensor, + base_bytes_lut: Tensor, + has_leading_space_lut: Tensor, + is_boundary_token_lut: Tensor, +) -> tuple[float, float]: + local_batch_tokens = args.val_batch_size // (world_size * grad_accum_steps) + if local_batch_tokens < args.train_seq_len: + raise ValueError( + "VAL_BATCH_SIZE must provide at least one sequence per rank; " + f"got VAL_BATCH_SIZE={args.val_batch_size}, WORLD_SIZE={world_size}, " + f"GRAD_ACCUM_STEPS={grad_accum_steps}, TRAIN_SEQ_LEN={args.train_seq_len}" + ) + local_batch_seqs = local_batch_tokens // args.train_seq_len + total_seqs = (val_tokens.numel() - 1) // args.train_seq_len + seq_start = (total_seqs * rank) // world_size + seq_end = (total_seqs * (rank + 1)) // world_size + val_loss_sum = torch.zeros((), device=device, dtype=torch.float64) + val_token_count = torch.zeros((), device=device, dtype=torch.float64) + val_byte_count = torch.zeros((), device=device, dtype=torch.float64) + model.eval() + with torch.inference_mode(): + for batch_seq_start in range(seq_start, seq_end, local_batch_seqs): + batch_seq_end = min(batch_seq_start + local_batch_seqs, seq_end) + raw_start = batch_seq_start * args.train_seq_len + raw_end = batch_seq_end * args.train_seq_len + 1 + local = val_tokens[raw_start:raw_end].to(device=device, dtype=torch.int64, non_blocking=True) + x = local[:-1].reshape(-1, args.train_seq_len) + y = local[1:].reshape(-1, args.train_seq_len) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + batch_loss = model(x, y).detach() + batch_token_count = float(y.numel()) + val_loss_sum += batch_loss.to(torch.float64) * batch_token_count + val_token_count += batch_token_count + prev_ids = x.reshape(-1) + tgt_ids = y.reshape(-1) + token_bytes = base_bytes_lut[tgt_ids].to(dtype=torch.int16) + token_bytes += (has_leading_space_lut[tgt_ids] & ~is_boundary_token_lut[prev_ids]).to(dtype=torch.int16) + val_byte_count += token_bytes.to(torch.float64).sum() + if dist.is_available() and dist.is_initialized(): + dist.all_reduce(val_loss_sum, op=dist.ReduceOp.SUM) + dist.all_reduce(val_token_count, op=dist.ReduceOp.SUM) + dist.all_reduce(val_byte_count, op=dist.ReduceOp.SUM) + val_loss = val_loss_sum / val_token_count + bits_per_token = val_loss.item() / math.log(2.0) + tokens_per_byte = val_token_count.item() / val_byte_count.item() + model.train() + return float(val_loss.item()), float(bits_per_token * tokens_per_byte) +def eval_val_sliding( + args: Hyperparameters, + base_model: nn.Module, + rank: int, + world_size: int, + device: torch.device, + val_tokens: Tensor, + base_bytes_lut: Tensor, + has_leading_space_lut: Tensor, + is_boundary_token_lut: Tensor, + stride: int, + batch_seqs: int = 32, +) -> tuple[float, float]: + seq_len = args.train_seq_len + total_tokens = val_tokens.numel() - 1 + window_starts = [ws for ws in range(0, total_tokens, stride) + if min(ws + seq_len, total_tokens) - ws >= 1] + total_windows = len(window_starts) + my_s = (total_windows * rank) // world_size + my_e = (total_windows * (rank + 1)) // world_size + my_windows = window_starts[my_s:my_e] + loss_sum = torch.zeros((), device=device, dtype=torch.float64) + token_count = torch.zeros((), device=device, dtype=torch.float64) + byte_count = torch.zeros((), device=device, dtype=torch.float64) + base_model.eval() + compiled_logits = torch.compile(base_model.forward_logits, dynamic=False, fullgraph=True) + with torch.inference_mode(): + for bi in range(0, len(my_windows), batch_seqs): + batch_ws = my_windows[bi:bi + batch_seqs] + bsz = len(batch_ws) + x_batch = torch.zeros(bsz, seq_len, dtype=torch.int64, device=device) + y_batch = torch.zeros(bsz, seq_len, dtype=torch.int64, device=device) + wlens: list[int] = [] + for i, ws in enumerate(batch_ws): + end = min(ws + seq_len, total_tokens) + wlen = end - ws + wlens.append(wlen) + chunk = val_tokens[ws:end + 1].to(dtype=torch.int64, device=device) + x_batch[i, :wlen] = chunk[:-1] + y_batch[i, :wlen] = chunk[1:] + with torch.autocast(device_type="cuda", dtype=torch.bfloat16): + logits = compiled_logits(x_batch) + nll = F.cross_entropy( + logits.reshape(-1, logits.size(-1)).float(), + y_batch.reshape(-1), + reduction="none", + ).reshape(bsz, seq_len) + for i, ws in enumerate(batch_ws): + wlen = wlens[i] + s = 0 if ws == 0 else max(wlen - stride, 0) + scored_nll = nll[i, s:wlen].to(torch.float64) + loss_sum += scored_nll.sum() + token_count += float(wlen - s) + tgt = y_batch[i, s:wlen] + prev = x_batch[i, s:wlen] + tb = base_bytes_lut[tgt].to(torch.float64) + tb += (has_leading_space_lut[tgt] & ~is_boundary_token_lut[prev]).to(torch.float64) + byte_count += tb.sum() + if dist.is_available() and dist.is_initialized(): + dist.all_reduce(loss_sum, op=dist.ReduceOp.SUM) + dist.all_reduce(token_count, op=dist.ReduceOp.SUM) + dist.all_reduce(byte_count, op=dist.ReduceOp.SUM) + val_loss = (loss_sum / token_count).item() + bits_per_token = val_loss / math.log(2.0) + tokens_per_byte = token_count.item() / byte_count.item() + base_model.train() + return val_loss, bits_per_token * tokens_per_byte +class LoRAAdapter(nn.Module): + """Low-Rank Adaptation adapter for a linear layer.""" + def __init__(self, base_linear: nn.Linear, rank: int = 8): + super().__init__() + dim_in = base_linear.in_features + dim_out = base_linear.out_features + self.A = nn.Parameter(torch.empty(dim_in, rank)) + self.B = nn.Parameter(torch.zeros(rank, dim_out)) + nn.init.kaiming_uniform_(self.A, a=math.sqrt(5)) + self.base_linear = base_linear + self._hook_handle = None + + def _lora_hook(self, module: nn.Module, input: tuple[Tensor, ...], output: Tensor) -> Tensor: + """Forward hook that adds LoRA contribution to base linear output.""" + x = input[0] + lora_delta = (x @ self.A.to(x.dtype)) @ self.B.to(x.dtype) + return output + lora_delta + + def register(self) -> None: + """Register the LoRA hook on the base linear layer.""" + if self._hook_handle is None: + self._hook_handle = self.base_linear.register_forward_hook(self._lora_hook) + + def remove(self) -> None: + """Remove the LoRA hook from the base linear layer.""" + if self._hook_handle is not None: + self._hook_handle.remove() + self._hook_handle = None + +def eval_val_sliding_lora_ttt( + args: Hyperparameters, + base_model: nn.Module, + rank: int, + world_size: int, + device: torch.device, + val_tokens: Tensor, + base_bytes_lut: Tensor, + has_leading_space_lut: Tensor, + is_boundary_token_lut: Tensor, + stride: int, + batch_seqs: int = 32, + log0=print, +) -> tuple[float, float]: + """Test-time training with LoRA adapters (only train adapter weights, not base model).""" + seq_len = args.train_seq_len + total_tokens = val_tokens.numel() - 1 + ttt_chunk = args.ttt_chunk_tokens + + window_starts = [ws for ws in range(0, total_tokens, stride) + if min(ws + seq_len, total_tokens) - ws >= stride or ws == 0] + num_chunks = (total_tokens + ttt_chunk - 1) // ttt_chunk + + chunk_windows: list[list[int]] = [[] for _ in range(num_chunks)] + for ws in window_starts: + end = min(ws + seq_len, total_tokens) + wlen = end - ws + s = 0 if ws == 0 else max(wlen - stride, 0) + scored_start = ws + s + ci = min(scored_start // ttt_chunk, num_chunks - 1) + chunk_windows[ci].append(ws) + + log0(f"lora_ttt_sliding:start chunks={num_chunks} chunk_tokens={ttt_chunk} " + f"total_windows={len(window_starts)} stride={stride} " + f"lora_rank={args.ttt_lora_rank} lora_lr={args.ttt_lora_lr} " + f"ttt_epochs={args.ttt_epochs}") + + loss_sum = torch.zeros((), device=device, dtype=torch.float64) + token_count = torch.zeros((), device=device, dtype=torch.float64) + byte_count = torch.zeros((), device=device, dtype=torch.float64) + + # Create LoRA adapters for attention projections and LM head + lora_adapters: list[LoRAAdapter] = [] + + # Add LoRA to attention blocks (c_q, c_v, proj) + if args.recurrence_repeats > 0: + # Recurrent model structure + all_blocks = list(base_model.head_blocks) + [base_model.shared_block] + list(base_model.tail_blocks) + else: + # Standard block structure + all_blocks = list(base_model.blocks) + + for block in all_blocks: + if hasattr(block, 'attn'): + attn = block.attn + # Add LoRA to Q, V, and output projection + for layer_name in ['c_q', 'c_v', 'proj']: + if hasattr(attn, layer_name): + base_linear = getattr(attn, layer_name) + adapter = LoRAAdapter(base_linear, rank=args.ttt_lora_rank).to(device) + lora_adapters.append(adapter) + + # Add LoRA to LM head (or tok_emb if tied) + if base_model.tie_embeddings: + # For tied embeddings, we need to patch tok_emb.weight usage in forward + # Create a pseudo-linear wrapper for the embedding weight + class EmbeddingAsLinear(nn.Module): + def __init__(self, weight: nn.Parameter): + super().__init__() + self.weight = weight + self.in_features = weight.size(1) + self.out_features = weight.size(0) + self.bias = None + + def forward(self, x: Tensor) -> Tensor: + return F.linear(x, self.weight) + + emb_linear = EmbeddingAsLinear(base_model.tok_emb.weight).to(device) + lm_adapter = LoRAAdapter(emb_linear, rank=args.ttt_lora_rank).to(device) + lora_adapters.append(lm_adapter) + elif base_model.lm_head is not None: + lm_adapter = LoRAAdapter(base_model.lm_head, rank=args.ttt_lora_rank).to(device) + lora_adapters.append(lm_adapter) + + # Collect all LoRA parameters + lora_params = [] + for adapter in lora_adapters: + lora_params.extend([adapter.A, adapter.B]) + + log0(f"lora_ttt_sliding:params lora={sum(p.numel() for p in lora_params)} " + f"base_frozen={sum(p.numel() for p in base_model.parameters())} " + f"adapters={len(lora_adapters)}") + + # Freeze all base model parameters + for p in base_model.parameters(): + p.requires_grad_(False) + + # Register all LoRA hooks + for adapter in lora_adapters: + adapter.register() + + # Use AdamW optimizer for LoRA params + optimizer = torch.optim.AdamW( + lora_params, + lr=args.ttt_lora_lr, + betas=(0.9, 0.999), + weight_decay=0.01, + ) + + # Compile forward paths + if not int(os.environ.get("TORCHDYNAMO_DISABLE", "0")): + compiled_logits_lora = torch.compile(base_model.forward_logits, dynamic=False, fullgraph=True) + compiled_forward_lora = torch.compile(base_model.forward, dynamic=False, fullgraph=True) + else: + compiled_logits_lora = base_model.forward_logits + compiled_forward_lora = base_model.forward + + t0 = time.perf_counter() + + for ci in range(num_chunks): + windows = chunk_windows[ci] + if not windows: + continue + + chunk_start = ci * ttt_chunk + chunk_end = min((ci + 1) * ttt_chunk, total_tokens) + + my_s = (len(windows) * rank) // world_size + my_e = (len(windows) * (rank + 1)) // world_size + my_windows = windows[my_s:my_e] + + # Score phase (inference only) + base_model.eval() + with torch.inference_mode(): + for bi in range(0, len(my_windows), batch_seqs): + batch_ws = my_windows[bi:bi + batch_seqs] + bsz = len(batch_ws) + x_batch = torch.zeros(bsz, seq_len, dtype=torch.int64, device=device) + y_batch = torch.zeros(bsz, seq_len, dtype=torch.int64, device=device) + wlens: list[int] = [] + + for i, ws in enumerate(batch_ws): + end = min(ws + seq_len, total_tokens) + wlen = end - ws + wlens.append(wlen) + chunk_tok = val_tokens[ws:end + 1].to(dtype=torch.int64, device=device) + x_batch[i, :wlen] = chunk_tok[:-1] + y_batch[i, :wlen] = chunk_tok[1:] + + with torch.autocast(device_type="cuda", dtype=torch.bfloat16): + logits = compiled_logits_lora(x_batch) + + nll = F.cross_entropy( + logits.reshape(-1, logits.size(-1)).float(), + y_batch.reshape(-1), reduction="none", + ).reshape(bsz, seq_len) + + for i, ws in enumerate(batch_ws): + wlen = wlens[i] + s = 0 if ws == 0 else max(wlen - stride, 0) + scored_nll = nll[i, s:wlen].to(torch.float64) + loss_sum += scored_nll.sum() + token_count += float(wlen - s) + + tgt, prev = y_batch[i, s:wlen], x_batch[i, s:wlen] + tb = base_bytes_lut[tgt].to(torch.float64) + tb += (has_leading_space_lut[tgt] & ~is_boundary_token_lut[prev]).to(torch.float64) + byte_count += tb.sum() + + # Train phase (update LoRA params only, skip last chunk) + is_last_chunk = (ci == num_chunks - 1) + if not is_last_chunk and args.ttt_epochs > 0: + base_model.train() + chunk_seqs = (chunk_end - chunk_start) // seq_len + + if chunk_seqs > 0: + # Cosine LR decay + cos_lr = args.ttt_lora_lr * 0.5 * (1.0 + math.cos(math.pi * ci / max(num_chunks - 1, 1))) + for pg in optimizer.param_groups: + pg['lr'] = cos_lr + + my_seq_s = (chunk_seqs * rank) // world_size + my_seq_e = (chunk_seqs * (rank + 1)) // world_size + my_chunk_seqs = my_seq_e - my_seq_s + + for _ep in range(args.ttt_epochs): + for bs in range(0, my_chunk_seqs, args.ttt_batch_seqs): + be = min(bs + args.ttt_batch_seqs, my_chunk_seqs) + actual_bs = my_seq_s + bs + start_tok = chunk_start + actual_bs * seq_len + end_tok = chunk_start + (my_seq_s + be) * seq_len + 1 + + if end_tok > val_tokens.numel(): + continue + + local = val_tokens[start_tok:end_tok].to(device=device, dtype=torch.int64) + x = local[:-1].reshape(-1, seq_len) + y = local[1:].reshape(-1, seq_len) + + optimizer.zero_grad(set_to_none=True) + + with torch.autocast(device_type="cuda", dtype=torch.bfloat16): + loss = compiled_forward_lora(x, y) + + loss.backward() + + # All-reduce LoRA gradients across GPUs + if world_size > 1: + for p in lora_params: + if p.grad is not None: + dist.all_reduce(p.grad, op=dist.ReduceOp.AVG) + + # Gradient clipping on LoRA params + torch.nn.utils.clip_grad_norm_(lora_params, args.ttt_grad_clip) + optimizer.step() + + if rank == 0 and (ci % 10 == 0 or ci == num_chunks - 1): + elapsed = time.perf_counter() - t0 + rl = loss_sum.item() / max(token_count.item(), 1) + rbpb = rl / math.log(2.0) * (token_count.item() / max(byte_count.item(), 1)) if token_count.item() > 0 else 0.0 + log0(f" lora_ttt_chunk [{ci+1}/{num_chunks}] bpb={rbpb:.6f} time={elapsed:.1f}s") + + if dist.is_available() and dist.is_initialized(): + dist.all_reduce(loss_sum, op=dist.ReduceOp.SUM) + dist.all_reduce(token_count, op=dist.ReduceOp.SUM) + dist.all_reduce(byte_count, op=dist.ReduceOp.SUM) + + val_loss = (loss_sum / token_count).item() + val_bpb = val_loss / math.log(2.0) * (token_count.item() / byte_count.item()) + + # Clean up: remove all LoRA hooks and restore base model + for adapter in lora_adapters: + adapter.remove() + + # Restore requires_grad on base model + for p in base_model.parameters(): + p.requires_grad_(True) + + base_model.eval() + + log0(f"lora_ttt_sliding:done val_loss={val_loss:.6f} val_bpb={val_bpb:.6f} " + f"elapsed={time.perf_counter() - t0:.1f}s") + + return val_loss, val_bpb + +def eval_val_sliding_ttt( + args: Hyperparameters, + base_model: nn.Module, + rank: int, + world_size: int, + device: torch.device, + val_tokens: Tensor, + base_bytes_lut: Tensor, + has_leading_space_lut: Tensor, + is_boundary_token_lut: Tensor, + stride: int, + batch_seqs: int = 32, + log0=print, +) -> tuple[float, float]: + seq_len = args.train_seq_len + total_tokens = val_tokens.numel() - 1 + ttt_chunk = args.ttt_chunk_tokens + window_starts = [ws for ws in range(0, total_tokens, stride) + if min(ws + seq_len, total_tokens) - ws >= stride or ws == 0] + num_chunks = (total_tokens + ttt_chunk - 1) // ttt_chunk + chunk_windows: list[list[int]] = [[] for _ in range(num_chunks)] + for ws in window_starts: + end = min(ws + seq_len, total_tokens) + wlen = end - ws + s = 0 if ws == 0 else max(wlen - stride, 0) + scored_start = ws + s + ci = min(scored_start // ttt_chunk, num_chunks - 1) + chunk_windows[ci].append(ws) + log0(f"ttt_sliding:start chunks={num_chunks} chunk_tokens={ttt_chunk} " + f"total_windows={len(window_starts)} stride={stride} " + f"ttt_lr={args.ttt_lr} ttt_epochs={args.ttt_epochs} " + f"freeze_blocks={args.ttt_freeze_blocks}") + loss_sum = torch.zeros((), device=device, dtype=torch.float64) + token_count = torch.zeros((), device=device, dtype=torch.float64) + byte_count = torch.zeros((), device=device, dtype=torch.float64) + frozen_block_ids = set(range(min(args.ttt_freeze_blocks, len(base_model.blocks)))) + ttt_params = [] + for name, p in base_model.named_parameters(): + freeze = False + for bi in frozen_block_ids: + if f"blocks.{bi}." in name: + freeze = True + break + if freeze: + p.requires_grad_(False) + else: + p.requires_grad_(True) + ttt_params.append(p) + log0(f"ttt_sliding:params unfrozen={sum(p.numel() for p in ttt_params)} " + f"frozen={sum(p.numel() for p in base_model.parameters() if not p.requires_grad)}") + if args.ttt_optimizer == "adamw": + ttt_opt = torch.optim.AdamW(ttt_params, lr=args.ttt_lr, weight_decay=0.0) + else: + ttt_opt = torch.optim.SGD(ttt_params, lr=args.ttt_lr, momentum=args.ttt_momentum) + optimizer = ttt_opt + # Compile forward paths for TTT (matches eval_val_sliding behavior) + if not int(os.environ.get("TORCHDYNAMO_DISABLE", "0")): + compiled_logits_ttt = torch.compile(base_model.forward_logits, dynamic=False, fullgraph=True) + compiled_forward_ttt = torch.compile(base_model.forward, dynamic=False, fullgraph=True) + else: + compiled_logits_ttt = base_model.forward_logits + compiled_forward_ttt = base_model.forward + t0 = time.perf_counter() + for ci in range(num_chunks): + windows = chunk_windows[ci] + if not windows: + continue + chunk_start = ci * ttt_chunk + chunk_end = min((ci + 1) * ttt_chunk, total_tokens) + my_s = (len(windows) * rank) // world_size + my_e = (len(windows) * (rank + 1)) // world_size + my_windows = windows[my_s:my_e] + base_model.eval() + with torch.inference_mode(): + for bi in range(0, len(my_windows), batch_seqs): + batch_ws = my_windows[bi:bi + batch_seqs] + bsz = len(batch_ws) + x_batch = torch.zeros(bsz, seq_len, dtype=torch.int64, device=device) + y_batch = torch.zeros(bsz, seq_len, dtype=torch.int64, device=device) + wlens: list[int] = [] + for i, ws in enumerate(batch_ws): + end = min(ws + seq_len, total_tokens) + wlen = end - ws + wlens.append(wlen) + chunk_tok = val_tokens[ws:end + 1].to(dtype=torch.int64, device=device) + x_batch[i, :wlen] = chunk_tok[:-1] + y_batch[i, :wlen] = chunk_tok[1:] + with torch.autocast(device_type="cuda", dtype=torch.bfloat16): + logits = compiled_logits_ttt(x_batch) + nll = F.cross_entropy( + logits.reshape(-1, logits.size(-1)).float(), + y_batch.reshape(-1), reduction="none", + ).reshape(bsz, seq_len) + for i, ws in enumerate(batch_ws): + wlen = wlens[i] + s = 0 if ws == 0 else max(wlen - stride, 0) + scored_nll = nll[i, s:wlen].to(torch.float64) + loss_sum += scored_nll.sum() + token_count += float(wlen - s) + tgt, prev = y_batch[i, s:wlen], x_batch[i, s:wlen] + tb = base_bytes_lut[tgt].to(torch.float64) + tb += (has_leading_space_lut[tgt] & ~is_boundary_token_lut[prev]).to(torch.float64) + byte_count += tb.sum() + is_last_chunk = (ci == num_chunks - 1) + if not is_last_chunk and args.ttt_epochs > 0: + base_model.train() + chunk_seqs = (chunk_end - chunk_start) // seq_len + if chunk_seqs > 0: + cos_lr = args.ttt_lr * 0.5 * (1.0 + math.cos(math.pi * ci / max(num_chunks - 1, 1))) + for pg in optimizer.param_groups: + pg['lr'] = cos_lr + my_seq_s = (chunk_seqs * rank) // world_size + my_seq_e = (chunk_seqs * (rank + 1)) // world_size + my_chunk_seqs = my_seq_e - my_seq_s + for _ep in range(args.ttt_epochs): + for bs in range(0, my_chunk_seqs, args.ttt_batch_seqs): + be = min(bs + args.ttt_batch_seqs, my_chunk_seqs) + actual_bs = my_seq_s + bs + start_tok = chunk_start + actual_bs * seq_len + end_tok = chunk_start + (my_seq_s + be) * seq_len + 1 + if end_tok > val_tokens.numel(): + continue + local = val_tokens[start_tok:end_tok].to(device=device, dtype=torch.int64) + x = local[:-1].reshape(-1, seq_len) + y = local[1:].reshape(-1, seq_len) + optimizer.zero_grad(set_to_none=True) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16): + loss = compiled_forward_ttt(x, y) + loss.backward() + if world_size > 1: + for p in ttt_params: + if p.grad is not None: + dist.all_reduce(p.grad, op=dist.ReduceOp.AVG) + torch.nn.utils.clip_grad_norm_(ttt_params, args.ttt_grad_clip) + optimizer.step() + if rank == 0 and (ci % 10 == 0 or ci == num_chunks - 1): + elapsed = time.perf_counter() - t0 + rl = loss_sum.item() / max(token_count.item(), 1) + rbpb = rl / math.log(2.0) * (token_count.item() / max(byte_count.item(), 1)) if token_count.item() > 0 else 0.0 + log0(f" ttt_chunk [{ci+1}/{num_chunks}] bpb={rbpb:.6f} time={elapsed:.1f}s") + if dist.is_available() and dist.is_initialized(): + dist.all_reduce(loss_sum, op=dist.ReduceOp.SUM) + dist.all_reduce(token_count, op=dist.ReduceOp.SUM) + dist.all_reduce(byte_count, op=dist.ReduceOp.SUM) + val_loss = (loss_sum / token_count).item() + val_bpb = val_loss / math.log(2.0) * (token_count.item() / byte_count.item()) + for p in base_model.parameters(): + p.requires_grad_(True) + base_model.eval() + log0(f"ttt_sliding:done val_loss={val_loss:.6f} val_bpb={val_bpb:.6f} " + f"elapsed={time.perf_counter() - t0:.1f}s") + return val_loss, val_bpb +_INT6_MODE = bool(int(os.environ.get("INT6_QAT", "1"))) +_INT3_MODE = bool(int(os.environ.get("INT3_QUANT", "0"))) # Ternary quantization +_QUANT_OVERRIDE = int(os.environ.get("QUANT_BITS", "0")) # Override: 3,4,5,6,8 +if _QUANT_OVERRIDE > 0: + _QUANT_BITS = _QUANT_OVERRIDE + _QUANT_MAX_VAL = (1 << (_QUANT_BITS - 1)) - 1 # e.g. int5 -> 15, int4 -> 7 +elif _INT3_MODE: + _QUANT_MAX_VAL = 3 # Ternary: -3, -2, -1, 0, 1, 2, 3 + _QUANT_BITS = 3 +elif _INT6_MODE: + _QUANT_MAX_VAL = 31 + _QUANT_BITS = 6 +else: + _QUANT_MAX_VAL = 127 + _QUANT_BITS = 8 +CONTROL_TENSOR_NAME_PATTERNS = tuple( + pattern + for pattern in os.environ.get( + "CONTROL_TENSOR_NAME_PATTERNS", + "attn_scale,attn_scales,mlp_scale,mlp_scales,resid_mix,resid_mixes,q_gain,skip_weight,skip_weights,smear", + ).split(",") + if pattern +) +INT8_KEEP_FLOAT_FP32_NAME_PATTERNS = tuple( + pattern + for pattern in os.environ.get( + "INT8_KEEP_FLOAT_FP32_NAME_PATTERNS", + ",".join(CONTROL_TENSOR_NAME_PATTERNS), + ).split(",") + if pattern +) +INT8_KEEP_FLOAT_MAX_NUMEL = 65_536 +INT8_KEEP_FLOAT_STORE_DTYPE = torch.float16 +INT8_PER_ROW_SCALE_DTYPE = torch.float16 +INT8_CLIP_PERCENTILE = 99.99984 +INT8_CLIP_Q = INT8_CLIP_PERCENTILE / 100.0 +def tensor_nbytes(t: Tensor) -> int: + return int(t.numel()) * int(t.element_size()) +def keep_float_tensor(name: str, t: Tensor, passthrough_orig_dtypes: dict[str, str]) -> Tensor: + if any(pattern in name for pattern in INT8_KEEP_FLOAT_FP32_NAME_PATTERNS): + return t.float().contiguous() + if t.dtype in {torch.float32, torch.bfloat16}: + passthrough_orig_dtypes[name] = str(t.dtype).removeprefix("torch.") + return t.to(dtype=INT8_KEEP_FLOAT_STORE_DTYPE).contiguous() + return t +_GPTQ_LITE_PERCENTILES = [0.9999, 0.99995, 0.99999, 0.999995, 1.0] +def quantize_float_tensor(t: Tensor, gpu_device: torch.device | None = None) -> tuple[Tensor, Tensor]: + # Run GPTQ-lite search on GPU when available (much faster for large 2D weight matrices) + dev = gpu_device if (gpu_device is not None and t.ndim == 2 and t.numel() > 0) else t.device + t32 = t.float().to(dev) + qmax = _QUANT_MAX_VAL + if t32.ndim == 2 and t32.numel() > 0: + abs_vals = t32.abs() + best_q = None + best_scale = None + best_mse = None + for pct in _GPTQ_LITE_PERCENTILES: + if pct >= 1.0: + clip_abs = abs_vals.amax(dim=1) + else: + clip_abs = torch.quantile(abs_vals, pct, dim=1) + clip_abs = clip_abs.clamp_min(1e-8) + scale = (clip_abs / qmax).clamp_min(1.0 / qmax) + clipped = torch.maximum(torch.minimum(t32, clip_abs[:, None]), -clip_abs[:, None]) + q = torch.clamp(torch.round(clipped / scale[:, None]), -qmax, qmax) + recon = q * scale[:, None] + mse = (t32 - recon).square().mean(dim=1) + if best_mse is None: + best_mse = mse + best_q = q + best_scale = scale + else: + improved = mse < best_mse + if improved.any(): + best_mse = torch.where(improved, mse, best_mse) + best_q = torch.where(improved[:, None], q, best_q) + best_scale = torch.where(improved, scale, best_scale) + return best_q.cpu().to(torch.int8).contiguous(), best_scale.cpu().to(dtype=INT8_PER_ROW_SCALE_DTYPE).contiguous() + if t32.ndim == 2: + return torch.empty_like(t32, dtype=torch.int8, device="cpu"), torch.empty((t32.shape[0],), dtype=INT8_PER_ROW_SCALE_DTYPE) + t32 = t32.cpu() + clip_abs = float(torch.quantile(t32.abs().flatten(), INT8_CLIP_Q).item()) if t32.numel() else 0.0 + scale = torch.tensor(clip_abs / qmax if clip_abs > 0 else 1.0, dtype=torch.float32) + q = torch.clamp(torch.round(torch.clamp(t32, -clip_abs, clip_abs) / scale), -qmax, qmax).to(torch.int8).contiguous() + return q, scale +def quantize_state_dict_int8(state_dict: dict[str, Tensor], gpu_device: torch.device | None = None, + rank: int = 0, world_size: int = 1): + """Quantize state dict. When world_size > 1, distributes large tensor quantization across ranks.""" + quantized: dict[str, Tensor] = {} + scales: dict[str, Tensor] = {} + dtypes: dict[str, str] = {} + passthrough: dict[str, Tensor] = {} + passthrough_orig_dtypes: dict[str, str] = {} + qmeta: dict[str, dict[str, object]] = {} + stats = dict.fromkeys( + ("param_count", "num_tensors", "num_float_tensors", "num_nonfloat_tensors", "baseline_tensor_bytes", "int8_payload_bytes"), + 0, + ) + # Separate tensors into passthrough vs needs-quantization + to_quantize: list[tuple[str, Tensor]] = [] + for name, tensor in state_dict.items(): + t = tensor.detach().to("cpu").contiguous() + stats["param_count"] += int(t.numel()) + stats["num_tensors"] += 1 + stats["baseline_tensor_bytes"] += tensor_nbytes(t) + if not t.is_floating_point(): + stats["num_nonfloat_tensors"] += 1 + passthrough[name] = t + stats["int8_payload_bytes"] += tensor_nbytes(t) + continue + if t.numel() <= INT8_KEEP_FLOAT_MAX_NUMEL: + kept = keep_float_tensor(name, t, passthrough_orig_dtypes) + passthrough[name] = kept + stats["int8_payload_bytes"] += tensor_nbytes(kept) + continue + stats["num_float_tensors"] += 1 + to_quantize.append((name, t)) + # Distribute quantization work across ranks (round-robin by tensor index) + for idx, (name, t) in enumerate(to_quantize): + if world_size > 1 and idx % world_size != rank: + continue # Another rank handles this tensor + q, s = quantize_float_tensor(t, gpu_device=gpu_device) + if s.ndim > 0: + qmeta[name] = {"scheme": "per_row", "axis": 0} + quantized[name] = q + scales[name] = s + dtypes[name] = str(t.dtype).removeprefix("torch.") + stats["int8_payload_bytes"] += tensor_nbytes(q) + tensor_nbytes(s) + # Gather results from all ranks + if world_size > 1 and dist.is_available() and dist.is_initialized(): + # Each rank broadcasts its quantized tensors to all others + for idx, (name, t) in enumerate(to_quantize): + owner = idx % world_size + if owner == rank: + # Broadcast shape info then tensor data + q_tensor = quantized[name] + s_tensor = scales[name] + # Send via broadcast + q_gpu = q_tensor.to(gpu_device) if gpu_device is not None else q_tensor.cuda() + s_gpu = s_tensor.to(gpu_device) if gpu_device is not None else s_tensor.cuda() + dist.broadcast(q_gpu, src=owner) + dist.broadcast(s_gpu, src=owner) + quantized[name] = q_gpu.cpu() + scales[name] = s_gpu.cpu() + else: + # Allocate matching tensors and receive + q_shape = (t.shape[0],) + t.shape[1:] if t.ndim == 2 else t.shape + q_gpu = torch.empty(q_shape, dtype=torch.int8, device=gpu_device if gpu_device is not None else "cuda") + if t.ndim == 2 and t.numel() > 0: + s_gpu = torch.empty(t.shape[0], dtype=INT8_PER_ROW_SCALE_DTYPE, device=gpu_device if gpu_device is not None else "cuda") + else: + s_gpu = torch.empty((), dtype=torch.float32, device=gpu_device if gpu_device is not None else "cuda") + dist.broadcast(q_gpu, src=owner) + dist.broadcast(s_gpu, src=owner) + quantized[name] = q_gpu.cpu() + scales[name] = s_gpu.cpu() + dtypes[name] = str(t.dtype).removeprefix("torch.") + if s_gpu.ndim > 0: + qmeta[name] = {"scheme": "per_row", "axis": 0} + stats["int8_payload_bytes"] += tensor_nbytes(quantized[name]) + tensor_nbytes(scales[name]) + obj: dict[str, object] = { + "__quant_format__": f"int{_QUANT_BITS}_clean_per_row_v1", + "quantized": quantized, + "scales": scales, + "dtypes": dtypes, + "passthrough": passthrough, + } + if qmeta: + obj["qmeta"] = qmeta + if passthrough_orig_dtypes: + obj["passthrough_orig_dtypes"] = passthrough_orig_dtypes + return obj, stats +def dequantize_state_dict_int8(obj: dict[str, object]) -> dict[str, Tensor]: + out: dict[str, Tensor] = {} + qmeta = obj.get("qmeta", {}) + passthrough_orig_dtypes = obj.get("passthrough_orig_dtypes", {}) + for name, q in obj["quantized"].items(): + dtype = getattr(torch, obj["dtypes"][name]) + s = obj["scales"][name] + if qmeta.get(name, {}).get("scheme") == "per_row" or s.ndim > 0: + s = s.to(dtype=torch.float32) + out[name] = (q.float() * s.view(q.shape[0], *([1] * (q.ndim - 1)))).to(dtype=dtype).contiguous() + else: + scale = float(s.item()) + out[name] = (q.float() * scale).to(dtype=dtype).contiguous() + for name, t in obj["passthrough"].items(): + out_t = t.detach().to("cpu").contiguous() + orig_dtype = passthrough_orig_dtypes.get(name) + if isinstance(orig_dtype, str): + out_t = out_t.to(dtype=getattr(torch, orig_dtype)).contiguous() + out[name] = out_t + return out +def load_data_shard(file: Path) -> Tensor: + header_bytes = 256 * np.dtype(" None: + self.file_idx = (self.file_idx + 1) % len(self.files) + self.tokens = load_data_shard(self.files[self.file_idx]) + self.pos = 0 + def take(self, n: int) -> Tensor: + chunks: list[Tensor] = [] + remaining = n + while remaining > 0: + avail = self.tokens.numel() - self.pos + if avail <= 0: + self._advance_file() + continue + k = min(remaining, avail) + chunks.append(self.tokens[self.pos : self.pos + k]) + self.pos += k + remaining -= k + return chunks[0] if len(chunks) == 1 else torch.cat(chunks) +class DistributedTokenLoader: + def __init__(self, pattern: str, rank: int, world_size: int, device: torch.device): + self.rank = rank + self.world_size = world_size + self.device = device + self.stream = TokenStream(pattern) + def next_batch(self, global_tokens: int, seq_len: int, grad_accum_steps: int) -> tuple[Tensor, Tensor]: + local_tokens = global_tokens // (self.world_size * grad_accum_steps) + per_rank_span = local_tokens + 1 + chunk = self.stream.take(per_rank_span * self.world_size) + start = self.rank * per_rank_span + local = chunk[start : start + per_rank_span].to(dtype=torch.int64) + x = local[:-1].reshape(-1, seq_len) + y = local[1:].reshape(-1, seq_len) + return x.to(self.device, non_blocking=True), y.to(self.device, non_blocking=True) +class RMSNorm(nn.Module): + def __init__(self, eps: float | None = None): + super().__init__() + self.eps = eps + def forward(self, x: Tensor) -> Tensor: + return F.rms_norm(x, (x.size(-1),), eps=self.eps) +class CastedLinear(nn.Linear): + _int6_qat: bool = False + def forward(self, x: Tensor) -> Tensor: + w = self.weight.to(x.dtype) + if CastedLinear._int6_qat and self.training: + w = int6_ste(w) + bias = self.bias.to(x.dtype) if self.bias is not None else None + return F.linear(x, w, bias) +def restore_low_dim_params_to_fp32(module: nn.Module) -> None: + with torch.no_grad(): + for name, param in module.named_parameters(): + if (param.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS)) and param.dtype != torch.float32: + param.data = param.data.float() +class BigramHash(nn.Module): + def __init__(self, num_buckets: int, dim: int): + super().__init__() + self.num_buckets = num_buckets + self.emb = nn.Embedding(num_buckets, dim) + nn.init.zeros_(self.emb.weight) + def forward(self, input_ids: Tensor) -> Tensor: + prev = F.pad(input_ids[:, :-1], (1, 0), value=0) + bigram_hash = (prev * 31 + input_ids) % self.num_buckets + return self.emb(bigram_hash) +class SmearGate(nn.Module): + def __init__(self, dim: int): + super().__init__() + self.gate = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + def forward(self, x: Tensor) -> Tensor: + g = torch.sigmoid(self.gate.to(dtype=x.dtype))[None, None, :] + positions = torch.arange(1, x.shape[1] + 1, device=x.device, dtype=x.dtype).view(1, -1, 1) + smooth = torch.cumsum(x, dim=1) / positions + return g * x + (1 - g) * smooth +class Rotary(nn.Module): + def __init__(self, dim: int, base: float = 10000.0): + super().__init__() + inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.float32) / dim)) + self.register_buffer("inv_freq", inv_freq, persistent=False) + self._seq_len_cached = 0 + self._cos_cached: Tensor | None = None + self._sin_cached: Tensor | None = None + def forward(self, seq_len: int, device: torch.device, dtype: torch.dtype) -> tuple[Tensor, Tensor]: + if ( + self._cos_cached is None + or self._sin_cached is None + or self._seq_len_cached != seq_len + or self._cos_cached.device != device + ): + t = torch.arange(seq_len, device=device, dtype=self.inv_freq.dtype) + freqs = torch.outer(t, self.inv_freq.to(device)) + self._cos_cached = freqs.cos()[None, None, :, :] + self._sin_cached = freqs.sin()[None, None, :, :] + self._seq_len_cached = seq_len + return self._cos_cached.to(dtype=dtype), self._sin_cached.to(dtype=dtype) +def apply_rotary_emb(x: Tensor, cos: Tensor, sin: Tensor) -> Tensor: + half = x.size(-1) // 2 + x1, x2 = x[..., :half], x[..., half:] + return torch.cat((x1 * cos + x2 * sin, x1 * (-sin) + x2 * cos), dim=-1) +class CausalSelfAttention(nn.Module): + def __init__( + self, + dim: int, + num_heads: int, + num_kv_heads: int, + rope_base: float, + qk_gain_init: float, + rope_partial_dims: int = 0, + use_xsa: bool = False, + ): + super().__init__() + if dim % num_heads != 0: + raise ValueError("model_dim must be divisible by num_heads") + if num_heads % num_kv_heads != 0: + raise ValueError("num_heads must be divisible by num_kv_heads") + self.num_heads = num_heads + self.num_kv_heads = num_kv_heads + self.head_dim = dim // num_heads + if self.head_dim % 2 != 0: + raise ValueError("head_dim must be even for RoPE") + kv_dim = self.num_kv_heads * self.head_dim + self.c_q = CastedLinear(dim, dim, bias=False) + self.c_k = CastedLinear(dim, kv_dim, bias=False) + self.c_v = CastedLinear(dim, kv_dim, bias=False) + self.proj = CastedLinear(dim, dim, bias=False) + self.proj._zero_init = True + self.q_gain = nn.Parameter(torch.full((num_heads,), qk_gain_init, dtype=torch.float32)) + rope_dim = rope_partial_dims if 0 < rope_partial_dims < self.head_dim else self.head_dim + self.rope_dim = rope_dim + self.rotary = Rotary(rope_dim, base=rope_base) + self.use_xsa = use_xsa + def forward(self, x: Tensor) -> Tensor: + bsz, seqlen, dim = x.shape + q = self.c_q(x).reshape(bsz, seqlen, self.num_heads, self.head_dim).transpose(1, 2) + k = self.c_k(x).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim).transpose(1, 2) + v = self.c_v(x).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim).transpose(1, 2) + q = F.rms_norm(q, (q.size(-1),)) + k = F.rms_norm(k, (k.size(-1),)) + cos, sin = self.rotary(seqlen, x.device, q.dtype) + if self.rope_dim < self.head_dim: + q_rope, q_pass = q[..., :self.rope_dim], q[..., self.rope_dim:] + k_rope, k_pass = k[..., :self.rope_dim], k[..., self.rope_dim:] + q_rope = apply_rotary_emb(q_rope, cos, sin) + k_rope = apply_rotary_emb(k_rope, cos, sin) + q = torch.cat([q_rope, q_pass], dim=-1) + k = torch.cat([k_rope, k_pass], dim=-1) + else: + q = apply_rotary_emb(q, cos, sin) + k = apply_rotary_emb(k, cos, sin) + q = q * self.q_gain.to(dtype=q.dtype)[None, :, None, None] + y = F.scaled_dot_product_attention( + q, + k, + v, + attn_mask=None, + is_causal=True, + enable_gqa=(self.num_kv_heads != self.num_heads), + ) + if self.use_xsa: + if self.num_kv_heads != self.num_heads: + repeats = self.num_heads // self.num_kv_heads + v_expanded = v.repeat_interleave(repeats, dim=1) + else: + v_expanded = v + y = y - v_expanded + y = y.transpose(1, 2).contiguous().reshape(bsz, seqlen, dim) + return self.proj(y) +class MLP(nn.Module): + def __init__(self, dim: int, mlp_mult: int): + super().__init__() + hidden = mlp_mult * dim + self.fc = CastedLinear(dim, hidden, bias=False) + self.proj = CastedLinear(hidden, dim, bias=False) + self.proj._zero_init = True + def forward(self, x: Tensor) -> Tensor: + x = F.leaky_relu(self.fc(x), negative_slope=0.5) + return self.proj(x.square()) +class MoEMLP(nn.Module): + def __init__(self, dim: int, mlp_mult: int, num_experts: int): + super().__init__() + self.num_experts = num_experts + expert_hidden = (mlp_mult * dim) // num_experts + self.experts = nn.ModuleList([ + nn.ModuleDict({ + "fc": CastedLinear(dim, expert_hidden, bias=False), + "proj": CastedLinear(expert_hidden, dim, bias=False), + }) + for _ in range(num_experts) + ]) + for e in self.experts: + e["proj"]._zero_init = True + self.router = nn.Linear(dim, num_experts, bias=False) + def forward(self, x: Tensor) -> Tensor: + bsz, seqlen, dim = x.shape + logits = self.router(x.detach()) + weights = torch.softmax(logits, dim=-1) + indices = weights.argmax(dim=-1) + out = torch.zeros_like(x) + for i, expert in enumerate(self.experts): + mask = (indices == i) + if not mask.any(): + continue + tokens = x[mask] + h = F.leaky_relu(expert["fc"](tokens), negative_slope=0.5) + h = expert["proj"](h.square()) + out[mask] = h + return out +class Block(nn.Module): + def __init__( + self, + dim: int, + num_heads: int, + num_kv_heads: int, + mlp_mult: int, + rope_base: float, + qk_gain_init: float, + use_smeargate: bool = False, + rope_partial_dims: int = 0, + use_xsa: bool = False, + ln_scale_factor: float = 1.0, + moe_num_experts: int = 0, + ): + super().__init__() + self.attn_norm = RMSNorm() + self.mlp_norm = RMSNorm() + self.attn = CausalSelfAttention( + dim, num_heads, num_kv_heads, rope_base, qk_gain_init, + rope_partial_dims=rope_partial_dims, + use_xsa=use_xsa, + ) + self.mlp = MoEMLP(dim, mlp_mult, moe_num_experts) if moe_num_experts > 1 else MLP(dim, mlp_mult) + self.attn_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + self.mlp_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + self.resid_mix = nn.Parameter(torch.stack((torch.ones(dim), torch.zeros(dim))).float()) + self.smear = SmearGate(dim) if use_smeargate else None + self.ln_scale_factor = ln_scale_factor + def forward(self, x: Tensor, x0: Tensor) -> Tensor: + mix = self.resid_mix.to(dtype=x.dtype) + x = mix[0][None, None, :] * x + mix[1][None, None, :] * x0 + normed = self.attn_norm(x) + if self.ln_scale_factor != 1.0: + normed = normed * self.ln_scale_factor + attn_out = self.attn(normed) + x = x + self.attn_scale.to(dtype=x.dtype)[None, None, :] * attn_out + mlp_normed = self.mlp_norm(x) + if self.ln_scale_factor != 1.0: + mlp_normed = mlp_normed * self.ln_scale_factor + x = x + self.mlp_scale.to(dtype=x.dtype)[None, None, :] * self.mlp(mlp_normed) + if self.smear is not None: + x = self.smear(x) + return x +class GPT(nn.Module): + def __init__( + self, + vocab_size: int, + num_layers: int, + model_dim: int, + num_heads: int, + num_kv_heads: int, + mlp_mult: int, + tie_embeddings: bool, + tied_embed_init_std: float, + logit_softcap: float, + rope_base: float, + qk_gain_init: float, + bigramhash_buckets: int = 0, + use_smeargate: bool = False, + unet_skips: bool = True, + rope_partial_dims: int = 0, + ln_scale: bool = False, + xsa_layers: int = 0, + recurrence_repeats: int = 0, + recurrence_unique_head: int = 2, + recurrence_unique_tail: int = 2, + moe_num_experts: int = 0, + ): + super().__init__() + if logit_softcap <= 0.0: + raise ValueError(f"logit_softcap must be positive, got {logit_softcap}") + self.tie_embeddings = tie_embeddings + self.tied_embed_init_std = tied_embed_init_std + self.logit_softcap = logit_softcap + self.tok_emb = nn.Embedding(vocab_size, model_dim) + self.bigram_hash = BigramHash(bigramhash_buckets, model_dim) if bigramhash_buckets > 0 else None + self.recurrence_repeats = recurrence_repeats + if recurrence_repeats > 0: + self.num_encoder_layers = 0 + self.num_decoder_layers = 0 + self.num_skip_weights = 0 + self.skip_weights = nn.Parameter(torch.zeros(0, dtype=torch.float32)) + n_head = recurrence_unique_head + n_tail = recurrence_unique_tail + effective_layers = n_head + recurrence_repeats + n_tail + def make_block(layer_idx: int, eff_total: int) -> Block: + xsa_start_idx = eff_total - xsa_layers if xsa_layers > 0 else eff_total + return Block( + model_dim, num_heads, num_kv_heads, mlp_mult, rope_base, qk_gain_init, + use_smeargate=use_smeargate, + rope_partial_dims=rope_partial_dims, + use_xsa=(layer_idx >= xsa_start_idx), + ln_scale_factor=(1.0 / math.sqrt(layer_idx + 1)) if ln_scale else 1.0, + moe_num_experts=moe_num_experts, + ) + self.head_blocks = nn.ModuleList([make_block(i, effective_layers) for i in range(n_head)]) + self.shared_block = make_block(n_head, effective_layers) + self.recurrence_gate = nn.Parameter(torch.full((recurrence_repeats,), 0.5, dtype=torch.float32)) + self.tail_blocks = nn.ModuleList([ + make_block(n_head + recurrence_repeats + i, effective_layers) for i in range(n_tail) + ]) + self.blocks = nn.ModuleList() + else: + if unet_skips: + self.num_encoder_layers = num_layers // 2 + self.num_decoder_layers = num_layers - self.num_encoder_layers + self.num_skip_weights = min(self.num_encoder_layers, self.num_decoder_layers) + else: + self.num_encoder_layers = num_layers + self.num_decoder_layers = 0 + self.num_skip_weights = 0 + self.skip_weights = nn.Parameter(torch.ones(self.num_skip_weights, model_dim, dtype=torch.float32)) + xsa_start = num_layers - xsa_layers if xsa_layers > 0 else num_layers + self.blocks = nn.ModuleList( + [ + Block( + model_dim, num_heads, num_kv_heads, mlp_mult, rope_base, qk_gain_init, + use_smeargate=use_smeargate, + rope_partial_dims=rope_partial_dims, + use_xsa=(i >= xsa_start), + ln_scale_factor=(1.0 / math.sqrt(i + 1)) if ln_scale else 1.0, + moe_num_experts=moe_num_experts, + ) + for i in range(num_layers) + ] + ) + self.head_blocks = nn.ModuleList() + self.shared_block = None + self.recurrence_gate = None + self.tail_blocks = nn.ModuleList() + self.final_norm = RMSNorm() + self.lm_head = None if tie_embeddings else CastedLinear(model_dim, vocab_size, bias=False) + if self.lm_head is not None: + self.lm_head._zero_init = True + self._init_weights() + def _init_weights(self) -> None: + if self.tie_embeddings: + nn.init.normal_(self.tok_emb.weight, mean=0.0, std=self.tied_embed_init_std) + for module in self.modules(): + if isinstance(module, nn.Linear) and getattr(module, "_zero_init", False): + nn.init.zeros_(module.weight) + def forward(self, input_ids: Tensor, target_ids: Tensor) -> Tensor: + x = self.tok_emb(input_ids) + if self.bigram_hash is not None: + x = x + self.bigram_hash(input_ids) + x = F.rms_norm(x, (x.size(-1),)) + x0 = x + if self.recurrence_repeats > 0: + for block in self.head_blocks: + x = block(x, x0) + for r in range(self.recurrence_repeats): + gate = torch.sigmoid(self.recurrence_gate[r]).to(dtype=x.dtype) + residual = x + x = self.shared_block(x, x0) + x = gate * x + (1.0 - gate) * residual + for block in self.tail_blocks: + x = block(x, x0) + else: + skips: list[Tensor] = [] + for i in range(self.num_encoder_layers): + x = self.blocks[i](x, x0) + skips.append(x) + for i in range(self.num_decoder_layers): + if skips: + x = x + self.skip_weights[i].to(dtype=x.dtype)[None, None, :] * skips.pop() + x = self.blocks[self.num_encoder_layers + i](x, x0) + x = self.final_norm(x).reshape(-1, x.size(-1)) + targets = target_ids.reshape(-1) + if self.tie_embeddings: + logits_proj = F.linear(x, self.tok_emb.weight) + else: + if self.lm_head is None: + raise RuntimeError("lm_head is required when tie_embeddings=False") + logits_proj = self.lm_head(x) + logits = self.logit_softcap * torch.tanh(logits_proj / self.logit_softcap) + if getattr(self, '_return_logits', False): + return logits.float().reshape(input_ids.shape[0], input_ids.shape[1], -1) + return F.cross_entropy(logits.float(), targets, reduction="mean") + def forward_logits(self, input_ids: Tensor) -> Tensor: + x = self.tok_emb(input_ids) + if self.bigram_hash is not None: + x = x + self.bigram_hash(input_ids) + x = F.rms_norm(x, (x.size(-1),)) + x0 = x + if self.recurrence_repeats > 0: + for block in self.head_blocks: + x = block(x, x0) + for r in range(self.recurrence_repeats): + gate = torch.sigmoid(self.recurrence_gate[r]).to(dtype=x.dtype) + residual = x + x = self.shared_block(x, x0) + x = gate * x + (1.0 - gate) * residual + for block in self.tail_blocks: + x = block(x, x0) + else: + skips: list[Tensor] = [] + for i in range(self.num_encoder_layers): + x = self.blocks[i](x, x0) + skips.append(x) + for i in range(self.num_decoder_layers): + if skips: + x = x + self.skip_weights[i].to(dtype=x.dtype)[None, None, :] * skips.pop() + x = self.blocks[self.num_encoder_layers + i](x, x0) + x = self.final_norm(x) + if self.tie_embeddings: + logits_proj = F.linear(x, self.tok_emb.weight) + else: + if self.lm_head is None: + raise RuntimeError("lm_head is required when tie_embeddings=False") + logits_proj = self.lm_head(x) + return self.logit_softcap * torch.tanh(logits_proj / self.logit_softcap) +def main() -> None: + global zeropower_via_newtonschulz5 + code = Path(__file__).read_text(encoding="utf-8") + args = Hyperparameters() + if not int(os.environ.get("TORCHDYNAMO_DISABLE", "0")): + zeropower_via_newtonschulz5 = torch.compile(zeropower_via_newtonschulz5) + distributed = "RANK" in os.environ and "WORLD_SIZE" in os.environ + rank = int(os.environ.get("RANK", "0")) + world_size = int(os.environ.get("WORLD_SIZE", "1")) + local_rank = int(os.environ.get("LOCAL_RANK", "0")) + if world_size <= 0: + raise ValueError(f"WORLD_SIZE must be positive, got {world_size}") + if 8 % world_size != 0: + raise ValueError(f"WORLD_SIZE={world_size} must divide 8 so grad_accum_steps stays integral") + grad_accum_steps = 8 // world_size + grad_scale = 1.0 / grad_accum_steps + if not torch.cuda.is_available(): + raise RuntimeError("CUDA is required") + device = torch.device("cuda", local_rank) + torch.cuda.set_device(device) + if distributed: + dist.init_process_group(backend="nccl", device_id=device) + dist.barrier() + master_process = rank == 0 + torch.backends.cuda.matmul.allow_tf32 = True + torch.backends.cudnn.allow_tf32 = True + from torch.backends.cuda import enable_cudnn_sdp, enable_flash_sdp, enable_math_sdp, enable_mem_efficient_sdp + _use_flash = sys.platform != "win32" + enable_cudnn_sdp(False) + enable_flash_sdp(_use_flash) + enable_mem_efficient_sdp(False) + enable_math_sdp(not _use_flash) + # Create per-run output directory so artifacts/logs never collide + run_dir = Path(os.environ.get("RUN_OUTPUT_DIR", f"runs/{args.run_id}")) + if master_process: + run_dir.mkdir(parents=True, exist_ok=True) + # Backward-compat: also keep logs/ symlink/copy + os.makedirs("logs", exist_ok=True) + logfile = None + if master_process: + logfile = str(run_dir / "log.txt") + # Also write to legacy logs/ path for compatibility + legacy_logfile = f"logs/{args.run_id}.txt" + if not Path(legacy_logfile).exists(): + try: + os.symlink(os.path.abspath(logfile), legacy_logfile) + except OSError: + pass # Windows or cross-device; will write to both + print(logfile) + def log0(msg: str, console: bool = True) -> None: + if not master_process: + return + if console: + print(msg) + if logfile is not None: + with open(logfile, "a", encoding="utf-8") as f: + print(msg, file=f) + log0(code, console=False) + log0("=" * 100, console=False) + log0(f"Running Python {sys.version}", console=False) + log0(f"Running PyTorch {torch.__version__}", console=False) + log0( + subprocess.run(["nvidia-smi"], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True, check=False).stdout, + console=False, + ) + log0("=" * 100, console=False) + random.seed(args.seed) + np.random.seed(args.seed) + torch.manual_seed(args.seed) + torch.cuda.manual_seed_all(args.seed) + if not args.tokenizer_path.endswith(".model"): + raise ValueError(f"Script only setup for SentencePiece .model file: {args.tokenizer_path}") + sp = spm.SentencePieceProcessor(model_file=args.tokenizer_path) + if int(sp.vocab_size()) != args.vocab_size: + raise ValueError( + f"VOCAB_SIZE={args.vocab_size} does not match tokenizer vocab_size={int(sp.vocab_size())}" + ) + dataset_dir = Path(args.data_path).resolve() + actual_train_files = len(list(dataset_dir.glob("fineweb_train_*.bin"))) + val_tokens = load_validation_tokens(args.val_files, args.train_seq_len) + base_bytes_lut, has_leading_space_lut, is_boundary_token_lut = build_sentencepiece_luts( + sp, args.vocab_size, device + ) + log0(f"val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path={args.tokenizer_path}") + log0(f"train_loader:dataset:{dataset_dir.name} train_shards:{actual_train_files}") + log0(f"val_loader:shards pattern={args.val_files} tokens:{val_tokens.numel() - 1}") + CastedLinear._int6_qat = args.int6_qat + base_model = GPT( + vocab_size=args.vocab_size, + num_layers=args.num_layers, + model_dim=args.model_dim, + num_heads=args.num_heads, + num_kv_heads=args.num_kv_heads, + mlp_mult=args.mlp_mult, + tie_embeddings=args.tie_embeddings, + tied_embed_init_std=args.tied_embed_init_std, + logit_softcap=args.logit_softcap, + rope_base=args.rope_base, + qk_gain_init=args.qk_gain_init, + bigramhash_buckets=args.bigramhash_buckets, + use_smeargate=args.smeargate, + unet_skips=args.unet_skips, + rope_partial_dims=args.rope_partial_dims, + ln_scale=args.ln_scale, + xsa_layers=args.xsa_layers, + recurrence_repeats=args.recurrence_repeats, + recurrence_unique_head=args.recurrence_unique_head, + recurrence_unique_tail=args.recurrence_unique_tail, + moe_num_experts=args.moe_num_experts, + ).to(device).bfloat16() + for module in base_model.modules(): + if isinstance(module, CastedLinear): + module.float() + restore_low_dim_params_to_fp32(base_model) + if int(os.environ.get("TORCHDYNAMO_DISABLE", "0")): + compiled_model = base_model + else: + compiled_model = torch.compile(base_model, dynamic=False, fullgraph=True) + model: nn.Module = DDP(compiled_model, device_ids=[local_rank], broadcast_buffers=False) if distributed else compiled_model + if args.recurrence_repeats > 0: + block_named_params = ( + list(base_model.head_blocks.named_parameters()) + + list(base_model.shared_block.named_parameters()) + + list(base_model.tail_blocks.named_parameters()) + ) + if base_model.recurrence_gate is not None: + block_named_params.append(("recurrence_gate", base_model.recurrence_gate)) + else: + block_named_params = list(base_model.blocks.named_parameters()) + matrix_params = [ + p + for name, p in block_named_params + if p.ndim == 2 and not any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) + ] + scalar_params = [ + p + for name, p in block_named_params + if p.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) + ] + if base_model.skip_weights.numel() > 0: + scalar_params.append(base_model.skip_weights) + token_lr = args.tied_embed_lr if args.tie_embeddings else args.embed_lr + embed_params: list[nn.Parameter] = [base_model.tok_emb.weight] + if base_model.bigram_hash is not None: + embed_params.append(base_model.bigram_hash.emb.weight) + optimizer_tok = torch.optim.Adam( + [{"params": embed_params, "lr": token_lr, "base_lr": token_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + fused=True, + ) + optimizer_muon = Muon( + matrix_params, + lr=args.matrix_lr, + momentum=args.muon_momentum, + backend_steps=args.muon_backend_steps, + ) + for group in optimizer_muon.param_groups: + group["base_lr"] = args.matrix_lr + optimizer_scalar = torch.optim.Adam( + [{"params": scalar_params, "lr": args.scalar_lr, "base_lr": args.scalar_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + fused=True, + ) + optimizers: list[torch.optim.Optimizer] = [optimizer_tok, optimizer_muon, optimizer_scalar] + if base_model.lm_head is not None: + optimizer_head = torch.optim.Adam( + [{"params": [base_model.lm_head.weight], "lr": args.head_lr, "base_lr": args.head_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + fused=True, + ) + optimizers.insert(1, optimizer_head) + n_params = sum(p.numel() for p in base_model.parameters()) + log0(f"model_params:{n_params}") + log0(f"world_size:{world_size} grad_accum_steps:{grad_accum_steps}") + log0(f"sdp_backends:cudnn=False flash={_use_flash} mem_efficient=False math={not _use_flash}") + log0(f"attention_mode:gqa num_heads:{args.num_heads} num_kv_heads:{args.num_kv_heads}") + log0( + f"tie_embeddings:{args.tie_embeddings} embed_lr:{token_lr} " + f"head_lr:{args.head_lr if base_model.lm_head is not None else 0.0} " + f"matrix_lr:{args.matrix_lr} scalar_lr:{args.scalar_lr}" + ) + log0( + f"train_batch_tokens:{args.train_batch_tokens} train_seq_len:{args.train_seq_len} " + f"iterations:{args.iterations} warmup_steps:{args.warmup_steps} " + f"max_wallclock_seconds:{args.max_wallclock_seconds:.3f}" + ) + log0(f"seed:{args.seed}") + log0( + f"meta_baseline: int6_qat:{args.int6_qat} bigramhash:{args.bigramhash_buckets} " + f"smeargate:{args.smeargate} ema_decay:{args.ema_decay} unet_skips:{args.unet_skips} " + f"warmdown_iters:{args.warmdown_iters} compression:{'zstd-22' if _HAS_ZSTD else 'zlib-9'}" + ) + log0( + f"strong_gains: leaky_relu_sq:True rope_partial_dims:{args.rope_partial_dims} " + f"ln_scale:{args.ln_scale} xsa_layers:{args.xsa_layers}" + ) + if args.recurrence_repeats > 0: + eff = args.recurrence_unique_head + args.recurrence_repeats + args.recurrence_unique_tail + log0( + f"depth_recurrence: head={args.recurrence_unique_head} repeats={args.recurrence_repeats} " + f"tail={args.recurrence_unique_tail} effective_layers={eff}" + ) + train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) + def zero_grad_all() -> None: + for opt in optimizers: + opt.zero_grad(set_to_none=True) + max_wallclock_ms = 1000.0 * args.max_wallclock_seconds if args.max_wallclock_seconds > 0 else None + def lr_mul(step: int, elapsed_ms: float) -> float: + if args.warmdown_iters <= 0: + return 1.0 + if max_wallclock_ms is None: + warmdown_start = max(args.iterations - args.warmdown_iters, 0) + return max((args.iterations - step) / max(args.warmdown_iters, 1), 0.0) if warmdown_start <= step < args.iterations else 1.0 + step_ms = elapsed_ms / max(step, 1) + warmdown_ms = args.warmdown_iters * step_ms + remaining_ms = max(max_wallclock_ms - elapsed_ms, 0.0) + return remaining_ms / max(warmdown_ms, 1e-9) if remaining_ms <= warmdown_ms else 1.0 + if args.warmup_steps > 0: + initial_model_state = {name: tensor.detach().cpu().clone() for name, tensor in base_model.state_dict().items()} + initial_optimizer_states = [copy.deepcopy(opt.state_dict()) for opt in optimizers] + model.train() + for warmup_step in range(args.warmup_steps): + zero_grad_all() + for micro_step in range(grad_accum_steps): + if distributed: + model.require_backward_grad_sync = micro_step == grad_accum_steps - 1 + x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + warmup_loss = model(x, y) + (warmup_loss * grad_scale).backward() + for opt in optimizers: + opt.step() + zero_grad_all() + if args.warmup_steps <= 20 or (warmup_step + 1) % 10 == 0 or warmup_step + 1 == args.warmup_steps: + log0(f"warmup_step:{warmup_step + 1}/{args.warmup_steps}") + base_model.load_state_dict(initial_model_state, strict=True) + for opt, state in zip(optimizers, initial_optimizer_states, strict=True): + opt.load_state_dict(state) + zero_grad_all() + if distributed: + model.require_backward_grad_sync = True + train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) + ema_state: dict[str, Tensor] | None = None + if args.ema_decay > 0: + ema_state = {name: param.data.clone() for name, param in base_model.named_parameters()} + log0(f"ema:enabled decay={args.ema_decay}") + training_time_ms = 0.0 + stop_after_step: int | None = None + torch.cuda.synchronize() + t0 = time.perf_counter() + step = 0 + while True: + last_step = step == args.iterations or (stop_after_step is not None and step >= stop_after_step) + should_validate = last_step or (args.val_loss_every > 0 and step % args.val_loss_every == 0) + if should_validate: + torch.cuda.synchronize() + training_time_ms += 1000.0 * (time.perf_counter() - t0) + ema_backup: dict[str, Tensor] | None = None + if ema_state is not None: + ema_backup = {} + with torch.no_grad(): + for name, p in base_model.named_parameters(): + ema_backup[name] = p.data.clone() + p.data.copy_(ema_state[name]) + val_loss, val_bpb = eval_val( + args, + model, + rank, + world_size, + device, + grad_accum_steps, + val_tokens, + base_bytes_lut, + has_leading_space_lut, + is_boundary_token_lut, + ) + if ema_backup is not None: + with torch.no_grad(): + for name, p in base_model.named_parameters(): + p.data.copy_(ema_backup[name]) + del ema_backup + log0( + f"step:{step}/{args.iterations} val_loss:{val_loss:.4f} val_bpb:{val_bpb:.4f} " + f"train_time:{training_time_ms:.0f}ms step_avg:{training_time_ms / max(step, 1):.2f}ms" + ) + torch.cuda.synchronize() + t0 = time.perf_counter() + if last_step: + if stop_after_step is not None and step < args.iterations: + log0( + f"stopping_early: wallclock_cap train_time:{training_time_ms:.0f}ms " + f"step:{step}/{args.iterations}" + ) + break + elapsed_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) + scale = lr_mul(step, elapsed_ms) + zero_grad_all() + train_loss = torch.zeros((), device=device) + for micro_step in range(grad_accum_steps): + if distributed: + model.require_backward_grad_sync = micro_step == grad_accum_steps - 1 + x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + loss = model(x, y) + train_loss += loss.detach() + (loss * grad_scale).backward() + train_loss /= grad_accum_steps + frac = min(step / args.muon_momentum_warmup_steps, 1.0) if args.muon_momentum_warmup_steps > 0 else 1.0 + muon_momentum = (1 - frac) * args.muon_momentum_warmup_start + frac * args.muon_momentum + for group in optimizer_muon.param_groups: + group["momentum"] = muon_momentum + for opt in optimizers: + for group in opt.param_groups: + group["lr"] = group["base_lr"] * scale + if args.grad_clip_norm > 0: + torch.nn.utils.clip_grad_norm_(base_model.parameters(), args.grad_clip_norm) + for opt in optimizers: + opt.step() + zero_grad_all() + if ema_state is not None: + with torch.no_grad(): + decay = args.ema_decay + for name, p in base_model.named_parameters(): + ema_state[name].mul_(decay).add_(p.data, alpha=1.0 - decay) + step += 1 + if args.checkpoint_every > 0 and step % args.checkpoint_every == 0 and master_process: + ckpt_dir = run_dir / "checkpoints" + ckpt_dir.mkdir(exist_ok=True) + ckpt_path = ckpt_dir / f"ckpt_step{step}.pt" + ckpt_data = {"step": step, "model": base_model.state_dict()} + if ema_state is not None: + ckpt_data["ema"] = {k: v.cpu() for k, v in ema_state.items()} + torch.save(ckpt_data, ckpt_path) + ckpts = sorted(ckpt_dir.glob("ckpt_step*.pt"), key=lambda p: p.stat().st_mtime) + for old in ckpts[:-2]: + old.unlink(missing_ok=True) + log0(f"checkpoint:saved {ckpt_path} (kept {min(len(ckpts), 2)} latest)") + approx_training_time_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) + should_log_train = ( + args.train_log_every > 0 + and (step <= 10 or step % args.train_log_every == 0 or stop_after_step is not None) + ) + if should_log_train: + log0( + f"step:{step}/{args.iterations} train_loss:{train_loss.item():.4f} " + f"train_time:{approx_training_time_ms:.0f}ms step_avg:{approx_training_time_ms / step:.2f}ms" + ) + reached_cap = max_wallclock_ms is not None and approx_training_time_ms >= max_wallclock_ms + if distributed and max_wallclock_ms is not None: + reached_cap_tensor = torch.tensor(int(reached_cap), device=device) + dist.all_reduce(reached_cap_tensor, op=dist.ReduceOp.MAX) + reached_cap = bool(reached_cap_tensor.item()) + if stop_after_step is None and reached_cap: + stop_after_step = step + log0( + f"peak memory allocated: {torch.cuda.max_memory_allocated() // 1024 // 1024} MiB " + f"reserved: {torch.cuda.max_memory_reserved() // 1024 // 1024} MiB" + ) + if ema_state is not None: + with torch.no_grad(): + for name, p in base_model.named_parameters(): + p.data.copy_(ema_state[name]) + log0("ema:applied EMA weights for export") + # PRE-QUANT TTT: Run TTT on FP32 EMA weights before quantization + if args.ttt_enabled and args.ttt_prequant: + log0("prequant_ttt:starting TTT on FP32 EMA weights before quantization") + torch.cuda.synchronize() + t_prequant_ttt = time.perf_counter() + if args.ttt_lora_enabled: + prequant_ttt_loss, prequant_ttt_bpb = eval_val_sliding_lora_ttt( + args, base_model, rank, world_size, device, + val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, + stride=args.eval_stride if args.eval_stride > 0 else 64, + log0=log0, + ) + prequant_ttt_label = "prequant_lora_ttt" + else: + prequant_ttt_loss, prequant_ttt_bpb = eval_val_sliding_ttt( + args, base_model, rank, world_size, device, + val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, + stride=args.eval_stride if args.eval_stride > 0 else 64, + log0=log0, + ) + prequant_ttt_label = "prequant_ttt" + torch.cuda.synchronize() + log0( + f"{prequant_ttt_label} val_loss:{prequant_ttt_loss:.4f} val_bpb:{prequant_ttt_bpb:.4f} " + f"eval_time:{1000.0 * (time.perf_counter() - t_prequant_ttt):.0f}ms " + f"optimizer:{args.ttt_optimizer}" + ) + log0(f"{prequant_ttt_label}_exact val_loss:{prequant_ttt_loss:.8f} val_bpb:{prequant_ttt_bpb:.8f}") + fp32_path = str(run_dir / "final_model.pt") + if master_process: + torch.save(base_model.state_dict(), fp32_path) + model_bytes = os.path.getsize(fp32_path) + code_bytes = len(code.encode("utf-8")) + log0(f"Serialized model: {model_bytes} bytes → {fp32_path}") + log0(f"Code size: {code_bytes} bytes") + log0(f"Total submission size: {model_bytes + code_bytes} bytes") + quant_label = f"int{_QUANT_BITS}" + compress_label = "zstd" if _HAS_ZSTD else "zlib" + log0(f"quantization:start {quant_label} distributing across {world_size} GPUs") + torch.cuda.synchronize() + t_quant = time.perf_counter() + quant_obj, quant_stats = quantize_state_dict_int8( + base_model.state_dict(), gpu_device=device, rank=rank, world_size=world_size, + ) + torch.cuda.synchronize() + log0(f"quantization:done {quant_label} time={1000.0 * (time.perf_counter() - t_quant):.0f}ms") + quant_buf = io.BytesIO() + torch.save(quant_obj, quant_buf) + quant_raw = quant_buf.getvalue() + if _HAS_ZSTD: + compressor = zstd.ZstdCompressor(level=22, threads=-1) # -1 = use all CPU cores + quant_blob = compressor.compress(quant_raw) + else: + quant_blob = zlib.compress(quant_raw, level=9) + quant_raw_bytes = len(quant_raw) + artifact_name = str(run_dir / f"final_model.{quant_label}.ptz") + if master_process: + with open(artifact_name, "wb") as f: + f.write(quant_blob) + quant_file_bytes = os.path.getsize(artifact_name) + code_bytes = len(code.encode("utf-8")) + ratio = quant_stats["baseline_tensor_bytes"] / max(quant_stats["int8_payload_bytes"], 1) + log0( + f"Serialized model {quant_label}+{compress_label}: {quant_file_bytes} bytes " + f"(payload:{quant_stats['int8_payload_bytes']} raw_torch:{quant_raw_bytes} payload_ratio:{ratio:.2f}x)" + ) + log0(f"Total submission size {quant_label}+{compress_label}: {quant_file_bytes + code_bytes} bytes") + log0(f"artifact_path:{artifact_name}") + if distributed: + dist.barrier() + with open(artifact_name, "rb") as f: + quant_blob_disk = f.read() + if _HAS_ZSTD: + decompressor = zstd.ZstdDecompressor() + quant_state = torch.load(io.BytesIO(decompressor.decompress(quant_blob_disk)), map_location="cpu") + else: + quant_state = torch.load(io.BytesIO(zlib.decompress(quant_blob_disk)), map_location="cpu") + base_model.load_state_dict(dequantize_state_dict_int8(quant_state), strict=True) + torch.cuda.synchronize() + t_qeval = time.perf_counter() + q_val_loss, q_val_bpb = eval_val( + args, + model, + rank, + world_size, + device, + grad_accum_steps, + val_tokens, + base_bytes_lut, + has_leading_space_lut, + is_boundary_token_lut, + ) + torch.cuda.synchronize() + log0( + f"final_{quant_label}_{compress_label}_roundtrip val_loss:{q_val_loss:.4f} val_bpb:{q_val_bpb:.4f} " + f"eval_time:{1000.0 * (time.perf_counter() - t_qeval):.0f}ms" + ) + log0(f"final_{quant_label}_{compress_label}_roundtrip_exact val_loss:{q_val_loss:.8f} val_bpb:{q_val_bpb:.8f}") + if args.eval_stride > 0 and args.eval_stride < args.train_seq_len: + torch.cuda.synchronize() + t_slide = time.perf_counter() + sw_val_loss, sw_val_bpb = eval_val_sliding( + args, base_model, rank, world_size, device, + val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, + stride=args.eval_stride, + ) + torch.cuda.synchronize() + log0( + f"final_sliding_window val_loss:{sw_val_loss:.4f} val_bpb:{sw_val_bpb:.4f} " + f"stride:{args.eval_stride} eval_time:{1000.0 * (time.perf_counter() - t_slide):.0f}ms" + ) + log0(f"final_sliding_window_exact val_loss:{sw_val_loss:.8f} val_bpb:{sw_val_bpb:.8f}") + if args.ttt_enabled and not args.ttt_prequant: + # Only run post-quant TTT if pre-quant was NOT used + torch.cuda.synchronize() + t_ttt = time.perf_counter() + + # Choose LoRA TTT or full TTT based on TTT_LORA_ENABLED flag + if args.ttt_lora_enabled: + ttt_loss, ttt_bpb = eval_val_sliding_lora_ttt( + args, base_model, rank, world_size, device, + val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, + stride=args.eval_stride if args.eval_stride > 0 else 64, + log0=log0, + ) + ttt_label = "legal_lora_ttt" + else: + ttt_loss, ttt_bpb = eval_val_sliding_ttt( + args, base_model, rank, world_size, device, + val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, + stride=args.eval_stride if args.eval_stride > 0 else 64, + log0=log0, + ) + ttt_label = "legal_ttt" + + torch.cuda.synchronize() + log0( + f"{ttt_label} val_loss:{ttt_loss:.4f} val_bpb:{ttt_bpb:.4f} " + f"eval_time:{1000.0 * (time.perf_counter() - t_ttt):.0f}ms" + ) + log0(f"{ttt_label}_exact val_loss:{ttt_loss:.8f} val_bpb:{ttt_bpb:.8f}") + if int(os.environ.get("NGRAM_EVAL", "0")): + ngram_max_order = int(os.environ.get("NGRAM_MAX_ORDER", "9")) + log0(f"ngram_eval:starting max_order={ngram_max_order}") + torch.cuda.synchronize() + t_ngram = time.perf_counter() + try: + from ngram_eval import eval_val_ngram + ng_val_loss, ng_val_bpb = eval_val_ngram( + args, base_model, model, rank, world_size, device, grad_accum_steps, + val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, + max_order=ngram_max_order, log_fn=log0, + ) + torch.cuda.synchronize() + log0( + f"ngram_eval_result val_loss:{ng_val_loss:.4f} val_bpb:{ng_val_bpb:.4f} " + f"eval_time:{1000.0 * (time.perf_counter() - t_ngram):.0f}ms" + ) + log0(f"ngram_eval_result_exact val_loss:{ng_val_loss:.8f} val_bpb:{ng_val_bpb:.8f}") + except Exception as e: + log0(f"ngram_eval:FAILED {e}") + if distributed: + dist.destroy_process_group() +if __name__ == "__main__": + main() diff --git a/train_gpt_phase4_polar_ns.py b/train_gpt_phase4_polar_ns.py new file mode 100644 index 0000000000..1a57162189 --- /dev/null +++ b/train_gpt_phase4_polar_ns.py @@ -0,0 +1,1898 @@ +from __future__ import annotations +import copy +import glob +import io +import math +import os +import random +import subprocess +import sys +import time +import uuid +import zlib +from pathlib import Path +try: + import zstandard as zstd + _HAS_ZSTD = True +except ImportError: + _HAS_ZSTD = False +import numpy as np +import sentencepiece as spm +import torch +import torch.distributed as dist +import torch.nn.functional as F +from torch import Tensor, nn +from torch.nn.parallel import DistributedDataParallel as DDP +class Hyperparameters: + data_path = os.environ.get("DATA_PATH", "./data/datasets/fineweb10B_sp1024") + train_files = os.path.join(data_path, "fineweb_train_*.bin") + val_files = os.path.join(data_path, "fineweb_val_*.bin") + tokenizer_path = os.environ.get("TOKENIZER_PATH", "./data/tokenizers/fineweb_1024_bpe.model") + run_id = os.environ.get("RUN_ID", str(uuid.uuid4())) + seed = int(os.environ.get("SEED", 1337)) + val_batch_size = int(os.environ.get("VAL_BATCH_SIZE", 524_288)) + val_loss_every = int(os.environ.get("VAL_LOSS_EVERY", 1000)) + train_log_every = int(os.environ.get("TRAIN_LOG_EVERY", 200)) + val_max_tokens = int(os.environ.get("VAL_MAX_TOKENS", 0)) + checkpoint_every = int(os.environ.get("CHECKPOINT_EVERY", 0)) + iterations = int(os.environ.get("ITERATIONS", 20000)) + warmdown_iters = int(os.environ.get("WARMDOWN_ITERS", 3500)) + warmup_steps = int(os.environ.get("WARMUP_STEPS", 20)) + train_batch_tokens = int(os.environ.get("TRAIN_BATCH_TOKENS", 524_288)) + train_seq_len = int(os.environ.get("TRAIN_SEQ_LEN", 1024)) + max_wallclock_seconds = float(os.environ.get("MAX_WALLCLOCK_SECONDS", 600.0)) + qk_gain_init = float(os.environ.get("QK_GAIN_INIT", 1.5)) + vocab_size = int(os.environ.get("VOCAB_SIZE", 1024)) + num_layers = int(os.environ.get("NUM_LAYERS", 11)) + num_kv_heads = int(os.environ.get("NUM_KV_HEADS", 4)) + model_dim = int(os.environ.get("MODEL_DIM", 512)) + num_heads = int(os.environ.get("NUM_HEADS", 8)) + mlp_mult = int(os.environ.get("MLP_MULT", 3)) + tie_embeddings = bool(int(os.environ.get("TIE_EMBEDDINGS", "1"))) + rope_base = float(os.environ.get("ROPE_BASE", 10000.0)) + logit_softcap = float(os.environ.get("LOGIT_SOFTCAP", 30.0)) + embed_lr = float(os.environ.get("EMBED_LR", 0.6)) + head_lr = float(os.environ.get("HEAD_LR", 0.008)) + tied_embed_lr = float(os.environ.get("TIED_EMBED_LR", 0.05)) + tied_embed_init_std = float(os.environ.get("TIED_EMBED_INIT_STD", 0.005)) + matrix_lr = float(os.environ.get("MATRIX_LR", 0.04)) + scalar_lr = float(os.environ.get("SCALAR_LR", 0.04)) + muon_momentum = float(os.environ.get("MUON_MOMENTUM", 0.95)) + polar_express = bool(int(os.environ.get("POLAR_EXPRESS", "1"))) # default ON + muon_backend_steps = int(os.environ.get("MUON_BACKEND_STEPS", 4)) # was 5, Polar Express uses 4 + muon_momentum_warmup_start = float(os.environ.get("MUON_MOMENTUM_WARMUP_START", 0.85)) + muon_momentum_warmup_steps = int(os.environ.get("MUON_MOMENTUM_WARMUP_STEPS", 500)) + beta1 = float(os.environ.get("BETA1", 0.9)) + beta2 = float(os.environ.get("BETA2", 0.95)) + adam_eps = float(os.environ.get("ADAM_EPS", 1e-8)) + grad_clip_norm = float(os.environ.get("GRAD_CLIP_NORM", 0.0)) + bigramhash_buckets = int(os.environ.get("BIGRAMHASH_BUCKETS", 4096)) + ema_decay = float(os.environ.get("EMA_DECAY", 0.997)) + smeargate = bool(int(os.environ.get("SMEARGATE", "1"))) + unet_skips = bool(int(os.environ.get("UNET_SKIPS", "1"))) + int6_qat = bool(int(os.environ.get("INT6_QAT", "1"))) + rope_partial_dims = int(os.environ.get("ROPE_PARTIAL_DIMS", 16)) + ln_scale = bool(int(os.environ.get("LN_SCALE", "1"))) + xsa_layers = int(os.environ.get("XSA_LAYERS", 4)) + eval_stride = int(os.environ.get("EVAL_STRIDE", 0)) + ttt_enabled = bool(int(os.environ.get("TTT_ENABLED", "0"))) + ttt_lr = float(os.environ.get("TTT_LR", 0.002)) + ttt_epochs = int(os.environ.get("TTT_EPOCHS", 3)) + ttt_chunk_tokens = int(os.environ.get("TTT_CHUNK_TOKENS", 32768)) + ttt_freeze_blocks = int(os.environ.get("TTT_FREEZE_BLOCKS", 0)) + ttt_momentum = float(os.environ.get("TTT_MOMENTUM", 0.9)) + ttt_batch_seqs = int(os.environ.get("TTT_BATCH_SEQS", 32)) + ttt_grad_clip = float(os.environ.get("TTT_GRAD_CLIP", 1.0)) + ttt_lora_enabled = bool(int(os.environ.get("TTT_LORA_ENABLED", "0"))) + ttt_lora_rank = int(os.environ.get("TTT_LORA_RANK", 8)) + ttt_lora_lr = float(os.environ.get("TTT_LORA_LR", 0.001)) + recurrence_repeats = int(os.environ.get("RECURRENCE_REPEATS", 0)) + recurrence_unique_head = int(os.environ.get("RECURRENCE_UNIQUE_HEAD", 2)) + recurrence_unique_tail = int(os.environ.get("RECURRENCE_UNIQUE_TAIL", 2)) + moe_num_experts = int(os.environ.get("MOE_NUM_EXPERTS", 0)) +_PE_COEFFS = [ + (8.156554524902461, -22.48329292557795, 15.878769915207462), + (4.042929935166739, -2.808917465908714, 0.5000178451051316), + (3.8916678022926607, -2.772484153217685, 0.5060648178503393), + (3.285753657755655, -2.3681294933425376, 0.46449024233003106), + (2.3465413258596377, -1.7097828382687081, 0.42323551169305323), +] + +def zeropower_via_newtonschulz5(G: Tensor, steps: int = 10, eps: float = 1e-7, polar_express: bool = False) -> Tensor: + if polar_express: + X = G.bfloat16() + X /= X.norm() + eps + transposed = G.size(0) > G.size(1) + if transposed: + X = X.T + for i in range(min(steps, len(_PE_COEFFS))): + a, b, c = _PE_COEFFS[i] + A = X @ X.T + X = a * X + b * (A @ X) + c * (A @ (A @ X)) + if transposed: + X = X.T + return X + else: + # Original fixed-coefficient version + a, b, c = (3.4445, -4.7750, 2.0315) + X = G.bfloat16() + X /= X.norm() + eps + transposed = G.size(0) > G.size(1) + if transposed: + X = X.T + for _ in range(steps): + A = X @ X.T + B = b * A + c * A @ A + X = a * X + B @ X + if transposed: + X = X.T + return X +class Muon(torch.optim.Optimizer): + def __init__(self, params, lr: float, momentum: float, backend_steps: int, nesterov: bool = True): + super().__init__( + params, + dict(lr=lr, momentum=momentum, backend_steps=backend_steps, nesterov=nesterov), + ) + @torch.no_grad() + def step(self, closure=None): + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + distributed = dist.is_available() and dist.is_initialized() + world_size = dist.get_world_size() if distributed else 1 + rank = dist.get_rank() if distributed else 0 + for group in self.param_groups: + params = group["params"] + if not params: + continue + lr = group["lr"] + momentum = group["momentum"] + backend_steps = group["backend_steps"] + nesterov = group["nesterov"] + total_params = sum(int(p.numel()) for p in params) + updates_flat = torch.zeros(total_params, device=params[0].device, dtype=torch.bfloat16) + curr = 0 + for i, p in enumerate(params): + if i % world_size == rank and p.grad is not None: + g = p.grad + state = self.state[p] + if "momentum_buffer" not in state: + state["momentum_buffer"] = torch.zeros_like(g) + buf = state["momentum_buffer"] + buf.mul_(momentum).add_(g) + if nesterov: + g = g.add(buf, alpha=momentum) + g = zeropower_via_newtonschulz5(g, steps=backend_steps, polar_express=group.get("polar_express", False)) + g *= max(1, g.size(0) / g.size(1)) ** 0.5 + updates_flat[curr : curr + p.numel()] = g.reshape(-1) + curr += p.numel() + if distributed: + dist.all_reduce(updates_flat, op=dist.ReduceOp.SUM) + curr = 0 + for p in params: + g = updates_flat[curr : curr + p.numel()].view_as(p).to(dtype=p.dtype) + p.add_(g, alpha=-lr) + curr += p.numel() + return loss +def int6_ste(w: Tensor) -> Tensor: + qmax = _QUANT_MAX_VAL # Uses QUANT_BITS override if set (e.g. 15 for int5, 7 for int4) + if w.ndim == 2: + abs_max = w.abs().amax(dim=1, keepdim=True).clamp_min(1e-8) + else: + abs_max = w.abs().amax().clamp_min(1e-8) + scale = abs_max / float(qmax) + w_q = (w / scale).round().clamp(-qmax, qmax) * scale + return w + (w_q - w).detach() +def build_sentencepiece_luts( + sp: spm.SentencePieceProcessor, vocab_size: int, device: torch.device +) -> tuple[Tensor, Tensor, Tensor]: + sp_vocab_size = int(sp.vocab_size()) + table_size = max(sp_vocab_size, vocab_size) + base_bytes_np = np.zeros((table_size,), dtype=np.int16) + has_leading_space_np = np.zeros((table_size,), dtype=np.bool_) + is_boundary_token_np = np.ones((table_size,), dtype=np.bool_) + for token_id in range(sp_vocab_size): + if sp.is_control(token_id) or sp.is_unknown(token_id) or sp.is_unused(token_id): + continue + is_boundary_token_np[token_id] = False + if sp.is_byte(token_id): + base_bytes_np[token_id] = 1 + continue + piece = sp.id_to_piece(token_id) + if piece.startswith("▁"): + has_leading_space_np[token_id] = True + piece = piece[1:] + base_bytes_np[token_id] = len(piece.encode("utf-8")) + return ( + torch.tensor(base_bytes_np, dtype=torch.int16, device=device), + torch.tensor(has_leading_space_np, dtype=torch.bool, device=device), + torch.tensor(is_boundary_token_np, dtype=torch.bool, device=device), + ) +def load_validation_tokens(pattern: str, seq_len: int) -> Tensor: + files = [Path(p) for p in sorted(glob.glob(pattern))] + if not files: + raise FileNotFoundError(f"No files found for pattern: {pattern}") + tokens = torch.cat([load_data_shard(file) for file in files]).contiguous() + max_tokens = int(os.environ.get("VAL_MAX_TOKENS", 0)) + if max_tokens > 0 and tokens.numel() > max_tokens: + tokens = tokens[:max_tokens].contiguous() + usable = ((tokens.numel() - 1) // seq_len) * seq_len + if usable <= 0: + raise ValueError(f"Validation split is too short for TRAIN_SEQ_LEN={seq_len}") + return tokens[: usable + 1] +def eval_val( + args: Hyperparameters, + model: nn.Module, + rank: int, + world_size: int, + device: torch.device, + grad_accum_steps: int, + val_tokens: Tensor, + base_bytes_lut: Tensor, + has_leading_space_lut: Tensor, + is_boundary_token_lut: Tensor, +) -> tuple[float, float]: + local_batch_tokens = args.val_batch_size // (world_size * grad_accum_steps) + if local_batch_tokens < args.train_seq_len: + raise ValueError( + "VAL_BATCH_SIZE must provide at least one sequence per rank; " + f"got VAL_BATCH_SIZE={args.val_batch_size}, WORLD_SIZE={world_size}, " + f"GRAD_ACCUM_STEPS={grad_accum_steps}, TRAIN_SEQ_LEN={args.train_seq_len}" + ) + local_batch_seqs = local_batch_tokens // args.train_seq_len + total_seqs = (val_tokens.numel() - 1) // args.train_seq_len + seq_start = (total_seqs * rank) // world_size + seq_end = (total_seqs * (rank + 1)) // world_size + val_loss_sum = torch.zeros((), device=device, dtype=torch.float64) + val_token_count = torch.zeros((), device=device, dtype=torch.float64) + val_byte_count = torch.zeros((), device=device, dtype=torch.float64) + model.eval() + with torch.inference_mode(): + for batch_seq_start in range(seq_start, seq_end, local_batch_seqs): + batch_seq_end = min(batch_seq_start + local_batch_seqs, seq_end) + raw_start = batch_seq_start * args.train_seq_len + raw_end = batch_seq_end * args.train_seq_len + 1 + local = val_tokens[raw_start:raw_end].to(device=device, dtype=torch.int64, non_blocking=True) + x = local[:-1].reshape(-1, args.train_seq_len) + y = local[1:].reshape(-1, args.train_seq_len) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + batch_loss = model(x, y).detach() + batch_token_count = float(y.numel()) + val_loss_sum += batch_loss.to(torch.float64) * batch_token_count + val_token_count += batch_token_count + prev_ids = x.reshape(-1) + tgt_ids = y.reshape(-1) + token_bytes = base_bytes_lut[tgt_ids].to(dtype=torch.int16) + token_bytes += (has_leading_space_lut[tgt_ids] & ~is_boundary_token_lut[prev_ids]).to(dtype=torch.int16) + val_byte_count += token_bytes.to(torch.float64).sum() + if dist.is_available() and dist.is_initialized(): + dist.all_reduce(val_loss_sum, op=dist.ReduceOp.SUM) + dist.all_reduce(val_token_count, op=dist.ReduceOp.SUM) + dist.all_reduce(val_byte_count, op=dist.ReduceOp.SUM) + val_loss = val_loss_sum / val_token_count + bits_per_token = val_loss.item() / math.log(2.0) + tokens_per_byte = val_token_count.item() / val_byte_count.item() + model.train() + return float(val_loss.item()), float(bits_per_token * tokens_per_byte) +def eval_val_sliding( + args: Hyperparameters, + base_model: nn.Module, + rank: int, + world_size: int, + device: torch.device, + val_tokens: Tensor, + base_bytes_lut: Tensor, + has_leading_space_lut: Tensor, + is_boundary_token_lut: Tensor, + stride: int, + batch_seqs: int = 32, +) -> tuple[float, float]: + seq_len = args.train_seq_len + total_tokens = val_tokens.numel() - 1 + window_starts = [ws for ws in range(0, total_tokens, stride) + if min(ws + seq_len, total_tokens) - ws >= 1] + total_windows = len(window_starts) + my_s = (total_windows * rank) // world_size + my_e = (total_windows * (rank + 1)) // world_size + my_windows = window_starts[my_s:my_e] + loss_sum = torch.zeros((), device=device, dtype=torch.float64) + token_count = torch.zeros((), device=device, dtype=torch.float64) + byte_count = torch.zeros((), device=device, dtype=torch.float64) + base_model.eval() + compiled_logits = torch.compile(base_model.forward_logits, dynamic=False, fullgraph=True) + with torch.inference_mode(): + for bi in range(0, len(my_windows), batch_seqs): + batch_ws = my_windows[bi:bi + batch_seqs] + bsz = len(batch_ws) + x_batch = torch.zeros(bsz, seq_len, dtype=torch.int64, device=device) + y_batch = torch.zeros(bsz, seq_len, dtype=torch.int64, device=device) + wlens: list[int] = [] + for i, ws in enumerate(batch_ws): + end = min(ws + seq_len, total_tokens) + wlen = end - ws + wlens.append(wlen) + chunk = val_tokens[ws:end + 1].to(dtype=torch.int64, device=device) + x_batch[i, :wlen] = chunk[:-1] + y_batch[i, :wlen] = chunk[1:] + with torch.autocast(device_type="cuda", dtype=torch.bfloat16): + logits = compiled_logits(x_batch) + nll = F.cross_entropy( + logits.reshape(-1, logits.size(-1)).float(), + y_batch.reshape(-1), + reduction="none", + ).reshape(bsz, seq_len) + for i, ws in enumerate(batch_ws): + wlen = wlens[i] + s = 0 if ws == 0 else max(wlen - stride, 0) + scored_nll = nll[i, s:wlen].to(torch.float64) + loss_sum += scored_nll.sum() + token_count += float(wlen - s) + tgt = y_batch[i, s:wlen] + prev = x_batch[i, s:wlen] + tb = base_bytes_lut[tgt].to(torch.float64) + tb += (has_leading_space_lut[tgt] & ~is_boundary_token_lut[prev]).to(torch.float64) + byte_count += tb.sum() + if dist.is_available() and dist.is_initialized(): + dist.all_reduce(loss_sum, op=dist.ReduceOp.SUM) + dist.all_reduce(token_count, op=dist.ReduceOp.SUM) + dist.all_reduce(byte_count, op=dist.ReduceOp.SUM) + val_loss = (loss_sum / token_count).item() + bits_per_token = val_loss / math.log(2.0) + tokens_per_byte = token_count.item() / byte_count.item() + base_model.train() + return val_loss, bits_per_token * tokens_per_byte +class LoRAAdapter(nn.Module): + """Low-Rank Adaptation adapter for a linear layer.""" + def __init__(self, base_linear: nn.Linear, rank: int = 8): + super().__init__() + dim_in = base_linear.in_features + dim_out = base_linear.out_features + self.A = nn.Parameter(torch.empty(dim_in, rank)) + self.B = nn.Parameter(torch.zeros(rank, dim_out)) + nn.init.kaiming_uniform_(self.A, a=math.sqrt(5)) + self.base_linear = base_linear + self._hook_handle = None + + def _lora_hook(self, module: nn.Module, input: tuple[Tensor, ...], output: Tensor) -> Tensor: + """Forward hook that adds LoRA contribution to base linear output.""" + x = input[0] + lora_delta = (x @ self.A.to(x.dtype)) @ self.B.to(x.dtype) + return output + lora_delta + + def register(self) -> None: + """Register the LoRA hook on the base linear layer.""" + if self._hook_handle is None: + self._hook_handle = self.base_linear.register_forward_hook(self._lora_hook) + + def remove(self) -> None: + """Remove the LoRA hook from the base linear layer.""" + if self._hook_handle is not None: + self._hook_handle.remove() + self._hook_handle = None + +def eval_val_sliding_lora_ttt( + args: Hyperparameters, + base_model: nn.Module, + rank: int, + world_size: int, + device: torch.device, + val_tokens: Tensor, + base_bytes_lut: Tensor, + has_leading_space_lut: Tensor, + is_boundary_token_lut: Tensor, + stride: int, + batch_seqs: int = 32, + log0=print, +) -> tuple[float, float]: + """Test-time training with LoRA adapters (only train adapter weights, not base model).""" + seq_len = args.train_seq_len + total_tokens = val_tokens.numel() - 1 + ttt_chunk = args.ttt_chunk_tokens + + window_starts = [ws for ws in range(0, total_tokens, stride) + if min(ws + seq_len, total_tokens) - ws >= stride or ws == 0] + num_chunks = (total_tokens + ttt_chunk - 1) // ttt_chunk + + chunk_windows: list[list[int]] = [[] for _ in range(num_chunks)] + for ws in window_starts: + end = min(ws + seq_len, total_tokens) + wlen = end - ws + s = 0 if ws == 0 else max(wlen - stride, 0) + scored_start = ws + s + ci = min(scored_start // ttt_chunk, num_chunks - 1) + chunk_windows[ci].append(ws) + + log0(f"lora_ttt_sliding:start chunks={num_chunks} chunk_tokens={ttt_chunk} " + f"total_windows={len(window_starts)} stride={stride} " + f"lora_rank={args.ttt_lora_rank} lora_lr={args.ttt_lora_lr} " + f"ttt_epochs={args.ttt_epochs}") + + loss_sum = torch.zeros((), device=device, dtype=torch.float64) + token_count = torch.zeros((), device=device, dtype=torch.float64) + byte_count = torch.zeros((), device=device, dtype=torch.float64) + + # Create LoRA adapters for attention projections and LM head + lora_adapters: list[LoRAAdapter] = [] + + # Add LoRA to attention blocks (c_q, c_v, proj) + if args.recurrence_repeats > 0: + # Recurrent model structure + all_blocks = list(base_model.head_blocks) + [base_model.shared_block] + list(base_model.tail_blocks) + else: + # Standard block structure + all_blocks = list(base_model.blocks) + + for block in all_blocks: + if hasattr(block, 'attn'): + attn = block.attn + # Add LoRA to Q, V, and output projection + for layer_name in ['c_q', 'c_v', 'proj']: + if hasattr(attn, layer_name): + base_linear = getattr(attn, layer_name) + adapter = LoRAAdapter(base_linear, rank=args.ttt_lora_rank).to(device) + lora_adapters.append(adapter) + + # Add LoRA to LM head (or tok_emb if tied) + if base_model.tie_embeddings: + # For tied embeddings, we need to patch tok_emb.weight usage in forward + # Create a pseudo-linear wrapper for the embedding weight + class EmbeddingAsLinear(nn.Module): + def __init__(self, weight: nn.Parameter): + super().__init__() + self.weight = weight + self.in_features = weight.size(1) + self.out_features = weight.size(0) + self.bias = None + + def forward(self, x: Tensor) -> Tensor: + return F.linear(x, self.weight) + + emb_linear = EmbeddingAsLinear(base_model.tok_emb.weight).to(device) + lm_adapter = LoRAAdapter(emb_linear, rank=args.ttt_lora_rank).to(device) + lora_adapters.append(lm_adapter) + elif base_model.lm_head is not None: + lm_adapter = LoRAAdapter(base_model.lm_head, rank=args.ttt_lora_rank).to(device) + lora_adapters.append(lm_adapter) + + # Collect all LoRA parameters + lora_params = [] + for adapter in lora_adapters: + lora_params.extend([adapter.A, adapter.B]) + + log0(f"lora_ttt_sliding:params lora={sum(p.numel() for p in lora_params)} " + f"base_frozen={sum(p.numel() for p in base_model.parameters())} " + f"adapters={len(lora_adapters)}") + + # Freeze all base model parameters + for p in base_model.parameters(): + p.requires_grad_(False) + + # Register all LoRA hooks + for adapter in lora_adapters: + adapter.register() + + # Use AdamW optimizer for LoRA params + optimizer = torch.optim.AdamW( + lora_params, + lr=args.ttt_lora_lr, + betas=(0.9, 0.999), + weight_decay=0.01, + ) + + # Compile forward paths + if not int(os.environ.get("TORCHDYNAMO_DISABLE", "0")): + compiled_logits_lora = torch.compile(base_model.forward_logits, dynamic=False, fullgraph=True) + compiled_forward_lora = torch.compile(base_model.forward, dynamic=False, fullgraph=True) + else: + compiled_logits_lora = base_model.forward_logits + compiled_forward_lora = base_model.forward + + t0 = time.perf_counter() + + for ci in range(num_chunks): + windows = chunk_windows[ci] + if not windows: + continue + + chunk_start = ci * ttt_chunk + chunk_end = min((ci + 1) * ttt_chunk, total_tokens) + + my_s = (len(windows) * rank) // world_size + my_e = (len(windows) * (rank + 1)) // world_size + my_windows = windows[my_s:my_e] + + # Score phase (inference only) + base_model.eval() + with torch.inference_mode(): + for bi in range(0, len(my_windows), batch_seqs): + batch_ws = my_windows[bi:bi + batch_seqs] + bsz = len(batch_ws) + x_batch = torch.zeros(bsz, seq_len, dtype=torch.int64, device=device) + y_batch = torch.zeros(bsz, seq_len, dtype=torch.int64, device=device) + wlens: list[int] = [] + + for i, ws in enumerate(batch_ws): + end = min(ws + seq_len, total_tokens) + wlen = end - ws + wlens.append(wlen) + chunk_tok = val_tokens[ws:end + 1].to(dtype=torch.int64, device=device) + x_batch[i, :wlen] = chunk_tok[:-1] + y_batch[i, :wlen] = chunk_tok[1:] + + with torch.autocast(device_type="cuda", dtype=torch.bfloat16): + logits = compiled_logits_lora(x_batch) + + nll = F.cross_entropy( + logits.reshape(-1, logits.size(-1)).float(), + y_batch.reshape(-1), reduction="none", + ).reshape(bsz, seq_len) + + for i, ws in enumerate(batch_ws): + wlen = wlens[i] + s = 0 if ws == 0 else max(wlen - stride, 0) + scored_nll = nll[i, s:wlen].to(torch.float64) + loss_sum += scored_nll.sum() + token_count += float(wlen - s) + + tgt, prev = y_batch[i, s:wlen], x_batch[i, s:wlen] + tb = base_bytes_lut[tgt].to(torch.float64) + tb += (has_leading_space_lut[tgt] & ~is_boundary_token_lut[prev]).to(torch.float64) + byte_count += tb.sum() + + # Train phase (update LoRA params only, skip last chunk) + is_last_chunk = (ci == num_chunks - 1) + if not is_last_chunk and args.ttt_epochs > 0: + base_model.train() + chunk_seqs = (chunk_end - chunk_start) // seq_len + + if chunk_seqs > 0: + # Cosine LR decay + cos_lr = args.ttt_lora_lr * 0.5 * (1.0 + math.cos(math.pi * ci / max(num_chunks - 1, 1))) + for pg in optimizer.param_groups: + pg['lr'] = cos_lr + + my_seq_s = (chunk_seqs * rank) // world_size + my_seq_e = (chunk_seqs * (rank + 1)) // world_size + my_chunk_seqs = my_seq_e - my_seq_s + + for _ep in range(args.ttt_epochs): + for bs in range(0, my_chunk_seqs, args.ttt_batch_seqs): + be = min(bs + args.ttt_batch_seqs, my_chunk_seqs) + actual_bs = my_seq_s + bs + start_tok = chunk_start + actual_bs * seq_len + end_tok = chunk_start + (my_seq_s + be) * seq_len + 1 + + if end_tok > val_tokens.numel(): + continue + + local = val_tokens[start_tok:end_tok].to(device=device, dtype=torch.int64) + x = local[:-1].reshape(-1, seq_len) + y = local[1:].reshape(-1, seq_len) + + optimizer.zero_grad(set_to_none=True) + + with torch.autocast(device_type="cuda", dtype=torch.bfloat16): + loss = compiled_forward_lora(x, y) + + loss.backward() + + # All-reduce LoRA gradients across GPUs + if world_size > 1: + for p in lora_params: + if p.grad is not None: + dist.all_reduce(p.grad, op=dist.ReduceOp.AVG) + + # Gradient clipping on LoRA params + torch.nn.utils.clip_grad_norm_(lora_params, args.ttt_grad_clip) + optimizer.step() + + if rank == 0 and (ci % 10 == 0 or ci == num_chunks - 1): + elapsed = time.perf_counter() - t0 + rl = loss_sum.item() / max(token_count.item(), 1) + rbpb = rl / math.log(2.0) * (token_count.item() / max(byte_count.item(), 1)) if token_count.item() > 0 else 0.0 + log0(f" lora_ttt_chunk [{ci+1}/{num_chunks}] bpb={rbpb:.6f} time={elapsed:.1f}s") + + if dist.is_available() and dist.is_initialized(): + dist.all_reduce(loss_sum, op=dist.ReduceOp.SUM) + dist.all_reduce(token_count, op=dist.ReduceOp.SUM) + dist.all_reduce(byte_count, op=dist.ReduceOp.SUM) + + val_loss = (loss_sum / token_count).item() + val_bpb = val_loss / math.log(2.0) * (token_count.item() / byte_count.item()) + + # Clean up: remove all LoRA hooks and restore base model + for adapter in lora_adapters: + adapter.remove() + + # Restore requires_grad on base model + for p in base_model.parameters(): + p.requires_grad_(True) + + base_model.eval() + + log0(f"lora_ttt_sliding:done val_loss={val_loss:.6f} val_bpb={val_bpb:.6f} " + f"elapsed={time.perf_counter() - t0:.1f}s") + + return val_loss, val_bpb + +def eval_val_sliding_ttt( + args: Hyperparameters, + base_model: nn.Module, + rank: int, + world_size: int, + device: torch.device, + val_tokens: Tensor, + base_bytes_lut: Tensor, + has_leading_space_lut: Tensor, + is_boundary_token_lut: Tensor, + stride: int, + batch_seqs: int = 32, + log0=print, +) -> tuple[float, float]: + seq_len = args.train_seq_len + total_tokens = val_tokens.numel() - 1 + ttt_chunk = args.ttt_chunk_tokens + window_starts = [ws for ws in range(0, total_tokens, stride) + if min(ws + seq_len, total_tokens) - ws >= stride or ws == 0] + num_chunks = (total_tokens + ttt_chunk - 1) // ttt_chunk + chunk_windows: list[list[int]] = [[] for _ in range(num_chunks)] + for ws in window_starts: + end = min(ws + seq_len, total_tokens) + wlen = end - ws + s = 0 if ws == 0 else max(wlen - stride, 0) + scored_start = ws + s + ci = min(scored_start // ttt_chunk, num_chunks - 1) + chunk_windows[ci].append(ws) + log0(f"ttt_sliding:start chunks={num_chunks} chunk_tokens={ttt_chunk} " + f"total_windows={len(window_starts)} stride={stride} " + f"ttt_lr={args.ttt_lr} ttt_epochs={args.ttt_epochs} " + f"freeze_blocks={args.ttt_freeze_blocks}") + loss_sum = torch.zeros((), device=device, dtype=torch.float64) + token_count = torch.zeros((), device=device, dtype=torch.float64) + byte_count = torch.zeros((), device=device, dtype=torch.float64) + frozen_block_ids = set(range(min(args.ttt_freeze_blocks, len(base_model.blocks)))) + ttt_params = [] + for name, p in base_model.named_parameters(): + freeze = False + for bi in frozen_block_ids: + if f"blocks.{bi}." in name: + freeze = True + break + if freeze: + p.requires_grad_(False) + else: + p.requires_grad_(True) + ttt_params.append(p) + log0(f"ttt_sliding:params unfrozen={sum(p.numel() for p in ttt_params)} " + f"frozen={sum(p.numel() for p in base_model.parameters() if not p.requires_grad)}") + optimizer = torch.optim.SGD(ttt_params, lr=args.ttt_lr, momentum=args.ttt_momentum) + # Compile forward paths for TTT (matches eval_val_sliding behavior) + if not int(os.environ.get("TORCHDYNAMO_DISABLE", "0")): + compiled_logits_ttt = torch.compile(base_model.forward_logits, dynamic=False, fullgraph=True) + compiled_forward_ttt = torch.compile(base_model.forward, dynamic=False, fullgraph=True) + else: + compiled_logits_ttt = base_model.forward_logits + compiled_forward_ttt = base_model.forward + t0 = time.perf_counter() + for ci in range(num_chunks): + windows = chunk_windows[ci] + if not windows: + continue + chunk_start = ci * ttt_chunk + chunk_end = min((ci + 1) * ttt_chunk, total_tokens) + my_s = (len(windows) * rank) // world_size + my_e = (len(windows) * (rank + 1)) // world_size + my_windows = windows[my_s:my_e] + base_model.eval() + with torch.inference_mode(): + for bi in range(0, len(my_windows), batch_seqs): + batch_ws = my_windows[bi:bi + batch_seqs] + bsz = len(batch_ws) + x_batch = torch.zeros(bsz, seq_len, dtype=torch.int64, device=device) + y_batch = torch.zeros(bsz, seq_len, dtype=torch.int64, device=device) + wlens: list[int] = [] + for i, ws in enumerate(batch_ws): + end = min(ws + seq_len, total_tokens) + wlen = end - ws + wlens.append(wlen) + chunk_tok = val_tokens[ws:end + 1].to(dtype=torch.int64, device=device) + x_batch[i, :wlen] = chunk_tok[:-1] + y_batch[i, :wlen] = chunk_tok[1:] + with torch.autocast(device_type="cuda", dtype=torch.bfloat16): + logits = compiled_logits_ttt(x_batch) + nll = F.cross_entropy( + logits.reshape(-1, logits.size(-1)).float(), + y_batch.reshape(-1), reduction="none", + ).reshape(bsz, seq_len) + for i, ws in enumerate(batch_ws): + wlen = wlens[i] + s = 0 if ws == 0 else max(wlen - stride, 0) + scored_nll = nll[i, s:wlen].to(torch.float64) + loss_sum += scored_nll.sum() + token_count += float(wlen - s) + tgt, prev = y_batch[i, s:wlen], x_batch[i, s:wlen] + tb = base_bytes_lut[tgt].to(torch.float64) + tb += (has_leading_space_lut[tgt] & ~is_boundary_token_lut[prev]).to(torch.float64) + byte_count += tb.sum() + is_last_chunk = (ci == num_chunks - 1) + if not is_last_chunk and args.ttt_epochs > 0: + base_model.train() + chunk_seqs = (chunk_end - chunk_start) // seq_len + if chunk_seqs > 0: + cos_lr = args.ttt_lr * 0.5 * (1.0 + math.cos(math.pi * ci / max(num_chunks - 1, 1))) + for pg in optimizer.param_groups: + pg['lr'] = cos_lr + my_seq_s = (chunk_seqs * rank) // world_size + my_seq_e = (chunk_seqs * (rank + 1)) // world_size + my_chunk_seqs = my_seq_e - my_seq_s + for _ep in range(args.ttt_epochs): + for bs in range(0, my_chunk_seqs, args.ttt_batch_seqs): + be = min(bs + args.ttt_batch_seqs, my_chunk_seqs) + actual_bs = my_seq_s + bs + start_tok = chunk_start + actual_bs * seq_len + end_tok = chunk_start + (my_seq_s + be) * seq_len + 1 + if end_tok > val_tokens.numel(): + continue + local = val_tokens[start_tok:end_tok].to(device=device, dtype=torch.int64) + x = local[:-1].reshape(-1, seq_len) + y = local[1:].reshape(-1, seq_len) + optimizer.zero_grad(set_to_none=True) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16): + loss = compiled_forward_ttt(x, y) + loss.backward() + if world_size > 1: + for p in ttt_params: + if p.grad is not None: + dist.all_reduce(p.grad, op=dist.ReduceOp.AVG) + torch.nn.utils.clip_grad_norm_(ttt_params, args.ttt_grad_clip) + optimizer.step() + if rank == 0 and (ci % 10 == 0 or ci == num_chunks - 1): + elapsed = time.perf_counter() - t0 + rl = loss_sum.item() / max(token_count.item(), 1) + rbpb = rl / math.log(2.0) * (token_count.item() / max(byte_count.item(), 1)) if token_count.item() > 0 else 0.0 + log0(f" ttt_chunk [{ci+1}/{num_chunks}] bpb={rbpb:.6f} time={elapsed:.1f}s") + if dist.is_available() and dist.is_initialized(): + dist.all_reduce(loss_sum, op=dist.ReduceOp.SUM) + dist.all_reduce(token_count, op=dist.ReduceOp.SUM) + dist.all_reduce(byte_count, op=dist.ReduceOp.SUM) + val_loss = (loss_sum / token_count).item() + val_bpb = val_loss / math.log(2.0) * (token_count.item() / byte_count.item()) + for p in base_model.parameters(): + p.requires_grad_(True) + base_model.eval() + log0(f"ttt_sliding:done val_loss={val_loss:.6f} val_bpb={val_bpb:.6f} " + f"elapsed={time.perf_counter() - t0:.1f}s") + return val_loss, val_bpb +_INT6_MODE = bool(int(os.environ.get("INT6_QAT", "1"))) +_INT3_MODE = bool(int(os.environ.get("INT3_QUANT", "0"))) # Ternary quantization +_QUANT_OVERRIDE = int(os.environ.get("QUANT_BITS", "0")) # Override: 3,4,5,6,8 +if _QUANT_OVERRIDE > 0: + _QUANT_BITS = _QUANT_OVERRIDE + _QUANT_MAX_VAL = (1 << (_QUANT_BITS - 1)) - 1 # e.g. int5 -> 15, int4 -> 7 +elif _INT3_MODE: + _QUANT_MAX_VAL = 3 # Ternary: -3, -2, -1, 0, 1, 2, 3 + _QUANT_BITS = 3 +elif _INT6_MODE: + _QUANT_MAX_VAL = 31 + _QUANT_BITS = 6 +else: + _QUANT_MAX_VAL = 127 + _QUANT_BITS = 8 +CONTROL_TENSOR_NAME_PATTERNS = tuple( + pattern + for pattern in os.environ.get( + "CONTROL_TENSOR_NAME_PATTERNS", + "attn_scale,attn_scales,mlp_scale,mlp_scales,resid_mix,resid_mixes,q_gain,skip_weight,skip_weights,smear", + ).split(",") + if pattern +) +INT8_KEEP_FLOAT_FP32_NAME_PATTERNS = tuple( + pattern + for pattern in os.environ.get( + "INT8_KEEP_FLOAT_FP32_NAME_PATTERNS", + ",".join(CONTROL_TENSOR_NAME_PATTERNS), + ).split(",") + if pattern +) +INT8_KEEP_FLOAT_MAX_NUMEL = 65_536 +INT8_KEEP_FLOAT_STORE_DTYPE = torch.float16 +INT8_PER_ROW_SCALE_DTYPE = torch.float16 +INT8_CLIP_PERCENTILE = 99.99984 +INT8_CLIP_Q = INT8_CLIP_PERCENTILE / 100.0 +def tensor_nbytes(t: Tensor) -> int: + return int(t.numel()) * int(t.element_size()) +def keep_float_tensor(name: str, t: Tensor, passthrough_orig_dtypes: dict[str, str]) -> Tensor: + if any(pattern in name for pattern in INT8_KEEP_FLOAT_FP32_NAME_PATTERNS): + return t.float().contiguous() + if t.dtype in {torch.float32, torch.bfloat16}: + passthrough_orig_dtypes[name] = str(t.dtype).removeprefix("torch.") + return t.to(dtype=INT8_KEEP_FLOAT_STORE_DTYPE).contiguous() + return t +_GPTQ_LITE_PERCENTILES = [0.9999, 0.99995, 0.99999, 0.999995, 1.0] +def quantize_float_tensor(t: Tensor, gpu_device: torch.device | None = None) -> tuple[Tensor, Tensor]: + # Run GPTQ-lite search on GPU when available (much faster for large 2D weight matrices) + dev = gpu_device if (gpu_device is not None and t.ndim == 2 and t.numel() > 0) else t.device + t32 = t.float().to(dev) + qmax = _QUANT_MAX_VAL + if t32.ndim == 2 and t32.numel() > 0: + abs_vals = t32.abs() + best_q = None + best_scale = None + best_mse = None + for pct in _GPTQ_LITE_PERCENTILES: + if pct >= 1.0: + clip_abs = abs_vals.amax(dim=1) + else: + clip_abs = torch.quantile(abs_vals, pct, dim=1) + clip_abs = clip_abs.clamp_min(1e-8) + scale = (clip_abs / qmax).clamp_min(1.0 / qmax) + clipped = torch.maximum(torch.minimum(t32, clip_abs[:, None]), -clip_abs[:, None]) + q = torch.clamp(torch.round(clipped / scale[:, None]), -qmax, qmax) + recon = q * scale[:, None] + mse = (t32 - recon).square().mean(dim=1) + if best_mse is None: + best_mse = mse + best_q = q + best_scale = scale + else: + improved = mse < best_mse + if improved.any(): + best_mse = torch.where(improved, mse, best_mse) + best_q = torch.where(improved[:, None], q, best_q) + best_scale = torch.where(improved, scale, best_scale) + return best_q.cpu().to(torch.int8).contiguous(), best_scale.cpu().to(dtype=INT8_PER_ROW_SCALE_DTYPE).contiguous() + if t32.ndim == 2: + return torch.empty_like(t32, dtype=torch.int8, device="cpu"), torch.empty((t32.shape[0],), dtype=INT8_PER_ROW_SCALE_DTYPE) + t32 = t32.cpu() + clip_abs = float(torch.quantile(t32.abs().flatten(), INT8_CLIP_Q).item()) if t32.numel() else 0.0 + scale = torch.tensor(clip_abs / qmax if clip_abs > 0 else 1.0, dtype=torch.float32) + q = torch.clamp(torch.round(torch.clamp(t32, -clip_abs, clip_abs) / scale), -qmax, qmax).to(torch.int8).contiguous() + return q, scale +def quantize_state_dict_int8(state_dict: dict[str, Tensor], gpu_device: torch.device | None = None, + rank: int = 0, world_size: int = 1): + """Quantize state dict. When world_size > 1, distributes large tensor quantization across ranks.""" + quantized: dict[str, Tensor] = {} + scales: dict[str, Tensor] = {} + dtypes: dict[str, str] = {} + passthrough: dict[str, Tensor] = {} + passthrough_orig_dtypes: dict[str, str] = {} + qmeta: dict[str, dict[str, object]] = {} + stats = dict.fromkeys( + ("param_count", "num_tensors", "num_float_tensors", "num_nonfloat_tensors", "baseline_tensor_bytes", "int8_payload_bytes"), + 0, + ) + # Separate tensors into passthrough vs needs-quantization + to_quantize: list[tuple[str, Tensor]] = [] + for name, tensor in state_dict.items(): + t = tensor.detach().to("cpu").contiguous() + stats["param_count"] += int(t.numel()) + stats["num_tensors"] += 1 + stats["baseline_tensor_bytes"] += tensor_nbytes(t) + if not t.is_floating_point(): + stats["num_nonfloat_tensors"] += 1 + passthrough[name] = t + stats["int8_payload_bytes"] += tensor_nbytes(t) + continue + if t.numel() <= INT8_KEEP_FLOAT_MAX_NUMEL: + kept = keep_float_tensor(name, t, passthrough_orig_dtypes) + passthrough[name] = kept + stats["int8_payload_bytes"] += tensor_nbytes(kept) + continue + stats["num_float_tensors"] += 1 + to_quantize.append((name, t)) + # Distribute quantization work across ranks (round-robin by tensor index) + for idx, (name, t) in enumerate(to_quantize): + if world_size > 1 and idx % world_size != rank: + continue # Another rank handles this tensor + q, s = quantize_float_tensor(t, gpu_device=gpu_device) + if s.ndim > 0: + qmeta[name] = {"scheme": "per_row", "axis": 0} + quantized[name] = q + scales[name] = s + dtypes[name] = str(t.dtype).removeprefix("torch.") + stats["int8_payload_bytes"] += tensor_nbytes(q) + tensor_nbytes(s) + # Gather results from all ranks + if world_size > 1 and dist.is_available() and dist.is_initialized(): + # Each rank broadcasts its quantized tensors to all others + for idx, (name, t) in enumerate(to_quantize): + owner = idx % world_size + if owner == rank: + # Broadcast shape info then tensor data + q_tensor = quantized[name] + s_tensor = scales[name] + # Send via broadcast + q_gpu = q_tensor.to(gpu_device) if gpu_device is not None else q_tensor.cuda() + s_gpu = s_tensor.to(gpu_device) if gpu_device is not None else s_tensor.cuda() + dist.broadcast(q_gpu, src=owner) + dist.broadcast(s_gpu, src=owner) + quantized[name] = q_gpu.cpu() + scales[name] = s_gpu.cpu() + else: + # Allocate matching tensors and receive + q_shape = (t.shape[0],) + t.shape[1:] if t.ndim == 2 else t.shape + q_gpu = torch.empty(q_shape, dtype=torch.int8, device=gpu_device if gpu_device is not None else "cuda") + if t.ndim == 2 and t.numel() > 0: + s_gpu = torch.empty(t.shape[0], dtype=INT8_PER_ROW_SCALE_DTYPE, device=gpu_device if gpu_device is not None else "cuda") + else: + s_gpu = torch.empty((), dtype=torch.float32, device=gpu_device if gpu_device is not None else "cuda") + dist.broadcast(q_gpu, src=owner) + dist.broadcast(s_gpu, src=owner) + quantized[name] = q_gpu.cpu() + scales[name] = s_gpu.cpu() + dtypes[name] = str(t.dtype).removeprefix("torch.") + if s_gpu.ndim > 0: + qmeta[name] = {"scheme": "per_row", "axis": 0} + stats["int8_payload_bytes"] += tensor_nbytes(quantized[name]) + tensor_nbytes(scales[name]) + obj: dict[str, object] = { + "__quant_format__": f"int{_QUANT_BITS}_clean_per_row_v1", + "quantized": quantized, + "scales": scales, + "dtypes": dtypes, + "passthrough": passthrough, + } + if qmeta: + obj["qmeta"] = qmeta + if passthrough_orig_dtypes: + obj["passthrough_orig_dtypes"] = passthrough_orig_dtypes + return obj, stats +def dequantize_state_dict_int8(obj: dict[str, object]) -> dict[str, Tensor]: + out: dict[str, Tensor] = {} + qmeta = obj.get("qmeta", {}) + passthrough_orig_dtypes = obj.get("passthrough_orig_dtypes", {}) + for name, q in obj["quantized"].items(): + dtype = getattr(torch, obj["dtypes"][name]) + s = obj["scales"][name] + if qmeta.get(name, {}).get("scheme") == "per_row" or s.ndim > 0: + s = s.to(dtype=torch.float32) + out[name] = (q.float() * s.view(q.shape[0], *([1] * (q.ndim - 1)))).to(dtype=dtype).contiguous() + else: + scale = float(s.item()) + out[name] = (q.float() * scale).to(dtype=dtype).contiguous() + for name, t in obj["passthrough"].items(): + out_t = t.detach().to("cpu").contiguous() + orig_dtype = passthrough_orig_dtypes.get(name) + if isinstance(orig_dtype, str): + out_t = out_t.to(dtype=getattr(torch, orig_dtype)).contiguous() + out[name] = out_t + return out +def load_data_shard(file: Path) -> Tensor: + header_bytes = 256 * np.dtype(" None: + self.file_idx = (self.file_idx + 1) % len(self.files) + self.tokens = load_data_shard(self.files[self.file_idx]) + self.pos = 0 + def take(self, n: int) -> Tensor: + chunks: list[Tensor] = [] + remaining = n + while remaining > 0: + avail = self.tokens.numel() - self.pos + if avail <= 0: + self._advance_file() + continue + k = min(remaining, avail) + chunks.append(self.tokens[self.pos : self.pos + k]) + self.pos += k + remaining -= k + return chunks[0] if len(chunks) == 1 else torch.cat(chunks) +class DistributedTokenLoader: + def __init__(self, pattern: str, rank: int, world_size: int, device: torch.device): + self.rank = rank + self.world_size = world_size + self.device = device + self.stream = TokenStream(pattern) + def next_batch(self, global_tokens: int, seq_len: int, grad_accum_steps: int) -> tuple[Tensor, Tensor]: + local_tokens = global_tokens // (self.world_size * grad_accum_steps) + per_rank_span = local_tokens + 1 + chunk = self.stream.take(per_rank_span * self.world_size) + start = self.rank * per_rank_span + local = chunk[start : start + per_rank_span].to(dtype=torch.int64) + x = local[:-1].reshape(-1, seq_len) + y = local[1:].reshape(-1, seq_len) + return x.to(self.device, non_blocking=True), y.to(self.device, non_blocking=True) +class RMSNorm(nn.Module): + def __init__(self, eps: float | None = None): + super().__init__() + self.eps = eps + def forward(self, x: Tensor) -> Tensor: + return F.rms_norm(x, (x.size(-1),), eps=self.eps) +class CastedLinear(nn.Linear): + _int6_qat: bool = False + def forward(self, x: Tensor) -> Tensor: + w = self.weight.to(x.dtype) + if CastedLinear._int6_qat and self.training: + w = int6_ste(w) + bias = self.bias.to(x.dtype) if self.bias is not None else None + return F.linear(x, w, bias) +def restore_low_dim_params_to_fp32(module: nn.Module) -> None: + with torch.no_grad(): + for name, param in module.named_parameters(): + if (param.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS)) and param.dtype != torch.float32: + param.data = param.data.float() +class BigramHash(nn.Module): + def __init__(self, num_buckets: int, dim: int): + super().__init__() + self.num_buckets = num_buckets + self.emb = nn.Embedding(num_buckets, dim) + nn.init.zeros_(self.emb.weight) + def forward(self, input_ids: Tensor) -> Tensor: + prev = F.pad(input_ids[:, :-1], (1, 0), value=0) + bigram_hash = (prev * 31 + input_ids) % self.num_buckets + return self.emb(bigram_hash) +class SmearGate(nn.Module): + def __init__(self, dim: int): + super().__init__() + self.gate = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + def forward(self, x: Tensor) -> Tensor: + g = torch.sigmoid(self.gate.to(dtype=x.dtype))[None, None, :] + positions = torch.arange(1, x.shape[1] + 1, device=x.device, dtype=x.dtype).view(1, -1, 1) + smooth = torch.cumsum(x, dim=1) / positions + return g * x + (1 - g) * smooth +class Rotary(nn.Module): + def __init__(self, dim: int, base: float = 10000.0): + super().__init__() + inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.float32) / dim)) + self.register_buffer("inv_freq", inv_freq, persistent=False) + self._seq_len_cached = 0 + self._cos_cached: Tensor | None = None + self._sin_cached: Tensor | None = None + def forward(self, seq_len: int, device: torch.device, dtype: torch.dtype) -> tuple[Tensor, Tensor]: + if ( + self._cos_cached is None + or self._sin_cached is None + or self._seq_len_cached != seq_len + or self._cos_cached.device != device + ): + t = torch.arange(seq_len, device=device, dtype=self.inv_freq.dtype) + freqs = torch.outer(t, self.inv_freq.to(device)) + self._cos_cached = freqs.cos()[None, None, :, :] + self._sin_cached = freqs.sin()[None, None, :, :] + self._seq_len_cached = seq_len + return self._cos_cached.to(dtype=dtype), self._sin_cached.to(dtype=dtype) +def apply_rotary_emb(x: Tensor, cos: Tensor, sin: Tensor) -> Tensor: + half = x.size(-1) // 2 + x1, x2 = x[..., :half], x[..., half:] + return torch.cat((x1 * cos + x2 * sin, x1 * (-sin) + x2 * cos), dim=-1) +class CausalSelfAttention(nn.Module): + def __init__( + self, + dim: int, + num_heads: int, + num_kv_heads: int, + rope_base: float, + qk_gain_init: float, + rope_partial_dims: int = 0, + use_xsa: bool = False, + ): + super().__init__() + if dim % num_heads != 0: + raise ValueError("model_dim must be divisible by num_heads") + if num_heads % num_kv_heads != 0: + raise ValueError("num_heads must be divisible by num_kv_heads") + self.num_heads = num_heads + self.num_kv_heads = num_kv_heads + self.head_dim = dim // num_heads + if self.head_dim % 2 != 0: + raise ValueError("head_dim must be even for RoPE") + kv_dim = self.num_kv_heads * self.head_dim + self.c_q = CastedLinear(dim, dim, bias=False) + self.c_k = CastedLinear(dim, kv_dim, bias=False) + self.c_v = CastedLinear(dim, kv_dim, bias=False) + self.proj = CastedLinear(dim, dim, bias=False) + self.proj._zero_init = True + self.q_gain = nn.Parameter(torch.full((num_heads,), qk_gain_init, dtype=torch.float32)) + rope_dim = rope_partial_dims if 0 < rope_partial_dims < self.head_dim else self.head_dim + self.rope_dim = rope_dim + self.rotary = Rotary(rope_dim, base=rope_base) + self.use_xsa = use_xsa + def forward(self, x: Tensor) -> Tensor: + bsz, seqlen, dim = x.shape + q = self.c_q(x).reshape(bsz, seqlen, self.num_heads, self.head_dim).transpose(1, 2) + k = self.c_k(x).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim).transpose(1, 2) + v = self.c_v(x).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim).transpose(1, 2) + q = F.rms_norm(q, (q.size(-1),)) + k = F.rms_norm(k, (k.size(-1),)) + cos, sin = self.rotary(seqlen, x.device, q.dtype) + if self.rope_dim < self.head_dim: + q_rope, q_pass = q[..., :self.rope_dim], q[..., self.rope_dim:] + k_rope, k_pass = k[..., :self.rope_dim], k[..., self.rope_dim:] + q_rope = apply_rotary_emb(q_rope, cos, sin) + k_rope = apply_rotary_emb(k_rope, cos, sin) + q = torch.cat([q_rope, q_pass], dim=-1) + k = torch.cat([k_rope, k_pass], dim=-1) + else: + q = apply_rotary_emb(q, cos, sin) + k = apply_rotary_emb(k, cos, sin) + q = q * self.q_gain.to(dtype=q.dtype)[None, :, None, None] + y = F.scaled_dot_product_attention( + q, + k, + v, + attn_mask=None, + is_causal=True, + enable_gqa=(self.num_kv_heads != self.num_heads), + ) + if self.use_xsa: + if self.num_kv_heads != self.num_heads: + repeats = self.num_heads // self.num_kv_heads + v_expanded = v.repeat_interleave(repeats, dim=1) + else: + v_expanded = v + y = y - v_expanded + y = y.transpose(1, 2).contiguous().reshape(bsz, seqlen, dim) + return self.proj(y) +class MLP(nn.Module): + def __init__(self, dim: int, mlp_mult: int): + super().__init__() + hidden = mlp_mult * dim + self.fc = CastedLinear(dim, hidden, bias=False) + self.proj = CastedLinear(hidden, dim, bias=False) + self.proj._zero_init = True + def forward(self, x: Tensor) -> Tensor: + x = F.leaky_relu(self.fc(x), negative_slope=0.5) + return self.proj(x.square()) +class MoEMLP(nn.Module): + def __init__(self, dim: int, mlp_mult: int, num_experts: int): + super().__init__() + self.num_experts = num_experts + expert_hidden = (mlp_mult * dim) // num_experts + self.experts = nn.ModuleList([ + nn.ModuleDict({ + "fc": CastedLinear(dim, expert_hidden, bias=False), + "proj": CastedLinear(expert_hidden, dim, bias=False), + }) + for _ in range(num_experts) + ]) + for e in self.experts: + e["proj"]._zero_init = True + self.router = nn.Linear(dim, num_experts, bias=False) + def forward(self, x: Tensor) -> Tensor: + bsz, seqlen, dim = x.shape + logits = self.router(x.detach()) + weights = torch.softmax(logits, dim=-1) + indices = weights.argmax(dim=-1) + out = torch.zeros_like(x) + for i, expert in enumerate(self.experts): + mask = (indices == i) + if not mask.any(): + continue + tokens = x[mask] + h = F.leaky_relu(expert["fc"](tokens), negative_slope=0.5) + h = expert["proj"](h.square()) + out[mask] = h + return out +class Block(nn.Module): + def __init__( + self, + dim: int, + num_heads: int, + num_kv_heads: int, + mlp_mult: int, + rope_base: float, + qk_gain_init: float, + use_smeargate: bool = False, + rope_partial_dims: int = 0, + use_xsa: bool = False, + ln_scale_factor: float = 1.0, + moe_num_experts: int = 0, + ): + super().__init__() + self.attn_norm = RMSNorm() + self.mlp_norm = RMSNorm() + self.attn = CausalSelfAttention( + dim, num_heads, num_kv_heads, rope_base, qk_gain_init, + rope_partial_dims=rope_partial_dims, + use_xsa=use_xsa, + ) + self.mlp = MoEMLP(dim, mlp_mult, moe_num_experts) if moe_num_experts > 1 else MLP(dim, mlp_mult) + self.attn_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + self.mlp_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + self.resid_mix = nn.Parameter(torch.stack((torch.ones(dim), torch.zeros(dim))).float()) + self.smear = SmearGate(dim) if use_smeargate else None + self.ln_scale_factor = ln_scale_factor + def forward(self, x: Tensor, x0: Tensor) -> Tensor: + mix = self.resid_mix.to(dtype=x.dtype) + x = mix[0][None, None, :] * x + mix[1][None, None, :] * x0 + normed = self.attn_norm(x) + if self.ln_scale_factor != 1.0: + normed = normed * self.ln_scale_factor + attn_out = self.attn(normed) + x = x + self.attn_scale.to(dtype=x.dtype)[None, None, :] * attn_out + mlp_normed = self.mlp_norm(x) + if self.ln_scale_factor != 1.0: + mlp_normed = mlp_normed * self.ln_scale_factor + x = x + self.mlp_scale.to(dtype=x.dtype)[None, None, :] * self.mlp(mlp_normed) + if self.smear is not None: + x = self.smear(x) + return x +class GPT(nn.Module): + def __init__( + self, + vocab_size: int, + num_layers: int, + model_dim: int, + num_heads: int, + num_kv_heads: int, + mlp_mult: int, + tie_embeddings: bool, + tied_embed_init_std: float, + logit_softcap: float, + rope_base: float, + qk_gain_init: float, + bigramhash_buckets: int = 0, + use_smeargate: bool = False, + unet_skips: bool = True, + rope_partial_dims: int = 0, + ln_scale: bool = False, + xsa_layers: int = 0, + recurrence_repeats: int = 0, + recurrence_unique_head: int = 2, + recurrence_unique_tail: int = 2, + moe_num_experts: int = 0, + ): + super().__init__() + if logit_softcap <= 0.0: + raise ValueError(f"logit_softcap must be positive, got {logit_softcap}") + self.tie_embeddings = tie_embeddings + self.tied_embed_init_std = tied_embed_init_std + self.logit_softcap = logit_softcap + self.tok_emb = nn.Embedding(vocab_size, model_dim) + self.bigram_hash = BigramHash(bigramhash_buckets, model_dim) if bigramhash_buckets > 0 else None + self.recurrence_repeats = recurrence_repeats + if recurrence_repeats > 0: + self.num_encoder_layers = 0 + self.num_decoder_layers = 0 + self.num_skip_weights = 0 + self.skip_weights = nn.Parameter(torch.zeros(0, dtype=torch.float32)) + n_head = recurrence_unique_head + n_tail = recurrence_unique_tail + effective_layers = n_head + recurrence_repeats + n_tail + def make_block(layer_idx: int, eff_total: int) -> Block: + xsa_start_idx = eff_total - xsa_layers if xsa_layers > 0 else eff_total + return Block( + model_dim, num_heads, num_kv_heads, mlp_mult, rope_base, qk_gain_init, + use_smeargate=use_smeargate, + rope_partial_dims=rope_partial_dims, + use_xsa=(layer_idx >= xsa_start_idx), + ln_scale_factor=(1.0 / math.sqrt(layer_idx + 1)) if ln_scale else 1.0, + moe_num_experts=moe_num_experts, + ) + self.head_blocks = nn.ModuleList([make_block(i, effective_layers) for i in range(n_head)]) + self.shared_block = make_block(n_head, effective_layers) + self.recurrence_gate = nn.Parameter(torch.full((recurrence_repeats,), 0.5, dtype=torch.float32)) + self.tail_blocks = nn.ModuleList([ + make_block(n_head + recurrence_repeats + i, effective_layers) for i in range(n_tail) + ]) + self.blocks = nn.ModuleList() + else: + if unet_skips: + self.num_encoder_layers = num_layers // 2 + self.num_decoder_layers = num_layers - self.num_encoder_layers + self.num_skip_weights = min(self.num_encoder_layers, self.num_decoder_layers) + else: + self.num_encoder_layers = num_layers + self.num_decoder_layers = 0 + self.num_skip_weights = 0 + self.skip_weights = nn.Parameter(torch.ones(self.num_skip_weights, model_dim, dtype=torch.float32)) + xsa_start = num_layers - xsa_layers if xsa_layers > 0 else num_layers + self.blocks = nn.ModuleList( + [ + Block( + model_dim, num_heads, num_kv_heads, mlp_mult, rope_base, qk_gain_init, + use_smeargate=use_smeargate, + rope_partial_dims=rope_partial_dims, + use_xsa=(i >= xsa_start), + ln_scale_factor=(1.0 / math.sqrt(i + 1)) if ln_scale else 1.0, + moe_num_experts=moe_num_experts, + ) + for i in range(num_layers) + ] + ) + self.head_blocks = nn.ModuleList() + self.shared_block = None + self.recurrence_gate = None + self.tail_blocks = nn.ModuleList() + self.final_norm = RMSNorm() + self.lm_head = None if tie_embeddings else CastedLinear(model_dim, vocab_size, bias=False) + if self.lm_head is not None: + self.lm_head._zero_init = True + self._init_weights() + def _init_weights(self) -> None: + if self.tie_embeddings: + nn.init.normal_(self.tok_emb.weight, mean=0.0, std=self.tied_embed_init_std) + for module in self.modules(): + if isinstance(module, nn.Linear) and getattr(module, "_zero_init", False): + nn.init.zeros_(module.weight) + def forward(self, input_ids: Tensor, target_ids: Tensor) -> Tensor: + x = self.tok_emb(input_ids) + if self.bigram_hash is not None: + x = x + self.bigram_hash(input_ids) + x = F.rms_norm(x, (x.size(-1),)) + x0 = x + if self.recurrence_repeats > 0: + for block in self.head_blocks: + x = block(x, x0) + for r in range(self.recurrence_repeats): + gate = torch.sigmoid(self.recurrence_gate[r]).to(dtype=x.dtype) + residual = x + x = self.shared_block(x, x0) + x = gate * x + (1.0 - gate) * residual + for block in self.tail_blocks: + x = block(x, x0) + else: + skips: list[Tensor] = [] + for i in range(self.num_encoder_layers): + x = self.blocks[i](x, x0) + skips.append(x) + for i in range(self.num_decoder_layers): + if skips: + x = x + self.skip_weights[i].to(dtype=x.dtype)[None, None, :] * skips.pop() + x = self.blocks[self.num_encoder_layers + i](x, x0) + x = self.final_norm(x).reshape(-1, x.size(-1)) + targets = target_ids.reshape(-1) + if self.tie_embeddings: + logits_proj = F.linear(x, self.tok_emb.weight) + else: + if self.lm_head is None: + raise RuntimeError("lm_head is required when tie_embeddings=False") + logits_proj = self.lm_head(x) + logits = self.logit_softcap * torch.tanh(logits_proj / self.logit_softcap) + if getattr(self, '_return_logits', False): + return logits.float().reshape(input_ids.shape[0], input_ids.shape[1], -1) + return F.cross_entropy(logits.float(), targets, reduction="mean") + def forward_logits(self, input_ids: Tensor) -> Tensor: + x = self.tok_emb(input_ids) + if self.bigram_hash is not None: + x = x + self.bigram_hash(input_ids) + x = F.rms_norm(x, (x.size(-1),)) + x0 = x + if self.recurrence_repeats > 0: + for block in self.head_blocks: + x = block(x, x0) + for r in range(self.recurrence_repeats): + gate = torch.sigmoid(self.recurrence_gate[r]).to(dtype=x.dtype) + residual = x + x = self.shared_block(x, x0) + x = gate * x + (1.0 - gate) * residual + for block in self.tail_blocks: + x = block(x, x0) + else: + skips: list[Tensor] = [] + for i in range(self.num_encoder_layers): + x = self.blocks[i](x, x0) + skips.append(x) + for i in range(self.num_decoder_layers): + if skips: + x = x + self.skip_weights[i].to(dtype=x.dtype)[None, None, :] * skips.pop() + x = self.blocks[self.num_encoder_layers + i](x, x0) + x = self.final_norm(x) + if self.tie_embeddings: + logits_proj = F.linear(x, self.tok_emb.weight) + else: + if self.lm_head is None: + raise RuntimeError("lm_head is required when tie_embeddings=False") + logits_proj = self.lm_head(x) + return self.logit_softcap * torch.tanh(logits_proj / self.logit_softcap) +def main() -> None: + global zeropower_via_newtonschulz5 + code = Path(__file__).read_text(encoding="utf-8") + args = Hyperparameters() + if not int(os.environ.get("TORCHDYNAMO_DISABLE", "0")): + zeropower_via_newtonschulz5 = torch.compile(zeropower_via_newtonschulz5) + distributed = "RANK" in os.environ and "WORLD_SIZE" in os.environ + rank = int(os.environ.get("RANK", "0")) + world_size = int(os.environ.get("WORLD_SIZE", "1")) + local_rank = int(os.environ.get("LOCAL_RANK", "0")) + if world_size <= 0: + raise ValueError(f"WORLD_SIZE must be positive, got {world_size}") + if 8 % world_size != 0: + raise ValueError(f"WORLD_SIZE={world_size} must divide 8 so grad_accum_steps stays integral") + grad_accum_steps = 8 // world_size + grad_scale = 1.0 / grad_accum_steps + if not torch.cuda.is_available(): + raise RuntimeError("CUDA is required") + device = torch.device("cuda", local_rank) + torch.cuda.set_device(device) + if distributed: + dist.init_process_group(backend="nccl", device_id=device) + dist.barrier() + master_process = rank == 0 + torch.backends.cuda.matmul.allow_tf32 = True + torch.backends.cudnn.allow_tf32 = True + from torch.backends.cuda import enable_cudnn_sdp, enable_flash_sdp, enable_math_sdp, enable_mem_efficient_sdp + _use_flash = sys.platform != "win32" + enable_cudnn_sdp(False) + enable_flash_sdp(_use_flash) + enable_mem_efficient_sdp(False) + enable_math_sdp(not _use_flash) + # Create per-run output directory so artifacts/logs never collide + run_dir = Path(os.environ.get("RUN_OUTPUT_DIR", f"runs/{args.run_id}")) + if master_process: + run_dir.mkdir(parents=True, exist_ok=True) + # Backward-compat: also keep logs/ symlink/copy + os.makedirs("logs", exist_ok=True) + logfile = None + if master_process: + logfile = str(run_dir / "log.txt") + # Also write to legacy logs/ path for compatibility + legacy_logfile = f"logs/{args.run_id}.txt" + if not Path(legacy_logfile).exists(): + try: + os.symlink(os.path.abspath(logfile), legacy_logfile) + except OSError: + pass # Windows or cross-device; will write to both + print(logfile) + def log0(msg: str, console: bool = True) -> None: + if not master_process: + return + if console: + print(msg) + if logfile is not None: + with open(logfile, "a", encoding="utf-8") as f: + print(msg, file=f) + log0(code, console=False) + log0("=" * 100, console=False) + log0(f"Running Python {sys.version}", console=False) + log0(f"Running PyTorch {torch.__version__}", console=False) + log0( + subprocess.run(["nvidia-smi"], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True, check=False).stdout, + console=False, + ) + log0("=" * 100, console=False) + random.seed(args.seed) + np.random.seed(args.seed) + torch.manual_seed(args.seed) + torch.cuda.manual_seed_all(args.seed) + if not args.tokenizer_path.endswith(".model"): + raise ValueError(f"Script only setup for SentencePiece .model file: {args.tokenizer_path}") + sp = spm.SentencePieceProcessor(model_file=args.tokenizer_path) + if int(sp.vocab_size()) != args.vocab_size: + raise ValueError( + f"VOCAB_SIZE={args.vocab_size} does not match tokenizer vocab_size={int(sp.vocab_size())}" + ) + dataset_dir = Path(args.data_path).resolve() + actual_train_files = len(list(dataset_dir.glob("fineweb_train_*.bin"))) + val_tokens = load_validation_tokens(args.val_files, args.train_seq_len) + base_bytes_lut, has_leading_space_lut, is_boundary_token_lut = build_sentencepiece_luts( + sp, args.vocab_size, device + ) + log0(f"val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path={args.tokenizer_path}") + log0(f"train_loader:dataset:{dataset_dir.name} train_shards:{actual_train_files}") + log0(f"val_loader:shards pattern={args.val_files} tokens:{val_tokens.numel() - 1}") + CastedLinear._int6_qat = args.int6_qat + base_model = GPT( + vocab_size=args.vocab_size, + num_layers=args.num_layers, + model_dim=args.model_dim, + num_heads=args.num_heads, + num_kv_heads=args.num_kv_heads, + mlp_mult=args.mlp_mult, + tie_embeddings=args.tie_embeddings, + tied_embed_init_std=args.tied_embed_init_std, + logit_softcap=args.logit_softcap, + rope_base=args.rope_base, + qk_gain_init=args.qk_gain_init, + bigramhash_buckets=args.bigramhash_buckets, + use_smeargate=args.smeargate, + unet_skips=args.unet_skips, + rope_partial_dims=args.rope_partial_dims, + ln_scale=args.ln_scale, + xsa_layers=args.xsa_layers, + recurrence_repeats=args.recurrence_repeats, + recurrence_unique_head=args.recurrence_unique_head, + recurrence_unique_tail=args.recurrence_unique_tail, + moe_num_experts=args.moe_num_experts, + ).to(device).bfloat16() + for module in base_model.modules(): + if isinstance(module, CastedLinear): + module.float() + restore_low_dim_params_to_fp32(base_model) + if int(os.environ.get("TORCHDYNAMO_DISABLE", "0")): + compiled_model = base_model + else: + compiled_model = torch.compile(base_model, dynamic=False, fullgraph=True) + model: nn.Module = DDP(compiled_model, device_ids=[local_rank], broadcast_buffers=False) if distributed else compiled_model + if args.recurrence_repeats > 0: + block_named_params = ( + list(base_model.head_blocks.named_parameters()) + + list(base_model.shared_block.named_parameters()) + + list(base_model.tail_blocks.named_parameters()) + ) + if base_model.recurrence_gate is not None: + block_named_params.append(("recurrence_gate", base_model.recurrence_gate)) + else: + block_named_params = list(base_model.blocks.named_parameters()) + matrix_params = [ + p + for name, p in block_named_params + if p.ndim == 2 and not any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) + ] + scalar_params = [ + p + for name, p in block_named_params + if p.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) + ] + if base_model.skip_weights.numel() > 0: + scalar_params.append(base_model.skip_weights) + token_lr = args.tied_embed_lr if args.tie_embeddings else args.embed_lr + embed_params: list[nn.Parameter] = [base_model.tok_emb.weight] + if base_model.bigram_hash is not None: + embed_params.append(base_model.bigram_hash.emb.weight) + optimizer_tok = torch.optim.Adam( + [{"params": embed_params, "lr": token_lr, "base_lr": token_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + fused=True, + ) + optimizer_muon = Muon( + matrix_params, + lr=args.matrix_lr, + momentum=args.muon_momentum, + backend_steps=args.muon_backend_steps, + ) + for group in optimizer_muon.param_groups: + group["base_lr"] = args.matrix_lr + group["polar_express"] = args.polar_express + optimizer_scalar = torch.optim.Adam( + [{"params": scalar_params, "lr": args.scalar_lr, "base_lr": args.scalar_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + fused=True, + ) + optimizers: list[torch.optim.Optimizer] = [optimizer_tok, optimizer_muon, optimizer_scalar] + if base_model.lm_head is not None: + optimizer_head = torch.optim.Adam( + [{"params": [base_model.lm_head.weight], "lr": args.head_lr, "base_lr": args.head_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + fused=True, + ) + optimizers.insert(1, optimizer_head) + n_params = sum(p.numel() for p in base_model.parameters()) + log0(f"model_params:{n_params}") + log0(f"world_size:{world_size} grad_accum_steps:{grad_accum_steps}") + log0(f"sdp_backends:cudnn=False flash={_use_flash} mem_efficient=False math={not _use_flash}") + log0(f"attention_mode:gqa num_heads:{args.num_heads} num_kv_heads:{args.num_kv_heads}") + log0( + f"tie_embeddings:{args.tie_embeddings} embed_lr:{token_lr} " + f"head_lr:{args.head_lr if base_model.lm_head is not None else 0.0} " + f"matrix_lr:{args.matrix_lr} scalar_lr:{args.scalar_lr}" + ) + log0( + f"train_batch_tokens:{args.train_batch_tokens} train_seq_len:{args.train_seq_len} " + f"iterations:{args.iterations} warmup_steps:{args.warmup_steps} " + f"max_wallclock_seconds:{args.max_wallclock_seconds:.3f}" + ) + log0(f"seed:{args.seed}") + log0( + f"meta_baseline: int6_qat:{args.int6_qat} bigramhash:{args.bigramhash_buckets} " + f"smeargate:{args.smeargate} ema_decay:{args.ema_decay} unet_skips:{args.unet_skips} " + f"warmdown_iters:{args.warmdown_iters} compression:{'zstd-22' if _HAS_ZSTD else 'zlib-9'}" + ) + log0( + f"strong_gains: leaky_relu_sq:True rope_partial_dims:{args.rope_partial_dims} " + f"ln_scale:{args.ln_scale} xsa_layers:{args.xsa_layers}" + ) + if args.recurrence_repeats > 0: + eff = args.recurrence_unique_head + args.recurrence_repeats + args.recurrence_unique_tail + log0( + f"depth_recurrence: head={args.recurrence_unique_head} repeats={args.recurrence_repeats} " + f"tail={args.recurrence_unique_tail} effective_layers={eff}" + ) + train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) + def zero_grad_all() -> None: + for opt in optimizers: + opt.zero_grad(set_to_none=True) + max_wallclock_ms = 1000.0 * args.max_wallclock_seconds if args.max_wallclock_seconds > 0 else None + def lr_mul(step: int, elapsed_ms: float) -> float: + if args.warmdown_iters <= 0: + return 1.0 + if max_wallclock_ms is None: + warmdown_start = max(args.iterations - args.warmdown_iters, 0) + return max((args.iterations - step) / max(args.warmdown_iters, 1), 0.0) if warmdown_start <= step < args.iterations else 1.0 + step_ms = elapsed_ms / max(step, 1) + warmdown_ms = args.warmdown_iters * step_ms + remaining_ms = max(max_wallclock_ms - elapsed_ms, 0.0) + return remaining_ms / max(warmdown_ms, 1e-9) if remaining_ms <= warmdown_ms else 1.0 + if args.warmup_steps > 0: + initial_model_state = {name: tensor.detach().cpu().clone() for name, tensor in base_model.state_dict().items()} + initial_optimizer_states = [copy.deepcopy(opt.state_dict()) for opt in optimizers] + model.train() + for warmup_step in range(args.warmup_steps): + zero_grad_all() + for micro_step in range(grad_accum_steps): + if distributed: + model.require_backward_grad_sync = micro_step == grad_accum_steps - 1 + x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + warmup_loss = model(x, y) + (warmup_loss * grad_scale).backward() + for opt in optimizers: + opt.step() + zero_grad_all() + if args.warmup_steps <= 20 or (warmup_step + 1) % 10 == 0 or warmup_step + 1 == args.warmup_steps: + log0(f"warmup_step:{warmup_step + 1}/{args.warmup_steps}") + base_model.load_state_dict(initial_model_state, strict=True) + for opt, state in zip(optimizers, initial_optimizer_states, strict=True): + opt.load_state_dict(state) + zero_grad_all() + if distributed: + model.require_backward_grad_sync = True + train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) + ema_state: dict[str, Tensor] | None = None + if args.ema_decay > 0: + ema_state = {name: param.data.clone() for name, param in base_model.named_parameters()} + log0(f"ema:enabled decay={args.ema_decay}") + training_time_ms = 0.0 + stop_after_step: int | None = None + torch.cuda.synchronize() + t0 = time.perf_counter() + step = 0 + while True: + last_step = step == args.iterations or (stop_after_step is not None and step >= stop_after_step) + should_validate = last_step or (args.val_loss_every > 0 and step % args.val_loss_every == 0) + if should_validate: + torch.cuda.synchronize() + training_time_ms += 1000.0 * (time.perf_counter() - t0) + ema_backup: dict[str, Tensor] | None = None + if ema_state is not None: + ema_backup = {} + with torch.no_grad(): + for name, p in base_model.named_parameters(): + ema_backup[name] = p.data.clone() + p.data.copy_(ema_state[name]) + val_loss, val_bpb = eval_val( + args, + model, + rank, + world_size, + device, + grad_accum_steps, + val_tokens, + base_bytes_lut, + has_leading_space_lut, + is_boundary_token_lut, + ) + if ema_backup is not None: + with torch.no_grad(): + for name, p in base_model.named_parameters(): + p.data.copy_(ema_backup[name]) + del ema_backup + log0( + f"step:{step}/{args.iterations} val_loss:{val_loss:.4f} val_bpb:{val_bpb:.4f} " + f"train_time:{training_time_ms:.0f}ms step_avg:{training_time_ms / max(step, 1):.2f}ms" + ) + torch.cuda.synchronize() + t0 = time.perf_counter() + if last_step: + if stop_after_step is not None and step < args.iterations: + log0( + f"stopping_early: wallclock_cap train_time:{training_time_ms:.0f}ms " + f"step:{step}/{args.iterations}" + ) + break + elapsed_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) + scale = lr_mul(step, elapsed_ms) + zero_grad_all() + train_loss = torch.zeros((), device=device) + for micro_step in range(grad_accum_steps): + if distributed: + model.require_backward_grad_sync = micro_step == grad_accum_steps - 1 + x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + loss = model(x, y) + train_loss += loss.detach() + (loss * grad_scale).backward() + train_loss /= grad_accum_steps + frac = min(step / args.muon_momentum_warmup_steps, 1.0) if args.muon_momentum_warmup_steps > 0 else 1.0 + muon_momentum = (1 - frac) * args.muon_momentum_warmup_start + frac * args.muon_momentum + for group in optimizer_muon.param_groups: + group["momentum"] = muon_momentum + for opt in optimizers: + for group in opt.param_groups: + group["lr"] = group["base_lr"] * scale + if args.grad_clip_norm > 0: + torch.nn.utils.clip_grad_norm_(base_model.parameters(), args.grad_clip_norm) + for opt in optimizers: + opt.step() + zero_grad_all() + if ema_state is not None: + with torch.no_grad(): + decay = args.ema_decay + for name, p in base_model.named_parameters(): + ema_state[name].mul_(decay).add_(p.data, alpha=1.0 - decay) + step += 1 + if args.checkpoint_every > 0 and step % args.checkpoint_every == 0 and master_process: + ckpt_dir = run_dir / "checkpoints" + ckpt_dir.mkdir(exist_ok=True) + ckpt_path = ckpt_dir / f"ckpt_step{step}.pt" + ckpt_data = {"step": step, "model": base_model.state_dict()} + if ema_state is not None: + ckpt_data["ema"] = {k: v.cpu() for k, v in ema_state.items()} + torch.save(ckpt_data, ckpt_path) + ckpts = sorted(ckpt_dir.glob("ckpt_step*.pt"), key=lambda p: p.stat().st_mtime) + for old in ckpts[:-2]: + old.unlink(missing_ok=True) + log0(f"checkpoint:saved {ckpt_path} (kept {min(len(ckpts), 2)} latest)") + approx_training_time_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) + should_log_train = ( + args.train_log_every > 0 + and (step <= 10 or step % args.train_log_every == 0 or stop_after_step is not None) + ) + if should_log_train: + log0( + f"step:{step}/{args.iterations} train_loss:{train_loss.item():.4f} " + f"train_time:{approx_training_time_ms:.0f}ms step_avg:{approx_training_time_ms / step:.2f}ms" + ) + reached_cap = max_wallclock_ms is not None and approx_training_time_ms >= max_wallclock_ms + if distributed and max_wallclock_ms is not None: + reached_cap_tensor = torch.tensor(int(reached_cap), device=device) + dist.all_reduce(reached_cap_tensor, op=dist.ReduceOp.MAX) + reached_cap = bool(reached_cap_tensor.item()) + if stop_after_step is None and reached_cap: + stop_after_step = step + log0( + f"peak memory allocated: {torch.cuda.max_memory_allocated() // 1024 // 1024} MiB " + f"reserved: {torch.cuda.max_memory_reserved() // 1024 // 1024} MiB" + ) + if ema_state is not None: + with torch.no_grad(): + for name, p in base_model.named_parameters(): + p.data.copy_(ema_state[name]) + log0("ema:applied EMA weights for export") + fp32_path = str(run_dir / "final_model.pt") + if master_process: + torch.save(base_model.state_dict(), fp32_path) + model_bytes = os.path.getsize(fp32_path) + code_bytes = len(code.encode("utf-8")) + log0(f"Serialized model: {model_bytes} bytes → {fp32_path}") + log0(f"Code size: {code_bytes} bytes") + log0(f"Total submission size: {model_bytes + code_bytes} bytes") + quant_label = f"int{_QUANT_BITS}" + compress_label = "zstd" if _HAS_ZSTD else "zlib" + log0(f"quantization:start {quant_label} distributing across {world_size} GPUs") + torch.cuda.synchronize() + t_quant = time.perf_counter() + quant_obj, quant_stats = quantize_state_dict_int8( + base_model.state_dict(), gpu_device=device, rank=rank, world_size=world_size, + ) + torch.cuda.synchronize() + log0(f"quantization:done {quant_label} time={1000.0 * (time.perf_counter() - t_quant):.0f}ms") + quant_buf = io.BytesIO() + torch.save(quant_obj, quant_buf) + quant_raw = quant_buf.getvalue() + if _HAS_ZSTD: + compressor = zstd.ZstdCompressor(level=22, threads=-1) # -1 = use all CPU cores + quant_blob = compressor.compress(quant_raw) + else: + quant_blob = zlib.compress(quant_raw, level=9) + quant_raw_bytes = len(quant_raw) + artifact_name = str(run_dir / f"final_model.{quant_label}.ptz") + if master_process: + with open(artifact_name, "wb") as f: + f.write(quant_blob) + quant_file_bytes = os.path.getsize(artifact_name) + code_bytes = len(code.encode("utf-8")) + ratio = quant_stats["baseline_tensor_bytes"] / max(quant_stats["int8_payload_bytes"], 1) + log0( + f"Serialized model {quant_label}+{compress_label}: {quant_file_bytes} bytes " + f"(payload:{quant_stats['int8_payload_bytes']} raw_torch:{quant_raw_bytes} payload_ratio:{ratio:.2f}x)" + ) + log0(f"Total submission size {quant_label}+{compress_label}: {quant_file_bytes + code_bytes} bytes") + log0(f"artifact_path:{artifact_name}") + if distributed: + dist.barrier() + with open(artifact_name, "rb") as f: + quant_blob_disk = f.read() + if _HAS_ZSTD: + decompressor = zstd.ZstdDecompressor() + quant_state = torch.load(io.BytesIO(decompressor.decompress(quant_blob_disk)), map_location="cpu") + else: + quant_state = torch.load(io.BytesIO(zlib.decompress(quant_blob_disk)), map_location="cpu") + base_model.load_state_dict(dequantize_state_dict_int8(quant_state), strict=True) + torch.cuda.synchronize() + t_qeval = time.perf_counter() + q_val_loss, q_val_bpb = eval_val( + args, + model, + rank, + world_size, + device, + grad_accum_steps, + val_tokens, + base_bytes_lut, + has_leading_space_lut, + is_boundary_token_lut, + ) + torch.cuda.synchronize() + log0( + f"final_{quant_label}_{compress_label}_roundtrip val_loss:{q_val_loss:.4f} val_bpb:{q_val_bpb:.4f} " + f"eval_time:{1000.0 * (time.perf_counter() - t_qeval):.0f}ms" + ) + log0(f"final_{quant_label}_{compress_label}_roundtrip_exact val_loss:{q_val_loss:.8f} val_bpb:{q_val_bpb:.8f}") + if args.eval_stride > 0 and args.eval_stride < args.train_seq_len: + torch.cuda.synchronize() + t_slide = time.perf_counter() + sw_val_loss, sw_val_bpb = eval_val_sliding( + args, base_model, rank, world_size, device, + val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, + stride=args.eval_stride, + ) + torch.cuda.synchronize() + log0( + f"final_sliding_window val_loss:{sw_val_loss:.4f} val_bpb:{sw_val_bpb:.4f} " + f"stride:{args.eval_stride} eval_time:{1000.0 * (time.perf_counter() - t_slide):.0f}ms" + ) + log0(f"final_sliding_window_exact val_loss:{sw_val_loss:.8f} val_bpb:{sw_val_bpb:.8f}") + if args.ttt_enabled: + torch.cuda.synchronize() + t_ttt = time.perf_counter() + + # Choose LoRA TTT or full TTT based on TTT_LORA_ENABLED flag + if args.ttt_lora_enabled: + ttt_loss, ttt_bpb = eval_val_sliding_lora_ttt( + args, base_model, rank, world_size, device, + val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, + stride=args.eval_stride if args.eval_stride > 0 else 64, + log0=log0, + ) + ttt_label = "legal_lora_ttt" + else: + ttt_loss, ttt_bpb = eval_val_sliding_ttt( + args, base_model, rank, world_size, device, + val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, + stride=args.eval_stride if args.eval_stride > 0 else 64, + log0=log0, + ) + ttt_label = "legal_ttt" + + torch.cuda.synchronize() + log0( + f"{ttt_label} val_loss:{ttt_loss:.4f} val_bpb:{ttt_bpb:.4f} " + f"eval_time:{1000.0 * (time.perf_counter() - t_ttt):.0f}ms" + ) + log0(f"{ttt_label}_exact val_loss:{ttt_loss:.8f} val_bpb:{ttt_bpb:.8f}") + if int(os.environ.get("NGRAM_EVAL", "0")): + ngram_max_order = int(os.environ.get("NGRAM_MAX_ORDER", "9")) + log0(f"ngram_eval:starting max_order={ngram_max_order}") + torch.cuda.synchronize() + t_ngram = time.perf_counter() + try: + from ngram_eval import eval_val_ngram + ng_val_loss, ng_val_bpb = eval_val_ngram( + args, base_model, model, rank, world_size, device, grad_accum_steps, + val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, + max_order=ngram_max_order, log_fn=log0, + ) + torch.cuda.synchronize() + log0( + f"ngram_eval_result val_loss:{ng_val_loss:.4f} val_bpb:{ng_val_bpb:.4f} " + f"eval_time:{1000.0 * (time.perf_counter() - t_ngram):.0f}ms" + ) + log0(f"ngram_eval_result_exact val_loss:{ng_val_loss:.8f} val_bpb:{ng_val_bpb:.8f}") + except Exception as e: + log0(f"ngram_eval:FAILED {e}") + if distributed: + dist.destroy_process_group() +if __name__ == "__main__": + main() diff --git a/train_gpt_phase4_quickwins.py b/train_gpt_phase4_quickwins.py new file mode 100644 index 0000000000..2bf0c9a31d --- /dev/null +++ b/train_gpt_phase4_quickwins.py @@ -0,0 +1,1695 @@ +""" +PHASE 4: Quick Wins from Top Leaderboard (2026-04-01) +- XSA_LAYERS: 11 → 4 (like #3/#4 leaderboard entries) +- WARMDOWN_ITERS: 4000 → 3500 (like #2 leaderboard - GPTQ-lite winner) +- BIGRAMHASH_BUCKETS: 4096 → 10240 (like #6 leaderboard - Int5-MLP) +- MUON_MOMENTUM: 0.95 → 0.99 (higher momentum variant) +- RESID_LAMBDAS: OFF (ablation proved no benefit) + +Expected improvement: 0.003-0.008 BPB +""" +from __future__ import annotations +import copy +import glob +import io +import lzma +import math +import os +import random +import subprocess +import sys +import time +import uuid +from pathlib import Path + +import numpy as np +import sentencepiece as spm +import torch +import torch.distributed as dist +import torch.nn.functional as F +from torch import Tensor, nn +from torch.nn.parallel import DistributedDataParallel as DDP +class Hyperparameters: + data_path = os.environ.get("DATA_PATH", "./data/datasets/fineweb10B_sp1024") + train_files = os.path.join(data_path, "fineweb_train_*.bin") + val_files = os.path.join(data_path, "fineweb_val_*.bin") + tokenizer_path = os.environ.get("TOKENIZER_PATH", "./data/tokenizers/fineweb_1024_bpe.model") + run_id = os.environ.get("RUN_ID", str(uuid.uuid4())) + seed = int(os.environ.get("SEED", 1337)) + val_batch_size = int(os.environ.get("VAL_BATCH_SIZE", 524_288)) + val_loss_every = int(os.environ.get("VAL_LOSS_EVERY", 1000)) + train_log_every = int(os.environ.get("TRAIN_LOG_EVERY", 200)) + val_max_tokens = int(os.environ.get("VAL_MAX_TOKENS", 0)) + checkpoint_every = int(os.environ.get("CHECKPOINT_EVERY", 0)) + iterations = int(os.environ.get("ITERATIONS", 20000)) + warmdown_iters = int(os.environ.get("WARMDOWN_ITERS", 3500)) # PHASE4: 4000→3500 (like #2 leaderboard) + warmup_steps = int(os.environ.get("WARMUP_STEPS", 40)) # BESTV3: 20 → 40 (longer warmup for wider model) + train_batch_tokens = int(os.environ.get("TRAIN_BATCH_TOKENS", 524_288)) + train_seq_len = int(os.environ.get("TRAIN_SEQ_LEN", 2048)) # LONGCTX: default 2048 (was 1024) + eval_seq_len = int(os.environ.get("EVAL_SEQ_LEN", 0)) # LONGCTX: 0 = same as train_seq_len + max_wallclock_seconds = float(os.environ.get("MAX_WALLCLOCK_SECONDS", 600.0)) + max_eval_seconds = float(os.environ.get("MAX_EVAL_SECONDS", 580.0)) + qk_gain_init = float(os.environ.get("QK_GAIN_INIT", 4.0)) + vocab_size = int(os.environ.get("VOCAB_SIZE", 1024)) + num_layers = int(os.environ.get("NUM_LAYERS", 11)) + num_kv_heads = int(os.environ.get("NUM_KV_HEADS", 4)) + model_dim = int(os.environ.get("MODEL_DIM", 576)) # BESTV3: 512 → 576 (10% wider, still fits budget) + num_heads = int(os.environ.get("NUM_HEADS", 8)) + mlp_mult = int(os.environ.get("MLP_MULT", 2)) + tie_embeddings = bool(int(os.environ.get("TIE_EMBEDDINGS", "1"))) + rope_base = float(os.environ.get("ROPE_BASE", 10000.0)) + logit_softcap = float(os.environ.get("LOGIT_SOFTCAP", 30.0)) + embed_lr = float(os.environ.get("EMBED_LR", 0.6)) + head_lr = float(os.environ.get("HEAD_LR", 0.008)) + tied_embed_lr = float(os.environ.get("TIED_EMBED_LR", 0.05)) + tied_embed_init_std = float(os.environ.get("TIED_EMBED_INIT_STD", 0.005)) + matrix_lr = float(os.environ.get("MATRIX_LR", 0.04)) + scalar_lr = float(os.environ.get("SCALAR_LR", 0.04)) + muon_momentum = float(os.environ.get("MUON_MOMENTUM", 0.99)) # PHASE4: 0.95→0.99 (higher momentum) + muon_backend_steps = int(os.environ.get("MUON_BACKEND_STEPS", 4)) # QUICKWIN: 5->4, saves 1-2ms/step + muon_momentum_warmup_start = float(os.environ.get("MUON_MOMENTUM_WARMUP_START", 0.85)) + muon_momentum_warmup_steps = int(os.environ.get("MUON_MOMENTUM_WARMUP_STEPS", 500)) + beta1 = float(os.environ.get("BETA1", 0.9)) + beta2 = float(os.environ.get("BETA2", 0.95)) + adam_eps = float(os.environ.get("ADAM_EPS", 1e-8)) + grad_clip_norm = float(os.environ.get("GRAD_CLIP_NORM", 0.0)) + bigramhash_buckets = int(os.environ.get("BIGRAMHASH_BUCKETS", 10240)) # PHASE4: 4096→10240 (like #6 leaderboard) + ema_decay = float(os.environ.get("EMA_DECAY", 0.997)) + smeargate = bool(int(os.environ.get("SMEARGATE", "1"))) + unet_skips = bool(int(os.environ.get("UNET_SKIPS", "1"))) + int6_qat = bool(int(os.environ.get("INT6_QAT", "1"))) + rope_partial_dims = int(os.environ.get("ROPE_PARTIAL_DIMS", 16)) + ln_scale = bool(int(os.environ.get("LN_SCALE", "1"))) + xsa_layers = int(os.environ.get("XSA_LAYERS", 4)) # PHASE4: 11→4 (like #3/#4 leaderboard) + eval_stride = int(os.environ.get("EVAL_STRIDE", 64)) # LONGCTX: default 64 (was 0) + cosine_warmdown = bool(int(os.environ.get("COSINE_WARMDOWN", "1"))) # QUICKWIN: cosine > linear warmdown + ttt_enabled = bool(int(os.environ.get("TTT_ENABLED", "1"))) + ttt_lr = float(os.environ.get("TTT_LR", 0.002)) + ttt_epochs = int(os.environ.get("TTT_EPOCHS", 2)) # TIME BUDGET: 3->2 to leave room for SLOT with 2048 ctx + ttt_chunk_tokens = int(os.environ.get("TTT_CHUNK_TOKENS", 32768)) + ttt_freeze_blocks = int(os.environ.get("TTT_FREEZE_BLOCKS", 0)) + ttt_momentum = float(os.environ.get("TTT_MOMENTUM", 0.9)) + ttt_batch_seqs = int(os.environ.get("TTT_BATCH_SEQS", 32)) + ttt_grad_clip = float(os.environ.get("TTT_GRAD_CLIP", 1.0)) + split_lr = bool(int(os.environ.get("SPLIT_LR", "1"))) + late_lr_scale = float(os.environ.get("LATE_LR_SCALE", 1.2)) + slot_enabled = bool(int(os.environ.get("SLOT_ENABLED", "1"))) + slot_lr = float(os.environ.get("SLOT_LR", 0.003)) + slot_steps = int(os.environ.get("SLOT_STEPS", 5)) + # LONGCTX: YaRN positional encoding params + yarn_enabled = bool(int(os.environ.get("YARN_ENABLED", "1"))) + yarn_base_len = int(os.environ.get("YARN_BASE_LEN", 1024)) + yarn_max_len = int(os.environ.get("YARN_MAX_LEN", 2048)) + yarn_beta_fast = float(os.environ.get("YARN_BETA_FAST", 32.0)) + yarn_beta_slow = float(os.environ.get("YARN_BETA_SLOW", 1.0)) + # LONGCTX: coprime-stride data loader + coprime_stride = bool(int(os.environ.get("COPRIME_STRIDE", "1"))) + # RESIDLAMBDA: per-sublayer learned residual scaling + resid_lambdas = bool(int(os.environ.get("RESID_LAMBDAS", "0"))) # PHASE4: OFF (ablation proved no benefit) + resid_lambda_init = float(os.environ.get("RESID_LAMBDA_INIT", 1.0488)) # sqrt(1.1) + resid_lambda_lr_mult = float(os.environ.get("RESID_LAMBDA_LR_MULT", 5.0)) + +def zeropower_via_newtonschulz5(G: Tensor, steps: int = 10, eps: float = 1e-7) -> Tensor: + a, b, c = (3.4445, -4.7750, 2.0315) + X = G.bfloat16() + X.div_(X.norm() + eps) + transposed = G.size(0) > G.size(1) + if transposed: + X = X.T + for _ in range(steps): + A = X @ X.T + B = b * A + c * A @ A + X = a * X + B @ X + return X.T if transposed else X +class Muon(torch.optim.Optimizer): + def __init__(self, params, lr: float, momentum: float, backend_steps: int, nesterov: bool = True): + super().__init__( + params, + dict(lr=lr, momentum=momentum, backend_steps=backend_steps, nesterov=nesterov), + ) + @torch.no_grad() + def step(self, closure=None): + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + distributed = dist.is_available() and dist.is_initialized() + world_size = dist.get_world_size() if distributed else 1 + rank = dist.get_rank() if distributed else 0 + for group in self.param_groups: + params = group["params"] + if not params: + continue + lr = group["lr"] + momentum = group["momentum"] + backend_steps = group["backend_steps"] + nesterov = group["nesterov"] + total_params = sum(int(p.numel()) for p in params) + updates_flat = torch.zeros(total_params, device=params[0].device, dtype=torch.bfloat16) + curr = 0 + for i, p in enumerate(params): + if i % world_size == rank and p.grad is not None: + g = p.grad + state = self.state[p] + if "momentum_buffer" not in state: + state["momentum_buffer"] = torch.zeros_like(g) + buf = state["momentum_buffer"] + buf.mul_(momentum).add_(g) + if nesterov: + g = g.add(buf, alpha=momentum) + g = zeropower_via_newtonschulz5(g, steps=backend_steps) + scale_factor = max(1.0, float(g.size(0)) / float(g.size(1))) ** 0.5 + g.mul_(scale_factor) + updates_flat[curr : curr + p.numel()] = g.view(-1) + curr += p.numel() + if distributed: + dist.all_reduce(updates_flat, op=dist.ReduceOp.SUM) + curr = 0 + for p in params: + g = updates_flat[curr : curr + p.numel()].view_as(p).to(dtype=p.dtype) + p.add_(g, alpha=-lr) + curr += p.numel() + return loss +def int6_ste(w: Tensor) -> Tensor: + if w.ndim == 2: + abs_max = w.abs().amax(dim=1, keepdim=True).clamp_min(1e-8) + else: + abs_max = w.abs().amax().clamp_min(1e-8) + scale = abs_max / 31.0 + w_q = (w / scale).round().clamp(-31, 31) * scale + return w + (w_q - w).detach() +def build_sentencepiece_luts( + sp: spm.SentencePieceProcessor, vocab_size: int, device: torch.device +) -> tuple[Tensor, Tensor, Tensor]: + sp_vocab_size = int(sp.vocab_size()) + table_size = max(sp_vocab_size, vocab_size) + base_bytes_np = np.zeros((table_size,), dtype=np.int16) + has_leading_space_np = np.zeros((table_size,), dtype=np.bool_) + is_boundary_token_np = np.ones((table_size,), dtype=np.bool_) + for token_id in range(sp_vocab_size): + if sp.is_control(token_id) or sp.is_unknown(token_id) or sp.is_unused(token_id): + continue + is_boundary_token_np[token_id] = False + if sp.is_byte(token_id): + base_bytes_np[token_id] = 1 + continue + piece = sp.id_to_piece(token_id) + if piece.startswith("\u2581"): + has_leading_space_np[token_id] = True + piece = piece[1:] + base_bytes_np[token_id] = len(piece.encode("utf-8")) + return ( + torch.tensor(base_bytes_np, dtype=torch.int16, device=device), + torch.tensor(has_leading_space_np, dtype=torch.bool, device=device), + torch.tensor(is_boundary_token_np, dtype=torch.bool, device=device), + ) +def load_validation_tokens(pattern: str, seq_len: int) -> Tensor: + files = [Path(p) for p in sorted(glob.glob(pattern))] + if not files: + raise FileNotFoundError(f"No files found for pattern: {pattern}") + tokens = torch.cat([load_data_shard(file) for file in files]).contiguous() + max_tokens = int(os.environ.get("VAL_MAX_TOKENS", 0)) + if max_tokens > 0 and tokens.numel() > max_tokens: + tokens = tokens[:max_tokens].contiguous() + usable = ((tokens.numel() - 1) // seq_len) * seq_len + if usable <= 0: + raise ValueError(f"Validation split is too short for TRAIN_SEQ_LEN={seq_len}") + return tokens[: usable + 1] +def eval_val( + args: Hyperparameters, + model: nn.Module, + rank: int, + world_size: int, + device: torch.device, + grad_accum_steps: int, + val_tokens: Tensor, + base_bytes_lut: Tensor, + has_leading_space_lut: Tensor, + is_boundary_token_lut: Tensor, +) -> tuple[float, float]: + local_batch_tokens = args.val_batch_size // (world_size * grad_accum_steps) + if local_batch_tokens < args.train_seq_len: + raise ValueError( + "VAL_BATCH_SIZE must provide at least one sequence per rank; " + f"got VAL_BATCH_SIZE={args.val_batch_size}, WORLD_SIZE={world_size}, " + f"GRAD_ACCUM_STEPS={grad_accum_steps}, TRAIN_SEQ_LEN={args.train_seq_len}" + ) + local_batch_seqs = local_batch_tokens // args.train_seq_len + total_seqs = (val_tokens.numel() - 1) // args.train_seq_len + seq_start = (total_seqs * rank) // world_size + seq_end = (total_seqs * (rank + 1)) // world_size + val_loss_sum = torch.zeros((), device=device, dtype=torch.float64) + val_token_count = torch.zeros((), device=device, dtype=torch.float64) + val_byte_count = torch.zeros((), device=device, dtype=torch.float64) + model.eval() + with torch.inference_mode(): + for batch_seq_start in range(seq_start, seq_end, local_batch_seqs): + batch_seq_end = min(batch_seq_start + local_batch_seqs, seq_end) + raw_start = batch_seq_start * args.train_seq_len + raw_end = batch_seq_end * args.train_seq_len + 1 + local = val_tokens[raw_start:raw_end].to(device=device, dtype=torch.int64, non_blocking=True) + x = local[:-1].view(-1, args.train_seq_len) + y = local[1:].view(-1, args.train_seq_len) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + batch_loss = model(x, y).detach() + batch_token_count = float(y.numel()) + val_loss_sum += batch_loss.to(torch.float64) * batch_token_count + val_token_count += batch_token_count + prev_ids = x.view(-1) + tgt_ids = y.view(-1) + token_bytes = base_bytes_lut[tgt_ids].to(dtype=torch.int16) + token_bytes += (has_leading_space_lut[tgt_ids] & ~is_boundary_token_lut[prev_ids]).to(dtype=torch.int16) + val_byte_count += token_bytes.to(torch.float64).sum() + if dist.is_available() and dist.is_initialized(): + dist.all_reduce(val_loss_sum, op=dist.ReduceOp.SUM) + dist.all_reduce(val_token_count, op=dist.ReduceOp.SUM) + dist.all_reduce(val_byte_count, op=dist.ReduceOp.SUM) + val_loss = val_loss_sum / val_token_count + bits_per_token = val_loss.item() / math.log(2.0) + tokens_per_byte = val_token_count.item() / val_byte_count.item() + model.train() + return float(val_loss.item()), float(bits_per_token * tokens_per_byte) +def eval_val_sliding( + args: Hyperparameters, + base_model: nn.Module, + rank: int, + world_size: int, + device: torch.device, + val_tokens: Tensor, + base_bytes_lut: Tensor, + has_leading_space_lut: Tensor, + is_boundary_token_lut: Tensor, + stride: int, + batch_seqs: int = 32, +) -> tuple[float, float]: + # LONGCTX: use eval_seq_len if set, otherwise fall back to train_seq_len + seq_len = args.eval_seq_len if args.eval_seq_len > 0 else args.train_seq_len + total_tokens = val_tokens.numel() - 1 + window_starts = [ws for ws in range(0, total_tokens, stride) + if min(ws + seq_len, total_tokens) - ws >= 1] + total_windows = len(window_starts) + my_s = (total_windows * rank) // world_size + my_e = (total_windows * (rank + 1)) // world_size + my_windows = window_starts[my_s:my_e] + loss_sum = torch.zeros((), device=device, dtype=torch.float64) + token_count = torch.zeros((), device=device, dtype=torch.float64) + byte_count = torch.zeros((), device=device, dtype=torch.float64) + base_model.eval() + compiled_logits = torch.compile(base_model.forward_logits, dynamic=False, fullgraph=True) + with torch.inference_mode(): + for bi in range(0, len(my_windows), batch_seqs): + batch_ws = my_windows[bi:bi + batch_seqs] + bsz = len(batch_ws) + x_batch = torch.zeros(bsz, seq_len, dtype=torch.int64, device=device) + y_batch = torch.zeros(bsz, seq_len, dtype=torch.int64, device=device) + wlens: list[int] = [] + for i, ws in enumerate(batch_ws): + end = min(ws + seq_len, total_tokens) + wlen = end - ws + wlens.append(wlen) + chunk = val_tokens[ws:end + 1] + x_batch[i, :wlen] = chunk[:-1] + y_batch[i, :wlen] = chunk[1:] + with torch.autocast(device_type="cuda", dtype=torch.bfloat16): + logits = compiled_logits(x_batch) + nll = F.cross_entropy( + logits.view(-1, logits.size(-1)).float(), + y_batch.view(-1), + reduction="none", + ).view(bsz, seq_len) + for i, ws in enumerate(batch_ws): + wlen = wlens[i] + s = 0 if ws == 0 else max(wlen - stride, 0) + scored_nll = nll[i, s:wlen].to(torch.float64) + loss_sum += scored_nll.sum() + token_count += float(wlen - s) + tgt = y_batch[i, s:wlen] + prev = x_batch[i, s:wlen] + tb = base_bytes_lut[tgt].to(torch.float64) + tb += (has_leading_space_lut[tgt] & ~is_boundary_token_lut[prev]).to(torch.float64) + byte_count += tb.sum() + if dist.is_available() and dist.is_initialized(): + dist.all_reduce(loss_sum, op=dist.ReduceOp.SUM) + dist.all_reduce(token_count, op=dist.ReduceOp.SUM) + dist.all_reduce(byte_count, op=dist.ReduceOp.SUM) + val_loss = (loss_sum / token_count).item() + bits_per_token = val_loss / math.log(2.0) + tokens_per_byte = token_count.item() / byte_count.item() + base_model.train() + return val_loss, bits_per_token * tokens_per_byte +def eval_val_sliding_ttt( + args: Hyperparameters, + base_model: nn.Module, + rank: int, + world_size: int, + device: torch.device, + val_tokens: Tensor, + base_bytes_lut: Tensor, + has_leading_space_lut: Tensor, + is_boundary_token_lut: Tensor, + stride: int, + batch_seqs: int = 32, + log0=print, + max_seconds: float = 0.0, +) -> tuple[float, float]: + # LONGCTX: use eval_seq_len if set, otherwise fall back to train_seq_len + seq_len = args.eval_seq_len if args.eval_seq_len > 0 else args.train_seq_len + total_tokens = val_tokens.numel() - 1 + ttt_chunk = args.ttt_chunk_tokens + window_starts = [ws for ws in range(0, total_tokens, stride) + if min(ws + seq_len, total_tokens) - ws >= stride or ws == 0] + num_chunks = (total_tokens + ttt_chunk - 1) // ttt_chunk + chunk_windows: list[list[int]] = [[] for _ in range(num_chunks)] + for ws in window_starts: + end = min(ws + seq_len, total_tokens) + wlen = end - ws + s = 0 if ws == 0 else max(wlen - stride, 0) + scored_start = ws + s + ci = min(scored_start // ttt_chunk, num_chunks - 1) + chunk_windows[ci].append(ws) + log0(f"ttt_sliding:start chunks={num_chunks} chunk_tokens={ttt_chunk} " + f"total_windows={len(window_starts)} stride={stride} " + f"ttt_lr={args.ttt_lr} ttt_epochs={args.ttt_epochs} " + f"freeze_blocks={args.ttt_freeze_blocks}") + loss_sum = torch.zeros((), device=device, dtype=torch.float64) + token_count = torch.zeros((), device=device, dtype=torch.float64) + byte_count = torch.zeros((), device=device, dtype=torch.float64) + frozen_block_ids = set(range(min(args.ttt_freeze_blocks, len(base_model.blocks)))) + ttt_params = [] + for name, p in base_model.named_parameters(): + freeze = False + for bi in frozen_block_ids: + if f"blocks.{bi}." in name: + freeze = True + break + if freeze: + p.requires_grad_(False) + else: + p.requires_grad_(True) + ttt_params.append(p) + log0(f"ttt_sliding:params unfrozen={sum(p.numel() for p in ttt_params)} " + f"frozen={sum(p.numel() for p in base_model.parameters() if not p.requires_grad)}") + optimizer = torch.optim.SGD(ttt_params, lr=args.ttt_lr, momentum=args.ttt_momentum) + t0 = time.perf_counter() + compiled_logits = torch.compile(base_model.forward_logits, dynamic=False, fullgraph=True) + for ci in range(num_chunks): + if max_seconds > 0 and (time.perf_counter() - t0) > max_seconds: + log0(f"ttt_sliding:time_limit reached after {ci}/{num_chunks} chunks, {time.perf_counter()-t0:.0f}s") + break + windows = chunk_windows[ci] + if not windows: + continue + chunk_start = ci * ttt_chunk + chunk_end = min((ci + 1) * ttt_chunk, total_tokens) + my_s = (len(windows) * rank) // world_size + my_e = (len(windows) * (rank + 1)) // world_size + my_windows = windows[my_s:my_e] + base_model.eval() + with torch.inference_mode(): + for bi in range(0, len(my_windows), batch_seqs): + batch_ws = my_windows[bi:bi + batch_seqs] + bsz = len(batch_ws) + x_batch = torch.zeros(bsz, seq_len, dtype=torch.int64, device=device) + y_batch = torch.zeros(bsz, seq_len, dtype=torch.int64, device=device) + wlens: list[int] = [] + for i, ws in enumerate(batch_ws): + end = min(ws + seq_len, total_tokens) + wlen = end - ws + wlens.append(wlen) + chunk_tok = val_tokens[ws:end + 1] + x_batch[i, :wlen] = chunk_tok[:-1] + y_batch[i, :wlen] = chunk_tok[1:] + with torch.autocast(device_type="cuda", dtype=torch.bfloat16): + logits = compiled_logits(x_batch) + nll = F.cross_entropy( + logits.view(-1, logits.size(-1)).float(), + y_batch.view(-1), reduction="none", + ).view(bsz, seq_len) + for i, ws in enumerate(batch_ws): + wlen = wlens[i] + s = 0 if ws == 0 else max(wlen - stride, 0) + scored_nll = nll[i, s:wlen].to(torch.float64) + loss_sum += scored_nll.sum() + token_count += float(wlen - s) + tgt, prev = y_batch[i, s:wlen], x_batch[i, s:wlen] + tb = base_bytes_lut[tgt].to(torch.float64) + tb += (has_leading_space_lut[tgt] & ~is_boundary_token_lut[prev]).to(torch.float64) + byte_count += tb.sum() + is_last_chunk = (ci == num_chunks - 1) + if not is_last_chunk and args.ttt_epochs > 0: + base_model.train() + chunk_seqs = (chunk_end - chunk_start) // seq_len + if chunk_seqs > 0: + cos_lr = args.ttt_lr * 0.5 * (1.0 + math.cos(math.pi * ci / max(num_chunks - 1, 1))) + for pg in optimizer.param_groups: + pg['lr'] = cos_lr + my_seq_s = (chunk_seqs * rank) // world_size + my_seq_e = (chunk_seqs * (rank + 1)) // world_size + my_chunk_seqs = my_seq_e - my_seq_s + for _ep in range(args.ttt_epochs): + for bs in range(0, my_chunk_seqs, args.ttt_batch_seqs): + be = min(bs + args.ttt_batch_seqs, my_chunk_seqs) + actual_bs = my_seq_s + bs + start_tok = chunk_start + actual_bs * seq_len + end_tok = chunk_start + (my_seq_s + be) * seq_len + 1 + if end_tok > val_tokens.numel(): + continue + local = val_tokens[start_tok:end_tok].to(device=device, dtype=torch.int64, non_blocking=True) + x = local[:-1].view(-1, seq_len) + y = local[1:].view(-1, seq_len) + optimizer.zero_grad(set_to_none=True) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16): + loss = base_model(x, y) + loss.backward() + if world_size > 1: + for p in ttt_params: + if p.grad is not None: + dist.all_reduce(p.grad, op=dist.ReduceOp.AVG) + torch.nn.utils.clip_grad_norm_(ttt_params, args.ttt_grad_clip) + optimizer.step() + if rank == 0 and (ci % 100 == 0 or ci == num_chunks - 1): + elapsed = time.perf_counter() - t0 + rl = loss_sum.item() / max(token_count.item(), 1) + rbpb = rl / math.log(2.0) * (token_count.item() / max(byte_count.item(), 1)) if token_count.item() > 0 else 0.0 + log0(f" ttt_chunk [{ci+1}/{num_chunks}] bpb={rbpb:.6f} time={elapsed:.1f}s") + if dist.is_available() and dist.is_initialized(): + dist.all_reduce(loss_sum, op=dist.ReduceOp.SUM) + dist.all_reduce(token_count, op=dist.ReduceOp.SUM) + dist.all_reduce(byte_count, op=dist.ReduceOp.SUM) + val_loss = (loss_sum / token_count).item() + val_bpb = val_loss / math.log(2.0) * (token_count.item() / byte_count.item()) + for p in base_model.parameters(): + p.requires_grad_(True) + base_model.eval() + log0(f"ttt_sliding:done val_loss={val_loss:.6f} val_bpb={val_bpb:.6f} " + f"elapsed={time.perf_counter() - t0:.1f}s") + return val_loss, val_bpb +def eval_val_sliding_slot( + args: Hyperparameters, + base_model: nn.Module, + rank: int, + world_size: int, + device: torch.device, + val_tokens: Tensor, + base_bytes_lut: Tensor, + has_leading_space_lut: Tensor, + is_boundary_token_lut: Tensor, + stride: int, + batch_seqs: int = 32, + log0=print, + max_seconds: float = 0.0, +) -> tuple[float, float]: + # LONGCTX: use eval_seq_len if set, otherwise fall back to train_seq_len + seq_len = args.eval_seq_len if args.eval_seq_len > 0 else args.train_seq_len + total_tokens = val_tokens.numel() - 1 + window_starts = [ws for ws in range(0, total_tokens, stride) + if min(ws + seq_len, total_tokens) - ws >= 1] + total_windows = len(window_starts) + my_s = (total_windows * rank) // world_size + my_e = (total_windows * (rank + 1)) // world_size + my_windows = window_starts[my_s:my_e] + loss_sum = torch.zeros((), device=device, dtype=torch.float64) + token_count = torch.zeros((), device=device, dtype=torch.float64) + byte_count = torch.zeros((), device=device, dtype=torch.float64) + base_model.eval() + for p in base_model.parameters(): + p.requires_grad_(False) + t0 = time.perf_counter() + log0(f"slot:start windows={len(my_windows)} stride={stride} " + f"slot_lr={args.slot_lr} slot_steps={args.slot_steps}") + for bi in range(0, len(my_windows), batch_seqs): + if max_seconds > 0 and (time.perf_counter() - t0) > max_seconds: + log0(f"slot:time_limit reached at batch {bi}/{len(my_windows)}, " + f"{time.perf_counter()-t0:.0f}s") + break + batch_ws = my_windows[bi:bi + batch_seqs] + bsz = len(batch_ws) + x_batch = torch.zeros(bsz, seq_len, dtype=torch.int64, device=device) + y_batch = torch.zeros(bsz, seq_len, dtype=torch.int64, device=device) + wlens: list[int] = [] + for i, ws in enumerate(batch_ws): + end = min(ws + seq_len, total_tokens) + wlen = end - ws + wlens.append(wlen) + chunk = val_tokens[ws:end + 1] + x_batch[i, :wlen] = chunk[:-1] + y_batch[i, :wlen] = chunk[1:] + with torch.no_grad(): + with torch.autocast(device_type="cuda", dtype=torch.bfloat16): + H = base_model.forward_hidden(x_batch) + H = H.detach() + delta = torch.zeros(1, 1, H.size(-1), device=device, dtype=H.dtype, + requires_grad=True) + slot_opt = torch.optim.AdamW([delta], lr=args.slot_lr, + weight_decay=1e-8, eps=1e-5) + targets_flat = y_batch.view(-1) + for _ in range(args.slot_steps): + slot_opt.zero_grad() + with torch.autocast(device_type="cuda", dtype=torch.bfloat16): + logits = base_model.compute_logits(H + delta) + loss = F.cross_entropy(logits.view(-1, logits.size(-1)).float(), + targets_flat) + loss.backward() + slot_opt.step() + with torch.no_grad(): + with torch.autocast(device_type="cuda", dtype=torch.bfloat16): + logits = base_model.compute_logits(H + delta.detach()) + nll = F.cross_entropy( + logits.view(-1, logits.size(-1)).float(), + targets_flat, reduction="none", + ).view(bsz, seq_len) + for i, ws in enumerate(batch_ws): + wlen = wlens[i] + s = 0 if ws == 0 else max(wlen - stride, 0) + scored_nll = nll[i, s:wlen].to(torch.float64) + loss_sum += scored_nll.sum() + token_count += float(wlen - s) + tgt, prev = y_batch[i, s:wlen], x_batch[i, s:wlen] + tb = base_bytes_lut[tgt].to(torch.float64) + tb += (has_leading_space_lut[tgt] & ~is_boundary_token_lut[prev]).to(torch.float64) + byte_count += tb.sum() + if rank == 0 and (bi % (batch_seqs * 100) == 0): + elapsed = time.perf_counter() - t0 + rl = loss_sum.item() / max(token_count.item(), 1) + rbpb = rl / math.log(2.0) * (token_count.item() / max(byte_count.item(), 1)) if token_count.item() > 0 else 0.0 + log0(f" slot_batch [{bi}/{len(my_windows)}] bpb={rbpb:.6f} time={elapsed:.1f}s") + for p in base_model.parameters(): + p.requires_grad_(True) + if dist.is_available() and dist.is_initialized(): + dist.all_reduce(loss_sum, op=dist.ReduceOp.SUM) + dist.all_reduce(token_count, op=dist.ReduceOp.SUM) + dist.all_reduce(byte_count, op=dist.ReduceOp.SUM) + val_loss = (loss_sum / token_count).item() + val_bpb = val_loss / math.log(2.0) * (token_count.item() / byte_count.item()) + log0(f"slot:done val_loss={val_loss:.6f} val_bpb={val_bpb:.6f} " + f"elapsed={time.perf_counter() - t0:.1f}s") + return val_loss, val_bpb +# Resolve quantization mode at import time, not runtime +_INT6_MODE = bool(int(os.environ.get("INT6_QAT", "1"))) +_INT3_MODE = bool(int(os.environ.get("INT3_QUANT", "0"))) + +if _INT3_MODE: + _QUANT_MAX_VAL = 3 + _QUANT_BITS = 3 +elif _INT6_MODE: + _QUANT_MAX_VAL = 31 + _QUANT_BITS = 6 +else: + _QUANT_MAX_VAL = 127 + _QUANT_BITS = 8 +CONTROL_TENSOR_NAME_PATTERNS = tuple( + pattern + for pattern in os.environ.get( + "CONTROL_TENSOR_NAME_PATTERNS", + "attn_scale,attn_scales,mlp_scale,mlp_scales,resid_mix,resid_mixes,q_gain,skip_weight,skip_weights,skip_gate,skip_gates,smear,attn_lambda,mlp_lambda", # RESIDLAMBDA: added attn_lambda,mlp_lambda + ).split(",") + if pattern +) +INT8_KEEP_FLOAT_FP32_NAME_PATTERNS = tuple( + pattern + for pattern in os.environ.get( + "INT8_KEEP_FLOAT_FP32_NAME_PATTERNS", + ",".join(CONTROL_TENSOR_NAME_PATTERNS), + ).split(",") + if pattern +) + +INT8_KEEP_FLOAT_MAX_NUMEL = 65_536 +INT8_KEEP_FLOAT_STORE_DTYPE = torch.float16 +INT8_PER_ROW_SCALE_DTYPE = torch.float16 +INT8_CLIP_PERCENTILE = 99.99984 +INT8_CLIP_Q = INT8_CLIP_PERCENTILE / 100.0 +def tensor_nbytes(t: Tensor) -> int: + return int(t.numel()) * int(t.element_size()) +def keep_float_tensor(name: str, t: Tensor, passthrough_orig_dtypes: dict[str, str]) -> Tensor: + if any(pattern in name for pattern in INT8_KEEP_FLOAT_FP32_NAME_PATTERNS): + return t.float().contiguous() + if t.dtype in {torch.float32, torch.bfloat16}: + passthrough_orig_dtypes[name] = str(t.dtype).removeprefix("torch.") + return t.to(dtype=INT8_KEEP_FLOAT_STORE_DTYPE).contiguous() + return t +_GPTQ_LITE_PERCENTILES = [0.9999, 0.99995, 0.99999, 0.999995, 1.0] +def quantize_float_tensor(t: Tensor) -> tuple[Tensor, Tensor]: + device = t.device + t32 = t.float() + qmax = _QUANT_MAX_VAL + if t32.ndim == 2 and t32.numel() > 0: + abs_vals = t32.abs() + best_q = None + best_scale = None + best_mse = None + for pct in _GPTQ_LITE_PERCENTILES: + if pct >= 1.0: + clip_abs = abs_vals.amax(dim=1) + else: + clip_abs = torch.quantile(abs_vals, pct, dim=1) + clip_abs = clip_abs.clamp_min(1e-8) + scale = (clip_abs / qmax).clamp_min(1.0 / qmax) + clipped = torch.clamp(t32, -clip_abs[:, None], clip_abs[:, None]) + q = torch.clamp(torch.round(clipped / scale[:, None]), -qmax, qmax) + recon = q * scale[:, None] + mse = (t32 - recon).square().mean(dim=1) + if best_mse is None: + best_mse = mse + best_q = q + best_scale = scale + else: + improved = mse < best_mse + if improved.any(): + best_mse = torch.where(improved, mse, best_mse) + best_q = torch.where(improved[:, None], q, best_q) + best_scale = torch.where(improved, scale, best_scale) + return best_q.to(torch.int8).contiguous(), best_scale.to(dtype=INT8_PER_ROW_SCALE_DTYPE).contiguous() + if t32.ndim == 2: + return torch.empty_like(t32, dtype=torch.int8), torch.empty((t32.shape[0],), dtype=INT8_PER_ROW_SCALE_DTYPE, device=device) + clip_abs = float(torch.quantile(t32.abs().flatten(), INT8_CLIP_Q).item()) if t32.numel() else 0.0 + scale_val = clip_abs / qmax if clip_abs > 0 else 1.0 + scale = torch.tensor(scale_val, dtype=torch.float32, device=device) + q = torch.clamp(torch.round(torch.clamp(t32, -clip_abs, clip_abs) / scale), -qmax, qmax).to(torch.int8).contiguous() + return q, scale +def quantize_state_dict_int8(state_dict: dict[str, Tensor], rank: int = 0, world_size: int = 1): + quantized: dict[str, Tensor] = {} + scales: dict[str, Tensor] = {} + dtypes: dict[str, str] = {} + passthrough: dict[str, Tensor] = {} + passthrough_orig_dtypes: dict[str, str] = {} + qmeta: dict[str, dict[str, object]] = {} + stats = dict.fromkeys( + ("param_count", "num_tensors", "num_float_tensors", "num_nonfloat_tensors", "baseline_tensor_bytes", "int8_payload_bytes"), + 0, + ) + for name, tensor in state_dict.items(): + t = tensor.detach().contiguous() + stats["param_count"] += int(t.numel()) + stats["num_tensors"] += 1 + stats["baseline_tensor_bytes"] += tensor_nbytes(t) + if not t.is_floating_point(): + stats["num_nonfloat_tensors"] += 1 + passthrough[name] = t.cpu() + stats["int8_payload_bytes"] += tensor_nbytes(t) + continue + if t.numel() <= INT8_KEEP_FLOAT_MAX_NUMEL: + kept = keep_float_tensor(name, t.cpu(), passthrough_orig_dtypes) + passthrough[name] = kept + stats["int8_payload_bytes"] += tensor_nbytes(kept) + continue + stats["num_float_tensors"] += 1 + q, s = quantize_float_tensor(t) + if s.ndim > 0: + qmeta[name] = {"scheme": "per_row", "axis": 0} + quantized[name] = q.cpu() + scales[name] = s.cpu() + dtypes[name] = str(t.dtype).removeprefix("torch.") + stats["int8_payload_bytes"] += tensor_nbytes(q) + tensor_nbytes(s) + obj: dict[str, object] = { + "__quant_format__": f"int{_QUANT_BITS}_clean_per_row_v1", + "quantized": quantized, + "scales": scales, + "dtypes": dtypes, + "passthrough": passthrough, + } + if qmeta: + obj["qmeta"] = qmeta + if passthrough_orig_dtypes: + obj["passthrough_orig_dtypes"] = passthrough_orig_dtypes + return obj, stats +def dequantize_state_dict_int8(obj: dict[str, object]) -> dict[str, Tensor]: + out: dict[str, Tensor] = {} + qmeta = obj.get("qmeta", {}) + passthrough_orig_dtypes = obj.get("passthrough_orig_dtypes", {}) + for name, q in obj["quantized"].items(): + dtype = getattr(torch, obj["dtypes"][name]) + s = obj["scales"][name] + if qmeta.get(name, {}).get("scheme") == "per_row" or s.ndim > 0: + s = s.to(dtype=torch.float32) + out[name] = (q.float() * s.view(q.shape[0], *([1] * (q.ndim - 1)))).to(dtype=dtype).contiguous() + else: + scale = float(s.item()) + out[name] = (q.float() * scale).to(dtype=dtype).contiguous() + for name, t in obj["passthrough"].items(): + out_t = t.detach().to("cpu").contiguous() + orig_dtype = passthrough_orig_dtypes.get(name) + if isinstance(orig_dtype, str): + out_t = out_t.to(dtype=getattr(torch, orig_dtype)).contiguous() + out[name] = out_t + return out + +def load_data_shard(file: Path) -> Tensor: + header_bytes = 256 * np.dtype(" int: # LONGCTX: coprime step for shard traversal + """Find a step coprime to n for maximum coverage diversity.""" + if n <= 1: + return 1 + for candidate in [7, 11, 13, 17, 19, 23, 29, 31, 37, 41]: + if math.gcd(candidate, n) == 1: + return candidate + for candidate in range(2, n): + if math.gcd(candidate, n) == 1: + return candidate + return 1 +class TokenStream: + def __init__(self, pattern: str, coprime_stride: bool = False): # LONGCTX: coprime_stride param + self.files = [Path(p) for p in sorted(glob.glob(pattern))] + if not self.files: + raise FileNotFoundError(f"No files found for pattern: {pattern}") + self.file_idx = 0 + self.tokens = load_data_shard(self.files[0]) + self.pos = 0 + # LONGCTX: coprime-stride shard access to avoid repeated patterns + if coprime_stride and len(self.files) > 1: + self._step = _find_coprime(len(self.files)) + else: + self._step = 1 + def _advance_file(self) -> None: + self.file_idx = (self.file_idx + self._step) % len(self.files) # LONGCTX: coprime step + self.tokens = load_data_shard(self.files[self.file_idx]) + self.pos = 0 + def take(self, n: int) -> Tensor: + chunks: list[Tensor] = [] + remaining = n + while remaining > 0: + avail = self.tokens.numel() - self.pos + if avail <= 0: + self._advance_file() + continue + k = min(remaining, avail) + chunks.append(self.tokens[self.pos : self.pos + k]) + self.pos += k + remaining -= k + return chunks[0] if len(chunks) == 1 else torch.cat(chunks) +class DistributedTokenLoader: + def __init__(self, pattern: str, rank: int, world_size: int, device: torch.device, + coprime_stride: bool = False): # LONGCTX: coprime_stride param + self.rank = rank + self.world_size = world_size + self.device = device + self.stream = TokenStream(pattern, coprime_stride=coprime_stride) # LONGCTX + def next_batch(self, global_tokens: int, seq_len: int, grad_accum_steps: int) -> tuple[Tensor, Tensor]: + local_tokens = global_tokens // (self.world_size * grad_accum_steps) + per_rank_span = local_tokens + 1 + chunk = self.stream.take(per_rank_span * self.world_size) + start = self.rank * per_rank_span + local = chunk[start : start + per_rank_span].to(dtype=torch.int64, device=self.device, non_blocking=True) + x = local[:-1].view(-1, seq_len) + y = local[1:].view(-1, seq_len) + return x, y +class RMSNorm(nn.Module): + def __init__(self, eps: float | None = None): + super().__init__() + self.eps = eps + def forward(self, x: Tensor) -> Tensor: + return F.rms_norm(x, (x.size(-1),), eps=self.eps) +class CastedLinear(nn.Linear): + _int6_qat: bool = False + def forward(self, x: Tensor) -> Tensor: + w = self.weight.to(x.dtype) + if CastedLinear._int6_qat and self.training: + w = int6_ste(w) + bias = self.bias.to(x.dtype) if self.bias is not None else None + return F.linear(x, w, bias) +def restore_low_dim_params_to_fp32(module: nn.Module) -> None: + with torch.no_grad(): + for name, param in module.named_parameters(): + if (param.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS)) and param.dtype != torch.float32: + param.data = param.data.float() +class BigramHash(nn.Module): + def __init__(self, num_buckets: int, dim: int): + super().__init__() + self.num_buckets = num_buckets + self.emb = nn.Embedding(num_buckets, dim) + nn.init.zeros_(self.emb.weight) + def forward(self, input_ids: Tensor) -> Tensor: + prev = F.pad(input_ids[:, :-1], (1, 0), value=0) + bigram_hash = (prev * 31 + input_ids) % self.num_buckets + return self.emb(bigram_hash) +class SmearGate(nn.Module): + def __init__(self, dim: int): + super().__init__() + self.gate = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + def forward(self, x: Tensor) -> Tensor: + g = torch.sigmoid(self.gate.to(dtype=x.dtype))[None, None, :] + positions = torch.arange(1, x.shape[1] + 1, device=x.device, dtype=x.dtype).view(1, -1, 1) + smooth = torch.cumsum(x, dim=1) / positions + return g * x + (1 - g) * smooth +class Rotary(nn.Module): + def __init__(self, dim: int, base: float = 10000.0, + yarn_enabled: bool = False, yarn_base_len: int = 1024, + yarn_max_len: int = 2048, yarn_beta_fast: float = 32.0, + yarn_beta_slow: float = 1.0): # LONGCTX: YaRN params + super().__init__() + inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.float32) / dim)) + # LONGCTX: Apply YaRN frequency scaling for long context support + if yarn_enabled and yarn_max_len > yarn_base_len: + scale = yarn_max_len / yarn_base_len + wavelengths = 2.0 * math.pi / inv_freq + low_thresh = yarn_base_len / yarn_beta_fast + high_thresh = yarn_base_len / yarn_beta_slow + ramp = (wavelengths - low_thresh) / max(high_thresh - low_thresh, 1e-8) + ramp = ramp.clamp(0.0, 1.0) + inv_freq = inv_freq / (ramp + (1.0 - ramp) * scale) + self._yarn_attn_factor = math.sqrt(1.0 + math.log(scale) * 0.1) + else: + self._yarn_attn_factor = 1.0 + self.register_buffer("inv_freq", inv_freq, persistent=False) + self._seq_len_cached = 0 + self._cos_cached: Tensor | None = None + self._sin_cached: Tensor | None = None + @property + def yarn_attn_factor(self) -> float: + return self._yarn_attn_factor + def forward(self, seq_len: int, device: torch.device, dtype: torch.dtype) -> tuple[Tensor, Tensor]: + if ( + self._cos_cached is None + or self._sin_cached is None + or self._seq_len_cached != seq_len + or self._cos_cached.device != device + ): + t = torch.arange(seq_len, device=device, dtype=self.inv_freq.dtype) + freqs = torch.outer(t, self.inv_freq.to(device)) + self._cos_cached = freqs.cos()[None, None, :, :] + self._sin_cached = freqs.sin()[None, None, :, :] + self._seq_len_cached = seq_len + return self._cos_cached.to(dtype=dtype), self._sin_cached.to(dtype=dtype) +def apply_rotary_emb(x: Tensor, cos: Tensor, sin: Tensor) -> Tensor: + half = x.size(-1) // 2 + x1, x2 = x[..., :half], x[..., half:] + return torch.cat((x1 * cos + x2 * sin, x2 * cos - x1 * sin), dim=-1) +class CausalSelfAttention(nn.Module): + def __init__( + self, + dim: int, + num_heads: int, + num_kv_heads: int, + rope_base: float, + qk_gain_init: float, + rope_partial_dims: int = 0, + use_xsa: bool = False, + yarn_enabled: bool = False, + yarn_base_len: int = 1024, + yarn_max_len: int = 2048, + yarn_beta_fast: float = 32.0, + yarn_beta_slow: float = 1.0, + ): + super().__init__() + if dim % num_heads != 0: + raise ValueError("model_dim must be divisible by num_heads") + if num_heads % num_kv_heads != 0: + raise ValueError("num_heads must be divisible by num_kv_heads") + self.num_heads = num_heads + self.num_kv_heads = num_kv_heads + self.head_dim = dim // num_heads + if self.head_dim % 2 != 0: + raise ValueError("head_dim must be even for RoPE") + kv_dim = self.num_kv_heads * self.head_dim + self.c_q = CastedLinear(dim, dim, bias=False) + self.c_k = CastedLinear(dim, kv_dim, bias=False) + self.c_v = CastedLinear(dim, kv_dim, bias=False) + self.proj = CastedLinear(dim, dim, bias=False) + self.proj._zero_init = True + self.q_gain = nn.Parameter(torch.full((num_heads,), qk_gain_init, dtype=torch.float32)) + rope_dim = rope_partial_dims if 0 < rope_partial_dims < self.head_dim else self.head_dim + self.rope_dim = rope_dim + self.rotary = Rotary(rope_dim, base=rope_base, + yarn_enabled=yarn_enabled, yarn_base_len=yarn_base_len, + yarn_max_len=yarn_max_len, yarn_beta_fast=yarn_beta_fast, + yarn_beta_slow=yarn_beta_slow) # LONGCTX: YaRN params + self.use_xsa = use_xsa + def forward(self, x: Tensor) -> Tensor: + bsz, seqlen, dim = x.shape + q = self.c_q(x).view(bsz, seqlen, self.num_heads, self.head_dim).transpose(1, 2) + k = self.c_k(x).view(bsz, seqlen, self.num_kv_heads, self.head_dim).transpose(1, 2) + v = self.c_v(x).view(bsz, seqlen, self.num_kv_heads, self.head_dim).transpose(1, 2) + q = F.rms_norm(q, (q.size(-1),)) + k = F.rms_norm(k, (k.size(-1),)) + cos, sin = self.rotary(seqlen, x.device, q.dtype) + if self.rope_dim < self.head_dim: + q_rope, q_pass = q[..., :self.rope_dim], q[..., self.rope_dim:] + k_rope, k_pass = k[..., :self.rope_dim], k[..., self.rope_dim:] + q_rope = apply_rotary_emb(q_rope, cos, sin) + k_rope = apply_rotary_emb(k_rope, cos, sin) + q = torch.cat([q_rope, q_pass], dim=-1) + k = torch.cat([k_rope, k_pass], dim=-1) + else: + q = apply_rotary_emb(q, cos, sin) + k = apply_rotary_emb(k, cos, sin) + q_gain = self.q_gain.to(dtype=q.dtype) + q = q * q_gain[None, :, None, None] + y = F.scaled_dot_product_attention( + q, k, v, + attn_mask=None, + is_causal=True, + enable_gqa=(self.num_kv_heads != self.num_heads), + ) + if self.use_xsa: + if self.num_kv_heads != self.num_heads: + repeats = self.num_heads // self.num_kv_heads + v_expanded = v.repeat_interleave(repeats, dim=1) + else: + v_expanded = v + y = y - v_expanded + y = y.transpose(1, 2).reshape(bsz, seqlen, dim) + return self.proj(y) +class MLP(nn.Module): + def __init__(self, dim: int, mlp_mult: int): + super().__init__() + hidden = mlp_mult * dim + self.fc = CastedLinear(dim, hidden, bias=False) + self.proj = CastedLinear(hidden, dim, bias=False) + self.proj._zero_init = True + def forward(self, x: Tensor) -> Tensor: + x = F.leaky_relu(self.fc(x), negative_slope=0.5) + return self.proj(x * x) + +class Block(nn.Module): + def __init__( + self, + dim: int, + num_heads: int, + num_kv_heads: int, + mlp_mult: int, + rope_base: float, + qk_gain_init: float, + use_smeargate: bool = False, + rope_partial_dims: int = 0, + use_xsa: bool = False, + ln_scale_factor: float = 1.0, + resid_lambda_init: float = 0.0, # RESIDLAMBDA: 0 = disabled, >0 = init value + yarn_enabled: bool = False, + yarn_base_len: int = 1024, + yarn_max_len: int = 2048, + yarn_beta_fast: float = 32.0, + yarn_beta_slow: float = 1.0, + ): + super().__init__() + self.attn_norm = RMSNorm() + self.mlp_norm = RMSNorm() + self.attn = CausalSelfAttention( + dim, num_heads, num_kv_heads, rope_base, qk_gain_init, + rope_partial_dims=rope_partial_dims, + use_xsa=use_xsa, + yarn_enabled=yarn_enabled, yarn_base_len=yarn_base_len, + yarn_max_len=yarn_max_len, yarn_beta_fast=yarn_beta_fast, + yarn_beta_slow=yarn_beta_slow, + ) + self.mlp = MLP(dim, mlp_mult) + self.attn_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + self.mlp_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + self.resid_mix = nn.Parameter(torch.stack((torch.ones(dim), torch.zeros(dim))).float()) + self.smear = SmearGate(dim) if use_smeargate else None + self.ln_scale_factor = ln_scale_factor + # RESIDLAMBDA: per-sublayer learned residual scaling vectors + if resid_lambda_init > 0: + self.attn_lambda = nn.Parameter(torch.full((dim,), resid_lambda_init, dtype=torch.float32)) + self.mlp_lambda = nn.Parameter(torch.full((dim,), resid_lambda_init, dtype=torch.float32)) + else: + self.attn_lambda = None + self.mlp_lambda = None + def forward(self, x: Tensor, x0: Tensor) -> Tensor: + mix = self.resid_mix.to(dtype=x.dtype) + attn_scale = self.attn_scale.to(dtype=x.dtype) + mlp_scale = self.mlp_scale.to(dtype=x.dtype) + x = mix[0][None, None, :] * x + mix[1][None, None, :] * x0 + normed = self.attn_norm(x) + if self.ln_scale_factor != 1.0: + normed = normed * self.ln_scale_factor + attn_out = self.attn(normed) + # RESIDLAMBDA: scale residual contribution + if self.attn_lambda is not None: + attn_out = self.attn_lambda.to(dtype=x.dtype)[None, None, :] * attn_out + x = x + attn_scale[None, None, :] * attn_out + mlp_normed = self.mlp_norm(x) + if self.ln_scale_factor != 1.0: + mlp_normed = mlp_normed * self.ln_scale_factor + mlp_out = self.mlp(mlp_normed) + # RESIDLAMBDA: scale residual contribution + if self.mlp_lambda is not None: + mlp_out = self.mlp_lambda.to(dtype=x.dtype)[None, None, :] * mlp_out + x = x + mlp_scale[None, None, :] * mlp_out + if self.smear is not None: + x = self.smear(x) + return x +class GPT(nn.Module): + def __init__( + self, + vocab_size: int, + num_layers: int, + model_dim: int, + num_heads: int, + num_kv_heads: int, + mlp_mult: int, + tie_embeddings: bool, + tied_embed_init_std: float, + logit_softcap: float, + rope_base: float, + qk_gain_init: float, + bigramhash_buckets: int = 0, + use_smeargate: bool = False, + unet_skips: bool = True, + rope_partial_dims: int = 0, + ln_scale: bool = False, + xsa_layers: int = 0, + resid_lambda_init: float = 0.0, # RESIDLAMBDA: 0 = disabled + yarn_enabled: bool = False, + yarn_base_len: int = 1024, + yarn_max_len: int = 2048, + yarn_beta_fast: float = 32.0, + yarn_beta_slow: float = 1.0, + ): + super().__init__() + if logit_softcap <= 0.0: + raise ValueError(f"logit_softcap must be positive, got {logit_softcap}") + self.tie_embeddings = tie_embeddings + self.tied_embed_init_std = tied_embed_init_std + self.logit_softcap = logit_softcap + self.tok_emb = nn.Embedding(vocab_size, model_dim) + self.bigram_hash = BigramHash(bigramhash_buckets, model_dim) if bigramhash_buckets > 0 else None + if unet_skips: + self.num_encoder_layers = num_layers // 2 + self.num_decoder_layers = num_layers - self.num_encoder_layers + self.num_skip_weights = min(self.num_encoder_layers, self.num_decoder_layers) + else: + self.num_encoder_layers = num_layers + self.num_decoder_layers = 0 + self.num_skip_weights = 0 + self.skip_weights = nn.Parameter(torch.ones(self.num_skip_weights, model_dim, dtype=torch.float32)) + self.skip_gates = nn.Parameter(torch.zeros(self.num_skip_weights, model_dim, dtype=torch.float32)) + xsa_start = num_layers - xsa_layers if xsa_layers > 0 else num_layers + self.blocks = nn.ModuleList( + [ + Block( + model_dim, num_heads, num_kv_heads, mlp_mult, rope_base, qk_gain_init, + use_smeargate=use_smeargate, + rope_partial_dims=rope_partial_dims, + use_xsa=(i >= xsa_start), + ln_scale_factor=(1.0 / math.sqrt(i + 1)) if ln_scale else 1.0, + resid_lambda_init=resid_lambda_init, # RESIDLAMBDA: pass through + yarn_enabled=yarn_enabled, yarn_base_len=yarn_base_len, + yarn_max_len=yarn_max_len, yarn_beta_fast=yarn_beta_fast, + yarn_beta_slow=yarn_beta_slow, + ) + for i in range(num_layers) + ] + ) + self.final_norm = RMSNorm() + self.lm_head = None if tie_embeddings else CastedLinear(model_dim, vocab_size, bias=False) + if self.lm_head is not None: + self.lm_head._zero_init = True + self._init_weights() + def _init_weights(self) -> None: + if self.tie_embeddings: + nn.init.normal_(self.tok_emb.weight, mean=0.0, std=self.tied_embed_init_std) + for module in self.modules(): + if isinstance(module, nn.Linear) and getattr(module, "_zero_init", False): + nn.init.zeros_(module.weight) + def _run_blocks(self, x: Tensor, x0: Tensor) -> Tensor: + """Shared encoder-decoder pass with sigmoid-gated skip connections.""" + skips: list[Tensor] = [torch.empty(0)] * self.num_encoder_layers + for i in range(self.num_encoder_layers): + x = self.blocks[i](x, x0) + skips[i] = x + skip_idx = self.num_encoder_layers - 1 + for i in range(self.num_decoder_layers): + if skip_idx >= 0: + g = torch.sigmoid(self.skip_gates[i].to(dtype=x.dtype))[None, None, :] + x = torch.lerp(self.skip_weights[i].to(dtype=x.dtype)[None, None, :] * skips[skip_idx], x, g) + skip_idx -= 1 + x = self.blocks[self.num_encoder_layers + i](x, x0) + return x + def forward(self, input_ids: Tensor, target_ids: Tensor) -> Tensor: + x = self.tok_emb(input_ids) + if self.bigram_hash is not None: + x = x + self.bigram_hash(input_ids) + x = F.rms_norm(x, (x.size(-1),)) + x0 = x + x = self._run_blocks(x, x0) + x = self.final_norm(x).view(-1, x.size(-1)) + targets = target_ids.view(-1) + if self.tie_embeddings: + logits_proj = F.linear(x, self.tok_emb.weight) + else: + if self.lm_head is None: + raise RuntimeError("lm_head is required when tie_embeddings=False") + logits_proj = self.lm_head(x) + logits = self.logit_softcap * torch.tanh(logits_proj / self.logit_softcap) + return F.cross_entropy(logits.float(), targets, reduction="mean") + def forward_hidden(self, input_ids: Tensor) -> Tensor: + """Run embedding + all blocks, return pre-norm hidden states.""" + x = self.tok_emb(input_ids) + if self.bigram_hash is not None: + x = x + self.bigram_hash(input_ids) + x = F.rms_norm(x, (x.size(-1),)) + x0 = x + return self._run_blocks(x, x0) + def compute_logits(self, h: Tensor) -> Tensor: + """Compute logits from hidden states (post-blocks, pre-norm).""" + x = self.final_norm(h) + if self.tie_embeddings: + logits_proj = F.linear(x, self.tok_emb.weight) + else: + if self.lm_head is None: + raise RuntimeError("lm_head is required when tie_embeddings=False") + logits_proj = self.lm_head(x) + return self.logit_softcap * torch.tanh(logits_proj / self.logit_softcap) + def forward_logits(self, input_ids: Tensor) -> Tensor: + return self.compute_logits(self.forward_hidden(input_ids)) +def main() -> None: + global zeropower_via_newtonschulz5 + code = Path(__file__).read_text(encoding="utf-8") + args = Hyperparameters() + if not int(os.environ.get("TORCHDYNAMO_DISABLE", "0")): + zeropower_via_newtonschulz5 = torch.compile(zeropower_via_newtonschulz5) + distributed = "RANK" in os.environ and "WORLD_SIZE" in os.environ + rank = int(os.environ.get("RANK", "0")) + world_size = int(os.environ.get("WORLD_SIZE", "1")) + local_rank = int(os.environ.get("LOCAL_RANK", "0")) + if world_size <= 0: + raise ValueError(f"WORLD_SIZE must be positive, got {world_size}") + if 8 % world_size != 0: + raise ValueError(f"WORLD_SIZE={world_size} must divide 8 so grad_accum_steps stays integral") + grad_accum_steps = 8 // world_size + grad_scale = 1.0 / grad_accum_steps + if not torch.cuda.is_available(): + raise RuntimeError("CUDA is required") + device = torch.device("cuda", local_rank) + torch.cuda.set_device(device) + if distributed: + dist.init_process_group(backend="nccl", device_id=device) + dist.barrier() + master_process = rank == 0 + torch.backends.cuda.matmul.allow_tf32 = True + torch.backends.cudnn.allow_tf32 = True + from torch.backends.cuda import enable_cudnn_sdp, enable_flash_sdp, enable_math_sdp, enable_mem_efficient_sdp + _use_flash = sys.platform != "win32" + enable_cudnn_sdp(False) + enable_flash_sdp(_use_flash) + enable_mem_efficient_sdp(False) + enable_math_sdp(not _use_flash) + logfile = None + if master_process: + os.makedirs("logs", exist_ok=True) + logfile = f"logs/{args.run_id}.txt" + print(logfile) + def log0(msg: str, console: bool = True) -> None: + if not master_process: + return + if console: + print(msg) + if logfile is not None: + with open(logfile, "a", encoding="utf-8") as f: + print(msg, file=f) + log0(code, console=False) + log0("=" * 100, console=False) + log0(f"Running Python {sys.version}", console=False) + log0(f"Running PyTorch {torch.__version__}", console=False) + log0( + subprocess.run(["nvidia-smi"], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True, check=False).stdout, + console=False, + ) + log0("=" * 100, console=False) + random.seed(args.seed) + np.random.seed(args.seed) + torch.manual_seed(args.seed) + torch.cuda.manual_seed_all(args.seed) + if not args.tokenizer_path.endswith(".model"): + raise ValueError(f"Script only setup for SentencePiece .model file: {args.tokenizer_path}") + sp = spm.SentencePieceProcessor(model_file=args.tokenizer_path) + if int(sp.vocab_size()) != args.vocab_size: + raise ValueError( + f"VOCAB_SIZE={args.vocab_size} does not match tokenizer vocab_size={int(sp.vocab_size())}" + ) + # LONGCTX: resolve eval_seq_len (0 means same as train_seq_len) + effective_eval_seq_len = args.eval_seq_len if args.eval_seq_len > 0 else args.train_seq_len + dataset_dir = Path(args.data_path).resolve() + actual_train_files = len(list(dataset_dir.glob("fineweb_train_*.bin"))) + # LONGCTX: align val tokens to max of train/eval seq_len for both chunked and sliding eval + val_tokens = load_validation_tokens(args.val_files, max(args.train_seq_len, effective_eval_seq_len)) + base_bytes_lut, has_leading_space_lut, is_boundary_token_lut = build_sentencepiece_luts( + sp, args.vocab_size, device + ) + log0(f"val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path={args.tokenizer_path}") + log0(f"train_loader:dataset:{dataset_dir.name} train_shards:{actual_train_files}") + log0(f"val_loader:shards pattern={args.val_files} tokens:{val_tokens.numel() - 1}") + CastedLinear._int6_qat = args.int6_qat + base_model = GPT( + vocab_size=args.vocab_size, + num_layers=args.num_layers, + model_dim=args.model_dim, + num_heads=args.num_heads, + num_kv_heads=args.num_kv_heads, + mlp_mult=args.mlp_mult, + tie_embeddings=args.tie_embeddings, + tied_embed_init_std=args.tied_embed_init_std, + logit_softcap=args.logit_softcap, + rope_base=args.rope_base, + qk_gain_init=args.qk_gain_init, + bigramhash_buckets=args.bigramhash_buckets, + use_smeargate=args.smeargate, + unet_skips=args.unet_skips, + rope_partial_dims=args.rope_partial_dims, + ln_scale=args.ln_scale, + xsa_layers=args.xsa_layers, + resid_lambda_init=args.resid_lambda_init if args.resid_lambdas else 0.0, # RESIDLAMBDA + yarn_enabled=args.yarn_enabled, # LONGCTX + yarn_base_len=args.yarn_base_len, + yarn_max_len=args.yarn_max_len, + yarn_beta_fast=args.yarn_beta_fast, + yarn_beta_slow=args.yarn_beta_slow, + ).to(device).bfloat16() + for module in base_model.modules(): + if isinstance(module, CastedLinear): + module.float() + restore_low_dim_params_to_fp32(base_model) + if int(os.environ.get("TORCHDYNAMO_DISABLE", "0")): + compiled_model = base_model + else: + compiled_model = torch.compile(base_model, dynamic=False, fullgraph=True) + model: nn.Module = DDP(compiled_model, device_ids=[local_rank], broadcast_buffers=False) if distributed else compiled_model + block_named_params = list(base_model.blocks.named_parameters()) + # Split matrix params into early/late for split Muon LR + early_cutoff = args.num_layers // 2 + early_matrix_params: list[nn.Parameter] = [] + late_matrix_params: list[nn.Parameter] = [] + # RESIDLAMBDA: separate lambda params from other scalars for higher LR + resid_lambda_params: list[nn.Parameter] = [] + scalar_params: list[nn.Parameter] = [] + for name, p in block_named_params: + if p.ndim == 2 and not any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS): + layer_idx = int(name.split(".")[0]) + if layer_idx < early_cutoff: + early_matrix_params.append(p) + else: + late_matrix_params.append(p) + else: + if "attn_lambda" in name or "mlp_lambda" in name: # RESIDLAMBDA + resid_lambda_params.append(p) + else: + scalar_params.append(p) + matrix_params = early_matrix_params + late_matrix_params + if base_model.skip_weights.numel() > 0: + scalar_params.append(base_model.skip_weights) + scalar_params.append(base_model.skip_gates) + token_lr = args.tied_embed_lr if args.tie_embeddings else args.embed_lr + embed_params: list[nn.Parameter] = [base_model.tok_emb.weight] + if base_model.bigram_hash is not None: + embed_params.append(base_model.bigram_hash.emb.weight) + optimizer_tok = torch.optim.Adam( + [{"params": embed_params, "lr": token_lr, "base_lr": token_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + fused=True, + ) + # Split Muon LR: early layers use MATRIX_LR, late layers use MATRIX_LR * LATE_LR_SCALE + early_lr = args.matrix_lr + late_lr = args.matrix_lr * args.late_lr_scale if args.split_lr else args.matrix_lr + if args.split_lr and early_matrix_params and late_matrix_params: + optimizer_muon = Muon( + [{"params": early_matrix_params}, {"params": late_matrix_params}], + lr=early_lr, + momentum=args.muon_momentum, + backend_steps=args.muon_backend_steps, + ) + optimizer_muon.param_groups[0]["base_lr"] = early_lr + optimizer_muon.param_groups[1]["lr"] = late_lr + optimizer_muon.param_groups[1]["base_lr"] = late_lr + else: + optimizer_muon = Muon( + matrix_params, + lr=args.matrix_lr, + momentum=args.muon_momentum, + backend_steps=args.muon_backend_steps, + ) + for group in optimizer_muon.param_groups: + group["base_lr"] = args.matrix_lr + optimizer_scalar = torch.optim.Adam( + [{"params": scalar_params, "lr": args.scalar_lr, "base_lr": args.scalar_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + fused=True, + ) + optimizers: list[torch.optim.Optimizer] = [optimizer_tok, optimizer_muon, optimizer_scalar] + # RESIDLAMBDA: separate optimizer for lambda params with higher LR + if resid_lambda_params: + lambda_lr = args.scalar_lr * args.resid_lambda_lr_mult + optimizer_lambda = torch.optim.Adam( + [{"params": resid_lambda_params, "lr": lambda_lr, "base_lr": lambda_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + fused=True, + ) + optimizers.append(optimizer_lambda) + log0(f"resid_lambdas:enabled init={args.resid_lambda_init} lr={lambda_lr} " + f"n_params={sum(p.numel() for p in resid_lambda_params)}") + if base_model.lm_head is not None: + optimizer_head = torch.optim.Adam( + [{"params": [base_model.lm_head.weight], "lr": args.head_lr, "base_lr": args.head_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + fused=True, + ) + optimizers.insert(1, optimizer_head) + n_params = sum(p.numel() for p in base_model.parameters()) + log0(f"model_params:{n_params}") + log0(f"world_size:{world_size} grad_accum_steps:{grad_accum_steps}") + log0(f"sdp_backends:cudnn=False flash={_use_flash} mem_efficient=False math={not _use_flash}") + log0(f"attention_mode:gqa num_heads:{args.num_heads} num_kv_heads:{args.num_kv_heads}") + log0( + f"tie_embeddings:{args.tie_embeddings} embed_lr:{token_lr} " + f"head_lr:{args.head_lr if base_model.lm_head is not None else 0.0} " + f"matrix_lr:{args.matrix_lr} scalar_lr:{args.scalar_lr} " + f"split_lr:{args.split_lr} late_lr:{late_lr:.4f}" + ) + log0( + f"train_batch_tokens:{args.train_batch_tokens} train_seq_len:{args.train_seq_len} " + f"iterations:{args.iterations} warmup_steps:{args.warmup_steps} " + f"max_wallclock_seconds:{args.max_wallclock_seconds:.3f}" + ) + log0(f"seed:{args.seed}") + log0( + f"meta_baseline: int6_qat:{args.int6_qat} bigramhash:{args.bigramhash_buckets} " + f"smeargate:{args.smeargate} ema_decay:{args.ema_decay} unet_skips:{args.unet_skips} " + f"warmdown_iters:{args.warmdown_iters} compression:lzma" + ) + log0( + f"strong_gains: leaky_relu_sq:True rope_partial_dims:{args.rope_partial_dims} " + f"ln_scale:{args.ln_scale} xsa_layers:{args.xsa_layers}" + ) + # LONGCTX: log long context configuration + log0( + f"longctx: train_seq_len:{args.train_seq_len} eval_seq_len:{effective_eval_seq_len} " + f"eval_stride:{args.eval_stride} yarn_enabled:{args.yarn_enabled} " + f"yarn_base_len:{args.yarn_base_len} yarn_max_len:{args.yarn_max_len} " + f"yarn_beta_fast:{args.yarn_beta_fast} yarn_beta_slow:{args.yarn_beta_slow} " + f"coprime_stride:{args.coprime_stride}" + ) + train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device, + coprime_stride=args.coprime_stride) # LONGCTX + def zero_grad_all() -> None: + for opt in optimizers: + opt.zero_grad(set_to_none=True) + max_wallclock_ms = 1000.0 * args.max_wallclock_seconds if args.max_wallclock_seconds > 0 else None + def lr_mul(step: int, elapsed_ms: float) -> float: + if args.warmdown_iters <= 0: + return 1.0 + if max_wallclock_ms is None: + warmdown_start = max(args.iterations - args.warmdown_iters, 0) + return max((args.iterations - step) / max(args.warmdown_iters, 1), 0.0) if warmdown_start <= step < args.iterations else 1.0 + step_ms = elapsed_ms / max(step, 1) + warmdown_ms = args.warmdown_iters * step_ms + remaining_ms = max(max_wallclock_ms - elapsed_ms, 0.0) + linear = remaining_ms / max(warmdown_ms, 1e-9) if remaining_ms <= warmdown_ms else 1.0 + # QUICKWIN: cosine warmdown (smoother than linear) + if args.cosine_warmdown and linear < 1.0: + return 0.5 * (1.0 + math.cos(math.pi * (1.0 - linear))) + return linear + if args.warmup_steps > 0: + initial_model_state = {name: tensor.detach().cpu().clone() for name, tensor in base_model.state_dict().items()} + initial_optimizer_states = [copy.deepcopy(opt.state_dict()) for opt in optimizers] + model.train() + for warmup_step in range(args.warmup_steps): + zero_grad_all() + for micro_step in range(grad_accum_steps): + if distributed: + model.require_backward_grad_sync = micro_step == grad_accum_steps - 1 + x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + warmup_loss = model(x, y) + (warmup_loss * grad_scale).backward() + for opt in optimizers: + opt.step() + zero_grad_all() + if args.warmup_steps <= 20 or (warmup_step + 1) % 10 == 0 or warmup_step + 1 == args.warmup_steps: + log0(f"warmup_step:{warmup_step + 1}/{args.warmup_steps}") + base_model.load_state_dict(initial_model_state, strict=True) + for opt, state in zip(optimizers, initial_optimizer_states, strict=True): + opt.load_state_dict(state) + zero_grad_all() + if distributed: + model.require_backward_grad_sync = True + train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device, + coprime_stride=args.coprime_stride) # LONGCTX + ema_state: dict[str, Tensor] | None = None + if args.ema_decay > 0: + ema_state = {name: param.data.clone() for name, param in base_model.named_parameters()} + log0(f"ema:enabled decay={args.ema_decay}") + training_time_ms = 0.0 + stop_after_step: int | None = None + torch.cuda.synchronize() + t0 = time.perf_counter() + step = 0 + while True: + last_step = step == args.iterations or (stop_after_step is not None and step >= stop_after_step) + should_validate = last_step or (args.val_loss_every > 0 and step % args.val_loss_every == 0) + if should_validate: + torch.cuda.synchronize() + training_time_ms += 1000.0 * (time.perf_counter() - t0) + ema_backup: dict[str, Tensor] | None = None + if ema_state is not None: + ema_backup = {} + with torch.no_grad(): + for name, p in base_model.named_parameters(): + ema_backup[name] = p.data.clone() + p.data.copy_(ema_state[name]) + val_loss, val_bpb = eval_val( + args, model, rank, world_size, device, grad_accum_steps, + val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, + ) + if ema_backup is not None: + with torch.no_grad(): + for name, p in base_model.named_parameters(): + p.data.copy_(ema_backup[name]) + del ema_backup + log0( + f"step:{step}/{args.iterations} val_loss:{val_loss:.4f} val_bpb:{val_bpb:.4f} " + f"train_time:{training_time_ms:.0f}ms step_avg:{training_time_ms / max(step, 1):.2f}ms" + ) + torch.cuda.synchronize() + t0 = time.perf_counter() + if last_step: + if stop_after_step is not None and step < args.iterations: + log0( + f"stopping_early: wallclock_cap train_time:{training_time_ms:.0f}ms " + f"step:{step}/{args.iterations}" + ) + break + elapsed_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) + scale = lr_mul(step, elapsed_ms) + zero_grad_all() + train_loss_accum = 0.0 + for micro_step in range(grad_accum_steps): + if distributed: + model.require_backward_grad_sync = micro_step == grad_accum_steps - 1 + x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + loss = model(x, y) + train_loss_accum += loss.detach().item() + (loss * grad_scale).backward() + train_loss = train_loss_accum / grad_accum_steps + frac = min(step / args.muon_momentum_warmup_steps, 1.0) if args.muon_momentum_warmup_steps > 0 else 1.0 + muon_momentum = (1 - frac) * args.muon_momentum_warmup_start + frac * args.muon_momentum + for group in optimizer_muon.param_groups: + group["momentum"] = muon_momentum + for opt in optimizers: + for group in opt.param_groups: + group["lr"] = group["base_lr"] * scale + if args.grad_clip_norm > 0: + torch.nn.utils.clip_grad_norm_(base_model.parameters(), args.grad_clip_norm) + for opt in optimizers: + opt.step() + zero_grad_all() + if ema_state is not None: + decay = args.ema_decay + one_minus_decay = 1.0 - decay + with torch.no_grad(): + for name, p in base_model.named_parameters(): + ema_state[name].mul_(decay).add_(p.data, alpha=one_minus_decay) + step += 1 + if args.checkpoint_every > 0 and step % args.checkpoint_every == 0 and master_process: + ckpt_dir = Path("checkpoints") + ckpt_dir.mkdir(exist_ok=True) + ckpt_path = ckpt_dir / f"ckpt_step{step}.pt" + ckpt_data = {"step": step, "model": base_model.state_dict()} + if ema_state is not None: + ckpt_data["ema"] = {k: v.cpu() for k, v in ema_state.items()} + torch.save(ckpt_data, ckpt_path) + ckpts = sorted(ckpt_dir.glob("ckpt_step*.pt"), key=lambda p: p.stat().st_mtime) + for old in ckpts[:-2]: + old.unlink(missing_ok=True) + log0(f"checkpoint:saved {ckpt_path} (kept {min(len(ckpts), 2)} latest)") + approx_training_time_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) + should_log_train = ( + args.train_log_every > 0 + and (step <= 10 or step % args.train_log_every == 0 or stop_after_step is not None) + ) + if should_log_train: + log0( + f"step:{step}/{args.iterations} train_loss:{train_loss:.4f} " + f"train_time:{approx_training_time_ms:.0f}ms step_avg:{approx_training_time_ms / step:.2f}ms" + ) + reached_cap = max_wallclock_ms is not None and approx_training_time_ms >= max_wallclock_ms + if distributed and max_wallclock_ms is not None: + reached_cap_tensor = torch.tensor(int(reached_cap), device=device) + dist.all_reduce(reached_cap_tensor, op=dist.ReduceOp.MAX) + reached_cap = bool(reached_cap_tensor.item()) + if stop_after_step is None and reached_cap: + stop_after_step = step + log0( + f"peak memory allocated: {torch.cuda.max_memory_allocated() // 1024 // 1024} MiB " + f"reserved: {torch.cuda.max_memory_reserved() // 1024 // 1024} MiB" + ) + if ema_state is not None: + with torch.no_grad(): + for name, p in base_model.named_parameters(): + p.data.copy_(ema_state[name]) + log0("ema:applied EMA weights for export") + if master_process: + torch.save(base_model.state_dict(), "final_model.pt") + model_bytes = os.path.getsize("final_model.pt") + code_bytes = len(code.encode("utf-8")) + log0(f"Serialized model: {model_bytes} bytes") + log0(f"Code size: {code_bytes} bytes") + log0(f"Total submission size: {model_bytes + code_bytes} bytes") + eval_start_time = time.perf_counter() + log0(f"eval_phase:starting (separate {args.max_eval_seconds}s budget)") + quant_label = f"int{_QUANT_BITS}" + t_quant = time.perf_counter() + quant_obj, quant_stats = quantize_state_dict_int8(base_model.state_dict()) + log0(f"quantization:done time={1000*(time.perf_counter()-t_quant):.0f}ms") + quant_buf = io.BytesIO() + torch.save(quant_obj, quant_buf) + quant_raw = quant_buf.getvalue() + compress_label = "lzma" + t_compress = time.perf_counter() + quant_blob = lzma.compress(quant_raw, preset=6) + log0(f"compression:done time={1000*(time.perf_counter()-t_compress):.0f}ms") + quant_raw_bytes = len(quant_raw) + quant_obj["__compress_format__"] = compress_label + quant_buf_with_meta = io.BytesIO() + torch.save(quant_obj, quant_buf_with_meta) + quant_raw_with_meta = quant_buf_with_meta.getvalue() + quant_blob = lzma.compress(quant_raw_with_meta, preset=6) + artifact_name = f"final_model.{quant_label}.ptz" + if master_process: + with open(artifact_name, "wb") as f: + f.write(quant_blob) + quant_file_bytes = os.path.getsize(artifact_name) + code_bytes = len(code.encode("utf-8")) + ratio = quant_stats["baseline_tensor_bytes"] / max(quant_stats["int8_payload_bytes"], 1) + log0( + f"Serialized model {quant_label}+{compress_label}: {quant_file_bytes} bytes " + f"(payload:{quant_stats['int8_payload_bytes']} raw_torch:{quant_raw_bytes} payload_ratio:{ratio:.2f}x)" + ) + log0(f"Total submission size {quant_label}+{compress_label}: {quant_file_bytes + code_bytes} bytes") + if distributed: dist.barrier() + with open(artifact_name, "rb") as f: + quant_blob_disk = f.read() + t_decompress = time.perf_counter() + decompressed = lzma.decompress(quant_blob_disk) + quant_state = torch.load(io.BytesIO(decompressed), map_location="cpu") + log0(f"decompression:done time={1000*(time.perf_counter()-t_decompress):.0f}ms") + base_model.load_state_dict(dequantize_state_dict_int8(quant_state), strict=True) + torch.cuda.synchronize() + t_qeval = time.perf_counter() + q_val_loss, q_val_bpb = eval_val( + args, model, rank, world_size, device, grad_accum_steps, + val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, + ) + torch.cuda.synchronize() + log0( + f"final_{quant_label}_{compress_label}_roundtrip val_loss:{q_val_loss:.4f} val_bpb:{q_val_bpb:.4f} " + f"eval_time:{1000.0 * (time.perf_counter() - t_qeval):.0f}ms" + ) + log0(f"final_{quant_label}_{compress_label}_roundtrip_exact val_loss:{q_val_loss:.8f} val_bpb:{q_val_bpb:.8f}") + if args.eval_stride > 0 and args.eval_stride < effective_eval_seq_len: # LONGCTX: use eval_seq_len + torch.cuda.synchronize() + t_slide = time.perf_counter() + sw_val_loss, sw_val_bpb = eval_val_sliding( + args, base_model, rank, world_size, device, + val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, + stride=args.eval_stride, + ) + torch.cuda.synchronize() + log0( + f"final_sliding_window val_loss:{sw_val_loss:.4f} val_bpb:{sw_val_bpb:.4f} " + f"stride:{args.eval_stride} eval_time:{1000.0 * (time.perf_counter() - t_slide):.0f}ms" + ) + log0(f"final_sliding_window_exact val_loss:{sw_val_loss:.8f} val_bpb:{sw_val_bpb:.8f}") + if args.ttt_enabled: + elapsed_eval = time.perf_counter() - eval_start_time + if args.max_eval_seconds > 0 and elapsed_eval > args.max_eval_seconds - 60: + log0(f"eval:skipping_ttt elapsed={elapsed_eval:.0f}s budget={args.max_eval_seconds}s") + else: + torch.cuda.synchronize() + t_ttt = time.perf_counter() + ttt_time_budget = max(args.max_eval_seconds - elapsed_eval - 30, 0) if args.max_eval_seconds > 0 else 0 + log0(f"ttt:time_budget={ttt_time_budget:.0f}s (eval_elapsed={elapsed_eval:.0f}s)") + ttt_loss, ttt_bpb = eval_val_sliding_ttt( + args, base_model, rank, world_size, device, + val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, + stride=args.eval_stride if args.eval_stride > 0 else 64, + log0=log0, + max_seconds=ttt_time_budget, + ) + torch.cuda.synchronize() + log0( + f"legal_ttt val_loss:{ttt_loss:.4f} val_bpb:{ttt_bpb:.4f} " + f"eval_time:{1000.0 * (time.perf_counter() - t_ttt):.0f}ms" + ) + log0(f"legal_ttt_exact val_loss:{ttt_loss:.8f} val_bpb:{ttt_bpb:.8f}") + if args.slot_enabled: + elapsed_eval = time.perf_counter() - eval_start_time + if args.max_eval_seconds > 0 and elapsed_eval > args.max_eval_seconds - 60: + log0(f"eval:skipping_slot elapsed={elapsed_eval:.0f}s budget={args.max_eval_seconds}s") + else: + torch.cuda.synchronize() + t_slot = time.perf_counter() + slot_time_budget = max(args.max_eval_seconds - elapsed_eval - 10, 0) if args.max_eval_seconds > 0 else 0 + log0(f"slot:time_budget={slot_time_budget:.0f}s (eval_elapsed={elapsed_eval:.0f}s)") + slot_loss, slot_bpb = eval_val_sliding_slot( + args, base_model, rank, world_size, device, + val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, + stride=args.eval_stride if args.eval_stride > 0 else 64, + log0=log0, + max_seconds=slot_time_budget, + ) + torch.cuda.synchronize() + log0( + f"slot val_loss:{slot_loss:.4f} val_bpb:{slot_bpb:.4f} " + f"eval_time:{1000.0 * (time.perf_counter() - t_slot):.0f}ms" + ) + log0(f"slot_exact val_loss:{slot_loss:.8f} val_bpb:{slot_bpb:.8f}") + if int(os.environ.get("NGRAM_EVAL", "0")): + ngram_max_order = int(os.environ.get("NGRAM_MAX_ORDER", "9")) + log0(f"ngram_eval:starting max_order={ngram_max_order}") + torch.cuda.synchronize() + t_ngram = time.perf_counter() + try: + from ngram_eval import eval_val_ngram + ng_val_loss, ng_val_bpb = eval_val_ngram( + args, base_model, model, rank, world_size, device, grad_accum_steps, + val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, + max_order=ngram_max_order, log_fn=log0, + ) + torch.cuda.synchronize() + log0( + f"ngram_eval_result val_loss:{ng_val_loss:.4f} val_bpb:{ng_val_bpb:.4f} " + f"eval_time:{1000.0 * (time.perf_counter() - t_ngram):.0f}ms" + ) + log0(f"ngram_eval_result_exact val_loss:{ng_val_loss:.8f} val_bpb:{ng_val_bpb:.8f}") + except Exception as e: + log0(f"ngram_eval:FAILED {e}") + if distributed: + dist.destroy_process_group() +if __name__ == "__main__": + main() diff --git a/train_gpt_phase4_swa.py b/train_gpt_phase4_swa.py new file mode 100644 index 0000000000..e876b8fd31 --- /dev/null +++ b/train_gpt_phase4_swa.py @@ -0,0 +1,1900 @@ +from __future__ import annotations +import copy +import glob +import io +import math +import os +import random +import subprocess +import sys +import time +import uuid +import zlib +from pathlib import Path +try: + import zstandard as zstd + _HAS_ZSTD = True +except ImportError: + _HAS_ZSTD = False +import numpy as np +import sentencepiece as spm +import torch +import torch.distributed as dist +import torch.nn.functional as F +from torch import Tensor, nn +from torch.nn.parallel import DistributedDataParallel as DDP +class Hyperparameters: + data_path = os.environ.get("DATA_PATH", "./data/datasets/fineweb10B_sp1024") + train_files = os.path.join(data_path, "fineweb_train_*.bin") + val_files = os.path.join(data_path, "fineweb_val_*.bin") + tokenizer_path = os.environ.get("TOKENIZER_PATH", "./data/tokenizers/fineweb_1024_bpe.model") + run_id = os.environ.get("RUN_ID", str(uuid.uuid4())) + seed = int(os.environ.get("SEED", 1337)) + val_batch_size = int(os.environ.get("VAL_BATCH_SIZE", 524_288)) + val_loss_every = int(os.environ.get("VAL_LOSS_EVERY", 1000)) + train_log_every = int(os.environ.get("TRAIN_LOG_EVERY", 200)) + val_max_tokens = int(os.environ.get("VAL_MAX_TOKENS", 0)) + checkpoint_every = int(os.environ.get("CHECKPOINT_EVERY", 0)) + iterations = int(os.environ.get("ITERATIONS", 20000)) + warmdown_iters = int(os.environ.get("WARMDOWN_ITERS", 3500)) + warmup_steps = int(os.environ.get("WARMUP_STEPS", 20)) + train_batch_tokens = int(os.environ.get("TRAIN_BATCH_TOKENS", 524_288)) + train_seq_len = int(os.environ.get("TRAIN_SEQ_LEN", 1024)) + max_wallclock_seconds = float(os.environ.get("MAX_WALLCLOCK_SECONDS", 600.0)) + qk_gain_init = float(os.environ.get("QK_GAIN_INIT", 1.5)) + vocab_size = int(os.environ.get("VOCAB_SIZE", 1024)) + num_layers = int(os.environ.get("NUM_LAYERS", 11)) + num_kv_heads = int(os.environ.get("NUM_KV_HEADS", 4)) + model_dim = int(os.environ.get("MODEL_DIM", 512)) + num_heads = int(os.environ.get("NUM_HEADS", 8)) + mlp_mult = int(os.environ.get("MLP_MULT", 3)) + tie_embeddings = bool(int(os.environ.get("TIE_EMBEDDINGS", "1"))) + rope_base = float(os.environ.get("ROPE_BASE", 10000.0)) + logit_softcap = float(os.environ.get("LOGIT_SOFTCAP", 30.0)) + embed_lr = float(os.environ.get("EMBED_LR", 0.6)) + head_lr = float(os.environ.get("HEAD_LR", 0.008)) + tied_embed_lr = float(os.environ.get("TIED_EMBED_LR", 0.05)) + tied_embed_init_std = float(os.environ.get("TIED_EMBED_INIT_STD", 0.005)) + matrix_lr = float(os.environ.get("MATRIX_LR", 0.04)) + scalar_lr = float(os.environ.get("SCALAR_LR", 0.04)) + muon_momentum = float(os.environ.get("MUON_MOMENTUM", 0.95)) + muon_backend_steps = int(os.environ.get("MUON_BACKEND_STEPS", 5)) + muon_momentum_warmup_start = float(os.environ.get("MUON_MOMENTUM_WARMUP_START", 0.85)) + muon_momentum_warmup_steps = int(os.environ.get("MUON_MOMENTUM_WARMUP_STEPS", 500)) + beta1 = float(os.environ.get("BETA1", 0.9)) + beta2 = float(os.environ.get("BETA2", 0.95)) + adam_eps = float(os.environ.get("ADAM_EPS", 1e-8)) + grad_clip_norm = float(os.environ.get("GRAD_CLIP_NORM", 0.0)) + bigramhash_buckets = int(os.environ.get("BIGRAMHASH_BUCKETS", 4096)) + ema_decay = float(os.environ.get("EMA_DECAY", 0.997)) + smeargate = bool(int(os.environ.get("SMEARGATE", "1"))) + unet_skips = bool(int(os.environ.get("UNET_SKIPS", "1"))) + int6_qat = bool(int(os.environ.get("INT6_QAT", "1"))) + rope_partial_dims = int(os.environ.get("ROPE_PARTIAL_DIMS", 16)) + ln_scale = bool(int(os.environ.get("LN_SCALE", "1"))) + xsa_layers = int(os.environ.get("XSA_LAYERS", 4)) + eval_stride = int(os.environ.get("EVAL_STRIDE", 0)) + ttt_enabled = bool(int(os.environ.get("TTT_ENABLED", "0"))) + ttt_lr = float(os.environ.get("TTT_LR", 0.002)) + ttt_epochs = int(os.environ.get("TTT_EPOCHS", 3)) + ttt_chunk_tokens = int(os.environ.get("TTT_CHUNK_TOKENS", 32768)) + ttt_freeze_blocks = int(os.environ.get("TTT_FREEZE_BLOCKS", 0)) + ttt_momentum = float(os.environ.get("TTT_MOMENTUM", 0.9)) + ttt_batch_seqs = int(os.environ.get("TTT_BATCH_SEQS", 32)) + ttt_grad_clip = float(os.environ.get("TTT_GRAD_CLIP", 1.0)) + ttt_lora_enabled = bool(int(os.environ.get("TTT_LORA_ENABLED", "0"))) + ttt_lora_rank = int(os.environ.get("TTT_LORA_RANK", 8)) + ttt_lora_lr = float(os.environ.get("TTT_LORA_LR", 0.001)) + recurrence_repeats = int(os.environ.get("RECURRENCE_REPEATS", 0)) + recurrence_unique_head = int(os.environ.get("RECURRENCE_UNIQUE_HEAD", 2)) + recurrence_unique_tail = int(os.environ.get("RECURRENCE_UNIQUE_TAIL", 2)) + moe_num_experts = int(os.environ.get("MOE_NUM_EXPERTS", 0)) + swa_enabled = bool(int(os.environ.get("SWA_ENABLED", "0"))) # default OFF + swa_every = int(os.environ.get("SWA_EVERY", 50)) # accumulate every N steps + swa_start_frac = float(os.environ.get("SWA_START_FRAC", 0.2)) # only during final 20% of warmdown (lr_scale < 0.2) +def zeropower_via_newtonschulz5(G: Tensor, steps: int = 10, eps: float = 1e-7) -> Tensor: + a, b, c = (3.4445, -4.7750, 2.0315) + X = G.bfloat16() + X /= X.norm() + eps + transposed = G.size(0) > G.size(1) + if transposed: + X = X.T + for _ in range(steps): + A = X @ X.T + B = b * A + c * A @ A + X = a * X + B @ X + return X.T if transposed else X +class Muon(torch.optim.Optimizer): + def __init__(self, params, lr: float, momentum: float, backend_steps: int, nesterov: bool = True): + super().__init__( + params, + dict(lr=lr, momentum=momentum, backend_steps=backend_steps, nesterov=nesterov), + ) + @torch.no_grad() + def step(self, closure=None): + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + distributed = dist.is_available() and dist.is_initialized() + world_size = dist.get_world_size() if distributed else 1 + rank = dist.get_rank() if distributed else 0 + for group in self.param_groups: + params = group["params"] + if not params: + continue + lr = group["lr"] + momentum = group["momentum"] + backend_steps = group["backend_steps"] + nesterov = group["nesterov"] + total_params = sum(int(p.numel()) for p in params) + updates_flat = torch.zeros(total_params, device=params[0].device, dtype=torch.bfloat16) + curr = 0 + for i, p in enumerate(params): + if i % world_size == rank and p.grad is not None: + g = p.grad + state = self.state[p] + if "momentum_buffer" not in state: + state["momentum_buffer"] = torch.zeros_like(g) + buf = state["momentum_buffer"] + buf.mul_(momentum).add_(g) + if nesterov: + g = g.add(buf, alpha=momentum) + g = zeropower_via_newtonschulz5(g, steps=backend_steps) + g *= max(1, g.size(0) / g.size(1)) ** 0.5 + updates_flat[curr : curr + p.numel()] = g.reshape(-1) + curr += p.numel() + if distributed: + dist.all_reduce(updates_flat, op=dist.ReduceOp.SUM) + curr = 0 + for p in params: + g = updates_flat[curr : curr + p.numel()].view_as(p).to(dtype=p.dtype) + p.add_(g, alpha=-lr) + curr += p.numel() + return loss +def int6_ste(w: Tensor) -> Tensor: + qmax = _QUANT_MAX_VAL # Uses QUANT_BITS override if set (e.g. 15 for int5, 7 for int4) + if w.ndim == 2: + abs_max = w.abs().amax(dim=1, keepdim=True).clamp_min(1e-8) + else: + abs_max = w.abs().amax().clamp_min(1e-8) + scale = abs_max / float(qmax) + w_q = (w / scale).round().clamp(-qmax, qmax) * scale + return w + (w_q - w).detach() +def build_sentencepiece_luts( + sp: spm.SentencePieceProcessor, vocab_size: int, device: torch.device +) -> tuple[Tensor, Tensor, Tensor]: + sp_vocab_size = int(sp.vocab_size()) + table_size = max(sp_vocab_size, vocab_size) + base_bytes_np = np.zeros((table_size,), dtype=np.int16) + has_leading_space_np = np.zeros((table_size,), dtype=np.bool_) + is_boundary_token_np = np.ones((table_size,), dtype=np.bool_) + for token_id in range(sp_vocab_size): + if sp.is_control(token_id) or sp.is_unknown(token_id) or sp.is_unused(token_id): + continue + is_boundary_token_np[token_id] = False + if sp.is_byte(token_id): + base_bytes_np[token_id] = 1 + continue + piece = sp.id_to_piece(token_id) + if piece.startswith("▁"): + has_leading_space_np[token_id] = True + piece = piece[1:] + base_bytes_np[token_id] = len(piece.encode("utf-8")) + return ( + torch.tensor(base_bytes_np, dtype=torch.int16, device=device), + torch.tensor(has_leading_space_np, dtype=torch.bool, device=device), + torch.tensor(is_boundary_token_np, dtype=torch.bool, device=device), + ) +def load_validation_tokens(pattern: str, seq_len: int) -> Tensor: + files = [Path(p) for p in sorted(glob.glob(pattern))] + if not files: + raise FileNotFoundError(f"No files found for pattern: {pattern}") + tokens = torch.cat([load_data_shard(file) for file in files]).contiguous() + max_tokens = int(os.environ.get("VAL_MAX_TOKENS", 0)) + if max_tokens > 0 and tokens.numel() > max_tokens: + tokens = tokens[:max_tokens].contiguous() + usable = ((tokens.numel() - 1) // seq_len) * seq_len + if usable <= 0: + raise ValueError(f"Validation split is too short for TRAIN_SEQ_LEN={seq_len}") + return tokens[: usable + 1] +def eval_val( + args: Hyperparameters, + model: nn.Module, + rank: int, + world_size: int, + device: torch.device, + grad_accum_steps: int, + val_tokens: Tensor, + base_bytes_lut: Tensor, + has_leading_space_lut: Tensor, + is_boundary_token_lut: Tensor, +) -> tuple[float, float]: + local_batch_tokens = args.val_batch_size // (world_size * grad_accum_steps) + if local_batch_tokens < args.train_seq_len: + raise ValueError( + "VAL_BATCH_SIZE must provide at least one sequence per rank; " + f"got VAL_BATCH_SIZE={args.val_batch_size}, WORLD_SIZE={world_size}, " + f"GRAD_ACCUM_STEPS={grad_accum_steps}, TRAIN_SEQ_LEN={args.train_seq_len}" + ) + local_batch_seqs = local_batch_tokens // args.train_seq_len + total_seqs = (val_tokens.numel() - 1) // args.train_seq_len + seq_start = (total_seqs * rank) // world_size + seq_end = (total_seqs * (rank + 1)) // world_size + val_loss_sum = torch.zeros((), device=device, dtype=torch.float64) + val_token_count = torch.zeros((), device=device, dtype=torch.float64) + val_byte_count = torch.zeros((), device=device, dtype=torch.float64) + model.eval() + with torch.inference_mode(): + for batch_seq_start in range(seq_start, seq_end, local_batch_seqs): + batch_seq_end = min(batch_seq_start + local_batch_seqs, seq_end) + raw_start = batch_seq_start * args.train_seq_len + raw_end = batch_seq_end * args.train_seq_len + 1 + local = val_tokens[raw_start:raw_end].to(device=device, dtype=torch.int64, non_blocking=True) + x = local[:-1].reshape(-1, args.train_seq_len) + y = local[1:].reshape(-1, args.train_seq_len) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + batch_loss = model(x, y).detach() + batch_token_count = float(y.numel()) + val_loss_sum += batch_loss.to(torch.float64) * batch_token_count + val_token_count += batch_token_count + prev_ids = x.reshape(-1) + tgt_ids = y.reshape(-1) + token_bytes = base_bytes_lut[tgt_ids].to(dtype=torch.int16) + token_bytes += (has_leading_space_lut[tgt_ids] & ~is_boundary_token_lut[prev_ids]).to(dtype=torch.int16) + val_byte_count += token_bytes.to(torch.float64).sum() + if dist.is_available() and dist.is_initialized(): + dist.all_reduce(val_loss_sum, op=dist.ReduceOp.SUM) + dist.all_reduce(val_token_count, op=dist.ReduceOp.SUM) + dist.all_reduce(val_byte_count, op=dist.ReduceOp.SUM) + val_loss = val_loss_sum / val_token_count + bits_per_token = val_loss.item() / math.log(2.0) + tokens_per_byte = val_token_count.item() / val_byte_count.item() + model.train() + return float(val_loss.item()), float(bits_per_token * tokens_per_byte) +def eval_val_sliding( + args: Hyperparameters, + base_model: nn.Module, + rank: int, + world_size: int, + device: torch.device, + val_tokens: Tensor, + base_bytes_lut: Tensor, + has_leading_space_lut: Tensor, + is_boundary_token_lut: Tensor, + stride: int, + batch_seqs: int = 32, +) -> tuple[float, float]: + seq_len = args.train_seq_len + total_tokens = val_tokens.numel() - 1 + window_starts = [ws for ws in range(0, total_tokens, stride) + if min(ws + seq_len, total_tokens) - ws >= 1] + total_windows = len(window_starts) + my_s = (total_windows * rank) // world_size + my_e = (total_windows * (rank + 1)) // world_size + my_windows = window_starts[my_s:my_e] + loss_sum = torch.zeros((), device=device, dtype=torch.float64) + token_count = torch.zeros((), device=device, dtype=torch.float64) + byte_count = torch.zeros((), device=device, dtype=torch.float64) + base_model.eval() + compiled_logits = torch.compile(base_model.forward_logits, dynamic=False, fullgraph=True) + with torch.inference_mode(): + for bi in range(0, len(my_windows), batch_seqs): + batch_ws = my_windows[bi:bi + batch_seqs] + bsz = len(batch_ws) + x_batch = torch.zeros(bsz, seq_len, dtype=torch.int64, device=device) + y_batch = torch.zeros(bsz, seq_len, dtype=torch.int64, device=device) + wlens: list[int] = [] + for i, ws in enumerate(batch_ws): + end = min(ws + seq_len, total_tokens) + wlen = end - ws + wlens.append(wlen) + chunk = val_tokens[ws:end + 1].to(dtype=torch.int64, device=device) + x_batch[i, :wlen] = chunk[:-1] + y_batch[i, :wlen] = chunk[1:] + with torch.autocast(device_type="cuda", dtype=torch.bfloat16): + logits = compiled_logits(x_batch) + nll = F.cross_entropy( + logits.reshape(-1, logits.size(-1)).float(), + y_batch.reshape(-1), + reduction="none", + ).reshape(bsz, seq_len) + for i, ws in enumerate(batch_ws): + wlen = wlens[i] + s = 0 if ws == 0 else max(wlen - stride, 0) + scored_nll = nll[i, s:wlen].to(torch.float64) + loss_sum += scored_nll.sum() + token_count += float(wlen - s) + tgt = y_batch[i, s:wlen] + prev = x_batch[i, s:wlen] + tb = base_bytes_lut[tgt].to(torch.float64) + tb += (has_leading_space_lut[tgt] & ~is_boundary_token_lut[prev]).to(torch.float64) + byte_count += tb.sum() + if dist.is_available() and dist.is_initialized(): + dist.all_reduce(loss_sum, op=dist.ReduceOp.SUM) + dist.all_reduce(token_count, op=dist.ReduceOp.SUM) + dist.all_reduce(byte_count, op=dist.ReduceOp.SUM) + val_loss = (loss_sum / token_count).item() + bits_per_token = val_loss / math.log(2.0) + tokens_per_byte = token_count.item() / byte_count.item() + base_model.train() + return val_loss, bits_per_token * tokens_per_byte +class LoRAAdapter(nn.Module): + """Low-Rank Adaptation adapter for a linear layer.""" + def __init__(self, base_linear: nn.Linear, rank: int = 8): + super().__init__() + dim_in = base_linear.in_features + dim_out = base_linear.out_features + self.A = nn.Parameter(torch.empty(dim_in, rank)) + self.B = nn.Parameter(torch.zeros(rank, dim_out)) + nn.init.kaiming_uniform_(self.A, a=math.sqrt(5)) + self.base_linear = base_linear + self._hook_handle = None + + def _lora_hook(self, module: nn.Module, input: tuple[Tensor, ...], output: Tensor) -> Tensor: + """Forward hook that adds LoRA contribution to base linear output.""" + x = input[0] + lora_delta = (x @ self.A.to(x.dtype)) @ self.B.to(x.dtype) + return output + lora_delta + + def register(self) -> None: + """Register the LoRA hook on the base linear layer.""" + if self._hook_handle is None: + self._hook_handle = self.base_linear.register_forward_hook(self._lora_hook) + + def remove(self) -> None: + """Remove the LoRA hook from the base linear layer.""" + if self._hook_handle is not None: + self._hook_handle.remove() + self._hook_handle = None + +def eval_val_sliding_lora_ttt( + args: Hyperparameters, + base_model: nn.Module, + rank: int, + world_size: int, + device: torch.device, + val_tokens: Tensor, + base_bytes_lut: Tensor, + has_leading_space_lut: Tensor, + is_boundary_token_lut: Tensor, + stride: int, + batch_seqs: int = 32, + log0=print, +) -> tuple[float, float]: + """Test-time training with LoRA adapters (only train adapter weights, not base model).""" + seq_len = args.train_seq_len + total_tokens = val_tokens.numel() - 1 + ttt_chunk = args.ttt_chunk_tokens + + window_starts = [ws for ws in range(0, total_tokens, stride) + if min(ws + seq_len, total_tokens) - ws >= stride or ws == 0] + num_chunks = (total_tokens + ttt_chunk - 1) // ttt_chunk + + chunk_windows: list[list[int]] = [[] for _ in range(num_chunks)] + for ws in window_starts: + end = min(ws + seq_len, total_tokens) + wlen = end - ws + s = 0 if ws == 0 else max(wlen - stride, 0) + scored_start = ws + s + ci = min(scored_start // ttt_chunk, num_chunks - 1) + chunk_windows[ci].append(ws) + + log0(f"lora_ttt_sliding:start chunks={num_chunks} chunk_tokens={ttt_chunk} " + f"total_windows={len(window_starts)} stride={stride} " + f"lora_rank={args.ttt_lora_rank} lora_lr={args.ttt_lora_lr} " + f"ttt_epochs={args.ttt_epochs}") + + loss_sum = torch.zeros((), device=device, dtype=torch.float64) + token_count = torch.zeros((), device=device, dtype=torch.float64) + byte_count = torch.zeros((), device=device, dtype=torch.float64) + + # Create LoRA adapters for attention projections and LM head + lora_adapters: list[LoRAAdapter] = [] + + # Add LoRA to attention blocks (c_q, c_v, proj) + if args.recurrence_repeats > 0: + # Recurrent model structure + all_blocks = list(base_model.head_blocks) + [base_model.shared_block] + list(base_model.tail_blocks) + else: + # Standard block structure + all_blocks = list(base_model.blocks) + + for block in all_blocks: + if hasattr(block, 'attn'): + attn = block.attn + # Add LoRA to Q, V, and output projection + for layer_name in ['c_q', 'c_v', 'proj']: + if hasattr(attn, layer_name): + base_linear = getattr(attn, layer_name) + adapter = LoRAAdapter(base_linear, rank=args.ttt_lora_rank).to(device) + lora_adapters.append(adapter) + + # Add LoRA to LM head (or tok_emb if tied) + if base_model.tie_embeddings: + # For tied embeddings, we need to patch tok_emb.weight usage in forward + # Create a pseudo-linear wrapper for the embedding weight + class EmbeddingAsLinear(nn.Module): + def __init__(self, weight: nn.Parameter): + super().__init__() + self.weight = weight + self.in_features = weight.size(1) + self.out_features = weight.size(0) + self.bias = None + + def forward(self, x: Tensor) -> Tensor: + return F.linear(x, self.weight) + + emb_linear = EmbeddingAsLinear(base_model.tok_emb.weight).to(device) + lm_adapter = LoRAAdapter(emb_linear, rank=args.ttt_lora_rank).to(device) + lora_adapters.append(lm_adapter) + elif base_model.lm_head is not None: + lm_adapter = LoRAAdapter(base_model.lm_head, rank=args.ttt_lora_rank).to(device) + lora_adapters.append(lm_adapter) + + # Collect all LoRA parameters + lora_params = [] + for adapter in lora_adapters: + lora_params.extend([adapter.A, adapter.B]) + + log0(f"lora_ttt_sliding:params lora={sum(p.numel() for p in lora_params)} " + f"base_frozen={sum(p.numel() for p in base_model.parameters())} " + f"adapters={len(lora_adapters)}") + + # Freeze all base model parameters + for p in base_model.parameters(): + p.requires_grad_(False) + + # Register all LoRA hooks + for adapter in lora_adapters: + adapter.register() + + # Use AdamW optimizer for LoRA params + optimizer = torch.optim.AdamW( + lora_params, + lr=args.ttt_lora_lr, + betas=(0.9, 0.999), + weight_decay=0.01, + ) + + # Compile forward paths + if not int(os.environ.get("TORCHDYNAMO_DISABLE", "0")): + compiled_logits_lora = torch.compile(base_model.forward_logits, dynamic=False, fullgraph=True) + compiled_forward_lora = torch.compile(base_model.forward, dynamic=False, fullgraph=True) + else: + compiled_logits_lora = base_model.forward_logits + compiled_forward_lora = base_model.forward + + t0 = time.perf_counter() + + for ci in range(num_chunks): + windows = chunk_windows[ci] + if not windows: + continue + + chunk_start = ci * ttt_chunk + chunk_end = min((ci + 1) * ttt_chunk, total_tokens) + + my_s = (len(windows) * rank) // world_size + my_e = (len(windows) * (rank + 1)) // world_size + my_windows = windows[my_s:my_e] + + # Score phase (inference only) + base_model.eval() + with torch.inference_mode(): + for bi in range(0, len(my_windows), batch_seqs): + batch_ws = my_windows[bi:bi + batch_seqs] + bsz = len(batch_ws) + x_batch = torch.zeros(bsz, seq_len, dtype=torch.int64, device=device) + y_batch = torch.zeros(bsz, seq_len, dtype=torch.int64, device=device) + wlens: list[int] = [] + + for i, ws in enumerate(batch_ws): + end = min(ws + seq_len, total_tokens) + wlen = end - ws + wlens.append(wlen) + chunk_tok = val_tokens[ws:end + 1].to(dtype=torch.int64, device=device) + x_batch[i, :wlen] = chunk_tok[:-1] + y_batch[i, :wlen] = chunk_tok[1:] + + with torch.autocast(device_type="cuda", dtype=torch.bfloat16): + logits = compiled_logits_lora(x_batch) + + nll = F.cross_entropy( + logits.reshape(-1, logits.size(-1)).float(), + y_batch.reshape(-1), reduction="none", + ).reshape(bsz, seq_len) + + for i, ws in enumerate(batch_ws): + wlen = wlens[i] + s = 0 if ws == 0 else max(wlen - stride, 0) + scored_nll = nll[i, s:wlen].to(torch.float64) + loss_sum += scored_nll.sum() + token_count += float(wlen - s) + + tgt, prev = y_batch[i, s:wlen], x_batch[i, s:wlen] + tb = base_bytes_lut[tgt].to(torch.float64) + tb += (has_leading_space_lut[tgt] & ~is_boundary_token_lut[prev]).to(torch.float64) + byte_count += tb.sum() + + # Train phase (update LoRA params only, skip last chunk) + is_last_chunk = (ci == num_chunks - 1) + if not is_last_chunk and args.ttt_epochs > 0: + base_model.train() + chunk_seqs = (chunk_end - chunk_start) // seq_len + + if chunk_seqs > 0: + # Cosine LR decay + cos_lr = args.ttt_lora_lr * 0.5 * (1.0 + math.cos(math.pi * ci / max(num_chunks - 1, 1))) + for pg in optimizer.param_groups: + pg['lr'] = cos_lr + + my_seq_s = (chunk_seqs * rank) // world_size + my_seq_e = (chunk_seqs * (rank + 1)) // world_size + my_chunk_seqs = my_seq_e - my_seq_s + + for _ep in range(args.ttt_epochs): + for bs in range(0, my_chunk_seqs, args.ttt_batch_seqs): + be = min(bs + args.ttt_batch_seqs, my_chunk_seqs) + actual_bs = my_seq_s + bs + start_tok = chunk_start + actual_bs * seq_len + end_tok = chunk_start + (my_seq_s + be) * seq_len + 1 + + if end_tok > val_tokens.numel(): + continue + + local = val_tokens[start_tok:end_tok].to(device=device, dtype=torch.int64) + x = local[:-1].reshape(-1, seq_len) + y = local[1:].reshape(-1, seq_len) + + optimizer.zero_grad(set_to_none=True) + + with torch.autocast(device_type="cuda", dtype=torch.bfloat16): + loss = compiled_forward_lora(x, y) + + loss.backward() + + # All-reduce LoRA gradients across GPUs + if world_size > 1: + for p in lora_params: + if p.grad is not None: + dist.all_reduce(p.grad, op=dist.ReduceOp.AVG) + + # Gradient clipping on LoRA params + torch.nn.utils.clip_grad_norm_(lora_params, args.ttt_grad_clip) + optimizer.step() + + if rank == 0 and (ci % 10 == 0 or ci == num_chunks - 1): + elapsed = time.perf_counter() - t0 + rl = loss_sum.item() / max(token_count.item(), 1) + rbpb = rl / math.log(2.0) * (token_count.item() / max(byte_count.item(), 1)) if token_count.item() > 0 else 0.0 + log0(f" lora_ttt_chunk [{ci+1}/{num_chunks}] bpb={rbpb:.6f} time={elapsed:.1f}s") + + if dist.is_available() and dist.is_initialized(): + dist.all_reduce(loss_sum, op=dist.ReduceOp.SUM) + dist.all_reduce(token_count, op=dist.ReduceOp.SUM) + dist.all_reduce(byte_count, op=dist.ReduceOp.SUM) + + val_loss = (loss_sum / token_count).item() + val_bpb = val_loss / math.log(2.0) * (token_count.item() / byte_count.item()) + + # Clean up: remove all LoRA hooks and restore base model + for adapter in lora_adapters: + adapter.remove() + + # Restore requires_grad on base model + for p in base_model.parameters(): + p.requires_grad_(True) + + base_model.eval() + + log0(f"lora_ttt_sliding:done val_loss={val_loss:.6f} val_bpb={val_bpb:.6f} " + f"elapsed={time.perf_counter() - t0:.1f}s") + + return val_loss, val_bpb + +def eval_val_sliding_ttt( + args: Hyperparameters, + base_model: nn.Module, + rank: int, + world_size: int, + device: torch.device, + val_tokens: Tensor, + base_bytes_lut: Tensor, + has_leading_space_lut: Tensor, + is_boundary_token_lut: Tensor, + stride: int, + batch_seqs: int = 32, + log0=print, +) -> tuple[float, float]: + seq_len = args.train_seq_len + total_tokens = val_tokens.numel() - 1 + ttt_chunk = args.ttt_chunk_tokens + window_starts = [ws for ws in range(0, total_tokens, stride) + if min(ws + seq_len, total_tokens) - ws >= stride or ws == 0] + num_chunks = (total_tokens + ttt_chunk - 1) // ttt_chunk + chunk_windows: list[list[int]] = [[] for _ in range(num_chunks)] + for ws in window_starts: + end = min(ws + seq_len, total_tokens) + wlen = end - ws + s = 0 if ws == 0 else max(wlen - stride, 0) + scored_start = ws + s + ci = min(scored_start // ttt_chunk, num_chunks - 1) + chunk_windows[ci].append(ws) + log0(f"ttt_sliding:start chunks={num_chunks} chunk_tokens={ttt_chunk} " + f"total_windows={len(window_starts)} stride={stride} " + f"ttt_lr={args.ttt_lr} ttt_epochs={args.ttt_epochs} " + f"freeze_blocks={args.ttt_freeze_blocks}") + loss_sum = torch.zeros((), device=device, dtype=torch.float64) + token_count = torch.zeros((), device=device, dtype=torch.float64) + byte_count = torch.zeros((), device=device, dtype=torch.float64) + frozen_block_ids = set(range(min(args.ttt_freeze_blocks, len(base_model.blocks)))) + ttt_params = [] + for name, p in base_model.named_parameters(): + freeze = False + for bi in frozen_block_ids: + if f"blocks.{bi}." in name: + freeze = True + break + if freeze: + p.requires_grad_(False) + else: + p.requires_grad_(True) + ttt_params.append(p) + log0(f"ttt_sliding:params unfrozen={sum(p.numel() for p in ttt_params)} " + f"frozen={sum(p.numel() for p in base_model.parameters() if not p.requires_grad)}") + optimizer = torch.optim.SGD(ttt_params, lr=args.ttt_lr, momentum=args.ttt_momentum) + # Compile forward paths for TTT (matches eval_val_sliding behavior) + if not int(os.environ.get("TORCHDYNAMO_DISABLE", "0")): + compiled_logits_ttt = torch.compile(base_model.forward_logits, dynamic=False, fullgraph=True) + compiled_forward_ttt = torch.compile(base_model.forward, dynamic=False, fullgraph=True) + else: + compiled_logits_ttt = base_model.forward_logits + compiled_forward_ttt = base_model.forward + t0 = time.perf_counter() + for ci in range(num_chunks): + windows = chunk_windows[ci] + if not windows: + continue + chunk_start = ci * ttt_chunk + chunk_end = min((ci + 1) * ttt_chunk, total_tokens) + my_s = (len(windows) * rank) // world_size + my_e = (len(windows) * (rank + 1)) // world_size + my_windows = windows[my_s:my_e] + base_model.eval() + with torch.inference_mode(): + for bi in range(0, len(my_windows), batch_seqs): + batch_ws = my_windows[bi:bi + batch_seqs] + bsz = len(batch_ws) + x_batch = torch.zeros(bsz, seq_len, dtype=torch.int64, device=device) + y_batch = torch.zeros(bsz, seq_len, dtype=torch.int64, device=device) + wlens: list[int] = [] + for i, ws in enumerate(batch_ws): + end = min(ws + seq_len, total_tokens) + wlen = end - ws + wlens.append(wlen) + chunk_tok = val_tokens[ws:end + 1].to(dtype=torch.int64, device=device) + x_batch[i, :wlen] = chunk_tok[:-1] + y_batch[i, :wlen] = chunk_tok[1:] + with torch.autocast(device_type="cuda", dtype=torch.bfloat16): + logits = compiled_logits_ttt(x_batch) + nll = F.cross_entropy( + logits.reshape(-1, logits.size(-1)).float(), + y_batch.reshape(-1), reduction="none", + ).reshape(bsz, seq_len) + for i, ws in enumerate(batch_ws): + wlen = wlens[i] + s = 0 if ws == 0 else max(wlen - stride, 0) + scored_nll = nll[i, s:wlen].to(torch.float64) + loss_sum += scored_nll.sum() + token_count += float(wlen - s) + tgt, prev = y_batch[i, s:wlen], x_batch[i, s:wlen] + tb = base_bytes_lut[tgt].to(torch.float64) + tb += (has_leading_space_lut[tgt] & ~is_boundary_token_lut[prev]).to(torch.float64) + byte_count += tb.sum() + is_last_chunk = (ci == num_chunks - 1) + if not is_last_chunk and args.ttt_epochs > 0: + base_model.train() + chunk_seqs = (chunk_end - chunk_start) // seq_len + if chunk_seqs > 0: + cos_lr = args.ttt_lr * 0.5 * (1.0 + math.cos(math.pi * ci / max(num_chunks - 1, 1))) + for pg in optimizer.param_groups: + pg['lr'] = cos_lr + my_seq_s = (chunk_seqs * rank) // world_size + my_seq_e = (chunk_seqs * (rank + 1)) // world_size + my_chunk_seqs = my_seq_e - my_seq_s + for _ep in range(args.ttt_epochs): + for bs in range(0, my_chunk_seqs, args.ttt_batch_seqs): + be = min(bs + args.ttt_batch_seqs, my_chunk_seqs) + actual_bs = my_seq_s + bs + start_tok = chunk_start + actual_bs * seq_len + end_tok = chunk_start + (my_seq_s + be) * seq_len + 1 + if end_tok > val_tokens.numel(): + continue + local = val_tokens[start_tok:end_tok].to(device=device, dtype=torch.int64) + x = local[:-1].reshape(-1, seq_len) + y = local[1:].reshape(-1, seq_len) + optimizer.zero_grad(set_to_none=True) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16): + loss = compiled_forward_ttt(x, y) + loss.backward() + if world_size > 1: + for p in ttt_params: + if p.grad is not None: + dist.all_reduce(p.grad, op=dist.ReduceOp.AVG) + torch.nn.utils.clip_grad_norm_(ttt_params, args.ttt_grad_clip) + optimizer.step() + if rank == 0 and (ci % 10 == 0 or ci == num_chunks - 1): + elapsed = time.perf_counter() - t0 + rl = loss_sum.item() / max(token_count.item(), 1) + rbpb = rl / math.log(2.0) * (token_count.item() / max(byte_count.item(), 1)) if token_count.item() > 0 else 0.0 + log0(f" ttt_chunk [{ci+1}/{num_chunks}] bpb={rbpb:.6f} time={elapsed:.1f}s") + if dist.is_available() and dist.is_initialized(): + dist.all_reduce(loss_sum, op=dist.ReduceOp.SUM) + dist.all_reduce(token_count, op=dist.ReduceOp.SUM) + dist.all_reduce(byte_count, op=dist.ReduceOp.SUM) + val_loss = (loss_sum / token_count).item() + val_bpb = val_loss / math.log(2.0) * (token_count.item() / byte_count.item()) + for p in base_model.parameters(): + p.requires_grad_(True) + base_model.eval() + log0(f"ttt_sliding:done val_loss={val_loss:.6f} val_bpb={val_bpb:.6f} " + f"elapsed={time.perf_counter() - t0:.1f}s") + return val_loss, val_bpb +_INT6_MODE = bool(int(os.environ.get("INT6_QAT", "1"))) +_INT3_MODE = bool(int(os.environ.get("INT3_QUANT", "0"))) # Ternary quantization +_QUANT_OVERRIDE = int(os.environ.get("QUANT_BITS", "0")) # Override: 3,4,5,6,8 +if _QUANT_OVERRIDE > 0: + _QUANT_BITS = _QUANT_OVERRIDE + _QUANT_MAX_VAL = (1 << (_QUANT_BITS - 1)) - 1 # e.g. int5 -> 15, int4 -> 7 +elif _INT3_MODE: + _QUANT_MAX_VAL = 3 # Ternary: -3, -2, -1, 0, 1, 2, 3 + _QUANT_BITS = 3 +elif _INT6_MODE: + _QUANT_MAX_VAL = 31 + _QUANT_BITS = 6 +else: + _QUANT_MAX_VAL = 127 + _QUANT_BITS = 8 +CONTROL_TENSOR_NAME_PATTERNS = tuple( + pattern + for pattern in os.environ.get( + "CONTROL_TENSOR_NAME_PATTERNS", + "attn_scale,attn_scales,mlp_scale,mlp_scales,resid_mix,resid_mixes,q_gain,skip_weight,skip_weights,smear", + ).split(",") + if pattern +) +INT8_KEEP_FLOAT_FP32_NAME_PATTERNS = tuple( + pattern + for pattern in os.environ.get( + "INT8_KEEP_FLOAT_FP32_NAME_PATTERNS", + ",".join(CONTROL_TENSOR_NAME_PATTERNS), + ).split(",") + if pattern +) +INT8_KEEP_FLOAT_MAX_NUMEL = 65_536 +INT8_KEEP_FLOAT_STORE_DTYPE = torch.float16 +INT8_PER_ROW_SCALE_DTYPE = torch.float16 +INT8_CLIP_PERCENTILE = 99.99984 +INT8_CLIP_Q = INT8_CLIP_PERCENTILE / 100.0 +def tensor_nbytes(t: Tensor) -> int: + return int(t.numel()) * int(t.element_size()) +def keep_float_tensor(name: str, t: Tensor, passthrough_orig_dtypes: dict[str, str]) -> Tensor: + if any(pattern in name for pattern in INT8_KEEP_FLOAT_FP32_NAME_PATTERNS): + return t.float().contiguous() + if t.dtype in {torch.float32, torch.bfloat16}: + passthrough_orig_dtypes[name] = str(t.dtype).removeprefix("torch.") + return t.to(dtype=INT8_KEEP_FLOAT_STORE_DTYPE).contiguous() + return t +_GPTQ_LITE_PERCENTILES = [0.9999, 0.99995, 0.99999, 0.999995, 1.0] +def quantize_float_tensor(t: Tensor, gpu_device: torch.device | None = None) -> tuple[Tensor, Tensor]: + # Run GPTQ-lite search on GPU when available (much faster for large 2D weight matrices) + dev = gpu_device if (gpu_device is not None and t.ndim == 2 and t.numel() > 0) else t.device + t32 = t.float().to(dev) + qmax = _QUANT_MAX_VAL + if t32.ndim == 2 and t32.numel() > 0: + abs_vals = t32.abs() + best_q = None + best_scale = None + best_mse = None + for pct in _GPTQ_LITE_PERCENTILES: + if pct >= 1.0: + clip_abs = abs_vals.amax(dim=1) + else: + clip_abs = torch.quantile(abs_vals, pct, dim=1) + clip_abs = clip_abs.clamp_min(1e-8) + scale = (clip_abs / qmax).clamp_min(1.0 / qmax) + clipped = torch.maximum(torch.minimum(t32, clip_abs[:, None]), -clip_abs[:, None]) + q = torch.clamp(torch.round(clipped / scale[:, None]), -qmax, qmax) + recon = q * scale[:, None] + mse = (t32 - recon).square().mean(dim=1) + if best_mse is None: + best_mse = mse + best_q = q + best_scale = scale + else: + improved = mse < best_mse + if improved.any(): + best_mse = torch.where(improved, mse, best_mse) + best_q = torch.where(improved[:, None], q, best_q) + best_scale = torch.where(improved, scale, best_scale) + return best_q.cpu().to(torch.int8).contiguous(), best_scale.cpu().to(dtype=INT8_PER_ROW_SCALE_DTYPE).contiguous() + if t32.ndim == 2: + return torch.empty_like(t32, dtype=torch.int8, device="cpu"), torch.empty((t32.shape[0],), dtype=INT8_PER_ROW_SCALE_DTYPE) + t32 = t32.cpu() + clip_abs = float(torch.quantile(t32.abs().flatten(), INT8_CLIP_Q).item()) if t32.numel() else 0.0 + scale = torch.tensor(clip_abs / qmax if clip_abs > 0 else 1.0, dtype=torch.float32) + q = torch.clamp(torch.round(torch.clamp(t32, -clip_abs, clip_abs) / scale), -qmax, qmax).to(torch.int8).contiguous() + return q, scale +def quantize_state_dict_int8(state_dict: dict[str, Tensor], gpu_device: torch.device | None = None, + rank: int = 0, world_size: int = 1): + """Quantize state dict. When world_size > 1, distributes large tensor quantization across ranks.""" + quantized: dict[str, Tensor] = {} + scales: dict[str, Tensor] = {} + dtypes: dict[str, str] = {} + passthrough: dict[str, Tensor] = {} + passthrough_orig_dtypes: dict[str, str] = {} + qmeta: dict[str, dict[str, object]] = {} + stats = dict.fromkeys( + ("param_count", "num_tensors", "num_float_tensors", "num_nonfloat_tensors", "baseline_tensor_bytes", "int8_payload_bytes"), + 0, + ) + # Separate tensors into passthrough vs needs-quantization + to_quantize: list[tuple[str, Tensor]] = [] + for name, tensor in state_dict.items(): + t = tensor.detach().to("cpu").contiguous() + stats["param_count"] += int(t.numel()) + stats["num_tensors"] += 1 + stats["baseline_tensor_bytes"] += tensor_nbytes(t) + if not t.is_floating_point(): + stats["num_nonfloat_tensors"] += 1 + passthrough[name] = t + stats["int8_payload_bytes"] += tensor_nbytes(t) + continue + if t.numel() <= INT8_KEEP_FLOAT_MAX_NUMEL: + kept = keep_float_tensor(name, t, passthrough_orig_dtypes) + passthrough[name] = kept + stats["int8_payload_bytes"] += tensor_nbytes(kept) + continue + stats["num_float_tensors"] += 1 + to_quantize.append((name, t)) + # Distribute quantization work across ranks (round-robin by tensor index) + for idx, (name, t) in enumerate(to_quantize): + if world_size > 1 and idx % world_size != rank: + continue # Another rank handles this tensor + q, s = quantize_float_tensor(t, gpu_device=gpu_device) + if s.ndim > 0: + qmeta[name] = {"scheme": "per_row", "axis": 0} + quantized[name] = q + scales[name] = s + dtypes[name] = str(t.dtype).removeprefix("torch.") + stats["int8_payload_bytes"] += tensor_nbytes(q) + tensor_nbytes(s) + # Gather results from all ranks + if world_size > 1 and dist.is_available() and dist.is_initialized(): + # Each rank broadcasts its quantized tensors to all others + for idx, (name, t) in enumerate(to_quantize): + owner = idx % world_size + if owner == rank: + # Broadcast shape info then tensor data + q_tensor = quantized[name] + s_tensor = scales[name] + # Send via broadcast + q_gpu = q_tensor.to(gpu_device) if gpu_device is not None else q_tensor.cuda() + s_gpu = s_tensor.to(gpu_device) if gpu_device is not None else s_tensor.cuda() + dist.broadcast(q_gpu, src=owner) + dist.broadcast(s_gpu, src=owner) + quantized[name] = q_gpu.cpu() + scales[name] = s_gpu.cpu() + else: + # Allocate matching tensors and receive + q_shape = (t.shape[0],) + t.shape[1:] if t.ndim == 2 else t.shape + q_gpu = torch.empty(q_shape, dtype=torch.int8, device=gpu_device if gpu_device is not None else "cuda") + if t.ndim == 2 and t.numel() > 0: + s_gpu = torch.empty(t.shape[0], dtype=INT8_PER_ROW_SCALE_DTYPE, device=gpu_device if gpu_device is not None else "cuda") + else: + s_gpu = torch.empty((), dtype=torch.float32, device=gpu_device if gpu_device is not None else "cuda") + dist.broadcast(q_gpu, src=owner) + dist.broadcast(s_gpu, src=owner) + quantized[name] = q_gpu.cpu() + scales[name] = s_gpu.cpu() + dtypes[name] = str(t.dtype).removeprefix("torch.") + if s_gpu.ndim > 0: + qmeta[name] = {"scheme": "per_row", "axis": 0} + stats["int8_payload_bytes"] += tensor_nbytes(quantized[name]) + tensor_nbytes(scales[name]) + obj: dict[str, object] = { + "__quant_format__": f"int{_QUANT_BITS}_clean_per_row_v1", + "quantized": quantized, + "scales": scales, + "dtypes": dtypes, + "passthrough": passthrough, + } + if qmeta: + obj["qmeta"] = qmeta + if passthrough_orig_dtypes: + obj["passthrough_orig_dtypes"] = passthrough_orig_dtypes + return obj, stats +def dequantize_state_dict_int8(obj: dict[str, object]) -> dict[str, Tensor]: + out: dict[str, Tensor] = {} + qmeta = obj.get("qmeta", {}) + passthrough_orig_dtypes = obj.get("passthrough_orig_dtypes", {}) + for name, q in obj["quantized"].items(): + dtype = getattr(torch, obj["dtypes"][name]) + s = obj["scales"][name] + if qmeta.get(name, {}).get("scheme") == "per_row" or s.ndim > 0: + s = s.to(dtype=torch.float32) + out[name] = (q.float() * s.view(q.shape[0], *([1] * (q.ndim - 1)))).to(dtype=dtype).contiguous() + else: + scale = float(s.item()) + out[name] = (q.float() * scale).to(dtype=dtype).contiguous() + for name, t in obj["passthrough"].items(): + out_t = t.detach().to("cpu").contiguous() + orig_dtype = passthrough_orig_dtypes.get(name) + if isinstance(orig_dtype, str): + out_t = out_t.to(dtype=getattr(torch, orig_dtype)).contiguous() + out[name] = out_t + return out +def load_data_shard(file: Path) -> Tensor: + header_bytes = 256 * np.dtype(" None: + self.file_idx = (self.file_idx + 1) % len(self.files) + self.tokens = load_data_shard(self.files[self.file_idx]) + self.pos = 0 + def take(self, n: int) -> Tensor: + chunks: list[Tensor] = [] + remaining = n + while remaining > 0: + avail = self.tokens.numel() - self.pos + if avail <= 0: + self._advance_file() + continue + k = min(remaining, avail) + chunks.append(self.tokens[self.pos : self.pos + k]) + self.pos += k + remaining -= k + return chunks[0] if len(chunks) == 1 else torch.cat(chunks) +class DistributedTokenLoader: + def __init__(self, pattern: str, rank: int, world_size: int, device: torch.device): + self.rank = rank + self.world_size = world_size + self.device = device + self.stream = TokenStream(pattern) + def next_batch(self, global_tokens: int, seq_len: int, grad_accum_steps: int) -> tuple[Tensor, Tensor]: + local_tokens = global_tokens // (self.world_size * grad_accum_steps) + per_rank_span = local_tokens + 1 + chunk = self.stream.take(per_rank_span * self.world_size) + start = self.rank * per_rank_span + local = chunk[start : start + per_rank_span].to(dtype=torch.int64) + x = local[:-1].reshape(-1, seq_len) + y = local[1:].reshape(-1, seq_len) + return x.to(self.device, non_blocking=True), y.to(self.device, non_blocking=True) +class RMSNorm(nn.Module): + def __init__(self, eps: float | None = None): + super().__init__() + self.eps = eps + def forward(self, x: Tensor) -> Tensor: + return F.rms_norm(x, (x.size(-1),), eps=self.eps) +class CastedLinear(nn.Linear): + _int6_qat: bool = False + def forward(self, x: Tensor) -> Tensor: + w = self.weight.to(x.dtype) + if CastedLinear._int6_qat and self.training: + w = int6_ste(w) + bias = self.bias.to(x.dtype) if self.bias is not None else None + return F.linear(x, w, bias) +def restore_low_dim_params_to_fp32(module: nn.Module) -> None: + with torch.no_grad(): + for name, param in module.named_parameters(): + if (param.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS)) and param.dtype != torch.float32: + param.data = param.data.float() +class BigramHash(nn.Module): + def __init__(self, num_buckets: int, dim: int): + super().__init__() + self.num_buckets = num_buckets + self.emb = nn.Embedding(num_buckets, dim) + nn.init.zeros_(self.emb.weight) + def forward(self, input_ids: Tensor) -> Tensor: + prev = F.pad(input_ids[:, :-1], (1, 0), value=0) + bigram_hash = (prev * 31 + input_ids) % self.num_buckets + return self.emb(bigram_hash) +class SmearGate(nn.Module): + def __init__(self, dim: int): + super().__init__() + self.gate = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + def forward(self, x: Tensor) -> Tensor: + g = torch.sigmoid(self.gate.to(dtype=x.dtype))[None, None, :] + positions = torch.arange(1, x.shape[1] + 1, device=x.device, dtype=x.dtype).view(1, -1, 1) + smooth = torch.cumsum(x, dim=1) / positions + return g * x + (1 - g) * smooth +class Rotary(nn.Module): + def __init__(self, dim: int, base: float = 10000.0): + super().__init__() + inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.float32) / dim)) + self.register_buffer("inv_freq", inv_freq, persistent=False) + self._seq_len_cached = 0 + self._cos_cached: Tensor | None = None + self._sin_cached: Tensor | None = None + def forward(self, seq_len: int, device: torch.device, dtype: torch.dtype) -> tuple[Tensor, Tensor]: + if ( + self._cos_cached is None + or self._sin_cached is None + or self._seq_len_cached != seq_len + or self._cos_cached.device != device + ): + t = torch.arange(seq_len, device=device, dtype=self.inv_freq.dtype) + freqs = torch.outer(t, self.inv_freq.to(device)) + self._cos_cached = freqs.cos()[None, None, :, :] + self._sin_cached = freqs.sin()[None, None, :, :] + self._seq_len_cached = seq_len + return self._cos_cached.to(dtype=dtype), self._sin_cached.to(dtype=dtype) +def apply_rotary_emb(x: Tensor, cos: Tensor, sin: Tensor) -> Tensor: + half = x.size(-1) // 2 + x1, x2 = x[..., :half], x[..., half:] + return torch.cat((x1 * cos + x2 * sin, x1 * (-sin) + x2 * cos), dim=-1) +class CausalSelfAttention(nn.Module): + def __init__( + self, + dim: int, + num_heads: int, + num_kv_heads: int, + rope_base: float, + qk_gain_init: float, + rope_partial_dims: int = 0, + use_xsa: bool = False, + ): + super().__init__() + if dim % num_heads != 0: + raise ValueError("model_dim must be divisible by num_heads") + if num_heads % num_kv_heads != 0: + raise ValueError("num_heads must be divisible by num_kv_heads") + self.num_heads = num_heads + self.num_kv_heads = num_kv_heads + self.head_dim = dim // num_heads + if self.head_dim % 2 != 0: + raise ValueError("head_dim must be even for RoPE") + kv_dim = self.num_kv_heads * self.head_dim + self.c_q = CastedLinear(dim, dim, bias=False) + self.c_k = CastedLinear(dim, kv_dim, bias=False) + self.c_v = CastedLinear(dim, kv_dim, bias=False) + self.proj = CastedLinear(dim, dim, bias=False) + self.proj._zero_init = True + self.q_gain = nn.Parameter(torch.full((num_heads,), qk_gain_init, dtype=torch.float32)) + rope_dim = rope_partial_dims if 0 < rope_partial_dims < self.head_dim else self.head_dim + self.rope_dim = rope_dim + self.rotary = Rotary(rope_dim, base=rope_base) + self.use_xsa = use_xsa + def forward(self, x: Tensor) -> Tensor: + bsz, seqlen, dim = x.shape + q = self.c_q(x).reshape(bsz, seqlen, self.num_heads, self.head_dim).transpose(1, 2) + k = self.c_k(x).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim).transpose(1, 2) + v = self.c_v(x).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim).transpose(1, 2) + q = F.rms_norm(q, (q.size(-1),)) + k = F.rms_norm(k, (k.size(-1),)) + cos, sin = self.rotary(seqlen, x.device, q.dtype) + if self.rope_dim < self.head_dim: + q_rope, q_pass = q[..., :self.rope_dim], q[..., self.rope_dim:] + k_rope, k_pass = k[..., :self.rope_dim], k[..., self.rope_dim:] + q_rope = apply_rotary_emb(q_rope, cos, sin) + k_rope = apply_rotary_emb(k_rope, cos, sin) + q = torch.cat([q_rope, q_pass], dim=-1) + k = torch.cat([k_rope, k_pass], dim=-1) + else: + q = apply_rotary_emb(q, cos, sin) + k = apply_rotary_emb(k, cos, sin) + q = q * self.q_gain.to(dtype=q.dtype)[None, :, None, None] + y = F.scaled_dot_product_attention( + q, + k, + v, + attn_mask=None, + is_causal=True, + enable_gqa=(self.num_kv_heads != self.num_heads), + ) + if self.use_xsa: + if self.num_kv_heads != self.num_heads: + repeats = self.num_heads // self.num_kv_heads + v_expanded = v.repeat_interleave(repeats, dim=1) + else: + v_expanded = v + y = y - v_expanded + y = y.transpose(1, 2).contiguous().reshape(bsz, seqlen, dim) + return self.proj(y) +class MLP(nn.Module): + def __init__(self, dim: int, mlp_mult: int): + super().__init__() + hidden = mlp_mult * dim + self.fc = CastedLinear(dim, hidden, bias=False) + self.proj = CastedLinear(hidden, dim, bias=False) + self.proj._zero_init = True + def forward(self, x: Tensor) -> Tensor: + x = F.leaky_relu(self.fc(x), negative_slope=0.5) + return self.proj(x.square()) +class MoEMLP(nn.Module): + def __init__(self, dim: int, mlp_mult: int, num_experts: int): + super().__init__() + self.num_experts = num_experts + expert_hidden = (mlp_mult * dim) // num_experts + self.experts = nn.ModuleList([ + nn.ModuleDict({ + "fc": CastedLinear(dim, expert_hidden, bias=False), + "proj": CastedLinear(expert_hidden, dim, bias=False), + }) + for _ in range(num_experts) + ]) + for e in self.experts: + e["proj"]._zero_init = True + self.router = nn.Linear(dim, num_experts, bias=False) + def forward(self, x: Tensor) -> Tensor: + bsz, seqlen, dim = x.shape + logits = self.router(x.detach()) + weights = torch.softmax(logits, dim=-1) + indices = weights.argmax(dim=-1) + out = torch.zeros_like(x) + for i, expert in enumerate(self.experts): + mask = (indices == i) + if not mask.any(): + continue + tokens = x[mask] + h = F.leaky_relu(expert["fc"](tokens), negative_slope=0.5) + h = expert["proj"](h.square()) + out[mask] = h + return out +class Block(nn.Module): + def __init__( + self, + dim: int, + num_heads: int, + num_kv_heads: int, + mlp_mult: int, + rope_base: float, + qk_gain_init: float, + use_smeargate: bool = False, + rope_partial_dims: int = 0, + use_xsa: bool = False, + ln_scale_factor: float = 1.0, + moe_num_experts: int = 0, + ): + super().__init__() + self.attn_norm = RMSNorm() + self.mlp_norm = RMSNorm() + self.attn = CausalSelfAttention( + dim, num_heads, num_kv_heads, rope_base, qk_gain_init, + rope_partial_dims=rope_partial_dims, + use_xsa=use_xsa, + ) + self.mlp = MoEMLP(dim, mlp_mult, moe_num_experts) if moe_num_experts > 1 else MLP(dim, mlp_mult) + self.attn_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + self.mlp_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + self.resid_mix = nn.Parameter(torch.stack((torch.ones(dim), torch.zeros(dim))).float()) + self.smear = SmearGate(dim) if use_smeargate else None + self.ln_scale_factor = ln_scale_factor + def forward(self, x: Tensor, x0: Tensor) -> Tensor: + mix = self.resid_mix.to(dtype=x.dtype) + x = mix[0][None, None, :] * x + mix[1][None, None, :] * x0 + normed = self.attn_norm(x) + if self.ln_scale_factor != 1.0: + normed = normed * self.ln_scale_factor + attn_out = self.attn(normed) + x = x + self.attn_scale.to(dtype=x.dtype)[None, None, :] * attn_out + mlp_normed = self.mlp_norm(x) + if self.ln_scale_factor != 1.0: + mlp_normed = mlp_normed * self.ln_scale_factor + x = x + self.mlp_scale.to(dtype=x.dtype)[None, None, :] * self.mlp(mlp_normed) + if self.smear is not None: + x = self.smear(x) + return x +class GPT(nn.Module): + def __init__( + self, + vocab_size: int, + num_layers: int, + model_dim: int, + num_heads: int, + num_kv_heads: int, + mlp_mult: int, + tie_embeddings: bool, + tied_embed_init_std: float, + logit_softcap: float, + rope_base: float, + qk_gain_init: float, + bigramhash_buckets: int = 0, + use_smeargate: bool = False, + unet_skips: bool = True, + rope_partial_dims: int = 0, + ln_scale: bool = False, + xsa_layers: int = 0, + recurrence_repeats: int = 0, + recurrence_unique_head: int = 2, + recurrence_unique_tail: int = 2, + moe_num_experts: int = 0, + ): + super().__init__() + if logit_softcap <= 0.0: + raise ValueError(f"logit_softcap must be positive, got {logit_softcap}") + self.tie_embeddings = tie_embeddings + self.tied_embed_init_std = tied_embed_init_std + self.logit_softcap = logit_softcap + self.tok_emb = nn.Embedding(vocab_size, model_dim) + self.bigram_hash = BigramHash(bigramhash_buckets, model_dim) if bigramhash_buckets > 0 else None + self.recurrence_repeats = recurrence_repeats + if recurrence_repeats > 0: + self.num_encoder_layers = 0 + self.num_decoder_layers = 0 + self.num_skip_weights = 0 + self.skip_weights = nn.Parameter(torch.zeros(0, dtype=torch.float32)) + n_head = recurrence_unique_head + n_tail = recurrence_unique_tail + effective_layers = n_head + recurrence_repeats + n_tail + def make_block(layer_idx: int, eff_total: int) -> Block: + xsa_start_idx = eff_total - xsa_layers if xsa_layers > 0 else eff_total + return Block( + model_dim, num_heads, num_kv_heads, mlp_mult, rope_base, qk_gain_init, + use_smeargate=use_smeargate, + rope_partial_dims=rope_partial_dims, + use_xsa=(layer_idx >= xsa_start_idx), + ln_scale_factor=(1.0 / math.sqrt(layer_idx + 1)) if ln_scale else 1.0, + moe_num_experts=moe_num_experts, + ) + self.head_blocks = nn.ModuleList([make_block(i, effective_layers) for i in range(n_head)]) + self.shared_block = make_block(n_head, effective_layers) + self.recurrence_gate = nn.Parameter(torch.full((recurrence_repeats,), 0.5, dtype=torch.float32)) + self.tail_blocks = nn.ModuleList([ + make_block(n_head + recurrence_repeats + i, effective_layers) for i in range(n_tail) + ]) + self.blocks = nn.ModuleList() + else: + if unet_skips: + self.num_encoder_layers = num_layers // 2 + self.num_decoder_layers = num_layers - self.num_encoder_layers + self.num_skip_weights = min(self.num_encoder_layers, self.num_decoder_layers) + else: + self.num_encoder_layers = num_layers + self.num_decoder_layers = 0 + self.num_skip_weights = 0 + self.skip_weights = nn.Parameter(torch.ones(self.num_skip_weights, model_dim, dtype=torch.float32)) + xsa_start = num_layers - xsa_layers if xsa_layers > 0 else num_layers + self.blocks = nn.ModuleList( + [ + Block( + model_dim, num_heads, num_kv_heads, mlp_mult, rope_base, qk_gain_init, + use_smeargate=use_smeargate, + rope_partial_dims=rope_partial_dims, + use_xsa=(i >= xsa_start), + ln_scale_factor=(1.0 / math.sqrt(i + 1)) if ln_scale else 1.0, + moe_num_experts=moe_num_experts, + ) + for i in range(num_layers) + ] + ) + self.head_blocks = nn.ModuleList() + self.shared_block = None + self.recurrence_gate = None + self.tail_blocks = nn.ModuleList() + self.final_norm = RMSNorm() + self.lm_head = None if tie_embeddings else CastedLinear(model_dim, vocab_size, bias=False) + if self.lm_head is not None: + self.lm_head._zero_init = True + self._init_weights() + def _init_weights(self) -> None: + if self.tie_embeddings: + nn.init.normal_(self.tok_emb.weight, mean=0.0, std=self.tied_embed_init_std) + for module in self.modules(): + if isinstance(module, nn.Linear) and getattr(module, "_zero_init", False): + nn.init.zeros_(module.weight) + def forward(self, input_ids: Tensor, target_ids: Tensor) -> Tensor: + x = self.tok_emb(input_ids) + if self.bigram_hash is not None: + x = x + self.bigram_hash(input_ids) + x = F.rms_norm(x, (x.size(-1),)) + x0 = x + if self.recurrence_repeats > 0: + for block in self.head_blocks: + x = block(x, x0) + for r in range(self.recurrence_repeats): + gate = torch.sigmoid(self.recurrence_gate[r]).to(dtype=x.dtype) + residual = x + x = self.shared_block(x, x0) + x = gate * x + (1.0 - gate) * residual + for block in self.tail_blocks: + x = block(x, x0) + else: + skips: list[Tensor] = [] + for i in range(self.num_encoder_layers): + x = self.blocks[i](x, x0) + skips.append(x) + for i in range(self.num_decoder_layers): + if skips: + x = x + self.skip_weights[i].to(dtype=x.dtype)[None, None, :] * skips.pop() + x = self.blocks[self.num_encoder_layers + i](x, x0) + x = self.final_norm(x).reshape(-1, x.size(-1)) + targets = target_ids.reshape(-1) + if self.tie_embeddings: + logits_proj = F.linear(x, self.tok_emb.weight) + else: + if self.lm_head is None: + raise RuntimeError("lm_head is required when tie_embeddings=False") + logits_proj = self.lm_head(x) + logits = self.logit_softcap * torch.tanh(logits_proj / self.logit_softcap) + if getattr(self, '_return_logits', False): + return logits.float().reshape(input_ids.shape[0], input_ids.shape[1], -1) + return F.cross_entropy(logits.float(), targets, reduction="mean") + def forward_logits(self, input_ids: Tensor) -> Tensor: + x = self.tok_emb(input_ids) + if self.bigram_hash is not None: + x = x + self.bigram_hash(input_ids) + x = F.rms_norm(x, (x.size(-1),)) + x0 = x + if self.recurrence_repeats > 0: + for block in self.head_blocks: + x = block(x, x0) + for r in range(self.recurrence_repeats): + gate = torch.sigmoid(self.recurrence_gate[r]).to(dtype=x.dtype) + residual = x + x = self.shared_block(x, x0) + x = gate * x + (1.0 - gate) * residual + for block in self.tail_blocks: + x = block(x, x0) + else: + skips: list[Tensor] = [] + for i in range(self.num_encoder_layers): + x = self.blocks[i](x, x0) + skips.append(x) + for i in range(self.num_decoder_layers): + if skips: + x = x + self.skip_weights[i].to(dtype=x.dtype)[None, None, :] * skips.pop() + x = self.blocks[self.num_encoder_layers + i](x, x0) + x = self.final_norm(x) + if self.tie_embeddings: + logits_proj = F.linear(x, self.tok_emb.weight) + else: + if self.lm_head is None: + raise RuntimeError("lm_head is required when tie_embeddings=False") + logits_proj = self.lm_head(x) + return self.logit_softcap * torch.tanh(logits_proj / self.logit_softcap) +def main() -> None: + global zeropower_via_newtonschulz5 + code = Path(__file__).read_text(encoding="utf-8") + args = Hyperparameters() + if not int(os.environ.get("TORCHDYNAMO_DISABLE", "0")): + zeropower_via_newtonschulz5 = torch.compile(zeropower_via_newtonschulz5) + distributed = "RANK" in os.environ and "WORLD_SIZE" in os.environ + rank = int(os.environ.get("RANK", "0")) + world_size = int(os.environ.get("WORLD_SIZE", "1")) + local_rank = int(os.environ.get("LOCAL_RANK", "0")) + if world_size <= 0: + raise ValueError(f"WORLD_SIZE must be positive, got {world_size}") + if 8 % world_size != 0: + raise ValueError(f"WORLD_SIZE={world_size} must divide 8 so grad_accum_steps stays integral") + grad_accum_steps = 8 // world_size + grad_scale = 1.0 / grad_accum_steps + if not torch.cuda.is_available(): + raise RuntimeError("CUDA is required") + device = torch.device("cuda", local_rank) + torch.cuda.set_device(device) + if distributed: + dist.init_process_group(backend="nccl", device_id=device) + dist.barrier() + master_process = rank == 0 + torch.backends.cuda.matmul.allow_tf32 = True + torch.backends.cudnn.allow_tf32 = True + from torch.backends.cuda import enable_cudnn_sdp, enable_flash_sdp, enable_math_sdp, enable_mem_efficient_sdp + _use_flash = sys.platform != "win32" + enable_cudnn_sdp(False) + enable_flash_sdp(_use_flash) + enable_mem_efficient_sdp(False) + enable_math_sdp(not _use_flash) + # Create per-run output directory so artifacts/logs never collide + run_dir = Path(os.environ.get("RUN_OUTPUT_DIR", f"runs/{args.run_id}")) + if master_process: + run_dir.mkdir(parents=True, exist_ok=True) + # Backward-compat: also keep logs/ symlink/copy + os.makedirs("logs", exist_ok=True) + logfile = None + if master_process: + logfile = str(run_dir / "log.txt") + # Also write to legacy logs/ path for compatibility + legacy_logfile = f"logs/{args.run_id}.txt" + if not Path(legacy_logfile).exists(): + try: + os.symlink(os.path.abspath(logfile), legacy_logfile) + except OSError: + pass # Windows or cross-device; will write to both + print(logfile) + def log0(msg: str, console: bool = True) -> None: + if not master_process: + return + if console: + print(msg) + if logfile is not None: + with open(logfile, "a", encoding="utf-8") as f: + print(msg, file=f) + log0(code, console=False) + log0("=" * 100, console=False) + log0(f"Running Python {sys.version}", console=False) + log0(f"Running PyTorch {torch.__version__}", console=False) + log0( + subprocess.run(["nvidia-smi"], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True, check=False).stdout, + console=False, + ) + log0("=" * 100, console=False) + random.seed(args.seed) + np.random.seed(args.seed) + torch.manual_seed(args.seed) + torch.cuda.manual_seed_all(args.seed) + if not args.tokenizer_path.endswith(".model"): + raise ValueError(f"Script only setup for SentencePiece .model file: {args.tokenizer_path}") + sp = spm.SentencePieceProcessor(model_file=args.tokenizer_path) + if int(sp.vocab_size()) != args.vocab_size: + raise ValueError( + f"VOCAB_SIZE={args.vocab_size} does not match tokenizer vocab_size={int(sp.vocab_size())}" + ) + dataset_dir = Path(args.data_path).resolve() + actual_train_files = len(list(dataset_dir.glob("fineweb_train_*.bin"))) + val_tokens = load_validation_tokens(args.val_files, args.train_seq_len) + base_bytes_lut, has_leading_space_lut, is_boundary_token_lut = build_sentencepiece_luts( + sp, args.vocab_size, device + ) + log0(f"val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path={args.tokenizer_path}") + log0(f"train_loader:dataset:{dataset_dir.name} train_shards:{actual_train_files}") + log0(f"val_loader:shards pattern={args.val_files} tokens:{val_tokens.numel() - 1}") + CastedLinear._int6_qat = args.int6_qat + base_model = GPT( + vocab_size=args.vocab_size, + num_layers=args.num_layers, + model_dim=args.model_dim, + num_heads=args.num_heads, + num_kv_heads=args.num_kv_heads, + mlp_mult=args.mlp_mult, + tie_embeddings=args.tie_embeddings, + tied_embed_init_std=args.tied_embed_init_std, + logit_softcap=args.logit_softcap, + rope_base=args.rope_base, + qk_gain_init=args.qk_gain_init, + bigramhash_buckets=args.bigramhash_buckets, + use_smeargate=args.smeargate, + unet_skips=args.unet_skips, + rope_partial_dims=args.rope_partial_dims, + ln_scale=args.ln_scale, + xsa_layers=args.xsa_layers, + recurrence_repeats=args.recurrence_repeats, + recurrence_unique_head=args.recurrence_unique_head, + recurrence_unique_tail=args.recurrence_unique_tail, + moe_num_experts=args.moe_num_experts, + ).to(device).bfloat16() + for module in base_model.modules(): + if isinstance(module, CastedLinear): + module.float() + restore_low_dim_params_to_fp32(base_model) + if int(os.environ.get("TORCHDYNAMO_DISABLE", "0")): + compiled_model = base_model + else: + compiled_model = torch.compile(base_model, dynamic=False, fullgraph=True) + model: nn.Module = DDP(compiled_model, device_ids=[local_rank], broadcast_buffers=False) if distributed else compiled_model + if args.recurrence_repeats > 0: + block_named_params = ( + list(base_model.head_blocks.named_parameters()) + + list(base_model.shared_block.named_parameters()) + + list(base_model.tail_blocks.named_parameters()) + ) + if base_model.recurrence_gate is not None: + block_named_params.append(("recurrence_gate", base_model.recurrence_gate)) + else: + block_named_params = list(base_model.blocks.named_parameters()) + matrix_params = [ + p + for name, p in block_named_params + if p.ndim == 2 and not any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) + ] + scalar_params = [ + p + for name, p in block_named_params + if p.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) + ] + if base_model.skip_weights.numel() > 0: + scalar_params.append(base_model.skip_weights) + token_lr = args.tied_embed_lr if args.tie_embeddings else args.embed_lr + embed_params: list[nn.Parameter] = [base_model.tok_emb.weight] + if base_model.bigram_hash is not None: + embed_params.append(base_model.bigram_hash.emb.weight) + optimizer_tok = torch.optim.Adam( + [{"params": embed_params, "lr": token_lr, "base_lr": token_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + fused=True, + ) + optimizer_muon = Muon( + matrix_params, + lr=args.matrix_lr, + momentum=args.muon_momentum, + backend_steps=args.muon_backend_steps, + ) + for group in optimizer_muon.param_groups: + group["base_lr"] = args.matrix_lr + optimizer_scalar = torch.optim.Adam( + [{"params": scalar_params, "lr": args.scalar_lr, "base_lr": args.scalar_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + fused=True, + ) + optimizers: list[torch.optim.Optimizer] = [optimizer_tok, optimizer_muon, optimizer_scalar] + if base_model.lm_head is not None: + optimizer_head = torch.optim.Adam( + [{"params": [base_model.lm_head.weight], "lr": args.head_lr, "base_lr": args.head_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + fused=True, + ) + optimizers.insert(1, optimizer_head) + n_params = sum(p.numel() for p in base_model.parameters()) + log0(f"model_params:{n_params}") + log0(f"world_size:{world_size} grad_accum_steps:{grad_accum_steps}") + log0(f"sdp_backends:cudnn=False flash={_use_flash} mem_efficient=False math={not _use_flash}") + log0(f"attention_mode:gqa num_heads:{args.num_heads} num_kv_heads:{args.num_kv_heads}") + log0( + f"tie_embeddings:{args.tie_embeddings} embed_lr:{token_lr} " + f"head_lr:{args.head_lr if base_model.lm_head is not None else 0.0} " + f"matrix_lr:{args.matrix_lr} scalar_lr:{args.scalar_lr}" + ) + log0( + f"train_batch_tokens:{args.train_batch_tokens} train_seq_len:{args.train_seq_len} " + f"iterations:{args.iterations} warmup_steps:{args.warmup_steps} " + f"max_wallclock_seconds:{args.max_wallclock_seconds:.3f}" + ) + log0(f"seed:{args.seed}") + log0( + f"meta_baseline: int6_qat:{args.int6_qat} bigramhash:{args.bigramhash_buckets} " + f"smeargate:{args.smeargate} ema_decay:{args.ema_decay} unet_skips:{args.unet_skips} " + f"warmdown_iters:{args.warmdown_iters} compression:{'zstd-22' if _HAS_ZSTD else 'zlib-9'}" + ) + log0( + f"strong_gains: leaky_relu_sq:True rope_partial_dims:{args.rope_partial_dims} " + f"ln_scale:{args.ln_scale} xsa_layers:{args.xsa_layers}" + ) + if args.recurrence_repeats > 0: + eff = args.recurrence_unique_head + args.recurrence_repeats + args.recurrence_unique_tail + log0( + f"depth_recurrence: head={args.recurrence_unique_head} repeats={args.recurrence_repeats} " + f"tail={args.recurrence_unique_tail} effective_layers={eff}" + ) + train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) + def zero_grad_all() -> None: + for opt in optimizers: + opt.zero_grad(set_to_none=True) + max_wallclock_ms = 1000.0 * args.max_wallclock_seconds if args.max_wallclock_seconds > 0 else None + def lr_mul(step: int, elapsed_ms: float) -> float: + if args.warmdown_iters <= 0: + return 1.0 + if max_wallclock_ms is None: + warmdown_start = max(args.iterations - args.warmdown_iters, 0) + return max((args.iterations - step) / max(args.warmdown_iters, 1), 0.0) if warmdown_start <= step < args.iterations else 1.0 + step_ms = elapsed_ms / max(step, 1) + warmdown_ms = args.warmdown_iters * step_ms + remaining_ms = max(max_wallclock_ms - elapsed_ms, 0.0) + return remaining_ms / max(warmdown_ms, 1e-9) if remaining_ms <= warmdown_ms else 1.0 + if args.warmup_steps > 0: + initial_model_state = {name: tensor.detach().cpu().clone() for name, tensor in base_model.state_dict().items()} + initial_optimizer_states = [copy.deepcopy(opt.state_dict()) for opt in optimizers] + model.train() + for warmup_step in range(args.warmup_steps): + zero_grad_all() + for micro_step in range(grad_accum_steps): + if distributed: + model.require_backward_grad_sync = micro_step == grad_accum_steps - 1 + x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + warmup_loss = model(x, y) + (warmup_loss * grad_scale).backward() + for opt in optimizers: + opt.step() + zero_grad_all() + if args.warmup_steps <= 20 or (warmup_step + 1) % 10 == 0 or warmup_step + 1 == args.warmup_steps: + log0(f"warmup_step:{warmup_step + 1}/{args.warmup_steps}") + base_model.load_state_dict(initial_model_state, strict=True) + for opt, state in zip(optimizers, initial_optimizer_states, strict=True): + opt.load_state_dict(state) + zero_grad_all() + if distributed: + model.require_backward_grad_sync = True + train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) + ema_state: dict[str, Tensor] | None = None + if args.ema_decay > 0: + ema_state = {name: param.data.clone() for name, param in base_model.named_parameters()} + log0(f"ema:enabled decay={args.ema_decay}") + if args.swa_enabled: + swa_state = {} # will be populated on first SWA checkpoint + swa_count = 0 + training_time_ms = 0.0 + stop_after_step: int | None = None + torch.cuda.synchronize() + t0 = time.perf_counter() + step = 0 + while True: + last_step = step == args.iterations or (stop_after_step is not None and step >= stop_after_step) + should_validate = last_step or (args.val_loss_every > 0 and step % args.val_loss_every == 0) + if should_validate: + torch.cuda.synchronize() + training_time_ms += 1000.0 * (time.perf_counter() - t0) + ema_backup: dict[str, Tensor] | None = None + if ema_state is not None: + ema_backup = {} + with torch.no_grad(): + for name, p in base_model.named_parameters(): + ema_backup[name] = p.data.clone() + p.data.copy_(ema_state[name]) + val_loss, val_bpb = eval_val( + args, + model, + rank, + world_size, + device, + grad_accum_steps, + val_tokens, + base_bytes_lut, + has_leading_space_lut, + is_boundary_token_lut, + ) + if ema_backup is not None: + with torch.no_grad(): + for name, p in base_model.named_parameters(): + p.data.copy_(ema_backup[name]) + del ema_backup + log0( + f"step:{step}/{args.iterations} val_loss:{val_loss:.4f} val_bpb:{val_bpb:.4f} " + f"train_time:{training_time_ms:.0f}ms step_avg:{training_time_ms / max(step, 1):.2f}ms" + ) + torch.cuda.synchronize() + t0 = time.perf_counter() + if last_step: + if stop_after_step is not None and step < args.iterations: + log0( + f"stopping_early: wallclock_cap train_time:{training_time_ms:.0f}ms " + f"step:{step}/{args.iterations}" + ) + break + elapsed_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) + scale = lr_mul(step, elapsed_ms) + zero_grad_all() + train_loss = torch.zeros((), device=device) + for micro_step in range(grad_accum_steps): + if distributed: + model.require_backward_grad_sync = micro_step == grad_accum_steps - 1 + x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + loss = model(x, y) + train_loss += loss.detach() + (loss * grad_scale).backward() + train_loss /= grad_accum_steps + frac = min(step / args.muon_momentum_warmup_steps, 1.0) if args.muon_momentum_warmup_steps > 0 else 1.0 + muon_momentum = (1 - frac) * args.muon_momentum_warmup_start + frac * args.muon_momentum + for group in optimizer_muon.param_groups: + group["momentum"] = muon_momentum + for opt in optimizers: + for group in opt.param_groups: + group["lr"] = group["base_lr"] * scale + if args.grad_clip_norm > 0: + torch.nn.utils.clip_grad_norm_(base_model.parameters(), args.grad_clip_norm) + for opt in optimizers: + opt.step() + zero_grad_all() + if ema_state is not None: + with torch.no_grad(): + decay = args.ema_decay + for name, p in base_model.named_parameters(): + ema_state[name].mul_(decay).add_(p.data, alpha=1.0 - decay) + # SWA accumulation (only during late warmdown) + if args.swa_enabled and scale < args.swa_start_frac and step % args.swa_every == 0: + if swa_count == 0: + # First checkpoint: initialize SWA state + swa_state = {name: p.data.clone().float() for name, p in base_model.named_parameters()} + else: + # Running average: swa = (swa * count + new) / (count + 1) + for name, p in base_model.named_parameters(): + swa_state[name].mul_(swa_count).add_(p.data.float()).div_(swa_count + 1) + swa_count += 1 + if master_process: + log0(f"SWA checkpoint {swa_count} at step {step} (lr_scale={scale:.4f})") + step += 1 + if args.checkpoint_every > 0 and step % args.checkpoint_every == 0 and master_process: + ckpt_dir = run_dir / "checkpoints" + ckpt_dir.mkdir(exist_ok=True) + ckpt_path = ckpt_dir / f"ckpt_step{step}.pt" + ckpt_data = {"step": step, "model": base_model.state_dict()} + if ema_state is not None: + ckpt_data["ema"] = {k: v.cpu() for k, v in ema_state.items()} + torch.save(ckpt_data, ckpt_path) + ckpts = sorted(ckpt_dir.glob("ckpt_step*.pt"), key=lambda p: p.stat().st_mtime) + for old in ckpts[:-2]: + old.unlink(missing_ok=True) + log0(f"checkpoint:saved {ckpt_path} (kept {min(len(ckpts), 2)} latest)") + approx_training_time_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) + should_log_train = ( + args.train_log_every > 0 + and (step <= 10 or step % args.train_log_every == 0 or stop_after_step is not None) + ) + if should_log_train: + log0( + f"step:{step}/{args.iterations} train_loss:{train_loss.item():.4f} " + f"train_time:{approx_training_time_ms:.0f}ms step_avg:{approx_training_time_ms / step:.2f}ms" + ) + reached_cap = max_wallclock_ms is not None and approx_training_time_ms >= max_wallclock_ms + if distributed and max_wallclock_ms is not None: + reached_cap_tensor = torch.tensor(int(reached_cap), device=device) + dist.all_reduce(reached_cap_tensor, op=dist.ReduceOp.MAX) + reached_cap = bool(reached_cap_tensor.item()) + if stop_after_step is None and reached_cap: + stop_after_step = step + log0( + f"peak memory allocated: {torch.cuda.max_memory_allocated() // 1024 // 1024} MiB " + f"reserved: {torch.cuda.max_memory_reserved() // 1024 // 1024} MiB" + ) + if args.swa_enabled and swa_count > 0: + if master_process: + log0(f"Merging EMA + SWA ({swa_count} checkpoints)") + with torch.no_grad(): + for name, p in base_model.named_parameters(): + # Average EMA and SWA with equal weight + ema_val = ema_state[name] + swa_val = swa_state[name].to(dtype=ema_val.dtype) + p.data.copy_((ema_val + swa_val) * 0.5) + else: + # Standard: just apply EMA + if ema_state is not None: + with torch.no_grad(): + for name, p in base_model.named_parameters(): + p.data.copy_(ema_state[name]) + log0("ema:applied EMA weights for export") + fp32_path = str(run_dir / "final_model.pt") + if master_process: + torch.save(base_model.state_dict(), fp32_path) + model_bytes = os.path.getsize(fp32_path) + code_bytes = len(code.encode("utf-8")) + log0(f"Serialized model: {model_bytes} bytes → {fp32_path}") + log0(f"Code size: {code_bytes} bytes") + log0(f"Total submission size: {model_bytes + code_bytes} bytes") + quant_label = f"int{_QUANT_BITS}" + compress_label = "zstd" if _HAS_ZSTD else "zlib" + log0(f"quantization:start {quant_label} distributing across {world_size} GPUs") + torch.cuda.synchronize() + t_quant = time.perf_counter() + quant_obj, quant_stats = quantize_state_dict_int8( + base_model.state_dict(), gpu_device=device, rank=rank, world_size=world_size, + ) + torch.cuda.synchronize() + log0(f"quantization:done {quant_label} time={1000.0 * (time.perf_counter() - t_quant):.0f}ms") + quant_buf = io.BytesIO() + torch.save(quant_obj, quant_buf) + quant_raw = quant_buf.getvalue() + if _HAS_ZSTD: + compressor = zstd.ZstdCompressor(level=22, threads=-1) # -1 = use all CPU cores + quant_blob = compressor.compress(quant_raw) + else: + quant_blob = zlib.compress(quant_raw, level=9) + quant_raw_bytes = len(quant_raw) + artifact_name = str(run_dir / f"final_model.{quant_label}.ptz") + if master_process: + with open(artifact_name, "wb") as f: + f.write(quant_blob) + quant_file_bytes = os.path.getsize(artifact_name) + code_bytes = len(code.encode("utf-8")) + ratio = quant_stats["baseline_tensor_bytes"] / max(quant_stats["int8_payload_bytes"], 1) + log0( + f"Serialized model {quant_label}+{compress_label}: {quant_file_bytes} bytes " + f"(payload:{quant_stats['int8_payload_bytes']} raw_torch:{quant_raw_bytes} payload_ratio:{ratio:.2f}x)" + ) + log0(f"Total submission size {quant_label}+{compress_label}: {quant_file_bytes + code_bytes} bytes") + log0(f"artifact_path:{artifact_name}") + if distributed: + dist.barrier() + with open(artifact_name, "rb") as f: + quant_blob_disk = f.read() + if _HAS_ZSTD: + decompressor = zstd.ZstdDecompressor() + quant_state = torch.load(io.BytesIO(decompressor.decompress(quant_blob_disk)), map_location="cpu") + else: + quant_state = torch.load(io.BytesIO(zlib.decompress(quant_blob_disk)), map_location="cpu") + base_model.load_state_dict(dequantize_state_dict_int8(quant_state), strict=True) + torch.cuda.synchronize() + t_qeval = time.perf_counter() + q_val_loss, q_val_bpb = eval_val( + args, + model, + rank, + world_size, + device, + grad_accum_steps, + val_tokens, + base_bytes_lut, + has_leading_space_lut, + is_boundary_token_lut, + ) + torch.cuda.synchronize() + log0( + f"final_{quant_label}_{compress_label}_roundtrip val_loss:{q_val_loss:.4f} val_bpb:{q_val_bpb:.4f} " + f"eval_time:{1000.0 * (time.perf_counter() - t_qeval):.0f}ms" + ) + log0(f"final_{quant_label}_{compress_label}_roundtrip_exact val_loss:{q_val_loss:.8f} val_bpb:{q_val_bpb:.8f}") + if args.eval_stride > 0 and args.eval_stride < args.train_seq_len: + torch.cuda.synchronize() + t_slide = time.perf_counter() + sw_val_loss, sw_val_bpb = eval_val_sliding( + args, base_model, rank, world_size, device, + val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, + stride=args.eval_stride, + ) + torch.cuda.synchronize() + log0( + f"final_sliding_window val_loss:{sw_val_loss:.4f} val_bpb:{sw_val_bpb:.4f} " + f"stride:{args.eval_stride} eval_time:{1000.0 * (time.perf_counter() - t_slide):.0f}ms" + ) + log0(f"final_sliding_window_exact val_loss:{sw_val_loss:.8f} val_bpb:{sw_val_bpb:.8f}") + if args.ttt_enabled: + torch.cuda.synchronize() + t_ttt = time.perf_counter() + + # Choose LoRA TTT or full TTT based on TTT_LORA_ENABLED flag + if args.ttt_lora_enabled: + ttt_loss, ttt_bpb = eval_val_sliding_lora_ttt( + args, base_model, rank, world_size, device, + val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, + stride=args.eval_stride if args.eval_stride > 0 else 64, + log0=log0, + ) + ttt_label = "legal_lora_ttt" + else: + ttt_loss, ttt_bpb = eval_val_sliding_ttt( + args, base_model, rank, world_size, device, + val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, + stride=args.eval_stride if args.eval_stride > 0 else 64, + log0=log0, + ) + ttt_label = "legal_ttt" + + torch.cuda.synchronize() + log0( + f"{ttt_label} val_loss:{ttt_loss:.4f} val_bpb:{ttt_bpb:.4f} " + f"eval_time:{1000.0 * (time.perf_counter() - t_ttt):.0f}ms" + ) + log0(f"{ttt_label}_exact val_loss:{ttt_loss:.8f} val_bpb:{ttt_bpb:.8f}") + if int(os.environ.get("NGRAM_EVAL", "0")): + ngram_max_order = int(os.environ.get("NGRAM_MAX_ORDER", "9")) + log0(f"ngram_eval:starting max_order={ngram_max_order}") + torch.cuda.synchronize() + t_ngram = time.perf_counter() + try: + from ngram_eval import eval_val_ngram + ng_val_loss, ng_val_bpb = eval_val_ngram( + args, base_model, model, rank, world_size, device, grad_accum_steps, + val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, + max_order=ngram_max_order, log_fn=log0, + ) + torch.cuda.synchronize() + log0( + f"ngram_eval_result val_loss:{ng_val_loss:.4f} val_bpb:{ng_val_bpb:.4f} " + f"eval_time:{1000.0 * (time.perf_counter() - t_ngram):.0f}ms" + ) + log0(f"ngram_eval_result_exact val_loss:{ng_val_loss:.8f} val_bpb:{ng_val_bpb:.8f}") + except Exception as e: + log0(f"ngram_eval:FAILED {e}") + if distributed: + dist.destroy_process_group() +if __name__ == "__main__": + main() diff --git a/train_gpt_phase4_ve128.py b/train_gpt_phase4_ve128.py new file mode 100644 index 0000000000..e7602ddc6a --- /dev/null +++ b/train_gpt_phase4_ve128.py @@ -0,0 +1,1910 @@ +from __future__ import annotations +import copy +import glob +import io +import math +import os +import random +import subprocess +import sys +import time +import uuid +import zlib +from pathlib import Path +try: + import zstandard as zstd + _HAS_ZSTD = True +except ImportError: + _HAS_ZSTD = False +import numpy as np +import sentencepiece as spm +import torch +import torch.distributed as dist +import torch.nn.functional as F +from torch import Tensor, nn +from torch.nn.parallel import DistributedDataParallel as DDP +class Hyperparameters: + data_path = os.environ.get("DATA_PATH", "./data/datasets/fineweb10B_sp1024") + train_files = os.path.join(data_path, "fineweb_train_*.bin") + val_files = os.path.join(data_path, "fineweb_val_*.bin") + tokenizer_path = os.environ.get("TOKENIZER_PATH", "./data/tokenizers/fineweb_1024_bpe.model") + run_id = os.environ.get("RUN_ID", str(uuid.uuid4())) + seed = int(os.environ.get("SEED", 1337)) + val_batch_size = int(os.environ.get("VAL_BATCH_SIZE", 524_288)) + val_loss_every = int(os.environ.get("VAL_LOSS_EVERY", 1000)) + train_log_every = int(os.environ.get("TRAIN_LOG_EVERY", 200)) + val_max_tokens = int(os.environ.get("VAL_MAX_TOKENS", 0)) + checkpoint_every = int(os.environ.get("CHECKPOINT_EVERY", 0)) + iterations = int(os.environ.get("ITERATIONS", 20000)) + warmdown_iters = int(os.environ.get("WARMDOWN_ITERS", 3500)) + warmup_steps = int(os.environ.get("WARMUP_STEPS", 20)) + train_batch_tokens = int(os.environ.get("TRAIN_BATCH_TOKENS", 524_288)) + train_seq_len = int(os.environ.get("TRAIN_SEQ_LEN", 1024)) + max_wallclock_seconds = float(os.environ.get("MAX_WALLCLOCK_SECONDS", 600.0)) + qk_gain_init = float(os.environ.get("QK_GAIN_INIT", 1.5)) + vocab_size = int(os.environ.get("VOCAB_SIZE", 1024)) + num_layers = int(os.environ.get("NUM_LAYERS", 11)) + num_kv_heads = int(os.environ.get("NUM_KV_HEADS", 4)) + model_dim = int(os.environ.get("MODEL_DIM", 512)) + num_heads = int(os.environ.get("NUM_HEADS", 8)) + mlp_mult = int(os.environ.get("MLP_MULT", 3)) + tie_embeddings = bool(int(os.environ.get("TIE_EMBEDDINGS", "1"))) + rope_base = float(os.environ.get("ROPE_BASE", 10000.0)) + logit_softcap = float(os.environ.get("LOGIT_SOFTCAP", 30.0)) + embed_lr = float(os.environ.get("EMBED_LR", 0.6)) + head_lr = float(os.environ.get("HEAD_LR", 0.008)) + tied_embed_lr = float(os.environ.get("TIED_EMBED_LR", 0.05)) + tied_embed_init_std = float(os.environ.get("TIED_EMBED_INIT_STD", 0.005)) + matrix_lr = float(os.environ.get("MATRIX_LR", 0.04)) + scalar_lr = float(os.environ.get("SCALAR_LR", 0.04)) + muon_momentum = float(os.environ.get("MUON_MOMENTUM", 0.95)) + muon_backend_steps = int(os.environ.get("MUON_BACKEND_STEPS", 5)) + muon_momentum_warmup_start = float(os.environ.get("MUON_MOMENTUM_WARMUP_START", 0.85)) + muon_momentum_warmup_steps = int(os.environ.get("MUON_MOMENTUM_WARMUP_STEPS", 500)) + beta1 = float(os.environ.get("BETA1", 0.9)) + beta2 = float(os.environ.get("BETA2", 0.95)) + adam_eps = float(os.environ.get("ADAM_EPS", 1e-8)) + grad_clip_norm = float(os.environ.get("GRAD_CLIP_NORM", 0.0)) + bigramhash_buckets = int(os.environ.get("BIGRAMHASH_BUCKETS", 4096)) + ema_decay = float(os.environ.get("EMA_DECAY", 0.997)) + smeargate = bool(int(os.environ.get("SMEARGATE", "1"))) + unet_skips = bool(int(os.environ.get("UNET_SKIPS", "1"))) + int6_qat = bool(int(os.environ.get("INT6_QAT", "1"))) + rope_partial_dims = int(os.environ.get("ROPE_PARTIAL_DIMS", 16)) + ln_scale = bool(int(os.environ.get("LN_SCALE", "1"))) + xsa_layers = int(os.environ.get("XSA_LAYERS", 4)) + eval_stride = int(os.environ.get("EVAL_STRIDE", 0)) + ttt_enabled = bool(int(os.environ.get("TTT_ENABLED", "0"))) + ttt_lr = float(os.environ.get("TTT_LR", 0.002)) + ttt_epochs = int(os.environ.get("TTT_EPOCHS", 3)) + ttt_chunk_tokens = int(os.environ.get("TTT_CHUNK_TOKENS", 32768)) + ttt_freeze_blocks = int(os.environ.get("TTT_FREEZE_BLOCKS", 0)) + ttt_momentum = float(os.environ.get("TTT_MOMENTUM", 0.9)) + ttt_batch_seqs = int(os.environ.get("TTT_BATCH_SEQS", 32)) + ttt_grad_clip = float(os.environ.get("TTT_GRAD_CLIP", 1.0)) + ttt_lora_enabled = bool(int(os.environ.get("TTT_LORA_ENABLED", "0"))) + ttt_lora_rank = int(os.environ.get("TTT_LORA_RANK", 8)) + ttt_lora_lr = float(os.environ.get("TTT_LORA_LR", 0.001)) + recurrence_repeats = int(os.environ.get("RECURRENCE_REPEATS", 0)) + recurrence_unique_head = int(os.environ.get("RECURRENCE_UNIQUE_HEAD", 2)) + recurrence_unique_tail = int(os.environ.get("RECURRENCE_UNIQUE_TAIL", 2)) + moe_num_experts = int(os.environ.get("MOE_NUM_EXPERTS", 0)) + ve_enabled = bool(int(os.environ.get("VE_ENABLED", "0"))) # default OFF + ve_dim = int(os.environ.get("VE_DIM", 128)) + ve_layers = os.environ.get("VE_LAYERS", "9,10") # comma-separated layer indices +def zeropower_via_newtonschulz5(G: Tensor, steps: int = 10, eps: float = 1e-7) -> Tensor: + a, b, c = (3.4445, -4.7750, 2.0315) + X = G.bfloat16() + X /= X.norm() + eps + transposed = G.size(0) > G.size(1) + if transposed: + X = X.T + for _ in range(steps): + A = X @ X.T + B = b * A + c * A @ A + X = a * X + B @ X + return X.T if transposed else X +class Muon(torch.optim.Optimizer): + def __init__(self, params, lr: float, momentum: float, backend_steps: int, nesterov: bool = True): + super().__init__( + params, + dict(lr=lr, momentum=momentum, backend_steps=backend_steps, nesterov=nesterov), + ) + @torch.no_grad() + def step(self, closure=None): + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + distributed = dist.is_available() and dist.is_initialized() + world_size = dist.get_world_size() if distributed else 1 + rank = dist.get_rank() if distributed else 0 + for group in self.param_groups: + params = group["params"] + if not params: + continue + lr = group["lr"] + momentum = group["momentum"] + backend_steps = group["backend_steps"] + nesterov = group["nesterov"] + total_params = sum(int(p.numel()) for p in params) + updates_flat = torch.zeros(total_params, device=params[0].device, dtype=torch.bfloat16) + curr = 0 + for i, p in enumerate(params): + if i % world_size == rank and p.grad is not None: + g = p.grad + state = self.state[p] + if "momentum_buffer" not in state: + state["momentum_buffer"] = torch.zeros_like(g) + buf = state["momentum_buffer"] + buf.mul_(momentum).add_(g) + if nesterov: + g = g.add(buf, alpha=momentum) + g = zeropower_via_newtonschulz5(g, steps=backend_steps) + g *= max(1, g.size(0) / g.size(1)) ** 0.5 + updates_flat[curr : curr + p.numel()] = g.reshape(-1) + curr += p.numel() + if distributed: + dist.all_reduce(updates_flat, op=dist.ReduceOp.SUM) + curr = 0 + for p in params: + g = updates_flat[curr : curr + p.numel()].view_as(p).to(dtype=p.dtype) + p.add_(g, alpha=-lr) + curr += p.numel() + return loss +def int6_ste(w: Tensor) -> Tensor: + qmax = _QUANT_MAX_VAL # Uses QUANT_BITS override if set (e.g. 15 for int5, 7 for int4) + if w.ndim == 2: + abs_max = w.abs().amax(dim=1, keepdim=True).clamp_min(1e-8) + else: + abs_max = w.abs().amax().clamp_min(1e-8) + scale = abs_max / float(qmax) + w_q = (w / scale).round().clamp(-qmax, qmax) * scale + return w + (w_q - w).detach() +def build_sentencepiece_luts( + sp: spm.SentencePieceProcessor, vocab_size: int, device: torch.device +) -> tuple[Tensor, Tensor, Tensor]: + sp_vocab_size = int(sp.vocab_size()) + table_size = max(sp_vocab_size, vocab_size) + base_bytes_np = np.zeros((table_size,), dtype=np.int16) + has_leading_space_np = np.zeros((table_size,), dtype=np.bool_) + is_boundary_token_np = np.ones((table_size,), dtype=np.bool_) + for token_id in range(sp_vocab_size): + if sp.is_control(token_id) or sp.is_unknown(token_id) or sp.is_unused(token_id): + continue + is_boundary_token_np[token_id] = False + if sp.is_byte(token_id): + base_bytes_np[token_id] = 1 + continue + piece = sp.id_to_piece(token_id) + if piece.startswith("▁"): + has_leading_space_np[token_id] = True + piece = piece[1:] + base_bytes_np[token_id] = len(piece.encode("utf-8")) + return ( + torch.tensor(base_bytes_np, dtype=torch.int16, device=device), + torch.tensor(has_leading_space_np, dtype=torch.bool, device=device), + torch.tensor(is_boundary_token_np, dtype=torch.bool, device=device), + ) +def load_validation_tokens(pattern: str, seq_len: int) -> Tensor: + files = [Path(p) for p in sorted(glob.glob(pattern))] + if not files: + raise FileNotFoundError(f"No files found for pattern: {pattern}") + tokens = torch.cat([load_data_shard(file) for file in files]).contiguous() + max_tokens = int(os.environ.get("VAL_MAX_TOKENS", 0)) + if max_tokens > 0 and tokens.numel() > max_tokens: + tokens = tokens[:max_tokens].contiguous() + usable = ((tokens.numel() - 1) // seq_len) * seq_len + if usable <= 0: + raise ValueError(f"Validation split is too short for TRAIN_SEQ_LEN={seq_len}") + return tokens[: usable + 1] +def eval_val( + args: Hyperparameters, + model: nn.Module, + rank: int, + world_size: int, + device: torch.device, + grad_accum_steps: int, + val_tokens: Tensor, + base_bytes_lut: Tensor, + has_leading_space_lut: Tensor, + is_boundary_token_lut: Tensor, +) -> tuple[float, float]: + local_batch_tokens = args.val_batch_size // (world_size * grad_accum_steps) + if local_batch_tokens < args.train_seq_len: + raise ValueError( + "VAL_BATCH_SIZE must provide at least one sequence per rank; " + f"got VAL_BATCH_SIZE={args.val_batch_size}, WORLD_SIZE={world_size}, " + f"GRAD_ACCUM_STEPS={grad_accum_steps}, TRAIN_SEQ_LEN={args.train_seq_len}" + ) + local_batch_seqs = local_batch_tokens // args.train_seq_len + total_seqs = (val_tokens.numel() - 1) // args.train_seq_len + seq_start = (total_seqs * rank) // world_size + seq_end = (total_seqs * (rank + 1)) // world_size + val_loss_sum = torch.zeros((), device=device, dtype=torch.float64) + val_token_count = torch.zeros((), device=device, dtype=torch.float64) + val_byte_count = torch.zeros((), device=device, dtype=torch.float64) + model.eval() + with torch.inference_mode(): + for batch_seq_start in range(seq_start, seq_end, local_batch_seqs): + batch_seq_end = min(batch_seq_start + local_batch_seqs, seq_end) + raw_start = batch_seq_start * args.train_seq_len + raw_end = batch_seq_end * args.train_seq_len + 1 + local = val_tokens[raw_start:raw_end].to(device=device, dtype=torch.int64, non_blocking=True) + x = local[:-1].reshape(-1, args.train_seq_len) + y = local[1:].reshape(-1, args.train_seq_len) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + batch_loss = model(x, y).detach() + batch_token_count = float(y.numel()) + val_loss_sum += batch_loss.to(torch.float64) * batch_token_count + val_token_count += batch_token_count + prev_ids = x.reshape(-1) + tgt_ids = y.reshape(-1) + token_bytes = base_bytes_lut[tgt_ids].to(dtype=torch.int16) + token_bytes += (has_leading_space_lut[tgt_ids] & ~is_boundary_token_lut[prev_ids]).to(dtype=torch.int16) + val_byte_count += token_bytes.to(torch.float64).sum() + if dist.is_available() and dist.is_initialized(): + dist.all_reduce(val_loss_sum, op=dist.ReduceOp.SUM) + dist.all_reduce(val_token_count, op=dist.ReduceOp.SUM) + dist.all_reduce(val_byte_count, op=dist.ReduceOp.SUM) + val_loss = val_loss_sum / val_token_count + bits_per_token = val_loss.item() / math.log(2.0) + tokens_per_byte = val_token_count.item() / val_byte_count.item() + model.train() + return float(val_loss.item()), float(bits_per_token * tokens_per_byte) +def eval_val_sliding( + args: Hyperparameters, + base_model: nn.Module, + rank: int, + world_size: int, + device: torch.device, + val_tokens: Tensor, + base_bytes_lut: Tensor, + has_leading_space_lut: Tensor, + is_boundary_token_lut: Tensor, + stride: int, + batch_seqs: int = 32, +) -> tuple[float, float]: + seq_len = args.train_seq_len + total_tokens = val_tokens.numel() - 1 + window_starts = [ws for ws in range(0, total_tokens, stride) + if min(ws + seq_len, total_tokens) - ws >= 1] + total_windows = len(window_starts) + my_s = (total_windows * rank) // world_size + my_e = (total_windows * (rank + 1)) // world_size + my_windows = window_starts[my_s:my_e] + loss_sum = torch.zeros((), device=device, dtype=torch.float64) + token_count = torch.zeros((), device=device, dtype=torch.float64) + byte_count = torch.zeros((), device=device, dtype=torch.float64) + base_model.eval() + compiled_logits = torch.compile(base_model.forward_logits, dynamic=False, fullgraph=True) + with torch.inference_mode(): + for bi in range(0, len(my_windows), batch_seqs): + batch_ws = my_windows[bi:bi + batch_seqs] + bsz = len(batch_ws) + x_batch = torch.zeros(bsz, seq_len, dtype=torch.int64, device=device) + y_batch = torch.zeros(bsz, seq_len, dtype=torch.int64, device=device) + wlens: list[int] = [] + for i, ws in enumerate(batch_ws): + end = min(ws + seq_len, total_tokens) + wlen = end - ws + wlens.append(wlen) + chunk = val_tokens[ws:end + 1].to(dtype=torch.int64, device=device) + x_batch[i, :wlen] = chunk[:-1] + y_batch[i, :wlen] = chunk[1:] + with torch.autocast(device_type="cuda", dtype=torch.bfloat16): + logits = compiled_logits(x_batch) + nll = F.cross_entropy( + logits.reshape(-1, logits.size(-1)).float(), + y_batch.reshape(-1), + reduction="none", + ).reshape(bsz, seq_len) + for i, ws in enumerate(batch_ws): + wlen = wlens[i] + s = 0 if ws == 0 else max(wlen - stride, 0) + scored_nll = nll[i, s:wlen].to(torch.float64) + loss_sum += scored_nll.sum() + token_count += float(wlen - s) + tgt = y_batch[i, s:wlen] + prev = x_batch[i, s:wlen] + tb = base_bytes_lut[tgt].to(torch.float64) + tb += (has_leading_space_lut[tgt] & ~is_boundary_token_lut[prev]).to(torch.float64) + byte_count += tb.sum() + if dist.is_available() and dist.is_initialized(): + dist.all_reduce(loss_sum, op=dist.ReduceOp.SUM) + dist.all_reduce(token_count, op=dist.ReduceOp.SUM) + dist.all_reduce(byte_count, op=dist.ReduceOp.SUM) + val_loss = (loss_sum / token_count).item() + bits_per_token = val_loss / math.log(2.0) + tokens_per_byte = token_count.item() / byte_count.item() + base_model.train() + return val_loss, bits_per_token * tokens_per_byte +class LoRAAdapter(nn.Module): + """Low-Rank Adaptation adapter for a linear layer.""" + def __init__(self, base_linear: nn.Linear, rank: int = 8): + super().__init__() + dim_in = base_linear.in_features + dim_out = base_linear.out_features + self.A = nn.Parameter(torch.empty(dim_in, rank)) + self.B = nn.Parameter(torch.zeros(rank, dim_out)) + nn.init.kaiming_uniform_(self.A, a=math.sqrt(5)) + self.base_linear = base_linear + self._hook_handle = None + + def _lora_hook(self, module: nn.Module, input: tuple[Tensor, ...], output: Tensor) -> Tensor: + """Forward hook that adds LoRA contribution to base linear output.""" + x = input[0] + lora_delta = (x @ self.A.to(x.dtype)) @ self.B.to(x.dtype) + return output + lora_delta + + def register(self) -> None: + """Register the LoRA hook on the base linear layer.""" + if self._hook_handle is None: + self._hook_handle = self.base_linear.register_forward_hook(self._lora_hook) + + def remove(self) -> None: + """Remove the LoRA hook from the base linear layer.""" + if self._hook_handle is not None: + self._hook_handle.remove() + self._hook_handle = None + +def eval_val_sliding_lora_ttt( + args: Hyperparameters, + base_model: nn.Module, + rank: int, + world_size: int, + device: torch.device, + val_tokens: Tensor, + base_bytes_lut: Tensor, + has_leading_space_lut: Tensor, + is_boundary_token_lut: Tensor, + stride: int, + batch_seqs: int = 32, + log0=print, +) -> tuple[float, float]: + """Test-time training with LoRA adapters (only train adapter weights, not base model).""" + seq_len = args.train_seq_len + total_tokens = val_tokens.numel() - 1 + ttt_chunk = args.ttt_chunk_tokens + + window_starts = [ws for ws in range(0, total_tokens, stride) + if min(ws + seq_len, total_tokens) - ws >= stride or ws == 0] + num_chunks = (total_tokens + ttt_chunk - 1) // ttt_chunk + + chunk_windows: list[list[int]] = [[] for _ in range(num_chunks)] + for ws in window_starts: + end = min(ws + seq_len, total_tokens) + wlen = end - ws + s = 0 if ws == 0 else max(wlen - stride, 0) + scored_start = ws + s + ci = min(scored_start // ttt_chunk, num_chunks - 1) + chunk_windows[ci].append(ws) + + log0(f"lora_ttt_sliding:start chunks={num_chunks} chunk_tokens={ttt_chunk} " + f"total_windows={len(window_starts)} stride={stride} " + f"lora_rank={args.ttt_lora_rank} lora_lr={args.ttt_lora_lr} " + f"ttt_epochs={args.ttt_epochs}") + + loss_sum = torch.zeros((), device=device, dtype=torch.float64) + token_count = torch.zeros((), device=device, dtype=torch.float64) + byte_count = torch.zeros((), device=device, dtype=torch.float64) + + # Create LoRA adapters for attention projections and LM head + lora_adapters: list[LoRAAdapter] = [] + + # Add LoRA to attention blocks (c_q, c_v, proj) + if args.recurrence_repeats > 0: + # Recurrent model structure + all_blocks = list(base_model.head_blocks) + [base_model.shared_block] + list(base_model.tail_blocks) + else: + # Standard block structure + all_blocks = list(base_model.blocks) + + for block in all_blocks: + if hasattr(block, 'attn'): + attn = block.attn + # Add LoRA to Q, V, and output projection + for layer_name in ['c_q', 'c_v', 'proj']: + if hasattr(attn, layer_name): + base_linear = getattr(attn, layer_name) + adapter = LoRAAdapter(base_linear, rank=args.ttt_lora_rank).to(device) + lora_adapters.append(adapter) + + # Add LoRA to LM head (or tok_emb if tied) + if base_model.tie_embeddings: + # For tied embeddings, we need to patch tok_emb.weight usage in forward + # Create a pseudo-linear wrapper for the embedding weight + class EmbeddingAsLinear(nn.Module): + def __init__(self, weight: nn.Parameter): + super().__init__() + self.weight = weight + self.in_features = weight.size(1) + self.out_features = weight.size(0) + self.bias = None + + def forward(self, x: Tensor) -> Tensor: + return F.linear(x, self.weight) + + emb_linear = EmbeddingAsLinear(base_model.tok_emb.weight).to(device) + lm_adapter = LoRAAdapter(emb_linear, rank=args.ttt_lora_rank).to(device) + lora_adapters.append(lm_adapter) + elif base_model.lm_head is not None: + lm_adapter = LoRAAdapter(base_model.lm_head, rank=args.ttt_lora_rank).to(device) + lora_adapters.append(lm_adapter) + + # Collect all LoRA parameters + lora_params = [] + for adapter in lora_adapters: + lora_params.extend([adapter.A, adapter.B]) + + log0(f"lora_ttt_sliding:params lora={sum(p.numel() for p in lora_params)} " + f"base_frozen={sum(p.numel() for p in base_model.parameters())} " + f"adapters={len(lora_adapters)}") + + # Freeze all base model parameters + for p in base_model.parameters(): + p.requires_grad_(False) + + # Register all LoRA hooks + for adapter in lora_adapters: + adapter.register() + + # Use AdamW optimizer for LoRA params + optimizer = torch.optim.AdamW( + lora_params, + lr=args.ttt_lora_lr, + betas=(0.9, 0.999), + weight_decay=0.01, + ) + + # Compile forward paths + if not int(os.environ.get("TORCHDYNAMO_DISABLE", "0")): + compiled_logits_lora = torch.compile(base_model.forward_logits, dynamic=False, fullgraph=True) + compiled_forward_lora = torch.compile(base_model.forward, dynamic=False, fullgraph=True) + else: + compiled_logits_lora = base_model.forward_logits + compiled_forward_lora = base_model.forward + + t0 = time.perf_counter() + + for ci in range(num_chunks): + windows = chunk_windows[ci] + if not windows: + continue + + chunk_start = ci * ttt_chunk + chunk_end = min((ci + 1) * ttt_chunk, total_tokens) + + my_s = (len(windows) * rank) // world_size + my_e = (len(windows) * (rank + 1)) // world_size + my_windows = windows[my_s:my_e] + + # Score phase (inference only) + base_model.eval() + with torch.inference_mode(): + for bi in range(0, len(my_windows), batch_seqs): + batch_ws = my_windows[bi:bi + batch_seqs] + bsz = len(batch_ws) + x_batch = torch.zeros(bsz, seq_len, dtype=torch.int64, device=device) + y_batch = torch.zeros(bsz, seq_len, dtype=torch.int64, device=device) + wlens: list[int] = [] + + for i, ws in enumerate(batch_ws): + end = min(ws + seq_len, total_tokens) + wlen = end - ws + wlens.append(wlen) + chunk_tok = val_tokens[ws:end + 1].to(dtype=torch.int64, device=device) + x_batch[i, :wlen] = chunk_tok[:-1] + y_batch[i, :wlen] = chunk_tok[1:] + + with torch.autocast(device_type="cuda", dtype=torch.bfloat16): + logits = compiled_logits_lora(x_batch) + + nll = F.cross_entropy( + logits.reshape(-1, logits.size(-1)).float(), + y_batch.reshape(-1), reduction="none", + ).reshape(bsz, seq_len) + + for i, ws in enumerate(batch_ws): + wlen = wlens[i] + s = 0 if ws == 0 else max(wlen - stride, 0) + scored_nll = nll[i, s:wlen].to(torch.float64) + loss_sum += scored_nll.sum() + token_count += float(wlen - s) + + tgt, prev = y_batch[i, s:wlen], x_batch[i, s:wlen] + tb = base_bytes_lut[tgt].to(torch.float64) + tb += (has_leading_space_lut[tgt] & ~is_boundary_token_lut[prev]).to(torch.float64) + byte_count += tb.sum() + + # Train phase (update LoRA params only, skip last chunk) + is_last_chunk = (ci == num_chunks - 1) + if not is_last_chunk and args.ttt_epochs > 0: + base_model.train() + chunk_seqs = (chunk_end - chunk_start) // seq_len + + if chunk_seqs > 0: + # Cosine LR decay + cos_lr = args.ttt_lora_lr * 0.5 * (1.0 + math.cos(math.pi * ci / max(num_chunks - 1, 1))) + for pg in optimizer.param_groups: + pg['lr'] = cos_lr + + my_seq_s = (chunk_seqs * rank) // world_size + my_seq_e = (chunk_seqs * (rank + 1)) // world_size + my_chunk_seqs = my_seq_e - my_seq_s + + for _ep in range(args.ttt_epochs): + for bs in range(0, my_chunk_seqs, args.ttt_batch_seqs): + be = min(bs + args.ttt_batch_seqs, my_chunk_seqs) + actual_bs = my_seq_s + bs + start_tok = chunk_start + actual_bs * seq_len + end_tok = chunk_start + (my_seq_s + be) * seq_len + 1 + + if end_tok > val_tokens.numel(): + continue + + local = val_tokens[start_tok:end_tok].to(device=device, dtype=torch.int64) + x = local[:-1].reshape(-1, seq_len) + y = local[1:].reshape(-1, seq_len) + + optimizer.zero_grad(set_to_none=True) + + with torch.autocast(device_type="cuda", dtype=torch.bfloat16): + loss = compiled_forward_lora(x, y) + + loss.backward() + + # All-reduce LoRA gradients across GPUs + if world_size > 1: + for p in lora_params: + if p.grad is not None: + dist.all_reduce(p.grad, op=dist.ReduceOp.AVG) + + # Gradient clipping on LoRA params + torch.nn.utils.clip_grad_norm_(lora_params, args.ttt_grad_clip) + optimizer.step() + + if rank == 0 and (ci % 10 == 0 or ci == num_chunks - 1): + elapsed = time.perf_counter() - t0 + rl = loss_sum.item() / max(token_count.item(), 1) + rbpb = rl / math.log(2.0) * (token_count.item() / max(byte_count.item(), 1)) if token_count.item() > 0 else 0.0 + log0(f" lora_ttt_chunk [{ci+1}/{num_chunks}] bpb={rbpb:.6f} time={elapsed:.1f}s") + + if dist.is_available() and dist.is_initialized(): + dist.all_reduce(loss_sum, op=dist.ReduceOp.SUM) + dist.all_reduce(token_count, op=dist.ReduceOp.SUM) + dist.all_reduce(byte_count, op=dist.ReduceOp.SUM) + + val_loss = (loss_sum / token_count).item() + val_bpb = val_loss / math.log(2.0) * (token_count.item() / byte_count.item()) + + # Clean up: remove all LoRA hooks and restore base model + for adapter in lora_adapters: + adapter.remove() + + # Restore requires_grad on base model + for p in base_model.parameters(): + p.requires_grad_(True) + + base_model.eval() + + log0(f"lora_ttt_sliding:done val_loss={val_loss:.6f} val_bpb={val_bpb:.6f} " + f"elapsed={time.perf_counter() - t0:.1f}s") + + return val_loss, val_bpb + +def eval_val_sliding_ttt( + args: Hyperparameters, + base_model: nn.Module, + rank: int, + world_size: int, + device: torch.device, + val_tokens: Tensor, + base_bytes_lut: Tensor, + has_leading_space_lut: Tensor, + is_boundary_token_lut: Tensor, + stride: int, + batch_seqs: int = 32, + log0=print, +) -> tuple[float, float]: + seq_len = args.train_seq_len + total_tokens = val_tokens.numel() - 1 + ttt_chunk = args.ttt_chunk_tokens + window_starts = [ws for ws in range(0, total_tokens, stride) + if min(ws + seq_len, total_tokens) - ws >= stride or ws == 0] + num_chunks = (total_tokens + ttt_chunk - 1) // ttt_chunk + chunk_windows: list[list[int]] = [[] for _ in range(num_chunks)] + for ws in window_starts: + end = min(ws + seq_len, total_tokens) + wlen = end - ws + s = 0 if ws == 0 else max(wlen - stride, 0) + scored_start = ws + s + ci = min(scored_start // ttt_chunk, num_chunks - 1) + chunk_windows[ci].append(ws) + log0(f"ttt_sliding:start chunks={num_chunks} chunk_tokens={ttt_chunk} " + f"total_windows={len(window_starts)} stride={stride} " + f"ttt_lr={args.ttt_lr} ttt_epochs={args.ttt_epochs} " + f"freeze_blocks={args.ttt_freeze_blocks}") + loss_sum = torch.zeros((), device=device, dtype=torch.float64) + token_count = torch.zeros((), device=device, dtype=torch.float64) + byte_count = torch.zeros((), device=device, dtype=torch.float64) + frozen_block_ids = set(range(min(args.ttt_freeze_blocks, len(base_model.blocks)))) + ttt_params = [] + for name, p in base_model.named_parameters(): + freeze = False + for bi in frozen_block_ids: + if f"blocks.{bi}." in name: + freeze = True + break + if freeze: + p.requires_grad_(False) + else: + p.requires_grad_(True) + ttt_params.append(p) + log0(f"ttt_sliding:params unfrozen={sum(p.numel() for p in ttt_params)} " + f"frozen={sum(p.numel() for p in base_model.parameters() if not p.requires_grad)}") + optimizer = torch.optim.SGD(ttt_params, lr=args.ttt_lr, momentum=args.ttt_momentum) + # Compile forward paths for TTT (matches eval_val_sliding behavior) + if not int(os.environ.get("TORCHDYNAMO_DISABLE", "0")): + compiled_logits_ttt = torch.compile(base_model.forward_logits, dynamic=False, fullgraph=True) + compiled_forward_ttt = torch.compile(base_model.forward, dynamic=False, fullgraph=True) + else: + compiled_logits_ttt = base_model.forward_logits + compiled_forward_ttt = base_model.forward + t0 = time.perf_counter() + for ci in range(num_chunks): + windows = chunk_windows[ci] + if not windows: + continue + chunk_start = ci * ttt_chunk + chunk_end = min((ci + 1) * ttt_chunk, total_tokens) + my_s = (len(windows) * rank) // world_size + my_e = (len(windows) * (rank + 1)) // world_size + my_windows = windows[my_s:my_e] + base_model.eval() + with torch.inference_mode(): + for bi in range(0, len(my_windows), batch_seqs): + batch_ws = my_windows[bi:bi + batch_seqs] + bsz = len(batch_ws) + x_batch = torch.zeros(bsz, seq_len, dtype=torch.int64, device=device) + y_batch = torch.zeros(bsz, seq_len, dtype=torch.int64, device=device) + wlens: list[int] = [] + for i, ws in enumerate(batch_ws): + end = min(ws + seq_len, total_tokens) + wlen = end - ws + wlens.append(wlen) + chunk_tok = val_tokens[ws:end + 1].to(dtype=torch.int64, device=device) + x_batch[i, :wlen] = chunk_tok[:-1] + y_batch[i, :wlen] = chunk_tok[1:] + with torch.autocast(device_type="cuda", dtype=torch.bfloat16): + logits = compiled_logits_ttt(x_batch) + nll = F.cross_entropy( + logits.reshape(-1, logits.size(-1)).float(), + y_batch.reshape(-1), reduction="none", + ).reshape(bsz, seq_len) + for i, ws in enumerate(batch_ws): + wlen = wlens[i] + s = 0 if ws == 0 else max(wlen - stride, 0) + scored_nll = nll[i, s:wlen].to(torch.float64) + loss_sum += scored_nll.sum() + token_count += float(wlen - s) + tgt, prev = y_batch[i, s:wlen], x_batch[i, s:wlen] + tb = base_bytes_lut[tgt].to(torch.float64) + tb += (has_leading_space_lut[tgt] & ~is_boundary_token_lut[prev]).to(torch.float64) + byte_count += tb.sum() + is_last_chunk = (ci == num_chunks - 1) + if not is_last_chunk and args.ttt_epochs > 0: + base_model.train() + chunk_seqs = (chunk_end - chunk_start) // seq_len + if chunk_seqs > 0: + cos_lr = args.ttt_lr * 0.5 * (1.0 + math.cos(math.pi * ci / max(num_chunks - 1, 1))) + for pg in optimizer.param_groups: + pg['lr'] = cos_lr + my_seq_s = (chunk_seqs * rank) // world_size + my_seq_e = (chunk_seqs * (rank + 1)) // world_size + my_chunk_seqs = my_seq_e - my_seq_s + for _ep in range(args.ttt_epochs): + for bs in range(0, my_chunk_seqs, args.ttt_batch_seqs): + be = min(bs + args.ttt_batch_seqs, my_chunk_seqs) + actual_bs = my_seq_s + bs + start_tok = chunk_start + actual_bs * seq_len + end_tok = chunk_start + (my_seq_s + be) * seq_len + 1 + if end_tok > val_tokens.numel(): + continue + local = val_tokens[start_tok:end_tok].to(device=device, dtype=torch.int64) + x = local[:-1].reshape(-1, seq_len) + y = local[1:].reshape(-1, seq_len) + optimizer.zero_grad(set_to_none=True) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16): + loss = compiled_forward_ttt(x, y) + loss.backward() + if world_size > 1: + for p in ttt_params: + if p.grad is not None: + dist.all_reduce(p.grad, op=dist.ReduceOp.AVG) + torch.nn.utils.clip_grad_norm_(ttt_params, args.ttt_grad_clip) + optimizer.step() + if rank == 0 and (ci % 10 == 0 or ci == num_chunks - 1): + elapsed = time.perf_counter() - t0 + rl = loss_sum.item() / max(token_count.item(), 1) + rbpb = rl / math.log(2.0) * (token_count.item() / max(byte_count.item(), 1)) if token_count.item() > 0 else 0.0 + log0(f" ttt_chunk [{ci+1}/{num_chunks}] bpb={rbpb:.6f} time={elapsed:.1f}s") + if dist.is_available() and dist.is_initialized(): + dist.all_reduce(loss_sum, op=dist.ReduceOp.SUM) + dist.all_reduce(token_count, op=dist.ReduceOp.SUM) + dist.all_reduce(byte_count, op=dist.ReduceOp.SUM) + val_loss = (loss_sum / token_count).item() + val_bpb = val_loss / math.log(2.0) * (token_count.item() / byte_count.item()) + for p in base_model.parameters(): + p.requires_grad_(True) + base_model.eval() + log0(f"ttt_sliding:done val_loss={val_loss:.6f} val_bpb={val_bpb:.6f} " + f"elapsed={time.perf_counter() - t0:.1f}s") + return val_loss, val_bpb +_INT6_MODE = bool(int(os.environ.get("INT6_QAT", "1"))) +_INT3_MODE = bool(int(os.environ.get("INT3_QUANT", "0"))) # Ternary quantization +_QUANT_OVERRIDE = int(os.environ.get("QUANT_BITS", "0")) # Override: 3,4,5,6,8 +if _QUANT_OVERRIDE > 0: + _QUANT_BITS = _QUANT_OVERRIDE + _QUANT_MAX_VAL = (1 << (_QUANT_BITS - 1)) - 1 # e.g. int5 -> 15, int4 -> 7 +elif _INT3_MODE: + _QUANT_MAX_VAL = 3 # Ternary: -3, -2, -1, 0, 1, 2, 3 + _QUANT_BITS = 3 +elif _INT6_MODE: + _QUANT_MAX_VAL = 31 + _QUANT_BITS = 6 +else: + _QUANT_MAX_VAL = 127 + _QUANT_BITS = 8 +CONTROL_TENSOR_NAME_PATTERNS = tuple( + pattern + for pattern in os.environ.get( + "CONTROL_TENSOR_NAME_PATTERNS", + "attn_scale,attn_scales,mlp_scale,mlp_scales,resid_mix,resid_mixes,q_gain,skip_weight,skip_weights,smear", + ).split(",") + if pattern +) +INT8_KEEP_FLOAT_FP32_NAME_PATTERNS = tuple( + pattern + for pattern in os.environ.get( + "INT8_KEEP_FLOAT_FP32_NAME_PATTERNS", + ",".join(CONTROL_TENSOR_NAME_PATTERNS), + ).split(",") + if pattern +) +INT8_KEEP_FLOAT_MAX_NUMEL = 65_536 +INT8_KEEP_FLOAT_STORE_DTYPE = torch.float16 +INT8_PER_ROW_SCALE_DTYPE = torch.float16 +INT8_CLIP_PERCENTILE = 99.99984 +INT8_CLIP_Q = INT8_CLIP_PERCENTILE / 100.0 +def tensor_nbytes(t: Tensor) -> int: + return int(t.numel()) * int(t.element_size()) +def keep_float_tensor(name: str, t: Tensor, passthrough_orig_dtypes: dict[str, str]) -> Tensor: + if any(pattern in name for pattern in INT8_KEEP_FLOAT_FP32_NAME_PATTERNS): + return t.float().contiguous() + if t.dtype in {torch.float32, torch.bfloat16}: + passthrough_orig_dtypes[name] = str(t.dtype).removeprefix("torch.") + return t.to(dtype=INT8_KEEP_FLOAT_STORE_DTYPE).contiguous() + return t +_GPTQ_LITE_PERCENTILES = [0.9999, 0.99995, 0.99999, 0.999995, 1.0] +def quantize_float_tensor(t: Tensor, gpu_device: torch.device | None = None) -> tuple[Tensor, Tensor]: + # Run GPTQ-lite search on GPU when available (much faster for large 2D weight matrices) + dev = gpu_device if (gpu_device is not None and t.ndim == 2 and t.numel() > 0) else t.device + t32 = t.float().to(dev) + qmax = _QUANT_MAX_VAL + if t32.ndim == 2 and t32.numel() > 0: + abs_vals = t32.abs() + best_q = None + best_scale = None + best_mse = None + for pct in _GPTQ_LITE_PERCENTILES: + if pct >= 1.0: + clip_abs = abs_vals.amax(dim=1) + else: + clip_abs = torch.quantile(abs_vals, pct, dim=1) + clip_abs = clip_abs.clamp_min(1e-8) + scale = (clip_abs / qmax).clamp_min(1.0 / qmax) + clipped = torch.maximum(torch.minimum(t32, clip_abs[:, None]), -clip_abs[:, None]) + q = torch.clamp(torch.round(clipped / scale[:, None]), -qmax, qmax) + recon = q * scale[:, None] + mse = (t32 - recon).square().mean(dim=1) + if best_mse is None: + best_mse = mse + best_q = q + best_scale = scale + else: + improved = mse < best_mse + if improved.any(): + best_mse = torch.where(improved, mse, best_mse) + best_q = torch.where(improved[:, None], q, best_q) + best_scale = torch.where(improved, scale, best_scale) + return best_q.cpu().to(torch.int8).contiguous(), best_scale.cpu().to(dtype=INT8_PER_ROW_SCALE_DTYPE).contiguous() + if t32.ndim == 2: + return torch.empty_like(t32, dtype=torch.int8, device="cpu"), torch.empty((t32.shape[0],), dtype=INT8_PER_ROW_SCALE_DTYPE) + t32 = t32.cpu() + clip_abs = float(torch.quantile(t32.abs().flatten(), INT8_CLIP_Q).item()) if t32.numel() else 0.0 + scale = torch.tensor(clip_abs / qmax if clip_abs > 0 else 1.0, dtype=torch.float32) + q = torch.clamp(torch.round(torch.clamp(t32, -clip_abs, clip_abs) / scale), -qmax, qmax).to(torch.int8).contiguous() + return q, scale +def quantize_state_dict_int8(state_dict: dict[str, Tensor], gpu_device: torch.device | None = None, + rank: int = 0, world_size: int = 1): + """Quantize state dict. When world_size > 1, distributes large tensor quantization across ranks.""" + quantized: dict[str, Tensor] = {} + scales: dict[str, Tensor] = {} + dtypes: dict[str, str] = {} + passthrough: dict[str, Tensor] = {} + passthrough_orig_dtypes: dict[str, str] = {} + qmeta: dict[str, dict[str, object]] = {} + stats = dict.fromkeys( + ("param_count", "num_tensors", "num_float_tensors", "num_nonfloat_tensors", "baseline_tensor_bytes", "int8_payload_bytes"), + 0, + ) + # Separate tensors into passthrough vs needs-quantization + to_quantize: list[tuple[str, Tensor]] = [] + for name, tensor in state_dict.items(): + t = tensor.detach().to("cpu").contiguous() + stats["param_count"] += int(t.numel()) + stats["num_tensors"] += 1 + stats["baseline_tensor_bytes"] += tensor_nbytes(t) + if not t.is_floating_point(): + stats["num_nonfloat_tensors"] += 1 + passthrough[name] = t + stats["int8_payload_bytes"] += tensor_nbytes(t) + continue + if t.numel() <= INT8_KEEP_FLOAT_MAX_NUMEL: + kept = keep_float_tensor(name, t, passthrough_orig_dtypes) + passthrough[name] = kept + stats["int8_payload_bytes"] += tensor_nbytes(kept) + continue + stats["num_float_tensors"] += 1 + to_quantize.append((name, t)) + # Distribute quantization work across ranks (round-robin by tensor index) + for idx, (name, t) in enumerate(to_quantize): + if world_size > 1 and idx % world_size != rank: + continue # Another rank handles this tensor + q, s = quantize_float_tensor(t, gpu_device=gpu_device) + if s.ndim > 0: + qmeta[name] = {"scheme": "per_row", "axis": 0} + quantized[name] = q + scales[name] = s + dtypes[name] = str(t.dtype).removeprefix("torch.") + stats["int8_payload_bytes"] += tensor_nbytes(q) + tensor_nbytes(s) + # Gather results from all ranks + if world_size > 1 and dist.is_available() and dist.is_initialized(): + # Each rank broadcasts its quantized tensors to all others + for idx, (name, t) in enumerate(to_quantize): + owner = idx % world_size + if owner == rank: + # Broadcast shape info then tensor data + q_tensor = quantized[name] + s_tensor = scales[name] + # Send via broadcast + q_gpu = q_tensor.to(gpu_device) if gpu_device is not None else q_tensor.cuda() + s_gpu = s_tensor.to(gpu_device) if gpu_device is not None else s_tensor.cuda() + dist.broadcast(q_gpu, src=owner) + dist.broadcast(s_gpu, src=owner) + quantized[name] = q_gpu.cpu() + scales[name] = s_gpu.cpu() + else: + # Allocate matching tensors and receive + q_shape = (t.shape[0],) + t.shape[1:] if t.ndim == 2 else t.shape + q_gpu = torch.empty(q_shape, dtype=torch.int8, device=gpu_device if gpu_device is not None else "cuda") + if t.ndim == 2 and t.numel() > 0: + s_gpu = torch.empty(t.shape[0], dtype=INT8_PER_ROW_SCALE_DTYPE, device=gpu_device if gpu_device is not None else "cuda") + else: + s_gpu = torch.empty((), dtype=torch.float32, device=gpu_device if gpu_device is not None else "cuda") + dist.broadcast(q_gpu, src=owner) + dist.broadcast(s_gpu, src=owner) + quantized[name] = q_gpu.cpu() + scales[name] = s_gpu.cpu() + dtypes[name] = str(t.dtype).removeprefix("torch.") + if s_gpu.ndim > 0: + qmeta[name] = {"scheme": "per_row", "axis": 0} + stats["int8_payload_bytes"] += tensor_nbytes(quantized[name]) + tensor_nbytes(scales[name]) + obj: dict[str, object] = { + "__quant_format__": f"int{_QUANT_BITS}_clean_per_row_v1", + "quantized": quantized, + "scales": scales, + "dtypes": dtypes, + "passthrough": passthrough, + } + if qmeta: + obj["qmeta"] = qmeta + if passthrough_orig_dtypes: + obj["passthrough_orig_dtypes"] = passthrough_orig_dtypes + return obj, stats +def dequantize_state_dict_int8(obj: dict[str, object]) -> dict[str, Tensor]: + out: dict[str, Tensor] = {} + qmeta = obj.get("qmeta", {}) + passthrough_orig_dtypes = obj.get("passthrough_orig_dtypes", {}) + for name, q in obj["quantized"].items(): + dtype = getattr(torch, obj["dtypes"][name]) + s = obj["scales"][name] + if qmeta.get(name, {}).get("scheme") == "per_row" or s.ndim > 0: + s = s.to(dtype=torch.float32) + out[name] = (q.float() * s.view(q.shape[0], *([1] * (q.ndim - 1)))).to(dtype=dtype).contiguous() + else: + scale = float(s.item()) + out[name] = (q.float() * scale).to(dtype=dtype).contiguous() + for name, t in obj["passthrough"].items(): + out_t = t.detach().to("cpu").contiguous() + orig_dtype = passthrough_orig_dtypes.get(name) + if isinstance(orig_dtype, str): + out_t = out_t.to(dtype=getattr(torch, orig_dtype)).contiguous() + out[name] = out_t + return out +def load_data_shard(file: Path) -> Tensor: + header_bytes = 256 * np.dtype(" None: + self.file_idx = (self.file_idx + 1) % len(self.files) + self.tokens = load_data_shard(self.files[self.file_idx]) + self.pos = 0 + def take(self, n: int) -> Tensor: + chunks: list[Tensor] = [] + remaining = n + while remaining > 0: + avail = self.tokens.numel() - self.pos + if avail <= 0: + self._advance_file() + continue + k = min(remaining, avail) + chunks.append(self.tokens[self.pos : self.pos + k]) + self.pos += k + remaining -= k + return chunks[0] if len(chunks) == 1 else torch.cat(chunks) +class DistributedTokenLoader: + def __init__(self, pattern: str, rank: int, world_size: int, device: torch.device): + self.rank = rank + self.world_size = world_size + self.device = device + self.stream = TokenStream(pattern) + def next_batch(self, global_tokens: int, seq_len: int, grad_accum_steps: int) -> tuple[Tensor, Tensor]: + local_tokens = global_tokens // (self.world_size * grad_accum_steps) + per_rank_span = local_tokens + 1 + chunk = self.stream.take(per_rank_span * self.world_size) + start = self.rank * per_rank_span + local = chunk[start : start + per_rank_span].to(dtype=torch.int64) + x = local[:-1].reshape(-1, seq_len) + y = local[1:].reshape(-1, seq_len) + return x.to(self.device, non_blocking=True), y.to(self.device, non_blocking=True) +class RMSNorm(nn.Module): + def __init__(self, eps: float | None = None): + super().__init__() + self.eps = eps + def forward(self, x: Tensor) -> Tensor: + return F.rms_norm(x, (x.size(-1),), eps=self.eps) +class CastedLinear(nn.Linear): + _int6_qat: bool = False + def forward(self, x: Tensor) -> Tensor: + w = self.weight.to(x.dtype) + if CastedLinear._int6_qat and self.training: + w = int6_ste(w) + bias = self.bias.to(x.dtype) if self.bias is not None else None + return F.linear(x, w, bias) +def restore_low_dim_params_to_fp32(module: nn.Module) -> None: + with torch.no_grad(): + for name, param in module.named_parameters(): + if (param.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS)) and param.dtype != torch.float32: + param.data = param.data.float() +class BigramHash(nn.Module): + def __init__(self, num_buckets: int, dim: int): + super().__init__() + self.num_buckets = num_buckets + self.emb = nn.Embedding(num_buckets, dim) + nn.init.zeros_(self.emb.weight) + def forward(self, input_ids: Tensor) -> Tensor: + prev = F.pad(input_ids[:, :-1], (1, 0), value=0) + bigram_hash = (prev * 31 + input_ids) % self.num_buckets + return self.emb(bigram_hash) +class SmearGate(nn.Module): + def __init__(self, dim: int): + super().__init__() + self.gate = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + def forward(self, x: Tensor) -> Tensor: + g = torch.sigmoid(self.gate.to(dtype=x.dtype))[None, None, :] + positions = torch.arange(1, x.shape[1] + 1, device=x.device, dtype=x.dtype).view(1, -1, 1) + smooth = torch.cumsum(x, dim=1) / positions + return g * x + (1 - g) * smooth +class Rotary(nn.Module): + def __init__(self, dim: int, base: float = 10000.0): + super().__init__() + inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.float32) / dim)) + self.register_buffer("inv_freq", inv_freq, persistent=False) + self._seq_len_cached = 0 + self._cos_cached: Tensor | None = None + self._sin_cached: Tensor | None = None + def forward(self, seq_len: int, device: torch.device, dtype: torch.dtype) -> tuple[Tensor, Tensor]: + if ( + self._cos_cached is None + or self._sin_cached is None + or self._seq_len_cached != seq_len + or self._cos_cached.device != device + ): + t = torch.arange(seq_len, device=device, dtype=self.inv_freq.dtype) + freqs = torch.outer(t, self.inv_freq.to(device)) + self._cos_cached = freqs.cos()[None, None, :, :] + self._sin_cached = freqs.sin()[None, None, :, :] + self._seq_len_cached = seq_len + return self._cos_cached.to(dtype=dtype), self._sin_cached.to(dtype=dtype) +def apply_rotary_emb(x: Tensor, cos: Tensor, sin: Tensor) -> Tensor: + half = x.size(-1) // 2 + x1, x2 = x[..., :half], x[..., half:] + return torch.cat((x1 * cos + x2 * sin, x1 * (-sin) + x2 * cos), dim=-1) +class CausalSelfAttention(nn.Module): + def __init__( + self, + dim: int, + num_heads: int, + num_kv_heads: int, + rope_base: float, + qk_gain_init: float, + rope_partial_dims: int = 0, + use_xsa: bool = False, + ve: ValueEmbedding = None, + ): + super().__init__() + if dim % num_heads != 0: + raise ValueError("model_dim must be divisible by num_heads") + if num_heads % num_kv_heads != 0: + raise ValueError("num_heads must be divisible by num_kv_heads") + self.num_heads = num_heads + self.num_kv_heads = num_kv_heads + self.head_dim = dim // num_heads + if self.head_dim % 2 != 0: + raise ValueError("head_dim must be even for RoPE") + kv_dim = self.num_kv_heads * self.head_dim + self.c_q = CastedLinear(dim, dim, bias=False) + self.c_k = CastedLinear(dim, kv_dim, bias=False) + self.c_v = CastedLinear(dim, kv_dim, bias=False) + self.proj = CastedLinear(dim, dim, bias=False) + self.proj._zero_init = True + self.q_gain = nn.Parameter(torch.full((num_heads,), qk_gain_init, dtype=torch.float32)) + rope_dim = rope_partial_dims if 0 < rope_partial_dims < self.head_dim else self.head_dim + self.rope_dim = rope_dim + self.rotary = Rotary(rope_dim, base=rope_base) + self.use_xsa = use_xsa + self.ve = ve + def forward(self, x: Tensor, input_ids: Tensor = None) -> Tensor: + bsz, seqlen, dim = x.shape + q = self.c_q(x).reshape(bsz, seqlen, self.num_heads, self.head_dim).transpose(1, 2) + k = self.c_k(x).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim).transpose(1, 2) + v = self.c_v(x).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim).transpose(1, 2) + if self.ve is not None and input_ids is not None: + ve_out = self.ve(input_ids) # [bsz, seq, kv_dim] + # Reshape to match V: [bsz, num_kv_heads, seq, head_dim] + ve_out = ve_out.view(bsz, seqlen, self.num_kv_heads, self.head_dim).transpose(1, 2) + v = v + ve_out + q = F.rms_norm(q, (q.size(-1),)) + k = F.rms_norm(k, (k.size(-1),)) + cos, sin = self.rotary(seqlen, x.device, q.dtype) + if self.rope_dim < self.head_dim: + q_rope, q_pass = q[..., :self.rope_dim], q[..., self.rope_dim:] + k_rope, k_pass = k[..., :self.rope_dim], k[..., self.rope_dim:] + q_rope = apply_rotary_emb(q_rope, cos, sin) + k_rope = apply_rotary_emb(k_rope, cos, sin) + q = torch.cat([q_rope, q_pass], dim=-1) + k = torch.cat([k_rope, k_pass], dim=-1) + else: + q = apply_rotary_emb(q, cos, sin) + k = apply_rotary_emb(k, cos, sin) + q = q * self.q_gain.to(dtype=q.dtype)[None, :, None, None] + y = F.scaled_dot_product_attention( + q, + k, + v, + attn_mask=None, + is_causal=True, + enable_gqa=(self.num_kv_heads != self.num_heads), + ) + if self.use_xsa: + if self.num_kv_heads != self.num_heads: + repeats = self.num_heads // self.num_kv_heads + v_expanded = v.repeat_interleave(repeats, dim=1) + else: + v_expanded = v + y = y - v_expanded + y = y.transpose(1, 2).contiguous().reshape(bsz, seqlen, dim) + return self.proj(y) +class MLP(nn.Module): + def __init__(self, dim: int, mlp_mult: int): + super().__init__() + hidden = mlp_mult * dim + self.fc = CastedLinear(dim, hidden, bias=False) + self.proj = CastedLinear(hidden, dim, bias=False) + self.proj._zero_init = True + def forward(self, x: Tensor) -> Tensor: + x = F.leaky_relu(self.fc(x), negative_slope=0.5) + return self.proj(x.square()) +class MoEMLP(nn.Module): + def __init__(self, dim: int, mlp_mult: int, num_experts: int): + super().__init__() + self.num_experts = num_experts + expert_hidden = (mlp_mult * dim) // num_experts + self.experts = nn.ModuleList([ + nn.ModuleDict({ + "fc": CastedLinear(dim, expert_hidden, bias=False), + "proj": CastedLinear(expert_hidden, dim, bias=False), + }) + for _ in range(num_experts) + ]) + for e in self.experts: + e["proj"]._zero_init = True + self.router = nn.Linear(dim, num_experts, bias=False) + def forward(self, x: Tensor) -> Tensor: + bsz, seqlen, dim = x.shape + logits = self.router(x.detach()) + weights = torch.softmax(logits, dim=-1) + indices = weights.argmax(dim=-1) + out = torch.zeros_like(x) + for i, expert in enumerate(self.experts): + mask = (indices == i) + if not mask.any(): + continue + tokens = x[mask] + h = F.leaky_relu(expert["fc"](tokens), negative_slope=0.5) + h = expert["proj"](h.square()) + out[mask] = h + return out +class ValueEmbedding(nn.Module): + def __init__(self, vocab_size: int, ve_dim: int, kv_dim: int): + super().__init__() + self.emb = nn.Embedding(vocab_size, ve_dim) + self.proj = nn.Linear(ve_dim, kv_dim, bias=False) + nn.init.zeros_(self.emb.weight) + nn.init.zeros_(self.proj.weight) + + def forward(self, token_ids: Tensor) -> Tensor: + return self.proj(self.emb(token_ids)) +class Block(nn.Module): + def __init__( + self, + dim: int, + num_heads: int, + num_kv_heads: int, + mlp_mult: int, + rope_base: float, + qk_gain_init: float, + use_smeargate: bool = False, + rope_partial_dims: int = 0, + use_xsa: bool = False, + ln_scale_factor: float = 1.0, + moe_num_experts: int = 0, + ve: ValueEmbedding = None, + ): + super().__init__() + self.attn_norm = RMSNorm() + self.mlp_norm = RMSNorm() + self.attn = CausalSelfAttention( + dim, num_heads, num_kv_heads, rope_base, qk_gain_init, + rope_partial_dims=rope_partial_dims, + use_xsa=use_xsa, + ve=ve, + ) + self.mlp = MoEMLP(dim, mlp_mult, moe_num_experts) if moe_num_experts > 1 else MLP(dim, mlp_mult) + self.attn_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + self.mlp_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + self.resid_mix = nn.Parameter(torch.stack((torch.ones(dim), torch.zeros(dim))).float()) + self.smear = SmearGate(dim) if use_smeargate else None + self.ln_scale_factor = ln_scale_factor + def forward(self, x: Tensor, x0: Tensor, input_ids: Tensor = None) -> Tensor: + mix = self.resid_mix.to(dtype=x.dtype) + x = mix[0][None, None, :] * x + mix[1][None, None, :] * x0 + normed = self.attn_norm(x) + if self.ln_scale_factor != 1.0: + normed = normed * self.ln_scale_factor + attn_out = self.attn(normed, input_ids=input_ids) + x = x + self.attn_scale.to(dtype=x.dtype)[None, None, :] * attn_out + mlp_normed = self.mlp_norm(x) + if self.ln_scale_factor != 1.0: + mlp_normed = mlp_normed * self.ln_scale_factor + x = x + self.mlp_scale.to(dtype=x.dtype)[None, None, :] * self.mlp(mlp_normed) + if self.smear is not None: + x = self.smear(x) + return x +class GPT(nn.Module): + def __init__( + self, + vocab_size: int, + num_layers: int, + model_dim: int, + num_heads: int, + num_kv_heads: int, + mlp_mult: int, + tie_embeddings: bool, + tied_embed_init_std: float, + logit_softcap: float, + rope_base: float, + qk_gain_init: float, + bigramhash_buckets: int = 0, + use_smeargate: bool = False, + unet_skips: bool = True, + rope_partial_dims: int = 0, + ln_scale: bool = False, + xsa_layers: int = 0, + recurrence_repeats: int = 0, + recurrence_unique_head: int = 2, + recurrence_unique_tail: int = 2, + moe_num_experts: int = 0, + ve_layer_indices: list = None, + ve_dim: int = 128, + ): + super().__init__() + if logit_softcap <= 0.0: + raise ValueError(f"logit_softcap must be positive, got {logit_softcap}") + self.tie_embeddings = tie_embeddings + self.tied_embed_init_std = tied_embed_init_std + self.logit_softcap = logit_softcap + self.tok_emb = nn.Embedding(vocab_size, model_dim) + self.bigram_hash = BigramHash(bigramhash_buckets, model_dim) if bigramhash_buckets > 0 else None + self.recurrence_repeats = recurrence_repeats + self.ve_layer_indices = ve_layer_indices if ve_layer_indices is not None else [] + if self.ve_layer_indices: + kv_dim = num_kv_heads * (model_dim // num_heads) + self.value_embedding = ValueEmbedding(vocab_size, ve_dim, kv_dim) + else: + self.value_embedding = None + if recurrence_repeats > 0: + self.num_encoder_layers = 0 + self.num_decoder_layers = 0 + self.num_skip_weights = 0 + self.skip_weights = nn.Parameter(torch.zeros(0, dtype=torch.float32)) + n_head = recurrence_unique_head + n_tail = recurrence_unique_tail + effective_layers = n_head + recurrence_repeats + n_tail + def make_block(layer_idx: int, eff_total: int) -> Block: + xsa_start_idx = eff_total - xsa_layers if xsa_layers > 0 else eff_total + return Block( + model_dim, num_heads, num_kv_heads, mlp_mult, rope_base, qk_gain_init, + use_smeargate=use_smeargate, + rope_partial_dims=rope_partial_dims, + use_xsa=(layer_idx >= xsa_start_idx), + ln_scale_factor=(1.0 / math.sqrt(layer_idx + 1)) if ln_scale else 1.0, + moe_num_experts=moe_num_experts, + ) + self.head_blocks = nn.ModuleList([make_block(i, effective_layers) for i in range(n_head)]) + self.shared_block = make_block(n_head, effective_layers) + self.recurrence_gate = nn.Parameter(torch.full((recurrence_repeats,), 0.5, dtype=torch.float32)) + self.tail_blocks = nn.ModuleList([ + make_block(n_head + recurrence_repeats + i, effective_layers) for i in range(n_tail) + ]) + self.blocks = nn.ModuleList() + else: + if unet_skips: + self.num_encoder_layers = num_layers // 2 + self.num_decoder_layers = num_layers - self.num_encoder_layers + self.num_skip_weights = min(self.num_encoder_layers, self.num_decoder_layers) + else: + self.num_encoder_layers = num_layers + self.num_decoder_layers = 0 + self.num_skip_weights = 0 + self.skip_weights = nn.Parameter(torch.ones(self.num_skip_weights, model_dim, dtype=torch.float32)) + xsa_start = num_layers - xsa_layers if xsa_layers > 0 else num_layers + self.blocks = nn.ModuleList( + [ + Block( + model_dim, num_heads, num_kv_heads, mlp_mult, rope_base, qk_gain_init, + use_smeargate=use_smeargate, + rope_partial_dims=rope_partial_dims, + use_xsa=(i >= xsa_start), + ln_scale_factor=(1.0 / math.sqrt(i + 1)) if ln_scale else 1.0, + moe_num_experts=moe_num_experts, + ve=self.value_embedding if i in self.ve_layer_indices else None, + ) + for i in range(num_layers) + ] + ) + self.head_blocks = nn.ModuleList() + self.shared_block = None + self.recurrence_gate = None + self.tail_blocks = nn.ModuleList() + self.final_norm = RMSNorm() + self.lm_head = None if tie_embeddings else CastedLinear(model_dim, vocab_size, bias=False) + if self.lm_head is not None: + self.lm_head._zero_init = True + self._init_weights() + def _init_weights(self) -> None: + if self.tie_embeddings: + nn.init.normal_(self.tok_emb.weight, mean=0.0, std=self.tied_embed_init_std) + for module in self.modules(): + if isinstance(module, nn.Linear) and getattr(module, "_zero_init", False): + nn.init.zeros_(module.weight) + def forward(self, input_ids: Tensor, target_ids: Tensor) -> Tensor: + x = self.tok_emb(input_ids) + if self.bigram_hash is not None: + x = x + self.bigram_hash(input_ids) + x = F.rms_norm(x, (x.size(-1),)) + x0 = x + if self.recurrence_repeats > 0: + for block in self.head_blocks: + x = block(x, x0) + for r in range(self.recurrence_repeats): + gate = torch.sigmoid(self.recurrence_gate[r]).to(dtype=x.dtype) + residual = x + x = self.shared_block(x, x0) + x = gate * x + (1.0 - gate) * residual + for block in self.tail_blocks: + x = block(x, x0) + else: + skips: list[Tensor] = [] + for i in range(self.num_encoder_layers): + x = self.blocks[i](x, x0, input_ids=input_ids if i in self.ve_layer_indices else None) + skips.append(x) + for i in range(self.num_decoder_layers): + if skips: + x = x + self.skip_weights[i].to(dtype=x.dtype)[None, None, :] * skips.pop() + block_idx = self.num_encoder_layers + i + x = self.blocks[block_idx](x, x0, input_ids=input_ids if block_idx in self.ve_layer_indices else None) + x = self.final_norm(x).reshape(-1, x.size(-1)) + targets = target_ids.reshape(-1) + if self.tie_embeddings: + logits_proj = F.linear(x, self.tok_emb.weight) + else: + if self.lm_head is None: + raise RuntimeError("lm_head is required when tie_embeddings=False") + logits_proj = self.lm_head(x) + logits = self.logit_softcap * torch.tanh(logits_proj / self.logit_softcap) + if getattr(self, '_return_logits', False): + return logits.float().reshape(input_ids.shape[0], input_ids.shape[1], -1) + return F.cross_entropy(logits.float(), targets, reduction="mean") + def forward_logits(self, input_ids: Tensor) -> Tensor: + x = self.tok_emb(input_ids) + if self.bigram_hash is not None: + x = x + self.bigram_hash(input_ids) + x = F.rms_norm(x, (x.size(-1),)) + x0 = x + if self.recurrence_repeats > 0: + for block in self.head_blocks: + x = block(x, x0) + for r in range(self.recurrence_repeats): + gate = torch.sigmoid(self.recurrence_gate[r]).to(dtype=x.dtype) + residual = x + x = self.shared_block(x, x0) + x = gate * x + (1.0 - gate) * residual + for block in self.tail_blocks: + x = block(x, x0) + else: + skips: list[Tensor] = [] + for i in range(self.num_encoder_layers): + x = self.blocks[i](x, x0, input_ids=input_ids if i in self.ve_layer_indices else None) + skips.append(x) + for i in range(self.num_decoder_layers): + if skips: + x = x + self.skip_weights[i].to(dtype=x.dtype)[None, None, :] * skips.pop() + block_idx = self.num_encoder_layers + i + x = self.blocks[block_idx](x, x0, input_ids=input_ids if block_idx in self.ve_layer_indices else None) + x = self.final_norm(x) + if self.tie_embeddings: + logits_proj = F.linear(x, self.tok_emb.weight) + else: + if self.lm_head is None: + raise RuntimeError("lm_head is required when tie_embeddings=False") + logits_proj = self.lm_head(x) + return self.logit_softcap * torch.tanh(logits_proj / self.logit_softcap) +def main() -> None: + global zeropower_via_newtonschulz5 + code = Path(__file__).read_text(encoding="utf-8") + args = Hyperparameters() + if not int(os.environ.get("TORCHDYNAMO_DISABLE", "0")): + zeropower_via_newtonschulz5 = torch.compile(zeropower_via_newtonschulz5) + distributed = "RANK" in os.environ and "WORLD_SIZE" in os.environ + rank = int(os.environ.get("RANK", "0")) + world_size = int(os.environ.get("WORLD_SIZE", "1")) + local_rank = int(os.environ.get("LOCAL_RANK", "0")) + if world_size <= 0: + raise ValueError(f"WORLD_SIZE must be positive, got {world_size}") + if 8 % world_size != 0: + raise ValueError(f"WORLD_SIZE={world_size} must divide 8 so grad_accum_steps stays integral") + grad_accum_steps = 8 // world_size + grad_scale = 1.0 / grad_accum_steps + if not torch.cuda.is_available(): + raise RuntimeError("CUDA is required") + device = torch.device("cuda", local_rank) + torch.cuda.set_device(device) + if distributed: + dist.init_process_group(backend="nccl", device_id=device) + dist.barrier() + master_process = rank == 0 + torch.backends.cuda.matmul.allow_tf32 = True + torch.backends.cudnn.allow_tf32 = True + from torch.backends.cuda import enable_cudnn_sdp, enable_flash_sdp, enable_math_sdp, enable_mem_efficient_sdp + _use_flash = sys.platform != "win32" + enable_cudnn_sdp(False) + enable_flash_sdp(_use_flash) + enable_mem_efficient_sdp(False) + enable_math_sdp(not _use_flash) + # Create per-run output directory so artifacts/logs never collide + run_dir = Path(os.environ.get("RUN_OUTPUT_DIR", f"runs/{args.run_id}")) + if master_process: + run_dir.mkdir(parents=True, exist_ok=True) + # Backward-compat: also keep logs/ symlink/copy + os.makedirs("logs", exist_ok=True) + logfile = None + if master_process: + logfile = str(run_dir / "log.txt") + # Also write to legacy logs/ path for compatibility + legacy_logfile = f"logs/{args.run_id}.txt" + if not Path(legacy_logfile).exists(): + try: + os.symlink(os.path.abspath(logfile), legacy_logfile) + except OSError: + pass # Windows or cross-device; will write to both + print(logfile) + def log0(msg: str, console: bool = True) -> None: + if not master_process: + return + if console: + print(msg) + if logfile is not None: + with open(logfile, "a", encoding="utf-8") as f: + print(msg, file=f) + log0(code, console=False) + log0("=" * 100, console=False) + log0(f"Running Python {sys.version}", console=False) + log0(f"Running PyTorch {torch.__version__}", console=False) + log0( + subprocess.run(["nvidia-smi"], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True, check=False).stdout, + console=False, + ) + log0("=" * 100, console=False) + random.seed(args.seed) + np.random.seed(args.seed) + torch.manual_seed(args.seed) + torch.cuda.manual_seed_all(args.seed) + if not args.tokenizer_path.endswith(".model"): + raise ValueError(f"Script only setup for SentencePiece .model file: {args.tokenizer_path}") + sp = spm.SentencePieceProcessor(model_file=args.tokenizer_path) + if int(sp.vocab_size()) != args.vocab_size: + raise ValueError( + f"VOCAB_SIZE={args.vocab_size} does not match tokenizer vocab_size={int(sp.vocab_size())}" + ) + dataset_dir = Path(args.data_path).resolve() + actual_train_files = len(list(dataset_dir.glob("fineweb_train_*.bin"))) + val_tokens = load_validation_tokens(args.val_files, args.train_seq_len) + base_bytes_lut, has_leading_space_lut, is_boundary_token_lut = build_sentencepiece_luts( + sp, args.vocab_size, device + ) + log0(f"val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path={args.tokenizer_path}") + log0(f"train_loader:dataset:{dataset_dir.name} train_shards:{actual_train_files}") + log0(f"val_loader:shards pattern={args.val_files} tokens:{val_tokens.numel() - 1}") + CastedLinear._int6_qat = args.int6_qat + ve_layer_indices = [int(x) for x in args.ve_layers.split(",") if x.strip()] if args.ve_enabled else [] + base_model = GPT( + vocab_size=args.vocab_size, + num_layers=args.num_layers, + model_dim=args.model_dim, + num_heads=args.num_heads, + num_kv_heads=args.num_kv_heads, + mlp_mult=args.mlp_mult, + tie_embeddings=args.tie_embeddings, + tied_embed_init_std=args.tied_embed_init_std, + logit_softcap=args.logit_softcap, + rope_base=args.rope_base, + qk_gain_init=args.qk_gain_init, + bigramhash_buckets=args.bigramhash_buckets, + use_smeargate=args.smeargate, + unet_skips=args.unet_skips, + rope_partial_dims=args.rope_partial_dims, + ln_scale=args.ln_scale, + xsa_layers=args.xsa_layers, + recurrence_repeats=args.recurrence_repeats, + recurrence_unique_head=args.recurrence_unique_head, + recurrence_unique_tail=args.recurrence_unique_tail, + moe_num_experts=args.moe_num_experts, + ve_layer_indices=ve_layer_indices, + ve_dim=args.ve_dim, + ).to(device).bfloat16() + for module in base_model.modules(): + if isinstance(module, CastedLinear): + module.float() + restore_low_dim_params_to_fp32(base_model) + if int(os.environ.get("TORCHDYNAMO_DISABLE", "0")): + compiled_model = base_model + else: + compiled_model = torch.compile(base_model, dynamic=False, fullgraph=True) + model: nn.Module = DDP(compiled_model, device_ids=[local_rank], broadcast_buffers=False) if distributed else compiled_model + if args.recurrence_repeats > 0: + block_named_params = ( + list(base_model.head_blocks.named_parameters()) + + list(base_model.shared_block.named_parameters()) + + list(base_model.tail_blocks.named_parameters()) + ) + if base_model.recurrence_gate is not None: + block_named_params.append(("recurrence_gate", base_model.recurrence_gate)) + else: + block_named_params = list(base_model.blocks.named_parameters()) + matrix_params = [ + p + for name, p in block_named_params + if p.ndim == 2 and not any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) + ] + scalar_params = [ + p + for name, p in block_named_params + if p.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) + ] + if base_model.skip_weights.numel() > 0: + scalar_params.append(base_model.skip_weights) + token_lr = args.tied_embed_lr if args.tie_embeddings else args.embed_lr + embed_params: list[nn.Parameter] = [base_model.tok_emb.weight] + if base_model.bigram_hash is not None: + embed_params.append(base_model.bigram_hash.emb.weight) + if base_model.value_embedding is not None: + embed_params.append(base_model.value_embedding.emb.weight) + embed_params.append(base_model.value_embedding.proj.weight) + optimizer_tok = torch.optim.Adam( + [{"params": embed_params, "lr": token_lr, "base_lr": token_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + fused=True, + ) + optimizer_muon = Muon( + matrix_params, + lr=args.matrix_lr, + momentum=args.muon_momentum, + backend_steps=args.muon_backend_steps, + ) + for group in optimizer_muon.param_groups: + group["base_lr"] = args.matrix_lr + optimizer_scalar = torch.optim.Adam( + [{"params": scalar_params, "lr": args.scalar_lr, "base_lr": args.scalar_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + fused=True, + ) + optimizers: list[torch.optim.Optimizer] = [optimizer_tok, optimizer_muon, optimizer_scalar] + if base_model.lm_head is not None: + optimizer_head = torch.optim.Adam( + [{"params": [base_model.lm_head.weight], "lr": args.head_lr, "base_lr": args.head_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + fused=True, + ) + optimizers.insert(1, optimizer_head) + n_params = sum(p.numel() for p in base_model.parameters()) + log0(f"model_params:{n_params}") + log0(f"world_size:{world_size} grad_accum_steps:{grad_accum_steps}") + log0(f"sdp_backends:cudnn=False flash={_use_flash} mem_efficient=False math={not _use_flash}") + log0(f"attention_mode:gqa num_heads:{args.num_heads} num_kv_heads:{args.num_kv_heads}") + log0( + f"tie_embeddings:{args.tie_embeddings} embed_lr:{token_lr} " + f"head_lr:{args.head_lr if base_model.lm_head is not None else 0.0} " + f"matrix_lr:{args.matrix_lr} scalar_lr:{args.scalar_lr}" + ) + log0( + f"train_batch_tokens:{args.train_batch_tokens} train_seq_len:{args.train_seq_len} " + f"iterations:{args.iterations} warmup_steps:{args.warmup_steps} " + f"max_wallclock_seconds:{args.max_wallclock_seconds:.3f}" + ) + log0(f"seed:{args.seed}") + log0( + f"meta_baseline: int6_qat:{args.int6_qat} bigramhash:{args.bigramhash_buckets} " + f"smeargate:{args.smeargate} ema_decay:{args.ema_decay} unet_skips:{args.unet_skips} " + f"warmdown_iters:{args.warmdown_iters} compression:{'zstd-22' if _HAS_ZSTD else 'zlib-9'}" + ) + log0( + f"strong_gains: leaky_relu_sq:True rope_partial_dims:{args.rope_partial_dims} " + f"ln_scale:{args.ln_scale} xsa_layers:{args.xsa_layers}" + ) + if args.recurrence_repeats > 0: + eff = args.recurrence_unique_head + args.recurrence_repeats + args.recurrence_unique_tail + log0( + f"depth_recurrence: head={args.recurrence_unique_head} repeats={args.recurrence_repeats} " + f"tail={args.recurrence_unique_tail} effective_layers={eff}" + ) + train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) + def zero_grad_all() -> None: + for opt in optimizers: + opt.zero_grad(set_to_none=True) + max_wallclock_ms = 1000.0 * args.max_wallclock_seconds if args.max_wallclock_seconds > 0 else None + def lr_mul(step: int, elapsed_ms: float) -> float: + if args.warmdown_iters <= 0: + return 1.0 + if max_wallclock_ms is None: + warmdown_start = max(args.iterations - args.warmdown_iters, 0) + return max((args.iterations - step) / max(args.warmdown_iters, 1), 0.0) if warmdown_start <= step < args.iterations else 1.0 + step_ms = elapsed_ms / max(step, 1) + warmdown_ms = args.warmdown_iters * step_ms + remaining_ms = max(max_wallclock_ms - elapsed_ms, 0.0) + return remaining_ms / max(warmdown_ms, 1e-9) if remaining_ms <= warmdown_ms else 1.0 + if args.warmup_steps > 0: + initial_model_state = {name: tensor.detach().cpu().clone() for name, tensor in base_model.state_dict().items()} + initial_optimizer_states = [copy.deepcopy(opt.state_dict()) for opt in optimizers] + model.train() + for warmup_step in range(args.warmup_steps): + zero_grad_all() + for micro_step in range(grad_accum_steps): + if distributed: + model.require_backward_grad_sync = micro_step == grad_accum_steps - 1 + x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + warmup_loss = model(x, y) + (warmup_loss * grad_scale).backward() + for opt in optimizers: + opt.step() + zero_grad_all() + if args.warmup_steps <= 20 or (warmup_step + 1) % 10 == 0 or warmup_step + 1 == args.warmup_steps: + log0(f"warmup_step:{warmup_step + 1}/{args.warmup_steps}") + base_model.load_state_dict(initial_model_state, strict=True) + for opt, state in zip(optimizers, initial_optimizer_states, strict=True): + opt.load_state_dict(state) + zero_grad_all() + if distributed: + model.require_backward_grad_sync = True + train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) + ema_state: dict[str, Tensor] | None = None + if args.ema_decay > 0: + ema_state = {name: param.data.clone() for name, param in base_model.named_parameters()} + log0(f"ema:enabled decay={args.ema_decay}") + training_time_ms = 0.0 + stop_after_step: int | None = None + torch.cuda.synchronize() + t0 = time.perf_counter() + step = 0 + while True: + last_step = step == args.iterations or (stop_after_step is not None and step >= stop_after_step) + should_validate = last_step or (args.val_loss_every > 0 and step % args.val_loss_every == 0) + if should_validate: + torch.cuda.synchronize() + training_time_ms += 1000.0 * (time.perf_counter() - t0) + ema_backup: dict[str, Tensor] | None = None + if ema_state is not None: + ema_backup = {} + with torch.no_grad(): + for name, p in base_model.named_parameters(): + ema_backup[name] = p.data.clone() + p.data.copy_(ema_state[name]) + val_loss, val_bpb = eval_val( + args, + model, + rank, + world_size, + device, + grad_accum_steps, + val_tokens, + base_bytes_lut, + has_leading_space_lut, + is_boundary_token_lut, + ) + if ema_backup is not None: + with torch.no_grad(): + for name, p in base_model.named_parameters(): + p.data.copy_(ema_backup[name]) + del ema_backup + log0( + f"step:{step}/{args.iterations} val_loss:{val_loss:.4f} val_bpb:{val_bpb:.4f} " + f"train_time:{training_time_ms:.0f}ms step_avg:{training_time_ms / max(step, 1):.2f}ms" + ) + torch.cuda.synchronize() + t0 = time.perf_counter() + if last_step: + if stop_after_step is not None and step < args.iterations: + log0( + f"stopping_early: wallclock_cap train_time:{training_time_ms:.0f}ms " + f"step:{step}/{args.iterations}" + ) + break + elapsed_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) + scale = lr_mul(step, elapsed_ms) + zero_grad_all() + train_loss = torch.zeros((), device=device) + for micro_step in range(grad_accum_steps): + if distributed: + model.require_backward_grad_sync = micro_step == grad_accum_steps - 1 + x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + loss = model(x, y) + train_loss += loss.detach() + (loss * grad_scale).backward() + train_loss /= grad_accum_steps + frac = min(step / args.muon_momentum_warmup_steps, 1.0) if args.muon_momentum_warmup_steps > 0 else 1.0 + muon_momentum = (1 - frac) * args.muon_momentum_warmup_start + frac * args.muon_momentum + for group in optimizer_muon.param_groups: + group["momentum"] = muon_momentum + for opt in optimizers: + for group in opt.param_groups: + group["lr"] = group["base_lr"] * scale + if args.grad_clip_norm > 0: + torch.nn.utils.clip_grad_norm_(base_model.parameters(), args.grad_clip_norm) + for opt in optimizers: + opt.step() + zero_grad_all() + if ema_state is not None: + with torch.no_grad(): + decay = args.ema_decay + for name, p in base_model.named_parameters(): + ema_state[name].mul_(decay).add_(p.data, alpha=1.0 - decay) + step += 1 + if args.checkpoint_every > 0 and step % args.checkpoint_every == 0 and master_process: + ckpt_dir = run_dir / "checkpoints" + ckpt_dir.mkdir(exist_ok=True) + ckpt_path = ckpt_dir / f"ckpt_step{step}.pt" + ckpt_data = {"step": step, "model": base_model.state_dict()} + if ema_state is not None: + ckpt_data["ema"] = {k: v.cpu() for k, v in ema_state.items()} + torch.save(ckpt_data, ckpt_path) + ckpts = sorted(ckpt_dir.glob("ckpt_step*.pt"), key=lambda p: p.stat().st_mtime) + for old in ckpts[:-2]: + old.unlink(missing_ok=True) + log0(f"checkpoint:saved {ckpt_path} (kept {min(len(ckpts), 2)} latest)") + approx_training_time_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) + should_log_train = ( + args.train_log_every > 0 + and (step <= 10 or step % args.train_log_every == 0 or stop_after_step is not None) + ) + if should_log_train: + log0( + f"step:{step}/{args.iterations} train_loss:{train_loss.item():.4f} " + f"train_time:{approx_training_time_ms:.0f}ms step_avg:{approx_training_time_ms / step:.2f}ms" + ) + reached_cap = max_wallclock_ms is not None and approx_training_time_ms >= max_wallclock_ms + if distributed and max_wallclock_ms is not None: + reached_cap_tensor = torch.tensor(int(reached_cap), device=device) + dist.all_reduce(reached_cap_tensor, op=dist.ReduceOp.MAX) + reached_cap = bool(reached_cap_tensor.item()) + if stop_after_step is None and reached_cap: + stop_after_step = step + log0( + f"peak memory allocated: {torch.cuda.max_memory_allocated() // 1024 // 1024} MiB " + f"reserved: {torch.cuda.max_memory_reserved() // 1024 // 1024} MiB" + ) + if ema_state is not None: + with torch.no_grad(): + for name, p in base_model.named_parameters(): + p.data.copy_(ema_state[name]) + log0("ema:applied EMA weights for export") + fp32_path = str(run_dir / "final_model.pt") + if master_process: + torch.save(base_model.state_dict(), fp32_path) + model_bytes = os.path.getsize(fp32_path) + code_bytes = len(code.encode("utf-8")) + log0(f"Serialized model: {model_bytes} bytes → {fp32_path}") + log0(f"Code size: {code_bytes} bytes") + log0(f"Total submission size: {model_bytes + code_bytes} bytes") + quant_label = f"int{_QUANT_BITS}" + compress_label = "zstd" if _HAS_ZSTD else "zlib" + log0(f"quantization:start {quant_label} distributing across {world_size} GPUs") + torch.cuda.synchronize() + t_quant = time.perf_counter() + quant_obj, quant_stats = quantize_state_dict_int8( + base_model.state_dict(), gpu_device=device, rank=rank, world_size=world_size, + ) + torch.cuda.synchronize() + log0(f"quantization:done {quant_label} time={1000.0 * (time.perf_counter() - t_quant):.0f}ms") + quant_buf = io.BytesIO() + torch.save(quant_obj, quant_buf) + quant_raw = quant_buf.getvalue() + if _HAS_ZSTD: + compressor = zstd.ZstdCompressor(level=22, threads=-1) # -1 = use all CPU cores + quant_blob = compressor.compress(quant_raw) + else: + quant_blob = zlib.compress(quant_raw, level=9) + quant_raw_bytes = len(quant_raw) + artifact_name = str(run_dir / f"final_model.{quant_label}.ptz") + if master_process: + with open(artifact_name, "wb") as f: + f.write(quant_blob) + quant_file_bytes = os.path.getsize(artifact_name) + code_bytes = len(code.encode("utf-8")) + ratio = quant_stats["baseline_tensor_bytes"] / max(quant_stats["int8_payload_bytes"], 1) + log0( + f"Serialized model {quant_label}+{compress_label}: {quant_file_bytes} bytes " + f"(payload:{quant_stats['int8_payload_bytes']} raw_torch:{quant_raw_bytes} payload_ratio:{ratio:.2f}x)" + ) + log0(f"Total submission size {quant_label}+{compress_label}: {quant_file_bytes + code_bytes} bytes") + log0(f"artifact_path:{artifact_name}") + if distributed: + dist.barrier() + with open(artifact_name, "rb") as f: + quant_blob_disk = f.read() + if _HAS_ZSTD: + decompressor = zstd.ZstdDecompressor() + quant_state = torch.load(io.BytesIO(decompressor.decompress(quant_blob_disk)), map_location="cpu") + else: + quant_state = torch.load(io.BytesIO(zlib.decompress(quant_blob_disk)), map_location="cpu") + base_model.load_state_dict(dequantize_state_dict_int8(quant_state), strict=True) + torch.cuda.synchronize() + t_qeval = time.perf_counter() + q_val_loss, q_val_bpb = eval_val( + args, + model, + rank, + world_size, + device, + grad_accum_steps, + val_tokens, + base_bytes_lut, + has_leading_space_lut, + is_boundary_token_lut, + ) + torch.cuda.synchronize() + log0( + f"final_{quant_label}_{compress_label}_roundtrip val_loss:{q_val_loss:.4f} val_bpb:{q_val_bpb:.4f} " + f"eval_time:{1000.0 * (time.perf_counter() - t_qeval):.0f}ms" + ) + log0(f"final_{quant_label}_{compress_label}_roundtrip_exact val_loss:{q_val_loss:.8f} val_bpb:{q_val_bpb:.8f}") + if args.eval_stride > 0 and args.eval_stride < args.train_seq_len: + torch.cuda.synchronize() + t_slide = time.perf_counter() + sw_val_loss, sw_val_bpb = eval_val_sliding( + args, base_model, rank, world_size, device, + val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, + stride=args.eval_stride, + ) + torch.cuda.synchronize() + log0( + f"final_sliding_window val_loss:{sw_val_loss:.4f} val_bpb:{sw_val_bpb:.4f} " + f"stride:{args.eval_stride} eval_time:{1000.0 * (time.perf_counter() - t_slide):.0f}ms" + ) + log0(f"final_sliding_window_exact val_loss:{sw_val_loss:.8f} val_bpb:{sw_val_bpb:.8f}") + if args.ttt_enabled: + torch.cuda.synchronize() + t_ttt = time.perf_counter() + + # Choose LoRA TTT or full TTT based on TTT_LORA_ENABLED flag + if args.ttt_lora_enabled: + ttt_loss, ttt_bpb = eval_val_sliding_lora_ttt( + args, base_model, rank, world_size, device, + val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, + stride=args.eval_stride if args.eval_stride > 0 else 64, + log0=log0, + ) + ttt_label = "legal_lora_ttt" + else: + ttt_loss, ttt_bpb = eval_val_sliding_ttt( + args, base_model, rank, world_size, device, + val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, + stride=args.eval_stride if args.eval_stride > 0 else 64, + log0=log0, + ) + ttt_label = "legal_ttt" + + torch.cuda.synchronize() + log0( + f"{ttt_label} val_loss:{ttt_loss:.4f} val_bpb:{ttt_bpb:.4f} " + f"eval_time:{1000.0 * (time.perf_counter() - t_ttt):.0f}ms" + ) + log0(f"{ttt_label}_exact val_loss:{ttt_loss:.8f} val_bpb:{ttt_bpb:.8f}") + if int(os.environ.get("NGRAM_EVAL", "0")): + ngram_max_order = int(os.environ.get("NGRAM_MAX_ORDER", "9")) + log0(f"ngram_eval:starting max_order={ngram_max_order}") + torch.cuda.synchronize() + t_ngram = time.perf_counter() + try: + from ngram_eval import eval_val_ngram + ng_val_loss, ng_val_bpb = eval_val_ngram( + args, base_model, model, rank, world_size, device, grad_accum_steps, + val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, + max_order=ngram_max_order, log_fn=log0, + ) + torch.cuda.synchronize() + log0( + f"ngram_eval_result val_loss:{ng_val_loss:.4f} val_bpb:{ng_val_bpb:.4f} " + f"eval_time:{1000.0 * (time.perf_counter() - t_ngram):.0f}ms" + ) + log0(f"ngram_eval_result_exact val_loss:{ng_val_loss:.8f} val_bpb:{ng_val_bpb:.8f}") + except Exception as e: + log0(f"ngram_eval:FAILED {e}") + if distributed: + dist.destroy_process_group() +if __name__ == "__main__": + main() diff --git a/train_gpt_phase5_full_gptq.py b/train_gpt_phase5_full_gptq.py new file mode 100644 index 0000000000..2c93a8dd85 --- /dev/null +++ b/train_gpt_phase5_full_gptq.py @@ -0,0 +1,2054 @@ +from __future__ import annotations +import copy +import glob +import io +import math +import os +import random +import subprocess +import sys +import time +import uuid +import zlib +from pathlib import Path +try: + import zstandard as zstd + _HAS_ZSTD = True +except ImportError: + _HAS_ZSTD = False +import numpy as np +import sentencepiece as spm +import torch +import torch.distributed as dist +import torch.nn.functional as F +from torch import Tensor, nn +from torch.nn.parallel import DistributedDataParallel as DDP +class Hyperparameters: + data_path = os.environ.get("DATA_PATH", "./data/datasets/fineweb10B_sp1024") + train_files = os.path.join(data_path, "fineweb_train_*.bin") + val_files = os.path.join(data_path, "fineweb_val_*.bin") + tokenizer_path = os.environ.get("TOKENIZER_PATH", "./data/tokenizers/fineweb_1024_bpe.model") + run_id = os.environ.get("RUN_ID", str(uuid.uuid4())) + seed = int(os.environ.get("SEED", 1337)) + val_batch_size = int(os.environ.get("VAL_BATCH_SIZE", 524_288)) + val_loss_every = int(os.environ.get("VAL_LOSS_EVERY", 1000)) + train_log_every = int(os.environ.get("TRAIN_LOG_EVERY", 200)) + val_max_tokens = int(os.environ.get("VAL_MAX_TOKENS", 0)) + checkpoint_every = int(os.environ.get("CHECKPOINT_EVERY", 0)) + iterations = int(os.environ.get("ITERATIONS", 20000)) + warmdown_iters = int(os.environ.get("WARMDOWN_ITERS", 3500)) + warmup_steps = int(os.environ.get("WARMUP_STEPS", 20)) + train_batch_tokens = int(os.environ.get("TRAIN_BATCH_TOKENS", 524_288)) + train_seq_len = int(os.environ.get("TRAIN_SEQ_LEN", 1024)) + max_wallclock_seconds = float(os.environ.get("MAX_WALLCLOCK_SECONDS", 600.0)) + qk_gain_init = float(os.environ.get("QK_GAIN_INIT", 1.5)) + vocab_size = int(os.environ.get("VOCAB_SIZE", 1024)) + num_layers = int(os.environ.get("NUM_LAYERS", 11)) + num_kv_heads = int(os.environ.get("NUM_KV_HEADS", 4)) + model_dim = int(os.environ.get("MODEL_DIM", 512)) + num_heads = int(os.environ.get("NUM_HEADS", 8)) + mlp_mult = int(os.environ.get("MLP_MULT", 3)) + tie_embeddings = bool(int(os.environ.get("TIE_EMBEDDINGS", "1"))) + rope_base = float(os.environ.get("ROPE_BASE", 10000.0)) + logit_softcap = float(os.environ.get("LOGIT_SOFTCAP", 30.0)) + embed_lr = float(os.environ.get("EMBED_LR", 0.6)) + head_lr = float(os.environ.get("HEAD_LR", 0.008)) + tied_embed_lr = float(os.environ.get("TIED_EMBED_LR", 0.05)) + tied_embed_init_std = float(os.environ.get("TIED_EMBED_INIT_STD", 0.005)) + matrix_lr = float(os.environ.get("MATRIX_LR", 0.04)) + scalar_lr = float(os.environ.get("SCALAR_LR", 0.04)) + muon_momentum = float(os.environ.get("MUON_MOMENTUM", 0.95)) + muon_backend_steps = int(os.environ.get("MUON_BACKEND_STEPS", 5)) + muon_momentum_warmup_start = float(os.environ.get("MUON_MOMENTUM_WARMUP_START", 0.85)) + muon_momentum_warmup_steps = int(os.environ.get("MUON_MOMENTUM_WARMUP_STEPS", 500)) + beta1 = float(os.environ.get("BETA1", 0.9)) + beta2 = float(os.environ.get("BETA2", 0.95)) + adam_eps = float(os.environ.get("ADAM_EPS", 1e-8)) + grad_clip_norm = float(os.environ.get("GRAD_CLIP_NORM", 0.0)) + bigramhash_buckets = int(os.environ.get("BIGRAMHASH_BUCKETS", 4096)) + ema_decay = float(os.environ.get("EMA_DECAY", 0.997)) + smeargate = bool(int(os.environ.get("SMEARGATE", "1"))) + unet_skips = bool(int(os.environ.get("UNET_SKIPS", "1"))) + int6_qat = bool(int(os.environ.get("INT6_QAT", "1"))) + rope_partial_dims = int(os.environ.get("ROPE_PARTIAL_DIMS", 16)) + ln_scale = bool(int(os.environ.get("LN_SCALE", "1"))) + xsa_layers = int(os.environ.get("XSA_LAYERS", 4)) + eval_stride = int(os.environ.get("EVAL_STRIDE", 0)) + ttt_enabled = bool(int(os.environ.get("TTT_ENABLED", "0"))) + ttt_lr = float(os.environ.get("TTT_LR", 0.002)) + ttt_epochs = int(os.environ.get("TTT_EPOCHS", 3)) + ttt_chunk_tokens = int(os.environ.get("TTT_CHUNK_TOKENS", 32768)) + ttt_freeze_blocks = int(os.environ.get("TTT_FREEZE_BLOCKS", 0)) + ttt_momentum = float(os.environ.get("TTT_MOMENTUM", 0.9)) + ttt_batch_seqs = int(os.environ.get("TTT_BATCH_SEQS", 32)) + ttt_grad_clip = float(os.environ.get("TTT_GRAD_CLIP", 1.0)) + ttt_lora_enabled = bool(int(os.environ.get("TTT_LORA_ENABLED", "0"))) + ttt_lora_rank = int(os.environ.get("TTT_LORA_RANK", 8)) + ttt_lora_lr = float(os.environ.get("TTT_LORA_LR", 0.001)) + recurrence_repeats = int(os.environ.get("RECURRENCE_REPEATS", 0)) + recurrence_unique_head = int(os.environ.get("RECURRENCE_UNIQUE_HEAD", 2)) + recurrence_unique_tail = int(os.environ.get("RECURRENCE_UNIQUE_TAIL", 2)) + moe_num_experts = int(os.environ.get("MOE_NUM_EXPERTS", 0)) + gptq_full_hessian = bool(int(os.environ.get("GPTQ_FULL_HESSIAN", "1"))) # default ON + gptq_damp = float(os.environ.get("GPTQ_DAMP", 0.005)) # Hessian damping factor + gptq_blocksize = int(os.environ.get("GPTQ_BLOCKSIZE", 128)) + gptq_actorder = bool(int(os.environ.get("GPTQ_ACTORDER", "1"))) # column reordering + gptq_ar_selfgen = bool(int(os.environ.get("GPTQ_AR_SELFGEN", "1"))) # AR self-gen calibration + gptq_calib_samples = int(os.environ.get("GPTQ_CALIB_SAMPLES", 64)) + gptq_calib_seqlen = int(os.environ.get("GPTQ_CALIB_SEQLEN", 2048)) + gptq_calib_temp = float(os.environ.get("GPTQ_CALIB_TEMP", 0.8)) +def zeropower_via_newtonschulz5(G: Tensor, steps: int = 10, eps: float = 1e-7) -> Tensor: + a, b, c = (3.4445, -4.7750, 2.0315) + X = G.bfloat16() + X /= X.norm() + eps + transposed = G.size(0) > G.size(1) + if transposed: + X = X.T + for _ in range(steps): + A = X @ X.T + B = b * A + c * A @ A + X = a * X + B @ X + return X.T if transposed else X +class Muon(torch.optim.Optimizer): + def __init__(self, params, lr: float, momentum: float, backend_steps: int, nesterov: bool = True): + super().__init__( + params, + dict(lr=lr, momentum=momentum, backend_steps=backend_steps, nesterov=nesterov), + ) + @torch.no_grad() + def step(self, closure=None): + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + distributed = dist.is_available() and dist.is_initialized() + world_size = dist.get_world_size() if distributed else 1 + rank = dist.get_rank() if distributed else 0 + for group in self.param_groups: + params = group["params"] + if not params: + continue + lr = group["lr"] + momentum = group["momentum"] + backend_steps = group["backend_steps"] + nesterov = group["nesterov"] + total_params = sum(int(p.numel()) for p in params) + updates_flat = torch.zeros(total_params, device=params[0].device, dtype=torch.bfloat16) + curr = 0 + for i, p in enumerate(params): + if i % world_size == rank and p.grad is not None: + g = p.grad + state = self.state[p] + if "momentum_buffer" not in state: + state["momentum_buffer"] = torch.zeros_like(g) + buf = state["momentum_buffer"] + buf.mul_(momentum).add_(g) + if nesterov: + g = g.add(buf, alpha=momentum) + g = zeropower_via_newtonschulz5(g, steps=backend_steps) + g *= max(1, g.size(0) / g.size(1)) ** 0.5 + updates_flat[curr : curr + p.numel()] = g.reshape(-1) + curr += p.numel() + if distributed: + dist.all_reduce(updates_flat, op=dist.ReduceOp.SUM) + curr = 0 + for p in params: + g = updates_flat[curr : curr + p.numel()].view_as(p).to(dtype=p.dtype) + p.add_(g, alpha=-lr) + curr += p.numel() + return loss +def int6_ste(w: Tensor) -> Tensor: + qmax = _QUANT_MAX_VAL # Uses QUANT_BITS override if set (e.g. 15 for int5, 7 for int4) + if w.ndim == 2: + abs_max = w.abs().amax(dim=1, keepdim=True).clamp_min(1e-8) + else: + abs_max = w.abs().amax().clamp_min(1e-8) + scale = abs_max / float(qmax) + w_q = (w / scale).round().clamp(-qmax, qmax) * scale + return w + (w_q - w).detach() +def build_sentencepiece_luts( + sp: spm.SentencePieceProcessor, vocab_size: int, device: torch.device +) -> tuple[Tensor, Tensor, Tensor]: + sp_vocab_size = int(sp.vocab_size()) + table_size = max(sp_vocab_size, vocab_size) + base_bytes_np = np.zeros((table_size,), dtype=np.int16) + has_leading_space_np = np.zeros((table_size,), dtype=np.bool_) + is_boundary_token_np = np.ones((table_size,), dtype=np.bool_) + for token_id in range(sp_vocab_size): + if sp.is_control(token_id) or sp.is_unknown(token_id) or sp.is_unused(token_id): + continue + is_boundary_token_np[token_id] = False + if sp.is_byte(token_id): + base_bytes_np[token_id] = 1 + continue + piece = sp.id_to_piece(token_id) + if piece.startswith("▁"): + has_leading_space_np[token_id] = True + piece = piece[1:] + base_bytes_np[token_id] = len(piece.encode("utf-8")) + return ( + torch.tensor(base_bytes_np, dtype=torch.int16, device=device), + torch.tensor(has_leading_space_np, dtype=torch.bool, device=device), + torch.tensor(is_boundary_token_np, dtype=torch.bool, device=device), + ) +def load_validation_tokens(pattern: str, seq_len: int) -> Tensor: + files = [Path(p) for p in sorted(glob.glob(pattern))] + if not files: + raise FileNotFoundError(f"No files found for pattern: {pattern}") + tokens = torch.cat([load_data_shard(file) for file in files]).contiguous() + max_tokens = int(os.environ.get("VAL_MAX_TOKENS", 0)) + if max_tokens > 0 and tokens.numel() > max_tokens: + tokens = tokens[:max_tokens].contiguous() + usable = ((tokens.numel() - 1) // seq_len) * seq_len + if usable <= 0: + raise ValueError(f"Validation split is too short for TRAIN_SEQ_LEN={seq_len}") + return tokens[: usable + 1] +def eval_val( + args: Hyperparameters, + model: nn.Module, + rank: int, + world_size: int, + device: torch.device, + grad_accum_steps: int, + val_tokens: Tensor, + base_bytes_lut: Tensor, + has_leading_space_lut: Tensor, + is_boundary_token_lut: Tensor, +) -> tuple[float, float]: + local_batch_tokens = args.val_batch_size // (world_size * grad_accum_steps) + if local_batch_tokens < args.train_seq_len: + raise ValueError( + "VAL_BATCH_SIZE must provide at least one sequence per rank; " + f"got VAL_BATCH_SIZE={args.val_batch_size}, WORLD_SIZE={world_size}, " + f"GRAD_ACCUM_STEPS={grad_accum_steps}, TRAIN_SEQ_LEN={args.train_seq_len}" + ) + local_batch_seqs = local_batch_tokens // args.train_seq_len + total_seqs = (val_tokens.numel() - 1) // args.train_seq_len + seq_start = (total_seqs * rank) // world_size + seq_end = (total_seqs * (rank + 1)) // world_size + val_loss_sum = torch.zeros((), device=device, dtype=torch.float64) + val_token_count = torch.zeros((), device=device, dtype=torch.float64) + val_byte_count = torch.zeros((), device=device, dtype=torch.float64) + model.eval() + with torch.inference_mode(): + for batch_seq_start in range(seq_start, seq_end, local_batch_seqs): + batch_seq_end = min(batch_seq_start + local_batch_seqs, seq_end) + raw_start = batch_seq_start * args.train_seq_len + raw_end = batch_seq_end * args.train_seq_len + 1 + local = val_tokens[raw_start:raw_end].to(device=device, dtype=torch.int64, non_blocking=True) + x = local[:-1].reshape(-1, args.train_seq_len) + y = local[1:].reshape(-1, args.train_seq_len) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + batch_loss = model(x, y).detach() + batch_token_count = float(y.numel()) + val_loss_sum += batch_loss.to(torch.float64) * batch_token_count + val_token_count += batch_token_count + prev_ids = x.reshape(-1) + tgt_ids = y.reshape(-1) + token_bytes = base_bytes_lut[tgt_ids].to(dtype=torch.int16) + token_bytes += (has_leading_space_lut[tgt_ids] & ~is_boundary_token_lut[prev_ids]).to(dtype=torch.int16) + val_byte_count += token_bytes.to(torch.float64).sum() + if dist.is_available() and dist.is_initialized(): + dist.all_reduce(val_loss_sum, op=dist.ReduceOp.SUM) + dist.all_reduce(val_token_count, op=dist.ReduceOp.SUM) + dist.all_reduce(val_byte_count, op=dist.ReduceOp.SUM) + val_loss = val_loss_sum / val_token_count + bits_per_token = val_loss.item() / math.log(2.0) + tokens_per_byte = val_token_count.item() / val_byte_count.item() + model.train() + return float(val_loss.item()), float(bits_per_token * tokens_per_byte) +def eval_val_sliding( + args: Hyperparameters, + base_model: nn.Module, + rank: int, + world_size: int, + device: torch.device, + val_tokens: Tensor, + base_bytes_lut: Tensor, + has_leading_space_lut: Tensor, + is_boundary_token_lut: Tensor, + stride: int, + batch_seqs: int = 32, +) -> tuple[float, float]: + seq_len = args.train_seq_len + total_tokens = val_tokens.numel() - 1 + window_starts = [ws for ws in range(0, total_tokens, stride) + if min(ws + seq_len, total_tokens) - ws >= 1] + total_windows = len(window_starts) + my_s = (total_windows * rank) // world_size + my_e = (total_windows * (rank + 1)) // world_size + my_windows = window_starts[my_s:my_e] + loss_sum = torch.zeros((), device=device, dtype=torch.float64) + token_count = torch.zeros((), device=device, dtype=torch.float64) + byte_count = torch.zeros((), device=device, dtype=torch.float64) + base_model.eval() + compiled_logits = torch.compile(base_model.forward_logits, dynamic=False, fullgraph=True) + with torch.inference_mode(): + for bi in range(0, len(my_windows), batch_seqs): + batch_ws = my_windows[bi:bi + batch_seqs] + bsz = len(batch_ws) + x_batch = torch.zeros(bsz, seq_len, dtype=torch.int64, device=device) + y_batch = torch.zeros(bsz, seq_len, dtype=torch.int64, device=device) + wlens: list[int] = [] + for i, ws in enumerate(batch_ws): + end = min(ws + seq_len, total_tokens) + wlen = end - ws + wlens.append(wlen) + chunk = val_tokens[ws:end + 1].to(dtype=torch.int64, device=device) + x_batch[i, :wlen] = chunk[:-1] + y_batch[i, :wlen] = chunk[1:] + with torch.autocast(device_type="cuda", dtype=torch.bfloat16): + logits = compiled_logits(x_batch) + nll = F.cross_entropy( + logits.reshape(-1, logits.size(-1)).float(), + y_batch.reshape(-1), + reduction="none", + ).reshape(bsz, seq_len) + for i, ws in enumerate(batch_ws): + wlen = wlens[i] + s = 0 if ws == 0 else max(wlen - stride, 0) + scored_nll = nll[i, s:wlen].to(torch.float64) + loss_sum += scored_nll.sum() + token_count += float(wlen - s) + tgt = y_batch[i, s:wlen] + prev = x_batch[i, s:wlen] + tb = base_bytes_lut[tgt].to(torch.float64) + tb += (has_leading_space_lut[tgt] & ~is_boundary_token_lut[prev]).to(torch.float64) + byte_count += tb.sum() + if dist.is_available() and dist.is_initialized(): + dist.all_reduce(loss_sum, op=dist.ReduceOp.SUM) + dist.all_reduce(token_count, op=dist.ReduceOp.SUM) + dist.all_reduce(byte_count, op=dist.ReduceOp.SUM) + val_loss = (loss_sum / token_count).item() + bits_per_token = val_loss / math.log(2.0) + tokens_per_byte = token_count.item() / byte_count.item() + base_model.train() + return val_loss, bits_per_token * tokens_per_byte +class LoRAAdapter(nn.Module): + """Low-Rank Adaptation adapter for a linear layer.""" + def __init__(self, base_linear: nn.Linear, rank: int = 8): + super().__init__() + dim_in = base_linear.in_features + dim_out = base_linear.out_features + self.A = nn.Parameter(torch.empty(dim_in, rank)) + self.B = nn.Parameter(torch.zeros(rank, dim_out)) + nn.init.kaiming_uniform_(self.A, a=math.sqrt(5)) + self.base_linear = base_linear + self._hook_handle = None + + def _lora_hook(self, module: nn.Module, input: tuple[Tensor, ...], output: Tensor) -> Tensor: + """Forward hook that adds LoRA contribution to base linear output.""" + x = input[0] + lora_delta = (x @ self.A.to(x.dtype)) @ self.B.to(x.dtype) + return output + lora_delta + + def register(self) -> None: + """Register the LoRA hook on the base linear layer.""" + if self._hook_handle is None: + self._hook_handle = self.base_linear.register_forward_hook(self._lora_hook) + + def remove(self) -> None: + """Remove the LoRA hook from the base linear layer.""" + if self._hook_handle is not None: + self._hook_handle.remove() + self._hook_handle = None + +def eval_val_sliding_lora_ttt( + args: Hyperparameters, + base_model: nn.Module, + rank: int, + world_size: int, + device: torch.device, + val_tokens: Tensor, + base_bytes_lut: Tensor, + has_leading_space_lut: Tensor, + is_boundary_token_lut: Tensor, + stride: int, + batch_seqs: int = 32, + log0=print, +) -> tuple[float, float]: + """Test-time training with LoRA adapters (only train adapter weights, not base model).""" + seq_len = args.train_seq_len + total_tokens = val_tokens.numel() - 1 + ttt_chunk = args.ttt_chunk_tokens + + window_starts = [ws for ws in range(0, total_tokens, stride) + if min(ws + seq_len, total_tokens) - ws >= stride or ws == 0] + num_chunks = (total_tokens + ttt_chunk - 1) // ttt_chunk + + chunk_windows: list[list[int]] = [[] for _ in range(num_chunks)] + for ws in window_starts: + end = min(ws + seq_len, total_tokens) + wlen = end - ws + s = 0 if ws == 0 else max(wlen - stride, 0) + scored_start = ws + s + ci = min(scored_start // ttt_chunk, num_chunks - 1) + chunk_windows[ci].append(ws) + + log0(f"lora_ttt_sliding:start chunks={num_chunks} chunk_tokens={ttt_chunk} " + f"total_windows={len(window_starts)} stride={stride} " + f"lora_rank={args.ttt_lora_rank} lora_lr={args.ttt_lora_lr} " + f"ttt_epochs={args.ttt_epochs}") + + loss_sum = torch.zeros((), device=device, dtype=torch.float64) + token_count = torch.zeros((), device=device, dtype=torch.float64) + byte_count = torch.zeros((), device=device, dtype=torch.float64) + + # Create LoRA adapters for attention projections and LM head + lora_adapters: list[LoRAAdapter] = [] + + # Add LoRA to attention blocks (c_q, c_v, proj) + if args.recurrence_repeats > 0: + # Recurrent model structure + all_blocks = list(base_model.head_blocks) + [base_model.shared_block] + list(base_model.tail_blocks) + else: + # Standard block structure + all_blocks = list(base_model.blocks) + + for block in all_blocks: + if hasattr(block, 'attn'): + attn = block.attn + # Add LoRA to Q, V, and output projection + for layer_name in ['c_q', 'c_v', 'proj']: + if hasattr(attn, layer_name): + base_linear = getattr(attn, layer_name) + adapter = LoRAAdapter(base_linear, rank=args.ttt_lora_rank).to(device) + lora_adapters.append(adapter) + + # Add LoRA to LM head (or tok_emb if tied) + if base_model.tie_embeddings: + # For tied embeddings, we need to patch tok_emb.weight usage in forward + # Create a pseudo-linear wrapper for the embedding weight + class EmbeddingAsLinear(nn.Module): + def __init__(self, weight: nn.Parameter): + super().__init__() + self.weight = weight + self.in_features = weight.size(1) + self.out_features = weight.size(0) + self.bias = None + + def forward(self, x: Tensor) -> Tensor: + return F.linear(x, self.weight) + + emb_linear = EmbeddingAsLinear(base_model.tok_emb.weight).to(device) + lm_adapter = LoRAAdapter(emb_linear, rank=args.ttt_lora_rank).to(device) + lora_adapters.append(lm_adapter) + elif base_model.lm_head is not None: + lm_adapter = LoRAAdapter(base_model.lm_head, rank=args.ttt_lora_rank).to(device) + lora_adapters.append(lm_adapter) + + # Collect all LoRA parameters + lora_params = [] + for adapter in lora_adapters: + lora_params.extend([adapter.A, adapter.B]) + + log0(f"lora_ttt_sliding:params lora={sum(p.numel() for p in lora_params)} " + f"base_frozen={sum(p.numel() for p in base_model.parameters())} " + f"adapters={len(lora_adapters)}") + + # Freeze all base model parameters + for p in base_model.parameters(): + p.requires_grad_(False) + + # Register all LoRA hooks + for adapter in lora_adapters: + adapter.register() + + # Use AdamW optimizer for LoRA params + optimizer = torch.optim.AdamW( + lora_params, + lr=args.ttt_lora_lr, + betas=(0.9, 0.999), + weight_decay=0.01, + ) + + # Compile forward paths + if not int(os.environ.get("TORCHDYNAMO_DISABLE", "0")): + compiled_logits_lora = torch.compile(base_model.forward_logits, dynamic=False, fullgraph=True) + compiled_forward_lora = torch.compile(base_model.forward, dynamic=False, fullgraph=True) + else: + compiled_logits_lora = base_model.forward_logits + compiled_forward_lora = base_model.forward + + t0 = time.perf_counter() + + for ci in range(num_chunks): + windows = chunk_windows[ci] + if not windows: + continue + + chunk_start = ci * ttt_chunk + chunk_end = min((ci + 1) * ttt_chunk, total_tokens) + + my_s = (len(windows) * rank) // world_size + my_e = (len(windows) * (rank + 1)) // world_size + my_windows = windows[my_s:my_e] + + # Score phase (inference only) + base_model.eval() + with torch.inference_mode(): + for bi in range(0, len(my_windows), batch_seqs): + batch_ws = my_windows[bi:bi + batch_seqs] + bsz = len(batch_ws) + x_batch = torch.zeros(bsz, seq_len, dtype=torch.int64, device=device) + y_batch = torch.zeros(bsz, seq_len, dtype=torch.int64, device=device) + wlens: list[int] = [] + + for i, ws in enumerate(batch_ws): + end = min(ws + seq_len, total_tokens) + wlen = end - ws + wlens.append(wlen) + chunk_tok = val_tokens[ws:end + 1].to(dtype=torch.int64, device=device) + x_batch[i, :wlen] = chunk_tok[:-1] + y_batch[i, :wlen] = chunk_tok[1:] + + with torch.autocast(device_type="cuda", dtype=torch.bfloat16): + logits = compiled_logits_lora(x_batch) + + nll = F.cross_entropy( + logits.reshape(-1, logits.size(-1)).float(), + y_batch.reshape(-1), reduction="none", + ).reshape(bsz, seq_len) + + for i, ws in enumerate(batch_ws): + wlen = wlens[i] + s = 0 if ws == 0 else max(wlen - stride, 0) + scored_nll = nll[i, s:wlen].to(torch.float64) + loss_sum += scored_nll.sum() + token_count += float(wlen - s) + + tgt, prev = y_batch[i, s:wlen], x_batch[i, s:wlen] + tb = base_bytes_lut[tgt].to(torch.float64) + tb += (has_leading_space_lut[tgt] & ~is_boundary_token_lut[prev]).to(torch.float64) + byte_count += tb.sum() + + # Train phase (update LoRA params only, skip last chunk) + is_last_chunk = (ci == num_chunks - 1) + if not is_last_chunk and args.ttt_epochs > 0: + base_model.train() + chunk_seqs = (chunk_end - chunk_start) // seq_len + + if chunk_seqs > 0: + # Cosine LR decay + cos_lr = args.ttt_lora_lr * 0.5 * (1.0 + math.cos(math.pi * ci / max(num_chunks - 1, 1))) + for pg in optimizer.param_groups: + pg['lr'] = cos_lr + + my_seq_s = (chunk_seqs * rank) // world_size + my_seq_e = (chunk_seqs * (rank + 1)) // world_size + my_chunk_seqs = my_seq_e - my_seq_s + + for _ep in range(args.ttt_epochs): + for bs in range(0, my_chunk_seqs, args.ttt_batch_seqs): + be = min(bs + args.ttt_batch_seqs, my_chunk_seqs) + actual_bs = my_seq_s + bs + start_tok = chunk_start + actual_bs * seq_len + end_tok = chunk_start + (my_seq_s + be) * seq_len + 1 + + if end_tok > val_tokens.numel(): + continue + + local = val_tokens[start_tok:end_tok].to(device=device, dtype=torch.int64) + x = local[:-1].reshape(-1, seq_len) + y = local[1:].reshape(-1, seq_len) + + optimizer.zero_grad(set_to_none=True) + + with torch.autocast(device_type="cuda", dtype=torch.bfloat16): + loss = compiled_forward_lora(x, y) + + loss.backward() + + # All-reduce LoRA gradients across GPUs + if world_size > 1: + for p in lora_params: + if p.grad is not None: + dist.all_reduce(p.grad, op=dist.ReduceOp.AVG) + + # Gradient clipping on LoRA params + torch.nn.utils.clip_grad_norm_(lora_params, args.ttt_grad_clip) + optimizer.step() + + if rank == 0 and (ci % 10 == 0 or ci == num_chunks - 1): + elapsed = time.perf_counter() - t0 + rl = loss_sum.item() / max(token_count.item(), 1) + rbpb = rl / math.log(2.0) * (token_count.item() / max(byte_count.item(), 1)) if token_count.item() > 0 else 0.0 + log0(f" lora_ttt_chunk [{ci+1}/{num_chunks}] bpb={rbpb:.6f} time={elapsed:.1f}s") + + if dist.is_available() and dist.is_initialized(): + dist.all_reduce(loss_sum, op=dist.ReduceOp.SUM) + dist.all_reduce(token_count, op=dist.ReduceOp.SUM) + dist.all_reduce(byte_count, op=dist.ReduceOp.SUM) + + val_loss = (loss_sum / token_count).item() + val_bpb = val_loss / math.log(2.0) * (token_count.item() / byte_count.item()) + + # Clean up: remove all LoRA hooks and restore base model + for adapter in lora_adapters: + adapter.remove() + + # Restore requires_grad on base model + for p in base_model.parameters(): + p.requires_grad_(True) + + base_model.eval() + + log0(f"lora_ttt_sliding:done val_loss={val_loss:.6f} val_bpb={val_bpb:.6f} " + f"elapsed={time.perf_counter() - t0:.1f}s") + + return val_loss, val_bpb + +def eval_val_sliding_ttt( + args: Hyperparameters, + base_model: nn.Module, + rank: int, + world_size: int, + device: torch.device, + val_tokens: Tensor, + base_bytes_lut: Tensor, + has_leading_space_lut: Tensor, + is_boundary_token_lut: Tensor, + stride: int, + batch_seqs: int = 32, + log0=print, +) -> tuple[float, float]: + seq_len = args.train_seq_len + total_tokens = val_tokens.numel() - 1 + ttt_chunk = args.ttt_chunk_tokens + window_starts = [ws for ws in range(0, total_tokens, stride) + if min(ws + seq_len, total_tokens) - ws >= stride or ws == 0] + num_chunks = (total_tokens + ttt_chunk - 1) // ttt_chunk + chunk_windows: list[list[int]] = [[] for _ in range(num_chunks)] + for ws in window_starts: + end = min(ws + seq_len, total_tokens) + wlen = end - ws + s = 0 if ws == 0 else max(wlen - stride, 0) + scored_start = ws + s + ci = min(scored_start // ttt_chunk, num_chunks - 1) + chunk_windows[ci].append(ws) + log0(f"ttt_sliding:start chunks={num_chunks} chunk_tokens={ttt_chunk} " + f"total_windows={len(window_starts)} stride={stride} " + f"ttt_lr={args.ttt_lr} ttt_epochs={args.ttt_epochs} " + f"freeze_blocks={args.ttt_freeze_blocks}") + loss_sum = torch.zeros((), device=device, dtype=torch.float64) + token_count = torch.zeros((), device=device, dtype=torch.float64) + byte_count = torch.zeros((), device=device, dtype=torch.float64) + frozen_block_ids = set(range(min(args.ttt_freeze_blocks, len(base_model.blocks)))) + ttt_params = [] + for name, p in base_model.named_parameters(): + freeze = False + for bi in frozen_block_ids: + if f"blocks.{bi}." in name: + freeze = True + break + if freeze: + p.requires_grad_(False) + else: + p.requires_grad_(True) + ttt_params.append(p) + log0(f"ttt_sliding:params unfrozen={sum(p.numel() for p in ttt_params)} " + f"frozen={sum(p.numel() for p in base_model.parameters() if not p.requires_grad)}") + optimizer = torch.optim.SGD(ttt_params, lr=args.ttt_lr, momentum=args.ttt_momentum) + # Compile forward paths for TTT (matches eval_val_sliding behavior) + if not int(os.environ.get("TORCHDYNAMO_DISABLE", "0")): + compiled_logits_ttt = torch.compile(base_model.forward_logits, dynamic=False, fullgraph=True) + compiled_forward_ttt = torch.compile(base_model.forward, dynamic=False, fullgraph=True) + else: + compiled_logits_ttt = base_model.forward_logits + compiled_forward_ttt = base_model.forward + t0 = time.perf_counter() + for ci in range(num_chunks): + windows = chunk_windows[ci] + if not windows: + continue + chunk_start = ci * ttt_chunk + chunk_end = min((ci + 1) * ttt_chunk, total_tokens) + my_s = (len(windows) * rank) // world_size + my_e = (len(windows) * (rank + 1)) // world_size + my_windows = windows[my_s:my_e] + base_model.eval() + with torch.inference_mode(): + for bi in range(0, len(my_windows), batch_seqs): + batch_ws = my_windows[bi:bi + batch_seqs] + bsz = len(batch_ws) + x_batch = torch.zeros(bsz, seq_len, dtype=torch.int64, device=device) + y_batch = torch.zeros(bsz, seq_len, dtype=torch.int64, device=device) + wlens: list[int] = [] + for i, ws in enumerate(batch_ws): + end = min(ws + seq_len, total_tokens) + wlen = end - ws + wlens.append(wlen) + chunk_tok = val_tokens[ws:end + 1].to(dtype=torch.int64, device=device) + x_batch[i, :wlen] = chunk_tok[:-1] + y_batch[i, :wlen] = chunk_tok[1:] + with torch.autocast(device_type="cuda", dtype=torch.bfloat16): + logits = compiled_logits_ttt(x_batch) + nll = F.cross_entropy( + logits.reshape(-1, logits.size(-1)).float(), + y_batch.reshape(-1), reduction="none", + ).reshape(bsz, seq_len) + for i, ws in enumerate(batch_ws): + wlen = wlens[i] + s = 0 if ws == 0 else max(wlen - stride, 0) + scored_nll = nll[i, s:wlen].to(torch.float64) + loss_sum += scored_nll.sum() + token_count += float(wlen - s) + tgt, prev = y_batch[i, s:wlen], x_batch[i, s:wlen] + tb = base_bytes_lut[tgt].to(torch.float64) + tb += (has_leading_space_lut[tgt] & ~is_boundary_token_lut[prev]).to(torch.float64) + byte_count += tb.sum() + is_last_chunk = (ci == num_chunks - 1) + if not is_last_chunk and args.ttt_epochs > 0: + base_model.train() + chunk_seqs = (chunk_end - chunk_start) // seq_len + if chunk_seqs > 0: + cos_lr = args.ttt_lr * 0.5 * (1.0 + math.cos(math.pi * ci / max(num_chunks - 1, 1))) + for pg in optimizer.param_groups: + pg['lr'] = cos_lr + my_seq_s = (chunk_seqs * rank) // world_size + my_seq_e = (chunk_seqs * (rank + 1)) // world_size + my_chunk_seqs = my_seq_e - my_seq_s + for _ep in range(args.ttt_epochs): + for bs in range(0, my_chunk_seqs, args.ttt_batch_seqs): + be = min(bs + args.ttt_batch_seqs, my_chunk_seqs) + actual_bs = my_seq_s + bs + start_tok = chunk_start + actual_bs * seq_len + end_tok = chunk_start + (my_seq_s + be) * seq_len + 1 + if end_tok > val_tokens.numel(): + continue + local = val_tokens[start_tok:end_tok].to(device=device, dtype=torch.int64) + x = local[:-1].reshape(-1, seq_len) + y = local[1:].reshape(-1, seq_len) + optimizer.zero_grad(set_to_none=True) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16): + loss = compiled_forward_ttt(x, y) + loss.backward() + if world_size > 1: + for p in ttt_params: + if p.grad is not None: + dist.all_reduce(p.grad, op=dist.ReduceOp.AVG) + torch.nn.utils.clip_grad_norm_(ttt_params, args.ttt_grad_clip) + optimizer.step() + if rank == 0 and (ci % 10 == 0 or ci == num_chunks - 1): + elapsed = time.perf_counter() - t0 + rl = loss_sum.item() / max(token_count.item(), 1) + rbpb = rl / math.log(2.0) * (token_count.item() / max(byte_count.item(), 1)) if token_count.item() > 0 else 0.0 + log0(f" ttt_chunk [{ci+1}/{num_chunks}] bpb={rbpb:.6f} time={elapsed:.1f}s") + if dist.is_available() and dist.is_initialized(): + dist.all_reduce(loss_sum, op=dist.ReduceOp.SUM) + dist.all_reduce(token_count, op=dist.ReduceOp.SUM) + dist.all_reduce(byte_count, op=dist.ReduceOp.SUM) + val_loss = (loss_sum / token_count).item() + val_bpb = val_loss / math.log(2.0) * (token_count.item() / byte_count.item()) + for p in base_model.parameters(): + p.requires_grad_(True) + base_model.eval() + log0(f"ttt_sliding:done val_loss={val_loss:.6f} val_bpb={val_bpb:.6f} " + f"elapsed={time.perf_counter() - t0:.1f}s") + return val_loss, val_bpb +_INT6_MODE = bool(int(os.environ.get("INT6_QAT", "1"))) +_INT3_MODE = bool(int(os.environ.get("INT3_QUANT", "0"))) # Ternary quantization +_QUANT_OVERRIDE = int(os.environ.get("QUANT_BITS", "0")) # Override: 3,4,5,6,8 +if _QUANT_OVERRIDE > 0: + _QUANT_BITS = _QUANT_OVERRIDE + _QUANT_MAX_VAL = (1 << (_QUANT_BITS - 1)) - 1 # e.g. int5 -> 15, int4 -> 7 +elif _INT3_MODE: + _QUANT_MAX_VAL = 3 # Ternary: -3, -2, -1, 0, 1, 2, 3 + _QUANT_BITS = 3 +elif _INT6_MODE: + _QUANT_MAX_VAL = 31 + _QUANT_BITS = 6 +else: + _QUANT_MAX_VAL = 127 + _QUANT_BITS = 8 +CONTROL_TENSOR_NAME_PATTERNS = tuple( + pattern + for pattern in os.environ.get( + "CONTROL_TENSOR_NAME_PATTERNS", + "attn_scale,attn_scales,mlp_scale,mlp_scales,resid_mix,resid_mixes,q_gain,skip_weight,skip_weights,smear", + ).split(",") + if pattern +) +INT8_KEEP_FLOAT_FP32_NAME_PATTERNS = tuple( + pattern + for pattern in os.environ.get( + "INT8_KEEP_FLOAT_FP32_NAME_PATTERNS", + ",".join(CONTROL_TENSOR_NAME_PATTERNS), + ).split(",") + if pattern +) +INT8_KEEP_FLOAT_MAX_NUMEL = 65_536 +INT8_KEEP_FLOAT_STORE_DTYPE = torch.float16 +INT8_PER_ROW_SCALE_DTYPE = torch.float16 +INT8_CLIP_PERCENTILE = 99.99984 +INT8_CLIP_Q = INT8_CLIP_PERCENTILE / 100.0 +def tensor_nbytes(t: Tensor) -> int: + return int(t.numel()) * int(t.element_size()) +def keep_float_tensor(name: str, t: Tensor, passthrough_orig_dtypes: dict[str, str]) -> Tensor: + if any(pattern in name for pattern in INT8_KEEP_FLOAT_FP32_NAME_PATTERNS): + return t.float().contiguous() + if t.dtype in {torch.float32, torch.bfloat16}: + passthrough_orig_dtypes[name] = str(t.dtype).removeprefix("torch.") + return t.to(dtype=INT8_KEEP_FLOAT_STORE_DTYPE).contiguous() + return t +class GPTQ: + """Full Hessian GPTQ quantizer for a single linear layer.""" + def __init__(self, weight: Tensor): + self.rows, self.cols = weight.shape + self.H = torch.zeros((self.cols, self.cols), device=weight.device, dtype=torch.float32) + self.nsamples = 0 + + def add_batch(self, inp: Tensor): + """Accumulate Hessian H = X^T X from input activations.""" + # inp: [batch*seq, d_in] + inp = inp.reshape(-1, inp.shape[-1]).float() + n = inp.shape[0] + self.H *= self.nsamples / (self.nsamples + n) + self.nsamples += n + self.H += (2.0 / (self.nsamples)) * (inp.T @ inp) + + def quantize(self, weight: Tensor, qmax: int = 31, blocksize: int = 128, + percdamp: float = 0.005, actorder: bool = True): + """Run full Hessian GPTQ quantization with Cholesky error compensation.""" + W = weight.clone().float() + H = self.H.clone() + + # Handle dead columns + dead = torch.diag(H) == 0 + H[dead, dead] = 1 + W[:, dead] = 0 + + # Damping for numerical stability + damp = percdamp * torch.mean(torch.diag(H)) + diag_idx = torch.arange(self.cols, device=H.device) + H[diag_idx, diag_idx] += damp + + # Column reordering by Hessian diagonal (actorder) + if actorder: + perm = torch.argsort(torch.diag(H), descending=True) + W = W[:, perm] + H = H[perm][:, perm] + invperm = torch.argsort(perm) + + # Cholesky decomposition of H^{-1} + try: + H_inv = torch.linalg.cholesky(H) + H_inv = torch.cholesky_inverse(H_inv) + H_inv = torch.linalg.cholesky(H_inv, upper=True) + except torch.linalg.LinAlgError: + # Fallback: add more damping and retry + H[diag_idx, diag_idx] += damp * 10 + H_inv = torch.linalg.cholesky(H) + H_inv = torch.cholesky_inverse(H_inv) + H_inv = torch.linalg.cholesky(H_inv, upper=True) + + Q = torch.zeros_like(W) + + for i1 in range(0, self.cols, blocksize): + i2 = min(i1 + blocksize, self.cols) + count = i2 - i1 + + W1 = W[:, i1:i2].clone() + Q1 = torch.zeros_like(W1) + Err1 = torch.zeros_like(W1) + Hinv1 = H_inv[i1:i2, i1:i2] + + for i in range(count): + w = W1[:, i] + d = Hinv1[i, i] + + # Quantize: symmetric int6 with per-row scale + scale = w.abs().max().clamp(min=1e-8) / qmax + q = (w / scale).round().clamp(-qmax, qmax) * scale + + Q1[:, i] = q + err = (w - q) / d + W1[:, i + 1:] -= err.unsqueeze(1) * Hinv1[i, i + 1:].unsqueeze(0) + Err1[:, i] = err + + Q[:, i1:i2] = Q1 + # Propagate errors to remaining columns + W[:, i2:] -= Err1 @ H_inv[i1:i2, i2:] + + if actorder: + Q = Q[:, invperm] + + return Q + + +@torch.no_grad() +def generate_ar_calibration(model: nn.Module, num_samples: int = 64, seq_len: int = 2048, + temp: float = 0.8, vocab_size: int = 1024) -> Tensor: + """Generate calibration data autoregressively from the model itself.""" + device = next(model.parameters()).device + all_tokens = [] + for _ in range(num_samples): + tokens = torch.randint(0, vocab_size, (1, 1), device=device) + for _ in range(seq_len - 1): + logits = model.forward_logits(tokens) + next_logits = logits[:, -1, :] / temp + probs = torch.softmax(next_logits, dim=-1) + next_token = torch.multinomial(probs, num_samples=1) + tokens = torch.cat([tokens, next_token], dim=1) + all_tokens.append(tokens.squeeze(0)) + return torch.stack(all_tokens) # [num_samples, seq_len] + + +def collect_hessians(model: nn.Module, calibration_data: Tensor, gptq_modules: dict) -> None: + """Run calibration data through the model, collecting Hessian for each linear layer.""" + hooks = [] + + def make_hook(name: str, gptq_obj: GPTQ): + def hook_fn(module: nn.Module, input: tuple, output: Tensor) -> None: + gptq_obj.add_batch(input[0]) + return hook_fn + + # Register hooks on all CastedLinear modules + for name, module in model.named_modules(): + if isinstance(module, CastedLinear) and name in gptq_modules: + h = module.register_forward_hook(make_hook(name, gptq_modules[name])) + hooks.append(h) + + # Run calibration data through model + for i in range(calibration_data.shape[0]): + tokens = calibration_data[i:i+1] + model(tokens[:, :-1], tokens[:, 1:]) + + # Remove hooks + for h in hooks: + h.remove() + + +_GPTQ_LITE_PERCENTILES = [0.9999, 0.99995, 0.99999, 0.999995, 1.0] +def quantize_float_tensor(t: Tensor, gpu_device: torch.device | None = None) -> tuple[Tensor, Tensor]: + # Run GPTQ-lite search on GPU when available (much faster for large 2D weight matrices) + dev = gpu_device if (gpu_device is not None and t.ndim == 2 and t.numel() > 0) else t.device + t32 = t.float().to(dev) + qmax = _QUANT_MAX_VAL + if t32.ndim == 2 and t32.numel() > 0: + abs_vals = t32.abs() + best_q = None + best_scale = None + best_mse = None + for pct in _GPTQ_LITE_PERCENTILES: + if pct >= 1.0: + clip_abs = abs_vals.amax(dim=1) + else: + clip_abs = torch.quantile(abs_vals, pct, dim=1) + clip_abs = clip_abs.clamp_min(1e-8) + scale = (clip_abs / qmax).clamp_min(1.0 / qmax) + clipped = torch.maximum(torch.minimum(t32, clip_abs[:, None]), -clip_abs[:, None]) + q = torch.clamp(torch.round(clipped / scale[:, None]), -qmax, qmax) + recon = q * scale[:, None] + mse = (t32 - recon).square().mean(dim=1) + if best_mse is None: + best_mse = mse + best_q = q + best_scale = scale + else: + improved = mse < best_mse + if improved.any(): + best_mse = torch.where(improved, mse, best_mse) + best_q = torch.where(improved[:, None], q, best_q) + best_scale = torch.where(improved, scale, best_scale) + return best_q.cpu().to(torch.int8).contiguous(), best_scale.cpu().to(dtype=INT8_PER_ROW_SCALE_DTYPE).contiguous() + if t32.ndim == 2: + return torch.empty_like(t32, dtype=torch.int8, device="cpu"), torch.empty((t32.shape[0],), dtype=INT8_PER_ROW_SCALE_DTYPE) + t32 = t32.cpu() + clip_abs = float(torch.quantile(t32.abs().flatten(), INT8_CLIP_Q).item()) if t32.numel() else 0.0 + scale = torch.tensor(clip_abs / qmax if clip_abs > 0 else 1.0, dtype=torch.float32) + q = torch.clamp(torch.round(torch.clamp(t32, -clip_abs, clip_abs) / scale), -qmax, qmax).to(torch.int8).contiguous() + return q, scale +def quantize_state_dict_int8(state_dict: dict[str, Tensor], gpu_device: torch.device | None = None, + rank: int = 0, world_size: int = 1): + """Quantize state dict. When world_size > 1, distributes large tensor quantization across ranks.""" + quantized: dict[str, Tensor] = {} + scales: dict[str, Tensor] = {} + dtypes: dict[str, str] = {} + passthrough: dict[str, Tensor] = {} + passthrough_orig_dtypes: dict[str, str] = {} + qmeta: dict[str, dict[str, object]] = {} + stats = dict.fromkeys( + ("param_count", "num_tensors", "num_float_tensors", "num_nonfloat_tensors", "baseline_tensor_bytes", "int8_payload_bytes"), + 0, + ) + # Separate tensors into passthrough vs needs-quantization + to_quantize: list[tuple[str, Tensor]] = [] + for name, tensor in state_dict.items(): + t = tensor.detach().to("cpu").contiguous() + stats["param_count"] += int(t.numel()) + stats["num_tensors"] += 1 + stats["baseline_tensor_bytes"] += tensor_nbytes(t) + if not t.is_floating_point(): + stats["num_nonfloat_tensors"] += 1 + passthrough[name] = t + stats["int8_payload_bytes"] += tensor_nbytes(t) + continue + if t.numel() <= INT8_KEEP_FLOAT_MAX_NUMEL: + kept = keep_float_tensor(name, t, passthrough_orig_dtypes) + passthrough[name] = kept + stats["int8_payload_bytes"] += tensor_nbytes(kept) + continue + stats["num_float_tensors"] += 1 + to_quantize.append((name, t)) + # Distribute quantization work across ranks (round-robin by tensor index) + for idx, (name, t) in enumerate(to_quantize): + if world_size > 1 and idx % world_size != rank: + continue # Another rank handles this tensor + q, s = quantize_float_tensor(t, gpu_device=gpu_device) + if s.ndim > 0: + qmeta[name] = {"scheme": "per_row", "axis": 0} + quantized[name] = q + scales[name] = s + dtypes[name] = str(t.dtype).removeprefix("torch.") + stats["int8_payload_bytes"] += tensor_nbytes(q) + tensor_nbytes(s) + # Gather results from all ranks + if world_size > 1 and dist.is_available() and dist.is_initialized(): + # Each rank broadcasts its quantized tensors to all others + for idx, (name, t) in enumerate(to_quantize): + owner = idx % world_size + if owner == rank: + # Broadcast shape info then tensor data + q_tensor = quantized[name] + s_tensor = scales[name] + # Send via broadcast + q_gpu = q_tensor.to(gpu_device) if gpu_device is not None else q_tensor.cuda() + s_gpu = s_tensor.to(gpu_device) if gpu_device is not None else s_tensor.cuda() + dist.broadcast(q_gpu, src=owner) + dist.broadcast(s_gpu, src=owner) + quantized[name] = q_gpu.cpu() + scales[name] = s_gpu.cpu() + else: + # Allocate matching tensors and receive + q_shape = (t.shape[0],) + t.shape[1:] if t.ndim == 2 else t.shape + q_gpu = torch.empty(q_shape, dtype=torch.int8, device=gpu_device if gpu_device is not None else "cuda") + if t.ndim == 2 and t.numel() > 0: + s_gpu = torch.empty(t.shape[0], dtype=INT8_PER_ROW_SCALE_DTYPE, device=gpu_device if gpu_device is not None else "cuda") + else: + s_gpu = torch.empty((), dtype=torch.float32, device=gpu_device if gpu_device is not None else "cuda") + dist.broadcast(q_gpu, src=owner) + dist.broadcast(s_gpu, src=owner) + quantized[name] = q_gpu.cpu() + scales[name] = s_gpu.cpu() + dtypes[name] = str(t.dtype).removeprefix("torch.") + if s_gpu.ndim > 0: + qmeta[name] = {"scheme": "per_row", "axis": 0} + stats["int8_payload_bytes"] += tensor_nbytes(quantized[name]) + tensor_nbytes(scales[name]) + obj: dict[str, object] = { + "__quant_format__": f"int{_QUANT_BITS}_clean_per_row_v1", + "quantized": quantized, + "scales": scales, + "dtypes": dtypes, + "passthrough": passthrough, + } + if qmeta: + obj["qmeta"] = qmeta + if passthrough_orig_dtypes: + obj["passthrough_orig_dtypes"] = passthrough_orig_dtypes + return obj, stats +def dequantize_state_dict_int8(obj: dict[str, object]) -> dict[str, Tensor]: + out: dict[str, Tensor] = {} + qmeta = obj.get("qmeta", {}) + passthrough_orig_dtypes = obj.get("passthrough_orig_dtypes", {}) + for name, q in obj["quantized"].items(): + dtype = getattr(torch, obj["dtypes"][name]) + s = obj["scales"][name] + if qmeta.get(name, {}).get("scheme") == "per_row" or s.ndim > 0: + s = s.to(dtype=torch.float32) + out[name] = (q.float() * s.view(q.shape[0], *([1] * (q.ndim - 1)))).to(dtype=dtype).contiguous() + else: + scale = float(s.item()) + out[name] = (q.float() * scale).to(dtype=dtype).contiguous() + for name, t in obj["passthrough"].items(): + out_t = t.detach().to("cpu").contiguous() + orig_dtype = passthrough_orig_dtypes.get(name) + if isinstance(orig_dtype, str): + out_t = out_t.to(dtype=getattr(torch, orig_dtype)).contiguous() + out[name] = out_t + return out +def load_data_shard(file: Path) -> Tensor: + header_bytes = 256 * np.dtype(" None: + self.file_idx = (self.file_idx + 1) % len(self.files) + self.tokens = load_data_shard(self.files[self.file_idx]) + self.pos = 0 + def take(self, n: int) -> Tensor: + chunks: list[Tensor] = [] + remaining = n + while remaining > 0: + avail = self.tokens.numel() - self.pos + if avail <= 0: + self._advance_file() + continue + k = min(remaining, avail) + chunks.append(self.tokens[self.pos : self.pos + k]) + self.pos += k + remaining -= k + return chunks[0] if len(chunks) == 1 else torch.cat(chunks) +class DistributedTokenLoader: + def __init__(self, pattern: str, rank: int, world_size: int, device: torch.device): + self.rank = rank + self.world_size = world_size + self.device = device + self.stream = TokenStream(pattern) + def next_batch(self, global_tokens: int, seq_len: int, grad_accum_steps: int) -> tuple[Tensor, Tensor]: + local_tokens = global_tokens // (self.world_size * grad_accum_steps) + per_rank_span = local_tokens + 1 + chunk = self.stream.take(per_rank_span * self.world_size) + start = self.rank * per_rank_span + local = chunk[start : start + per_rank_span].to(dtype=torch.int64) + x = local[:-1].reshape(-1, seq_len) + y = local[1:].reshape(-1, seq_len) + return x.to(self.device, non_blocking=True), y.to(self.device, non_blocking=True) +class RMSNorm(nn.Module): + def __init__(self, eps: float | None = None): + super().__init__() + self.eps = eps + def forward(self, x: Tensor) -> Tensor: + return F.rms_norm(x, (x.size(-1),), eps=self.eps) +class CastedLinear(nn.Linear): + _int6_qat: bool = False + def forward(self, x: Tensor) -> Tensor: + w = self.weight.to(x.dtype) + if CastedLinear._int6_qat and self.training: + w = int6_ste(w) + bias = self.bias.to(x.dtype) if self.bias is not None else None + return F.linear(x, w, bias) +def restore_low_dim_params_to_fp32(module: nn.Module) -> None: + with torch.no_grad(): + for name, param in module.named_parameters(): + if (param.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS)) and param.dtype != torch.float32: + param.data = param.data.float() +class BigramHash(nn.Module): + def __init__(self, num_buckets: int, dim: int): + super().__init__() + self.num_buckets = num_buckets + self.emb = nn.Embedding(num_buckets, dim) + nn.init.zeros_(self.emb.weight) + def forward(self, input_ids: Tensor) -> Tensor: + prev = F.pad(input_ids[:, :-1], (1, 0), value=0) + bigram_hash = (prev * 31 + input_ids) % self.num_buckets + return self.emb(bigram_hash) +class SmearGate(nn.Module): + def __init__(self, dim: int): + super().__init__() + self.gate = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + def forward(self, x: Tensor) -> Tensor: + g = torch.sigmoid(self.gate.to(dtype=x.dtype))[None, None, :] + positions = torch.arange(1, x.shape[1] + 1, device=x.device, dtype=x.dtype).view(1, -1, 1) + smooth = torch.cumsum(x, dim=1) / positions + return g * x + (1 - g) * smooth +class Rotary(nn.Module): + def __init__(self, dim: int, base: float = 10000.0): + super().__init__() + inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.float32) / dim)) + self.register_buffer("inv_freq", inv_freq, persistent=False) + self._seq_len_cached = 0 + self._cos_cached: Tensor | None = None + self._sin_cached: Tensor | None = None + def forward(self, seq_len: int, device: torch.device, dtype: torch.dtype) -> tuple[Tensor, Tensor]: + if ( + self._cos_cached is None + or self._sin_cached is None + or self._seq_len_cached != seq_len + or self._cos_cached.device != device + ): + t = torch.arange(seq_len, device=device, dtype=self.inv_freq.dtype) + freqs = torch.outer(t, self.inv_freq.to(device)) + self._cos_cached = freqs.cos()[None, None, :, :] + self._sin_cached = freqs.sin()[None, None, :, :] + self._seq_len_cached = seq_len + return self._cos_cached.to(dtype=dtype), self._sin_cached.to(dtype=dtype) +def apply_rotary_emb(x: Tensor, cos: Tensor, sin: Tensor) -> Tensor: + half = x.size(-1) // 2 + x1, x2 = x[..., :half], x[..., half:] + return torch.cat((x1 * cos + x2 * sin, x1 * (-sin) + x2 * cos), dim=-1) +class CausalSelfAttention(nn.Module): + def __init__( + self, + dim: int, + num_heads: int, + num_kv_heads: int, + rope_base: float, + qk_gain_init: float, + rope_partial_dims: int = 0, + use_xsa: bool = False, + ): + super().__init__() + if dim % num_heads != 0: + raise ValueError("model_dim must be divisible by num_heads") + if num_heads % num_kv_heads != 0: + raise ValueError("num_heads must be divisible by num_kv_heads") + self.num_heads = num_heads + self.num_kv_heads = num_kv_heads + self.head_dim = dim // num_heads + if self.head_dim % 2 != 0: + raise ValueError("head_dim must be even for RoPE") + kv_dim = self.num_kv_heads * self.head_dim + self.c_q = CastedLinear(dim, dim, bias=False) + self.c_k = CastedLinear(dim, kv_dim, bias=False) + self.c_v = CastedLinear(dim, kv_dim, bias=False) + self.proj = CastedLinear(dim, dim, bias=False) + self.proj._zero_init = True + self.q_gain = nn.Parameter(torch.full((num_heads,), qk_gain_init, dtype=torch.float32)) + rope_dim = rope_partial_dims if 0 < rope_partial_dims < self.head_dim else self.head_dim + self.rope_dim = rope_dim + self.rotary = Rotary(rope_dim, base=rope_base) + self.use_xsa = use_xsa + def forward(self, x: Tensor) -> Tensor: + bsz, seqlen, dim = x.shape + q = self.c_q(x).reshape(bsz, seqlen, self.num_heads, self.head_dim).transpose(1, 2) + k = self.c_k(x).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim).transpose(1, 2) + v = self.c_v(x).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim).transpose(1, 2) + q = F.rms_norm(q, (q.size(-1),)) + k = F.rms_norm(k, (k.size(-1),)) + cos, sin = self.rotary(seqlen, x.device, q.dtype) + if self.rope_dim < self.head_dim: + q_rope, q_pass = q[..., :self.rope_dim], q[..., self.rope_dim:] + k_rope, k_pass = k[..., :self.rope_dim], k[..., self.rope_dim:] + q_rope = apply_rotary_emb(q_rope, cos, sin) + k_rope = apply_rotary_emb(k_rope, cos, sin) + q = torch.cat([q_rope, q_pass], dim=-1) + k = torch.cat([k_rope, k_pass], dim=-1) + else: + q = apply_rotary_emb(q, cos, sin) + k = apply_rotary_emb(k, cos, sin) + q = q * self.q_gain.to(dtype=q.dtype)[None, :, None, None] + y = F.scaled_dot_product_attention( + q, + k, + v, + attn_mask=None, + is_causal=True, + enable_gqa=(self.num_kv_heads != self.num_heads), + ) + if self.use_xsa: + if self.num_kv_heads != self.num_heads: + repeats = self.num_heads // self.num_kv_heads + v_expanded = v.repeat_interleave(repeats, dim=1) + else: + v_expanded = v + y = y - v_expanded + y = y.transpose(1, 2).contiguous().reshape(bsz, seqlen, dim) + return self.proj(y) +class MLP(nn.Module): + def __init__(self, dim: int, mlp_mult: int): + super().__init__() + hidden = mlp_mult * dim + self.fc = CastedLinear(dim, hidden, bias=False) + self.proj = CastedLinear(hidden, dim, bias=False) + self.proj._zero_init = True + def forward(self, x: Tensor) -> Tensor: + x = F.leaky_relu(self.fc(x), negative_slope=0.5) + return self.proj(x.square()) +class MoEMLP(nn.Module): + def __init__(self, dim: int, mlp_mult: int, num_experts: int): + super().__init__() + self.num_experts = num_experts + expert_hidden = (mlp_mult * dim) // num_experts + self.experts = nn.ModuleList([ + nn.ModuleDict({ + "fc": CastedLinear(dim, expert_hidden, bias=False), + "proj": CastedLinear(expert_hidden, dim, bias=False), + }) + for _ in range(num_experts) + ]) + for e in self.experts: + e["proj"]._zero_init = True + self.router = nn.Linear(dim, num_experts, bias=False) + def forward(self, x: Tensor) -> Tensor: + bsz, seqlen, dim = x.shape + logits = self.router(x.detach()) + weights = torch.softmax(logits, dim=-1) + indices = weights.argmax(dim=-1) + out = torch.zeros_like(x) + for i, expert in enumerate(self.experts): + mask = (indices == i) + if not mask.any(): + continue + tokens = x[mask] + h = F.leaky_relu(expert["fc"](tokens), negative_slope=0.5) + h = expert["proj"](h.square()) + out[mask] = h + return out +class Block(nn.Module): + def __init__( + self, + dim: int, + num_heads: int, + num_kv_heads: int, + mlp_mult: int, + rope_base: float, + qk_gain_init: float, + use_smeargate: bool = False, + rope_partial_dims: int = 0, + use_xsa: bool = False, + ln_scale_factor: float = 1.0, + moe_num_experts: int = 0, + ): + super().__init__() + self.attn_norm = RMSNorm() + self.mlp_norm = RMSNorm() + self.attn = CausalSelfAttention( + dim, num_heads, num_kv_heads, rope_base, qk_gain_init, + rope_partial_dims=rope_partial_dims, + use_xsa=use_xsa, + ) + self.mlp = MoEMLP(dim, mlp_mult, moe_num_experts) if moe_num_experts > 1 else MLP(dim, mlp_mult) + self.attn_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + self.mlp_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + self.resid_mix = nn.Parameter(torch.stack((torch.ones(dim), torch.zeros(dim))).float()) + self.smear = SmearGate(dim) if use_smeargate else None + self.ln_scale_factor = ln_scale_factor + def forward(self, x: Tensor, x0: Tensor) -> Tensor: + mix = self.resid_mix.to(dtype=x.dtype) + x = mix[0][None, None, :] * x + mix[1][None, None, :] * x0 + normed = self.attn_norm(x) + if self.ln_scale_factor != 1.0: + normed = normed * self.ln_scale_factor + attn_out = self.attn(normed) + x = x + self.attn_scale.to(dtype=x.dtype)[None, None, :] * attn_out + mlp_normed = self.mlp_norm(x) + if self.ln_scale_factor != 1.0: + mlp_normed = mlp_normed * self.ln_scale_factor + x = x + self.mlp_scale.to(dtype=x.dtype)[None, None, :] * self.mlp(mlp_normed) + if self.smear is not None: + x = self.smear(x) + return x +class GPT(nn.Module): + def __init__( + self, + vocab_size: int, + num_layers: int, + model_dim: int, + num_heads: int, + num_kv_heads: int, + mlp_mult: int, + tie_embeddings: bool, + tied_embed_init_std: float, + logit_softcap: float, + rope_base: float, + qk_gain_init: float, + bigramhash_buckets: int = 0, + use_smeargate: bool = False, + unet_skips: bool = True, + rope_partial_dims: int = 0, + ln_scale: bool = False, + xsa_layers: int = 0, + recurrence_repeats: int = 0, + recurrence_unique_head: int = 2, + recurrence_unique_tail: int = 2, + moe_num_experts: int = 0, + ): + super().__init__() + if logit_softcap <= 0.0: + raise ValueError(f"logit_softcap must be positive, got {logit_softcap}") + self.tie_embeddings = tie_embeddings + self.tied_embed_init_std = tied_embed_init_std + self.logit_softcap = logit_softcap + self.tok_emb = nn.Embedding(vocab_size, model_dim) + self.bigram_hash = BigramHash(bigramhash_buckets, model_dim) if bigramhash_buckets > 0 else None + self.recurrence_repeats = recurrence_repeats + if recurrence_repeats > 0: + self.num_encoder_layers = 0 + self.num_decoder_layers = 0 + self.num_skip_weights = 0 + self.skip_weights = nn.Parameter(torch.zeros(0, dtype=torch.float32)) + n_head = recurrence_unique_head + n_tail = recurrence_unique_tail + effective_layers = n_head + recurrence_repeats + n_tail + def make_block(layer_idx: int, eff_total: int) -> Block: + xsa_start_idx = eff_total - xsa_layers if xsa_layers > 0 else eff_total + return Block( + model_dim, num_heads, num_kv_heads, mlp_mult, rope_base, qk_gain_init, + use_smeargate=use_smeargate, + rope_partial_dims=rope_partial_dims, + use_xsa=(layer_idx >= xsa_start_idx), + ln_scale_factor=(1.0 / math.sqrt(layer_idx + 1)) if ln_scale else 1.0, + moe_num_experts=moe_num_experts, + ) + self.head_blocks = nn.ModuleList([make_block(i, effective_layers) for i in range(n_head)]) + self.shared_block = make_block(n_head, effective_layers) + self.recurrence_gate = nn.Parameter(torch.full((recurrence_repeats,), 0.5, dtype=torch.float32)) + self.tail_blocks = nn.ModuleList([ + make_block(n_head + recurrence_repeats + i, effective_layers) for i in range(n_tail) + ]) + self.blocks = nn.ModuleList() + else: + if unet_skips: + self.num_encoder_layers = num_layers // 2 + self.num_decoder_layers = num_layers - self.num_encoder_layers + self.num_skip_weights = min(self.num_encoder_layers, self.num_decoder_layers) + else: + self.num_encoder_layers = num_layers + self.num_decoder_layers = 0 + self.num_skip_weights = 0 + self.skip_weights = nn.Parameter(torch.ones(self.num_skip_weights, model_dim, dtype=torch.float32)) + xsa_start = num_layers - xsa_layers if xsa_layers > 0 else num_layers + self.blocks = nn.ModuleList( + [ + Block( + model_dim, num_heads, num_kv_heads, mlp_mult, rope_base, qk_gain_init, + use_smeargate=use_smeargate, + rope_partial_dims=rope_partial_dims, + use_xsa=(i >= xsa_start), + ln_scale_factor=(1.0 / math.sqrt(i + 1)) if ln_scale else 1.0, + moe_num_experts=moe_num_experts, + ) + for i in range(num_layers) + ] + ) + self.head_blocks = nn.ModuleList() + self.shared_block = None + self.recurrence_gate = None + self.tail_blocks = nn.ModuleList() + self.final_norm = RMSNorm() + self.lm_head = None if tie_embeddings else CastedLinear(model_dim, vocab_size, bias=False) + if self.lm_head is not None: + self.lm_head._zero_init = True + self._init_weights() + def _init_weights(self) -> None: + if self.tie_embeddings: + nn.init.normal_(self.tok_emb.weight, mean=0.0, std=self.tied_embed_init_std) + for module in self.modules(): + if isinstance(module, nn.Linear) and getattr(module, "_zero_init", False): + nn.init.zeros_(module.weight) + def forward(self, input_ids: Tensor, target_ids: Tensor) -> Tensor: + x = self.tok_emb(input_ids) + if self.bigram_hash is not None: + x = x + self.bigram_hash(input_ids) + x = F.rms_norm(x, (x.size(-1),)) + x0 = x + if self.recurrence_repeats > 0: + for block in self.head_blocks: + x = block(x, x0) + for r in range(self.recurrence_repeats): + gate = torch.sigmoid(self.recurrence_gate[r]).to(dtype=x.dtype) + residual = x + x = self.shared_block(x, x0) + x = gate * x + (1.0 - gate) * residual + for block in self.tail_blocks: + x = block(x, x0) + else: + skips: list[Tensor] = [] + for i in range(self.num_encoder_layers): + x = self.blocks[i](x, x0) + skips.append(x) + for i in range(self.num_decoder_layers): + if skips: + x = x + self.skip_weights[i].to(dtype=x.dtype)[None, None, :] * skips.pop() + x = self.blocks[self.num_encoder_layers + i](x, x0) + x = self.final_norm(x).reshape(-1, x.size(-1)) + targets = target_ids.reshape(-1) + if self.tie_embeddings: + logits_proj = F.linear(x, self.tok_emb.weight) + else: + if self.lm_head is None: + raise RuntimeError("lm_head is required when tie_embeddings=False") + logits_proj = self.lm_head(x) + logits = self.logit_softcap * torch.tanh(logits_proj / self.logit_softcap) + if getattr(self, '_return_logits', False): + return logits.float().reshape(input_ids.shape[0], input_ids.shape[1], -1) + return F.cross_entropy(logits.float(), targets, reduction="mean") + def forward_logits(self, input_ids: Tensor) -> Tensor: + x = self.tok_emb(input_ids) + if self.bigram_hash is not None: + x = x + self.bigram_hash(input_ids) + x = F.rms_norm(x, (x.size(-1),)) + x0 = x + if self.recurrence_repeats > 0: + for block in self.head_blocks: + x = block(x, x0) + for r in range(self.recurrence_repeats): + gate = torch.sigmoid(self.recurrence_gate[r]).to(dtype=x.dtype) + residual = x + x = self.shared_block(x, x0) + x = gate * x + (1.0 - gate) * residual + for block in self.tail_blocks: + x = block(x, x0) + else: + skips: list[Tensor] = [] + for i in range(self.num_encoder_layers): + x = self.blocks[i](x, x0) + skips.append(x) + for i in range(self.num_decoder_layers): + if skips: + x = x + self.skip_weights[i].to(dtype=x.dtype)[None, None, :] * skips.pop() + x = self.blocks[self.num_encoder_layers + i](x, x0) + x = self.final_norm(x) + if self.tie_embeddings: + logits_proj = F.linear(x, self.tok_emb.weight) + else: + if self.lm_head is None: + raise RuntimeError("lm_head is required when tie_embeddings=False") + logits_proj = self.lm_head(x) + return self.logit_softcap * torch.tanh(logits_proj / self.logit_softcap) +def main() -> None: + global zeropower_via_newtonschulz5 + code = Path(__file__).read_text(encoding="utf-8") + args = Hyperparameters() + if not int(os.environ.get("TORCHDYNAMO_DISABLE", "0")): + zeropower_via_newtonschulz5 = torch.compile(zeropower_via_newtonschulz5) + distributed = "RANK" in os.environ and "WORLD_SIZE" in os.environ + rank = int(os.environ.get("RANK", "0")) + world_size = int(os.environ.get("WORLD_SIZE", "1")) + local_rank = int(os.environ.get("LOCAL_RANK", "0")) + if world_size <= 0: + raise ValueError(f"WORLD_SIZE must be positive, got {world_size}") + if 8 % world_size != 0: + raise ValueError(f"WORLD_SIZE={world_size} must divide 8 so grad_accum_steps stays integral") + grad_accum_steps = 8 // world_size + grad_scale = 1.0 / grad_accum_steps + if not torch.cuda.is_available(): + raise RuntimeError("CUDA is required") + device = torch.device("cuda", local_rank) + torch.cuda.set_device(device) + if distributed: + dist.init_process_group(backend="nccl", device_id=device) + dist.barrier() + master_process = rank == 0 + torch.backends.cuda.matmul.allow_tf32 = True + torch.backends.cudnn.allow_tf32 = True + from torch.backends.cuda import enable_cudnn_sdp, enable_flash_sdp, enable_math_sdp, enable_mem_efficient_sdp + _use_flash = sys.platform != "win32" + enable_cudnn_sdp(False) + enable_flash_sdp(_use_flash) + enable_mem_efficient_sdp(False) + enable_math_sdp(not _use_flash) + # Create per-run output directory so artifacts/logs never collide + run_dir = Path(os.environ.get("RUN_OUTPUT_DIR", f"runs/{args.run_id}")) + if master_process: + run_dir.mkdir(parents=True, exist_ok=True) + # Backward-compat: also keep logs/ symlink/copy + os.makedirs("logs", exist_ok=True) + logfile = None + if master_process: + logfile = str(run_dir / "log.txt") + # Also write to legacy logs/ path for compatibility + legacy_logfile = f"logs/{args.run_id}.txt" + if not Path(legacy_logfile).exists(): + try: + os.symlink(os.path.abspath(logfile), legacy_logfile) + except OSError: + pass # Windows or cross-device; will write to both + print(logfile) + def log0(msg: str, console: bool = True) -> None: + if not master_process: + return + if console: + print(msg) + if logfile is not None: + with open(logfile, "a", encoding="utf-8") as f: + print(msg, file=f) + log0(code, console=False) + log0("=" * 100, console=False) + log0(f"Running Python {sys.version}", console=False) + log0(f"Running PyTorch {torch.__version__}", console=False) + log0( + subprocess.run(["nvidia-smi"], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True, check=False).stdout, + console=False, + ) + log0("=" * 100, console=False) + random.seed(args.seed) + np.random.seed(args.seed) + torch.manual_seed(args.seed) + torch.cuda.manual_seed_all(args.seed) + if not args.tokenizer_path.endswith(".model"): + raise ValueError(f"Script only setup for SentencePiece .model file: {args.tokenizer_path}") + sp = spm.SentencePieceProcessor(model_file=args.tokenizer_path) + if int(sp.vocab_size()) != args.vocab_size: + raise ValueError( + f"VOCAB_SIZE={args.vocab_size} does not match tokenizer vocab_size={int(sp.vocab_size())}" + ) + dataset_dir = Path(args.data_path).resolve() + actual_train_files = len(list(dataset_dir.glob("fineweb_train_*.bin"))) + val_tokens = load_validation_tokens(args.val_files, args.train_seq_len) + base_bytes_lut, has_leading_space_lut, is_boundary_token_lut = build_sentencepiece_luts( + sp, args.vocab_size, device + ) + log0(f"val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path={args.tokenizer_path}") + log0(f"train_loader:dataset:{dataset_dir.name} train_shards:{actual_train_files}") + log0(f"val_loader:shards pattern={args.val_files} tokens:{val_tokens.numel() - 1}") + CastedLinear._int6_qat = args.int6_qat + base_model = GPT( + vocab_size=args.vocab_size, + num_layers=args.num_layers, + model_dim=args.model_dim, + num_heads=args.num_heads, + num_kv_heads=args.num_kv_heads, + mlp_mult=args.mlp_mult, + tie_embeddings=args.tie_embeddings, + tied_embed_init_std=args.tied_embed_init_std, + logit_softcap=args.logit_softcap, + rope_base=args.rope_base, + qk_gain_init=args.qk_gain_init, + bigramhash_buckets=args.bigramhash_buckets, + use_smeargate=args.smeargate, + unet_skips=args.unet_skips, + rope_partial_dims=args.rope_partial_dims, + ln_scale=args.ln_scale, + xsa_layers=args.xsa_layers, + recurrence_repeats=args.recurrence_repeats, + recurrence_unique_head=args.recurrence_unique_head, + recurrence_unique_tail=args.recurrence_unique_tail, + moe_num_experts=args.moe_num_experts, + ).to(device).bfloat16() + for module in base_model.modules(): + if isinstance(module, CastedLinear): + module.float() + restore_low_dim_params_to_fp32(base_model) + if int(os.environ.get("TORCHDYNAMO_DISABLE", "0")): + compiled_model = base_model + else: + compiled_model = torch.compile(base_model, dynamic=False, fullgraph=True) + model: nn.Module = DDP(compiled_model, device_ids=[local_rank], broadcast_buffers=False) if distributed else compiled_model + if args.recurrence_repeats > 0: + block_named_params = ( + list(base_model.head_blocks.named_parameters()) + + list(base_model.shared_block.named_parameters()) + + list(base_model.tail_blocks.named_parameters()) + ) + if base_model.recurrence_gate is not None: + block_named_params.append(("recurrence_gate", base_model.recurrence_gate)) + else: + block_named_params = list(base_model.blocks.named_parameters()) + matrix_params = [ + p + for name, p in block_named_params + if p.ndim == 2 and not any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) + ] + scalar_params = [ + p + for name, p in block_named_params + if p.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) + ] + if base_model.skip_weights.numel() > 0: + scalar_params.append(base_model.skip_weights) + token_lr = args.tied_embed_lr if args.tie_embeddings else args.embed_lr + embed_params: list[nn.Parameter] = [base_model.tok_emb.weight] + if base_model.bigram_hash is not None: + embed_params.append(base_model.bigram_hash.emb.weight) + optimizer_tok = torch.optim.Adam( + [{"params": embed_params, "lr": token_lr, "base_lr": token_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + fused=True, + ) + optimizer_muon = Muon( + matrix_params, + lr=args.matrix_lr, + momentum=args.muon_momentum, + backend_steps=args.muon_backend_steps, + ) + for group in optimizer_muon.param_groups: + group["base_lr"] = args.matrix_lr + optimizer_scalar = torch.optim.Adam( + [{"params": scalar_params, "lr": args.scalar_lr, "base_lr": args.scalar_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + fused=True, + ) + optimizers: list[torch.optim.Optimizer] = [optimizer_tok, optimizer_muon, optimizer_scalar] + if base_model.lm_head is not None: + optimizer_head = torch.optim.Adam( + [{"params": [base_model.lm_head.weight], "lr": args.head_lr, "base_lr": args.head_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + fused=True, + ) + optimizers.insert(1, optimizer_head) + n_params = sum(p.numel() for p in base_model.parameters()) + log0(f"model_params:{n_params}") + log0(f"world_size:{world_size} grad_accum_steps:{grad_accum_steps}") + log0(f"sdp_backends:cudnn=False flash={_use_flash} mem_efficient=False math={not _use_flash}") + log0(f"attention_mode:gqa num_heads:{args.num_heads} num_kv_heads:{args.num_kv_heads}") + log0( + f"tie_embeddings:{args.tie_embeddings} embed_lr:{token_lr} " + f"head_lr:{args.head_lr if base_model.lm_head is not None else 0.0} " + f"matrix_lr:{args.matrix_lr} scalar_lr:{args.scalar_lr}" + ) + log0( + f"train_batch_tokens:{args.train_batch_tokens} train_seq_len:{args.train_seq_len} " + f"iterations:{args.iterations} warmup_steps:{args.warmup_steps} " + f"max_wallclock_seconds:{args.max_wallclock_seconds:.3f}" + ) + log0(f"seed:{args.seed}") + log0( + f"meta_baseline: int6_qat:{args.int6_qat} bigramhash:{args.bigramhash_buckets} " + f"smeargate:{args.smeargate} ema_decay:{args.ema_decay} unet_skips:{args.unet_skips} " + f"warmdown_iters:{args.warmdown_iters} compression:{'zstd-22' if _HAS_ZSTD else 'zlib-9'}" + ) + log0( + f"strong_gains: leaky_relu_sq:True rope_partial_dims:{args.rope_partial_dims} " + f"ln_scale:{args.ln_scale} xsa_layers:{args.xsa_layers}" + ) + if args.recurrence_repeats > 0: + eff = args.recurrence_unique_head + args.recurrence_repeats + args.recurrence_unique_tail + log0( + f"depth_recurrence: head={args.recurrence_unique_head} repeats={args.recurrence_repeats} " + f"tail={args.recurrence_unique_tail} effective_layers={eff}" + ) + train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) + def zero_grad_all() -> None: + for opt in optimizers: + opt.zero_grad(set_to_none=True) + max_wallclock_ms = 1000.0 * args.max_wallclock_seconds if args.max_wallclock_seconds > 0 else None + def lr_mul(step: int, elapsed_ms: float) -> float: + if args.warmdown_iters <= 0: + return 1.0 + if max_wallclock_ms is None: + warmdown_start = max(args.iterations - args.warmdown_iters, 0) + return max((args.iterations - step) / max(args.warmdown_iters, 1), 0.0) if warmdown_start <= step < args.iterations else 1.0 + step_ms = elapsed_ms / max(step, 1) + warmdown_ms = args.warmdown_iters * step_ms + remaining_ms = max(max_wallclock_ms - elapsed_ms, 0.0) + return remaining_ms / max(warmdown_ms, 1e-9) if remaining_ms <= warmdown_ms else 1.0 + if args.warmup_steps > 0: + initial_model_state = {name: tensor.detach().cpu().clone() for name, tensor in base_model.state_dict().items()} + initial_optimizer_states = [copy.deepcopy(opt.state_dict()) for opt in optimizers] + model.train() + for warmup_step in range(args.warmup_steps): + zero_grad_all() + for micro_step in range(grad_accum_steps): + if distributed: + model.require_backward_grad_sync = micro_step == grad_accum_steps - 1 + x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + warmup_loss = model(x, y) + (warmup_loss * grad_scale).backward() + for opt in optimizers: + opt.step() + zero_grad_all() + if args.warmup_steps <= 20 or (warmup_step + 1) % 10 == 0 or warmup_step + 1 == args.warmup_steps: + log0(f"warmup_step:{warmup_step + 1}/{args.warmup_steps}") + base_model.load_state_dict(initial_model_state, strict=True) + for opt, state in zip(optimizers, initial_optimizer_states, strict=True): + opt.load_state_dict(state) + zero_grad_all() + if distributed: + model.require_backward_grad_sync = True + train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) + ema_state: dict[str, Tensor] | None = None + if args.ema_decay > 0: + ema_state = {name: param.data.clone() for name, param in base_model.named_parameters()} + log0(f"ema:enabled decay={args.ema_decay}") + training_time_ms = 0.0 + stop_after_step: int | None = None + torch.cuda.synchronize() + t0 = time.perf_counter() + step = 0 + while True: + last_step = step == args.iterations or (stop_after_step is not None and step >= stop_after_step) + should_validate = last_step or (args.val_loss_every > 0 and step % args.val_loss_every == 0) + if should_validate: + torch.cuda.synchronize() + training_time_ms += 1000.0 * (time.perf_counter() - t0) + ema_backup: dict[str, Tensor] | None = None + if ema_state is not None: + ema_backup = {} + with torch.no_grad(): + for name, p in base_model.named_parameters(): + ema_backup[name] = p.data.clone() + p.data.copy_(ema_state[name]) + val_loss, val_bpb = eval_val( + args, + model, + rank, + world_size, + device, + grad_accum_steps, + val_tokens, + base_bytes_lut, + has_leading_space_lut, + is_boundary_token_lut, + ) + if ema_backup is not None: + with torch.no_grad(): + for name, p in base_model.named_parameters(): + p.data.copy_(ema_backup[name]) + del ema_backup + log0( + f"step:{step}/{args.iterations} val_loss:{val_loss:.4f} val_bpb:{val_bpb:.4f} " + f"train_time:{training_time_ms:.0f}ms step_avg:{training_time_ms / max(step, 1):.2f}ms" + ) + torch.cuda.synchronize() + t0 = time.perf_counter() + if last_step: + if stop_after_step is not None and step < args.iterations: + log0( + f"stopping_early: wallclock_cap train_time:{training_time_ms:.0f}ms " + f"step:{step}/{args.iterations}" + ) + break + elapsed_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) + scale = lr_mul(step, elapsed_ms) + zero_grad_all() + train_loss = torch.zeros((), device=device) + for micro_step in range(grad_accum_steps): + if distributed: + model.require_backward_grad_sync = micro_step == grad_accum_steps - 1 + x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + loss = model(x, y) + train_loss += loss.detach() + (loss * grad_scale).backward() + train_loss /= grad_accum_steps + frac = min(step / args.muon_momentum_warmup_steps, 1.0) if args.muon_momentum_warmup_steps > 0 else 1.0 + muon_momentum = (1 - frac) * args.muon_momentum_warmup_start + frac * args.muon_momentum + for group in optimizer_muon.param_groups: + group["momentum"] = muon_momentum + for opt in optimizers: + for group in opt.param_groups: + group["lr"] = group["base_lr"] * scale + if args.grad_clip_norm > 0: + torch.nn.utils.clip_grad_norm_(base_model.parameters(), args.grad_clip_norm) + for opt in optimizers: + opt.step() + zero_grad_all() + if ema_state is not None: + with torch.no_grad(): + decay = args.ema_decay + for name, p in base_model.named_parameters(): + ema_state[name].mul_(decay).add_(p.data, alpha=1.0 - decay) + step += 1 + if args.checkpoint_every > 0 and step % args.checkpoint_every == 0 and master_process: + ckpt_dir = run_dir / "checkpoints" + ckpt_dir.mkdir(exist_ok=True) + ckpt_path = ckpt_dir / f"ckpt_step{step}.pt" + ckpt_data = {"step": step, "model": base_model.state_dict()} + if ema_state is not None: + ckpt_data["ema"] = {k: v.cpu() for k, v in ema_state.items()} + torch.save(ckpt_data, ckpt_path) + ckpts = sorted(ckpt_dir.glob("ckpt_step*.pt"), key=lambda p: p.stat().st_mtime) + for old in ckpts[:-2]: + old.unlink(missing_ok=True) + log0(f"checkpoint:saved {ckpt_path} (kept {min(len(ckpts), 2)} latest)") + approx_training_time_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) + should_log_train = ( + args.train_log_every > 0 + and (step <= 10 or step % args.train_log_every == 0 or stop_after_step is not None) + ) + if should_log_train: + log0( + f"step:{step}/{args.iterations} train_loss:{train_loss.item():.4f} " + f"train_time:{approx_training_time_ms:.0f}ms step_avg:{approx_training_time_ms / step:.2f}ms" + ) + reached_cap = max_wallclock_ms is not None and approx_training_time_ms >= max_wallclock_ms + if distributed and max_wallclock_ms is not None: + reached_cap_tensor = torch.tensor(int(reached_cap), device=device) + dist.all_reduce(reached_cap_tensor, op=dist.ReduceOp.MAX) + reached_cap = bool(reached_cap_tensor.item()) + if stop_after_step is None and reached_cap: + stop_after_step = step + log0( + f"peak memory allocated: {torch.cuda.max_memory_allocated() // 1024 // 1024} MiB " + f"reserved: {torch.cuda.max_memory_reserved() // 1024 // 1024} MiB" + ) + if ema_state is not None: + with torch.no_grad(): + for name, p in base_model.named_parameters(): + p.data.copy_(ema_state[name]) + log0("ema:applied EMA weights for export") + fp32_path = str(run_dir / "final_model.pt") + if master_process: + torch.save(base_model.state_dict(), fp32_path) + model_bytes = os.path.getsize(fp32_path) + code_bytes = len(code.encode("utf-8")) + log0(f"Serialized model: {model_bytes} bytes → {fp32_path}") + log0(f"Code size: {code_bytes} bytes") + log0(f"Total submission size: {model_bytes + code_bytes} bytes") + quant_label = f"int{_QUANT_BITS}" + compress_label = "zstd" if _HAS_ZSTD else "zlib" + if args.gptq_full_hessian: + if master_process: + log0("Running Full Hessian GPTQ...") + + # Step 1: Generate AR calibration data + calib_data: Tensor + if args.gptq_ar_selfgen: + if master_process: + log0(f"Generating {args.gptq_calib_samples} AR calibration sequences...") + base_model.eval() + calib_data = generate_ar_calibration( + base_model, args.gptq_calib_samples, args.gptq_calib_seqlen, + args.gptq_calib_temp, args.vocab_size + ) + base_model.train() + else: + raise ValueError("gptq_ar_selfgen=False requires an external calibration source (not implemented)") + + # Step 2: Create GPTQ objects for each CastedLinear + gptq_modules: dict[str, GPTQ] = {} + for name, module in base_model.named_modules(): + if isinstance(module, CastedLinear): + gptq_modules[name] = GPTQ(module.weight.data) + + # Step 3: Collect Hessians + base_model.eval() + with torch.no_grad(): + collect_hessians(base_model, calib_data, gptq_modules) + base_model.train() + + # Step 4: Quantize each layer + with torch.no_grad(): + for name, module in base_model.named_modules(): + if isinstance(module, CastedLinear) and name in gptq_modules: + gptq_obj = gptq_modules[name] + Q = gptq_obj.quantize( + module.weight.data, + qmax=_QUANT_MAX_VAL, + blocksize=args.gptq_blocksize, + percdamp=args.gptq_damp, + actorder=args.gptq_actorder + ) + module.weight.data.copy_(Q) + + if master_process: + log0("Full Hessian GPTQ complete") + + log0(f"quantization:start {quant_label} distributing across {world_size} GPUs") + torch.cuda.synchronize() + t_quant = time.perf_counter() + quant_obj, quant_stats = quantize_state_dict_int8( + base_model.state_dict(), gpu_device=device, rank=rank, world_size=world_size, + ) + torch.cuda.synchronize() + log0(f"quantization:done {quant_label} time={1000.0 * (time.perf_counter() - t_quant):.0f}ms") + quant_buf = io.BytesIO() + torch.save(quant_obj, quant_buf) + quant_raw = quant_buf.getvalue() + if _HAS_ZSTD: + compressor = zstd.ZstdCompressor(level=22, threads=-1) # -1 = use all CPU cores + quant_blob = compressor.compress(quant_raw) + else: + quant_blob = zlib.compress(quant_raw, level=9) + quant_raw_bytes = len(quant_raw) + artifact_name = str(run_dir / f"final_model.{quant_label}.ptz") + if master_process: + with open(artifact_name, "wb") as f: + f.write(quant_blob) + quant_file_bytes = os.path.getsize(artifact_name) + code_bytes = len(code.encode("utf-8")) + ratio = quant_stats["baseline_tensor_bytes"] / max(quant_stats["int8_payload_bytes"], 1) + log0( + f"Serialized model {quant_label}+{compress_label}: {quant_file_bytes} bytes " + f"(payload:{quant_stats['int8_payload_bytes']} raw_torch:{quant_raw_bytes} payload_ratio:{ratio:.2f}x)" + ) + log0(f"Total submission size {quant_label}+{compress_label}: {quant_file_bytes + code_bytes} bytes") + log0(f"artifact_path:{artifact_name}") + if distributed: + dist.barrier() + with open(artifact_name, "rb") as f: + quant_blob_disk = f.read() + if _HAS_ZSTD: + decompressor = zstd.ZstdDecompressor() + quant_state = torch.load(io.BytesIO(decompressor.decompress(quant_blob_disk)), map_location="cpu") + else: + quant_state = torch.load(io.BytesIO(zlib.decompress(quant_blob_disk)), map_location="cpu") + base_model.load_state_dict(dequantize_state_dict_int8(quant_state), strict=True) + torch.cuda.synchronize() + t_qeval = time.perf_counter() + q_val_loss, q_val_bpb = eval_val( + args, + model, + rank, + world_size, + device, + grad_accum_steps, + val_tokens, + base_bytes_lut, + has_leading_space_lut, + is_boundary_token_lut, + ) + torch.cuda.synchronize() + log0( + f"final_{quant_label}_{compress_label}_roundtrip val_loss:{q_val_loss:.4f} val_bpb:{q_val_bpb:.4f} " + f"eval_time:{1000.0 * (time.perf_counter() - t_qeval):.0f}ms" + ) + log0(f"final_{quant_label}_{compress_label}_roundtrip_exact val_loss:{q_val_loss:.8f} val_bpb:{q_val_bpb:.8f}") + if args.eval_stride > 0 and args.eval_stride < args.train_seq_len: + torch.cuda.synchronize() + t_slide = time.perf_counter() + sw_val_loss, sw_val_bpb = eval_val_sliding( + args, base_model, rank, world_size, device, + val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, + stride=args.eval_stride, + ) + torch.cuda.synchronize() + log0( + f"final_sliding_window val_loss:{sw_val_loss:.4f} val_bpb:{sw_val_bpb:.4f} " + f"stride:{args.eval_stride} eval_time:{1000.0 * (time.perf_counter() - t_slide):.0f}ms" + ) + log0(f"final_sliding_window_exact val_loss:{sw_val_loss:.8f} val_bpb:{sw_val_bpb:.8f}") + if args.ttt_enabled: + torch.cuda.synchronize() + t_ttt = time.perf_counter() + + # Choose LoRA TTT or full TTT based on TTT_LORA_ENABLED flag + if args.ttt_lora_enabled: + ttt_loss, ttt_bpb = eval_val_sliding_lora_ttt( + args, base_model, rank, world_size, device, + val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, + stride=args.eval_stride if args.eval_stride > 0 else 64, + log0=log0, + ) + ttt_label = "legal_lora_ttt" + else: + ttt_loss, ttt_bpb = eval_val_sliding_ttt( + args, base_model, rank, world_size, device, + val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, + stride=args.eval_stride if args.eval_stride > 0 else 64, + log0=log0, + ) + ttt_label = "legal_ttt" + + torch.cuda.synchronize() + log0( + f"{ttt_label} val_loss:{ttt_loss:.4f} val_bpb:{ttt_bpb:.4f} " + f"eval_time:{1000.0 * (time.perf_counter() - t_ttt):.0f}ms" + ) + log0(f"{ttt_label}_exact val_loss:{ttt_loss:.8f} val_bpb:{ttt_bpb:.8f}") + if int(os.environ.get("NGRAM_EVAL", "0")): + ngram_max_order = int(os.environ.get("NGRAM_MAX_ORDER", "9")) + log0(f"ngram_eval:starting max_order={ngram_max_order}") + torch.cuda.synchronize() + t_ngram = time.perf_counter() + try: + from ngram_eval import eval_val_ngram + ng_val_loss, ng_val_bpb = eval_val_ngram( + args, base_model, model, rank, world_size, device, grad_accum_steps, + val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, + max_order=ngram_max_order, log_fn=log0, + ) + torch.cuda.synchronize() + log0( + f"ngram_eval_result val_loss:{ng_val_loss:.4f} val_bpb:{ng_val_bpb:.4f} " + f"eval_time:{1000.0 * (time.perf_counter() - t_ngram):.0f}ms" + ) + log0(f"ngram_eval_result_exact val_loss:{ng_val_loss:.8f} val_bpb:{ng_val_bpb:.8f}") + except Exception as e: + log0(f"ngram_eval:FAILED {e}") + if distributed: + dist.destroy_process_group() +if __name__ == "__main__": + main() diff --git a/train_gpt_phase6_causal_slot.py b/train_gpt_phase6_causal_slot.py new file mode 100644 index 0000000000..ba2ba0aba8 --- /dev/null +++ b/train_gpt_phase6_causal_slot.py @@ -0,0 +1,2063 @@ +from __future__ import annotations +import copy +import glob +import io +import math +import os +import random +import subprocess +import sys +import time +import uuid +import zlib +from pathlib import Path +try: + import zstandard as zstd + _HAS_ZSTD = True +except ImportError: + _HAS_ZSTD = False +import numpy as np +import sentencepiece as spm +import torch +import torch.distributed as dist +import torch.nn.functional as F +from torch import Tensor, nn +from torch.nn.parallel import DistributedDataParallel as DDP +class Hyperparameters: + data_path = os.environ.get("DATA_PATH", "./data/datasets/fineweb10B_sp1024") + train_files = os.path.join(data_path, "fineweb_train_*.bin") + val_files = os.path.join(data_path, "fineweb_val_*.bin") + tokenizer_path = os.environ.get("TOKENIZER_PATH", "./data/tokenizers/fineweb_1024_bpe.model") + run_id = os.environ.get("RUN_ID", str(uuid.uuid4())) + seed = int(os.environ.get("SEED", 1337)) + val_batch_size = int(os.environ.get("VAL_BATCH_SIZE", 524_288)) + val_loss_every = int(os.environ.get("VAL_LOSS_EVERY", 1000)) + train_log_every = int(os.environ.get("TRAIN_LOG_EVERY", 200)) + val_max_tokens = int(os.environ.get("VAL_MAX_TOKENS", 0)) + checkpoint_every = int(os.environ.get("CHECKPOINT_EVERY", 0)) + iterations = int(os.environ.get("ITERATIONS", 20000)) + warmdown_iters = int(os.environ.get("WARMDOWN_ITERS", 3500)) + warmup_steps = int(os.environ.get("WARMUP_STEPS", 20)) + train_batch_tokens = int(os.environ.get("TRAIN_BATCH_TOKENS", 524_288)) + train_seq_len = int(os.environ.get("TRAIN_SEQ_LEN", 1024)) + max_wallclock_seconds = float(os.environ.get("MAX_WALLCLOCK_SECONDS", 600.0)) + qk_gain_init = float(os.environ.get("QK_GAIN_INIT", 1.5)) + vocab_size = int(os.environ.get("VOCAB_SIZE", 1024)) + num_layers = int(os.environ.get("NUM_LAYERS", 11)) + num_kv_heads = int(os.environ.get("NUM_KV_HEADS", 4)) + model_dim = int(os.environ.get("MODEL_DIM", 512)) + num_heads = int(os.environ.get("NUM_HEADS", 8)) + mlp_mult = int(os.environ.get("MLP_MULT", 3)) + tie_embeddings = bool(int(os.environ.get("TIE_EMBEDDINGS", "1"))) + rope_base = float(os.environ.get("ROPE_BASE", 10000.0)) + logit_softcap = float(os.environ.get("LOGIT_SOFTCAP", 30.0)) + embed_lr = float(os.environ.get("EMBED_LR", 0.6)) + head_lr = float(os.environ.get("HEAD_LR", 0.008)) + tied_embed_lr = float(os.environ.get("TIED_EMBED_LR", 0.05)) + tied_embed_init_std = float(os.environ.get("TIED_EMBED_INIT_STD", 0.005)) + matrix_lr = float(os.environ.get("MATRIX_LR", 0.04)) + scalar_lr = float(os.environ.get("SCALAR_LR", 0.04)) + muon_momentum = float(os.environ.get("MUON_MOMENTUM", 0.95)) + muon_backend_steps = int(os.environ.get("MUON_BACKEND_STEPS", 5)) + muon_momentum_warmup_start = float(os.environ.get("MUON_MOMENTUM_WARMUP_START", 0.85)) + muon_momentum_warmup_steps = int(os.environ.get("MUON_MOMENTUM_WARMUP_STEPS", 500)) + beta1 = float(os.environ.get("BETA1", 0.9)) + beta2 = float(os.environ.get("BETA2", 0.95)) + adam_eps = float(os.environ.get("ADAM_EPS", 1e-8)) + grad_clip_norm = float(os.environ.get("GRAD_CLIP_NORM", 0.0)) + bigramhash_buckets = int(os.environ.get("BIGRAMHASH_BUCKETS", 4096)) + ema_decay = float(os.environ.get("EMA_DECAY", 0.997)) + smeargate = bool(int(os.environ.get("SMEARGATE", "1"))) + unet_skips = bool(int(os.environ.get("UNET_SKIPS", "1"))) + int6_qat = bool(int(os.environ.get("INT6_QAT", "1"))) + rope_partial_dims = int(os.environ.get("ROPE_PARTIAL_DIMS", 16)) + ln_scale = bool(int(os.environ.get("LN_SCALE", "1"))) + xsa_layers = int(os.environ.get("XSA_LAYERS", 4)) + eval_stride = int(os.environ.get("EVAL_STRIDE", 0)) + ttt_enabled = bool(int(os.environ.get("TTT_ENABLED", "0"))) + ttt_lr = float(os.environ.get("TTT_LR", 0.002)) + ttt_epochs = int(os.environ.get("TTT_EPOCHS", 3)) + ttt_chunk_tokens = int(os.environ.get("TTT_CHUNK_TOKENS", 32768)) + ttt_freeze_blocks = int(os.environ.get("TTT_FREEZE_BLOCKS", 0)) + ttt_momentum = float(os.environ.get("TTT_MOMENTUM", 0.9)) + ttt_batch_seqs = int(os.environ.get("TTT_BATCH_SEQS", 32)) + ttt_grad_clip = float(os.environ.get("TTT_GRAD_CLIP", 1.0)) + ttt_lora_enabled = bool(int(os.environ.get("TTT_LORA_ENABLED", "0"))) + ttt_lora_rank = int(os.environ.get("TTT_LORA_RANK", 8)) + ttt_lora_lr = float(os.environ.get("TTT_LORA_LR", 0.001)) + recurrence_repeats = int(os.environ.get("RECURRENCE_REPEATS", 0)) + recurrence_unique_head = int(os.environ.get("RECURRENCE_UNIQUE_HEAD", 2)) + recurrence_unique_tail = int(os.environ.get("RECURRENCE_UNIQUE_TAIL", 2)) + moe_num_experts = int(os.environ.get("MOE_NUM_EXPERTS", 0)) + causal_slot_enabled = bool(int(os.environ.get("CAUSAL_SLOT_ENABLED", "0"))) # default OFF + causal_slot_steps = int(os.environ.get("CAUSAL_SLOT_STEPS", 8)) + causal_slot_lr = float(os.environ.get("CAUSAL_SLOT_LR", 0.005)) + causal_slot_dim = int(os.environ.get("CAUSAL_SLOT_DIM", 0)) # 0 = model_dim (hidden delta), >0 = that dim +def zeropower_via_newtonschulz5(G: Tensor, steps: int = 10, eps: float = 1e-7) -> Tensor: + a, b, c = (3.4445, -4.7750, 2.0315) + X = G.bfloat16() + X /= X.norm() + eps + transposed = G.size(0) > G.size(1) + if transposed: + X = X.T + for _ in range(steps): + A = X @ X.T + B = b * A + c * A @ A + X = a * X + B @ X + return X.T if transposed else X +class Muon(torch.optim.Optimizer): + def __init__(self, params, lr: float, momentum: float, backend_steps: int, nesterov: bool = True): + super().__init__( + params, + dict(lr=lr, momentum=momentum, backend_steps=backend_steps, nesterov=nesterov), + ) + @torch.no_grad() + def step(self, closure=None): + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + distributed = dist.is_available() and dist.is_initialized() + world_size = dist.get_world_size() if distributed else 1 + rank = dist.get_rank() if distributed else 0 + for group in self.param_groups: + params = group["params"] + if not params: + continue + lr = group["lr"] + momentum = group["momentum"] + backend_steps = group["backend_steps"] + nesterov = group["nesterov"] + total_params = sum(int(p.numel()) for p in params) + updates_flat = torch.zeros(total_params, device=params[0].device, dtype=torch.bfloat16) + curr = 0 + for i, p in enumerate(params): + if i % world_size == rank and p.grad is not None: + g = p.grad + state = self.state[p] + if "momentum_buffer" not in state: + state["momentum_buffer"] = torch.zeros_like(g) + buf = state["momentum_buffer"] + buf.mul_(momentum).add_(g) + if nesterov: + g = g.add(buf, alpha=momentum) + g = zeropower_via_newtonschulz5(g, steps=backend_steps) + g *= max(1, g.size(0) / g.size(1)) ** 0.5 + updates_flat[curr : curr + p.numel()] = g.reshape(-1) + curr += p.numel() + if distributed: + dist.all_reduce(updates_flat, op=dist.ReduceOp.SUM) + curr = 0 + for p in params: + g = updates_flat[curr : curr + p.numel()].view_as(p).to(dtype=p.dtype) + p.add_(g, alpha=-lr) + curr += p.numel() + return loss +def int6_ste(w: Tensor) -> Tensor: + qmax = _QUANT_MAX_VAL # Uses QUANT_BITS override if set (e.g. 15 for int5, 7 for int4) + if w.ndim == 2: + abs_max = w.abs().amax(dim=1, keepdim=True).clamp_min(1e-8) + else: + abs_max = w.abs().amax().clamp_min(1e-8) + scale = abs_max / float(qmax) + w_q = (w / scale).round().clamp(-qmax, qmax) * scale + return w + (w_q - w).detach() +def build_sentencepiece_luts( + sp: spm.SentencePieceProcessor, vocab_size: int, device: torch.device +) -> tuple[Tensor, Tensor, Tensor]: + sp_vocab_size = int(sp.vocab_size()) + table_size = max(sp_vocab_size, vocab_size) + base_bytes_np = np.zeros((table_size,), dtype=np.int16) + has_leading_space_np = np.zeros((table_size,), dtype=np.bool_) + is_boundary_token_np = np.ones((table_size,), dtype=np.bool_) + for token_id in range(sp_vocab_size): + if sp.is_control(token_id) or sp.is_unknown(token_id) or sp.is_unused(token_id): + continue + is_boundary_token_np[token_id] = False + if sp.is_byte(token_id): + base_bytes_np[token_id] = 1 + continue + piece = sp.id_to_piece(token_id) + if piece.startswith("▁"): + has_leading_space_np[token_id] = True + piece = piece[1:] + base_bytes_np[token_id] = len(piece.encode("utf-8")) + return ( + torch.tensor(base_bytes_np, dtype=torch.int16, device=device), + torch.tensor(has_leading_space_np, dtype=torch.bool, device=device), + torch.tensor(is_boundary_token_np, dtype=torch.bool, device=device), + ) +def load_validation_tokens(pattern: str, seq_len: int) -> Tensor: + files = [Path(p) for p in sorted(glob.glob(pattern))] + if not files: + raise FileNotFoundError(f"No files found for pattern: {pattern}") + tokens = torch.cat([load_data_shard(file) for file in files]).contiguous() + max_tokens = int(os.environ.get("VAL_MAX_TOKENS", 0)) + if max_tokens > 0 and tokens.numel() > max_tokens: + tokens = tokens[:max_tokens].contiguous() + usable = ((tokens.numel() - 1) // seq_len) * seq_len + if usable <= 0: + raise ValueError(f"Validation split is too short for TRAIN_SEQ_LEN={seq_len}") + return tokens[: usable + 1] +class CausalSLOT: + """Causal SLOT: optimize a delta vector using only already-scored context positions. + + Legal variant of SLOT — gradient only flows from positions that were scored + in PREVIOUS windows, not the current window's scored positions. + """ + def __init__(self, model_dim: int, device, lr: float = 0.005, steps: int = 8): + self.delta = torch.zeros(1, 1, model_dim, device=device, requires_grad=True) + self.optimizer = torch.optim.AdamW([self.delta], lr=lr) + self.steps = steps + + def reset(self): + """Reset delta to zero for each new batch.""" + self.delta.data.zero_() + self.optimizer = torch.optim.AdamW([self.delta], lr=self.optimizer.param_groups[0]['lr']) + + def optimize(self, model, input_ids: Tensor, context_mask: Tensor, target_ids: Tensor): + """Optimize delta using loss ONLY on context (already-scored) positions. + + Args: + model: the frozen model + input_ids: [1, seq_len] input tokens + context_mask: [seq_len] bool mask, True for already-scored positions + target_ids: [1, seq_len] target tokens + + Returns: + optimized delta tensor + """ + if not context_mask.any(): + return self.delta.detach() + + for _ in range(self.steps): + self.optimizer.zero_grad() + + # Forward pass with delta injected at the last hidden state + with torch.no_grad(): + # Get hidden states from frozen model (up to final norm) + x = model.tok_emb(input_ids) + if hasattr(model, 'bigram_hash') and model.bigram_hash is not None: + x = x + model.bigram_hash(input_ids) + x = F.rms_norm(x, (x.size(-1),)) + x0 = x + + # Run through blocks + skips = [] + for i in range(model.num_encoder_layers): + x = model.blocks[i](x, x0) + if model.num_decoder_layers > 0: + skips.append(x) + for i in range(model.num_decoder_layers): + if skips: + skip_idx = model.num_encoder_layers - 1 - i + if 0 <= skip_idx < len(skips): + x = x + model.skip_weights[min(i, len(model.skip_weights)-1)] * skips[skip_idx] + x = model.blocks[model.num_encoder_layers + i](x, x0) + + # Add delta (this is the only trainable part) + x_with_delta = x + self.delta + + # Final norm + logit projection (needs grad for delta) + x_normed = F.rms_norm(x_with_delta, (x_with_delta.size(-1),)) + + if model.tie_embeddings: + logits = F.linear(x_normed, model.tok_emb.weight) + else: + logits = model.lm_head(x_normed) + + if hasattr(model, 'logit_softcap') and model.logit_softcap > 0: + logits = model.logit_softcap * torch.tanh(logits / model.logit_softcap) + + logits = logits.float() + + # Loss ONLY on context positions (already-scored in previous windows) + context_logits = logits[0, context_mask] # [num_context, vocab] + context_targets = target_ids[0, context_mask] # [num_context] + + loss = F.cross_entropy(context_logits, context_targets) + loss.backward() + self.optimizer.step() + + return self.delta.detach() +def eval_val( + args: Hyperparameters, + model: nn.Module, + rank: int, + world_size: int, + device: torch.device, + grad_accum_steps: int, + val_tokens: Tensor, + base_bytes_lut: Tensor, + has_leading_space_lut: Tensor, + is_boundary_token_lut: Tensor, +) -> tuple[float, float]: + local_batch_tokens = args.val_batch_size // (world_size * grad_accum_steps) + if local_batch_tokens < args.train_seq_len: + raise ValueError( + "VAL_BATCH_SIZE must provide at least one sequence per rank; " + f"got VAL_BATCH_SIZE={args.val_batch_size}, WORLD_SIZE={world_size}, " + f"GRAD_ACCUM_STEPS={grad_accum_steps}, TRAIN_SEQ_LEN={args.train_seq_len}" + ) + local_batch_seqs = local_batch_tokens // args.train_seq_len + total_seqs = (val_tokens.numel() - 1) // args.train_seq_len + seq_start = (total_seqs * rank) // world_size + seq_end = (total_seqs * (rank + 1)) // world_size + val_loss_sum = torch.zeros((), device=device, dtype=torch.float64) + val_token_count = torch.zeros((), device=device, dtype=torch.float64) + val_byte_count = torch.zeros((), device=device, dtype=torch.float64) + model.eval() + with torch.inference_mode(): + for batch_seq_start in range(seq_start, seq_end, local_batch_seqs): + batch_seq_end = min(batch_seq_start + local_batch_seqs, seq_end) + raw_start = batch_seq_start * args.train_seq_len + raw_end = batch_seq_end * args.train_seq_len + 1 + local = val_tokens[raw_start:raw_end].to(device=device, dtype=torch.int64, non_blocking=True) + x = local[:-1].reshape(-1, args.train_seq_len) + y = local[1:].reshape(-1, args.train_seq_len) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + batch_loss = model(x, y).detach() + batch_token_count = float(y.numel()) + val_loss_sum += batch_loss.to(torch.float64) * batch_token_count + val_token_count += batch_token_count + prev_ids = x.reshape(-1) + tgt_ids = y.reshape(-1) + token_bytes = base_bytes_lut[tgt_ids].to(dtype=torch.int16) + token_bytes += (has_leading_space_lut[tgt_ids] & ~is_boundary_token_lut[prev_ids]).to(dtype=torch.int16) + val_byte_count += token_bytes.to(torch.float64).sum() + if dist.is_available() and dist.is_initialized(): + dist.all_reduce(val_loss_sum, op=dist.ReduceOp.SUM) + dist.all_reduce(val_token_count, op=dist.ReduceOp.SUM) + dist.all_reduce(val_byte_count, op=dist.ReduceOp.SUM) + val_loss = val_loss_sum / val_token_count + bits_per_token = val_loss.item() / math.log(2.0) + tokens_per_byte = val_token_count.item() / val_byte_count.item() + model.train() + return float(val_loss.item()), float(bits_per_token * tokens_per_byte) +def eval_val_sliding( + args: Hyperparameters, + base_model: nn.Module, + rank: int, + world_size: int, + device: torch.device, + val_tokens: Tensor, + base_bytes_lut: Tensor, + has_leading_space_lut: Tensor, + is_boundary_token_lut: Tensor, + stride: int, + batch_seqs: int = 32, +) -> tuple[float, float]: + seq_len = args.train_seq_len + total_tokens = val_tokens.numel() - 1 + window_starts = [ws for ws in range(0, total_tokens, stride) + if min(ws + seq_len, total_tokens) - ws >= 1] + total_windows = len(window_starts) + my_s = (total_windows * rank) // world_size + my_e = (total_windows * (rank + 1)) // world_size + my_windows = window_starts[my_s:my_e] + loss_sum = torch.zeros((), device=device, dtype=torch.float64) + token_count = torch.zeros((), device=device, dtype=torch.float64) + byte_count = torch.zeros((), device=device, dtype=torch.float64) + base_model.eval() + # Initialize Causal SLOT if enabled + causal_slot = None + if args.causal_slot_enabled: + slot_dim = args.causal_slot_dim if args.causal_slot_dim > 0 else args.model_dim + causal_slot = CausalSLOT(slot_dim, device, lr=args.causal_slot_lr, steps=args.causal_slot_steps) + if causal_slot is not None: + # Causal SLOT path: process windows one at a time (delta per window) + t_slot_start = time.perf_counter() + # Freeze model weights for the duration + for p in base_model.parameters(): + p.requires_grad_(False) + for window_idx, ws in enumerate(my_windows): + end = min(ws + seq_len, total_tokens) + wlen = end - ws + chunk = val_tokens[ws:end + 1].to(dtype=torch.int64, device=device) + input_chunk = chunk[:-1].unsqueeze(0) # [1, wlen] + target_chunk = chunk[1:].unsqueeze(0) # [1, wlen] + # Pad to seq_len for uniform tensor shapes + if wlen < seq_len: + pad = torch.zeros(1, seq_len - wlen, dtype=torch.int64, device=device) + input_chunk = torch.cat([input_chunk, pad], dim=1) + target_chunk = torch.cat([target_chunk, pad], dim=1) + s = 0 if ws == 0 else max(wlen - stride, 0) + if causal_slot is not None and window_idx > 0 and s > 0: + # Context mask: positions 0..(s-1) are already-scored context + context_mask = torch.zeros(seq_len, dtype=torch.bool, device=device) + context_mask[:s] = True + # Optimize delta on context positions only + causal_slot.reset() + optimized_delta = causal_slot.optimize( + base_model, input_chunk, context_mask, target_chunk + ) + # Score the new stride positions with delta applied + with torch.no_grad(): + x = base_model.tok_emb(input_chunk) + if hasattr(base_model, 'bigram_hash') and base_model.bigram_hash is not None: + x = x + base_model.bigram_hash(input_chunk) + x = F.rms_norm(x, (x.size(-1),)) + x0 = x + skips_list: list[Tensor] = [] + for i in range(base_model.num_encoder_layers): + x = base_model.blocks[i](x, x0) + if base_model.num_decoder_layers > 0: + skips_list.append(x) + for i in range(base_model.num_decoder_layers): + if skips_list: + skip_idx = base_model.num_encoder_layers - 1 - i + if 0 <= skip_idx < len(skips_list): + x = x + base_model.skip_weights[min(i, len(base_model.skip_weights)-1)] * skips_list[skip_idx] + x = base_model.blocks[base_model.num_encoder_layers + i](x, x0) + x = x + optimized_delta + x_normed = base_model.final_norm(x) + if base_model.tie_embeddings: + logits_proj = F.linear(x_normed, base_model.tok_emb.weight) + else: + logits_proj = base_model.lm_head(x_normed) + logits_with_slot = base_model.logit_softcap * torch.tanh( + logits_proj / base_model.logit_softcap + ) + # Score only stride positions + score_logits = logits_with_slot[0, s:wlen].float() + score_targets = target_chunk[0, s:wlen] + else: + # Standard scoring without SLOT (first window or no context yet) + with torch.no_grad(): + with torch.autocast(device_type="cuda", dtype=torch.bfloat16): + logits_plain = base_model.forward_logits(input_chunk) + score_logits = logits_plain[0, s:wlen].float() + score_targets = target_chunk[0, s:wlen] + scored_nll = F.cross_entropy(score_logits, score_targets, reduction="none").to(torch.float64) + loss_sum += scored_nll.sum() + token_count += float(wlen - s) + tgt = target_chunk[0, s:wlen] + prev = input_chunk[0, s:wlen] + tb = base_bytes_lut[tgt].to(torch.float64) + tb += (has_leading_space_lut[tgt] & ~is_boundary_token_lut[prev]).to(torch.float64) + byte_count += tb.sum() + if rank == 0 and (window_idx % 200 == 0 or window_idx == len(my_windows) - 1): + elapsed = time.perf_counter() - t_slot_start + print( + f" causal_slot window [{window_idx+1}/{len(my_windows)}] " + f"elapsed={elapsed:.1f}s" + ) + # Restore requires_grad + for p in base_model.parameters(): + p.requires_grad_(True) + else: + # Standard batch path (unchanged from original) + compiled_logits = torch.compile(base_model.forward_logits, dynamic=False, fullgraph=True) + with torch.inference_mode(): + for bi in range(0, len(my_windows), batch_seqs): + batch_ws = my_windows[bi:bi + batch_seqs] + bsz = len(batch_ws) + x_batch = torch.zeros(bsz, seq_len, dtype=torch.int64, device=device) + y_batch = torch.zeros(bsz, seq_len, dtype=torch.int64, device=device) + wlens: list[int] = [] + for i, ws in enumerate(batch_ws): + end = min(ws + seq_len, total_tokens) + wlen = end - ws + wlens.append(wlen) + chunk = val_tokens[ws:end + 1].to(dtype=torch.int64, device=device) + x_batch[i, :wlen] = chunk[:-1] + y_batch[i, :wlen] = chunk[1:] + with torch.autocast(device_type="cuda", dtype=torch.bfloat16): + logits = compiled_logits(x_batch) + nll = F.cross_entropy( + logits.reshape(-1, logits.size(-1)).float(), + y_batch.reshape(-1), + reduction="none", + ).reshape(bsz, seq_len) + for i, ws in enumerate(batch_ws): + wlen = wlens[i] + s = 0 if ws == 0 else max(wlen - stride, 0) + scored_nll = nll[i, s:wlen].to(torch.float64) + loss_sum += scored_nll.sum() + token_count += float(wlen - s) + tgt = y_batch[i, s:wlen] + prev = x_batch[i, s:wlen] + tb = base_bytes_lut[tgt].to(torch.float64) + tb += (has_leading_space_lut[tgt] & ~is_boundary_token_lut[prev]).to(torch.float64) + byte_count += tb.sum() + if dist.is_available() and dist.is_initialized(): + dist.all_reduce(loss_sum, op=dist.ReduceOp.SUM) + dist.all_reduce(token_count, op=dist.ReduceOp.SUM) + dist.all_reduce(byte_count, op=dist.ReduceOp.SUM) + val_loss = (loss_sum / token_count).item() + bits_per_token = val_loss / math.log(2.0) + tokens_per_byte = token_count.item() / byte_count.item() + base_model.train() + return val_loss, bits_per_token * tokens_per_byte +class LoRAAdapter(nn.Module): + """Low-Rank Adaptation adapter for a linear layer.""" + def __init__(self, base_linear: nn.Linear, rank: int = 8): + super().__init__() + dim_in = base_linear.in_features + dim_out = base_linear.out_features + self.A = nn.Parameter(torch.empty(dim_in, rank)) + self.B = nn.Parameter(torch.zeros(rank, dim_out)) + nn.init.kaiming_uniform_(self.A, a=math.sqrt(5)) + self.base_linear = base_linear + self._hook_handle = None + + def _lora_hook(self, module: nn.Module, input: tuple[Tensor, ...], output: Tensor) -> Tensor: + """Forward hook that adds LoRA contribution to base linear output.""" + x = input[0] + lora_delta = (x @ self.A.to(x.dtype)) @ self.B.to(x.dtype) + return output + lora_delta + + def register(self) -> None: + """Register the LoRA hook on the base linear layer.""" + if self._hook_handle is None: + self._hook_handle = self.base_linear.register_forward_hook(self._lora_hook) + + def remove(self) -> None: + """Remove the LoRA hook from the base linear layer.""" + if self._hook_handle is not None: + self._hook_handle.remove() + self._hook_handle = None + +def eval_val_sliding_lora_ttt( + args: Hyperparameters, + base_model: nn.Module, + rank: int, + world_size: int, + device: torch.device, + val_tokens: Tensor, + base_bytes_lut: Tensor, + has_leading_space_lut: Tensor, + is_boundary_token_lut: Tensor, + stride: int, + batch_seqs: int = 32, + log0=print, +) -> tuple[float, float]: + """Test-time training with LoRA adapters (only train adapter weights, not base model).""" + seq_len = args.train_seq_len + total_tokens = val_tokens.numel() - 1 + ttt_chunk = args.ttt_chunk_tokens + + window_starts = [ws for ws in range(0, total_tokens, stride) + if min(ws + seq_len, total_tokens) - ws >= stride or ws == 0] + num_chunks = (total_tokens + ttt_chunk - 1) // ttt_chunk + + chunk_windows: list[list[int]] = [[] for _ in range(num_chunks)] + for ws in window_starts: + end = min(ws + seq_len, total_tokens) + wlen = end - ws + s = 0 if ws == 0 else max(wlen - stride, 0) + scored_start = ws + s + ci = min(scored_start // ttt_chunk, num_chunks - 1) + chunk_windows[ci].append(ws) + + log0(f"lora_ttt_sliding:start chunks={num_chunks} chunk_tokens={ttt_chunk} " + f"total_windows={len(window_starts)} stride={stride} " + f"lora_rank={args.ttt_lora_rank} lora_lr={args.ttt_lora_lr} " + f"ttt_epochs={args.ttt_epochs}") + + loss_sum = torch.zeros((), device=device, dtype=torch.float64) + token_count = torch.zeros((), device=device, dtype=torch.float64) + byte_count = torch.zeros((), device=device, dtype=torch.float64) + + # Create LoRA adapters for attention projections and LM head + lora_adapters: list[LoRAAdapter] = [] + + # Add LoRA to attention blocks (c_q, c_v, proj) + if args.recurrence_repeats > 0: + # Recurrent model structure + all_blocks = list(base_model.head_blocks) + [base_model.shared_block] + list(base_model.tail_blocks) + else: + # Standard block structure + all_blocks = list(base_model.blocks) + + for block in all_blocks: + if hasattr(block, 'attn'): + attn = block.attn + # Add LoRA to Q, V, and output projection + for layer_name in ['c_q', 'c_v', 'proj']: + if hasattr(attn, layer_name): + base_linear = getattr(attn, layer_name) + adapter = LoRAAdapter(base_linear, rank=args.ttt_lora_rank).to(device) + lora_adapters.append(adapter) + + # Add LoRA to LM head (or tok_emb if tied) + if base_model.tie_embeddings: + # For tied embeddings, we need to patch tok_emb.weight usage in forward + # Create a pseudo-linear wrapper for the embedding weight + class EmbeddingAsLinear(nn.Module): + def __init__(self, weight: nn.Parameter): + super().__init__() + self.weight = weight + self.in_features = weight.size(1) + self.out_features = weight.size(0) + self.bias = None + + def forward(self, x: Tensor) -> Tensor: + return F.linear(x, self.weight) + + emb_linear = EmbeddingAsLinear(base_model.tok_emb.weight).to(device) + lm_adapter = LoRAAdapter(emb_linear, rank=args.ttt_lora_rank).to(device) + lora_adapters.append(lm_adapter) + elif base_model.lm_head is not None: + lm_adapter = LoRAAdapter(base_model.lm_head, rank=args.ttt_lora_rank).to(device) + lora_adapters.append(lm_adapter) + + # Collect all LoRA parameters + lora_params = [] + for adapter in lora_adapters: + lora_params.extend([adapter.A, adapter.B]) + + log0(f"lora_ttt_sliding:params lora={sum(p.numel() for p in lora_params)} " + f"base_frozen={sum(p.numel() for p in base_model.parameters())} " + f"adapters={len(lora_adapters)}") + + # Freeze all base model parameters + for p in base_model.parameters(): + p.requires_grad_(False) + + # Register all LoRA hooks + for adapter in lora_adapters: + adapter.register() + + # Use AdamW optimizer for LoRA params + optimizer = torch.optim.AdamW( + lora_params, + lr=args.ttt_lora_lr, + betas=(0.9, 0.999), + weight_decay=0.01, + ) + + # Compile forward paths + if not int(os.environ.get("TORCHDYNAMO_DISABLE", "0")): + compiled_logits_lora = torch.compile(base_model.forward_logits, dynamic=False, fullgraph=True) + compiled_forward_lora = torch.compile(base_model.forward, dynamic=False, fullgraph=True) + else: + compiled_logits_lora = base_model.forward_logits + compiled_forward_lora = base_model.forward + + t0 = time.perf_counter() + + for ci in range(num_chunks): + windows = chunk_windows[ci] + if not windows: + continue + + chunk_start = ci * ttt_chunk + chunk_end = min((ci + 1) * ttt_chunk, total_tokens) + + my_s = (len(windows) * rank) // world_size + my_e = (len(windows) * (rank + 1)) // world_size + my_windows = windows[my_s:my_e] + + # Score phase (inference only) + base_model.eval() + with torch.inference_mode(): + for bi in range(0, len(my_windows), batch_seqs): + batch_ws = my_windows[bi:bi + batch_seqs] + bsz = len(batch_ws) + x_batch = torch.zeros(bsz, seq_len, dtype=torch.int64, device=device) + y_batch = torch.zeros(bsz, seq_len, dtype=torch.int64, device=device) + wlens: list[int] = [] + + for i, ws in enumerate(batch_ws): + end = min(ws + seq_len, total_tokens) + wlen = end - ws + wlens.append(wlen) + chunk_tok = val_tokens[ws:end + 1].to(dtype=torch.int64, device=device) + x_batch[i, :wlen] = chunk_tok[:-1] + y_batch[i, :wlen] = chunk_tok[1:] + + with torch.autocast(device_type="cuda", dtype=torch.bfloat16): + logits = compiled_logits_lora(x_batch) + + nll = F.cross_entropy( + logits.reshape(-1, logits.size(-1)).float(), + y_batch.reshape(-1), reduction="none", + ).reshape(bsz, seq_len) + + for i, ws in enumerate(batch_ws): + wlen = wlens[i] + s = 0 if ws == 0 else max(wlen - stride, 0) + scored_nll = nll[i, s:wlen].to(torch.float64) + loss_sum += scored_nll.sum() + token_count += float(wlen - s) + + tgt, prev = y_batch[i, s:wlen], x_batch[i, s:wlen] + tb = base_bytes_lut[tgt].to(torch.float64) + tb += (has_leading_space_lut[tgt] & ~is_boundary_token_lut[prev]).to(torch.float64) + byte_count += tb.sum() + + # Train phase (update LoRA params only, skip last chunk) + is_last_chunk = (ci == num_chunks - 1) + if not is_last_chunk and args.ttt_epochs > 0: + base_model.train() + chunk_seqs = (chunk_end - chunk_start) // seq_len + + if chunk_seqs > 0: + # Cosine LR decay + cos_lr = args.ttt_lora_lr * 0.5 * (1.0 + math.cos(math.pi * ci / max(num_chunks - 1, 1))) + for pg in optimizer.param_groups: + pg['lr'] = cos_lr + + my_seq_s = (chunk_seqs * rank) // world_size + my_seq_e = (chunk_seqs * (rank + 1)) // world_size + my_chunk_seqs = my_seq_e - my_seq_s + + for _ep in range(args.ttt_epochs): + for bs in range(0, my_chunk_seqs, args.ttt_batch_seqs): + be = min(bs + args.ttt_batch_seqs, my_chunk_seqs) + actual_bs = my_seq_s + bs + start_tok = chunk_start + actual_bs * seq_len + end_tok = chunk_start + (my_seq_s + be) * seq_len + 1 + + if end_tok > val_tokens.numel(): + continue + + local = val_tokens[start_tok:end_tok].to(device=device, dtype=torch.int64) + x = local[:-1].reshape(-1, seq_len) + y = local[1:].reshape(-1, seq_len) + + optimizer.zero_grad(set_to_none=True) + + with torch.autocast(device_type="cuda", dtype=torch.bfloat16): + loss = compiled_forward_lora(x, y) + + loss.backward() + + # All-reduce LoRA gradients across GPUs + if world_size > 1: + for p in lora_params: + if p.grad is not None: + dist.all_reduce(p.grad, op=dist.ReduceOp.AVG) + + # Gradient clipping on LoRA params + torch.nn.utils.clip_grad_norm_(lora_params, args.ttt_grad_clip) + optimizer.step() + + if rank == 0 and (ci % 10 == 0 or ci == num_chunks - 1): + elapsed = time.perf_counter() - t0 + rl = loss_sum.item() / max(token_count.item(), 1) + rbpb = rl / math.log(2.0) * (token_count.item() / max(byte_count.item(), 1)) if token_count.item() > 0 else 0.0 + log0(f" lora_ttt_chunk [{ci+1}/{num_chunks}] bpb={rbpb:.6f} time={elapsed:.1f}s") + + if dist.is_available() and dist.is_initialized(): + dist.all_reduce(loss_sum, op=dist.ReduceOp.SUM) + dist.all_reduce(token_count, op=dist.ReduceOp.SUM) + dist.all_reduce(byte_count, op=dist.ReduceOp.SUM) + + val_loss = (loss_sum / token_count).item() + val_bpb = val_loss / math.log(2.0) * (token_count.item() / byte_count.item()) + + # Clean up: remove all LoRA hooks and restore base model + for adapter in lora_adapters: + adapter.remove() + + # Restore requires_grad on base model + for p in base_model.parameters(): + p.requires_grad_(True) + + base_model.eval() + + log0(f"lora_ttt_sliding:done val_loss={val_loss:.6f} val_bpb={val_bpb:.6f} " + f"elapsed={time.perf_counter() - t0:.1f}s") + + return val_loss, val_bpb + +def eval_val_sliding_ttt( + args: Hyperparameters, + base_model: nn.Module, + rank: int, + world_size: int, + device: torch.device, + val_tokens: Tensor, + base_bytes_lut: Tensor, + has_leading_space_lut: Tensor, + is_boundary_token_lut: Tensor, + stride: int, + batch_seqs: int = 32, + log0=print, +) -> tuple[float, float]: + seq_len = args.train_seq_len + total_tokens = val_tokens.numel() - 1 + ttt_chunk = args.ttt_chunk_tokens + window_starts = [ws for ws in range(0, total_tokens, stride) + if min(ws + seq_len, total_tokens) - ws >= stride or ws == 0] + num_chunks = (total_tokens + ttt_chunk - 1) // ttt_chunk + chunk_windows: list[list[int]] = [[] for _ in range(num_chunks)] + for ws in window_starts: + end = min(ws + seq_len, total_tokens) + wlen = end - ws + s = 0 if ws == 0 else max(wlen - stride, 0) + scored_start = ws + s + ci = min(scored_start // ttt_chunk, num_chunks - 1) + chunk_windows[ci].append(ws) + log0(f"ttt_sliding:start chunks={num_chunks} chunk_tokens={ttt_chunk} " + f"total_windows={len(window_starts)} stride={stride} " + f"ttt_lr={args.ttt_lr} ttt_epochs={args.ttt_epochs} " + f"freeze_blocks={args.ttt_freeze_blocks}") + loss_sum = torch.zeros((), device=device, dtype=torch.float64) + token_count = torch.zeros((), device=device, dtype=torch.float64) + byte_count = torch.zeros((), device=device, dtype=torch.float64) + frozen_block_ids = set(range(min(args.ttt_freeze_blocks, len(base_model.blocks)))) + ttt_params = [] + for name, p in base_model.named_parameters(): + freeze = False + for bi in frozen_block_ids: + if f"blocks.{bi}." in name: + freeze = True + break + if freeze: + p.requires_grad_(False) + else: + p.requires_grad_(True) + ttt_params.append(p) + log0(f"ttt_sliding:params unfrozen={sum(p.numel() for p in ttt_params)} " + f"frozen={sum(p.numel() for p in base_model.parameters() if not p.requires_grad)}") + optimizer = torch.optim.SGD(ttt_params, lr=args.ttt_lr, momentum=args.ttt_momentum) + # Compile forward paths for TTT (matches eval_val_sliding behavior) + if not int(os.environ.get("TORCHDYNAMO_DISABLE", "0")): + compiled_logits_ttt = torch.compile(base_model.forward_logits, dynamic=False, fullgraph=True) + compiled_forward_ttt = torch.compile(base_model.forward, dynamic=False, fullgraph=True) + else: + compiled_logits_ttt = base_model.forward_logits + compiled_forward_ttt = base_model.forward + t0 = time.perf_counter() + for ci in range(num_chunks): + windows = chunk_windows[ci] + if not windows: + continue + chunk_start = ci * ttt_chunk + chunk_end = min((ci + 1) * ttt_chunk, total_tokens) + my_s = (len(windows) * rank) // world_size + my_e = (len(windows) * (rank + 1)) // world_size + my_windows = windows[my_s:my_e] + base_model.eval() + with torch.inference_mode(): + for bi in range(0, len(my_windows), batch_seqs): + batch_ws = my_windows[bi:bi + batch_seqs] + bsz = len(batch_ws) + x_batch = torch.zeros(bsz, seq_len, dtype=torch.int64, device=device) + y_batch = torch.zeros(bsz, seq_len, dtype=torch.int64, device=device) + wlens: list[int] = [] + for i, ws in enumerate(batch_ws): + end = min(ws + seq_len, total_tokens) + wlen = end - ws + wlens.append(wlen) + chunk_tok = val_tokens[ws:end + 1].to(dtype=torch.int64, device=device) + x_batch[i, :wlen] = chunk_tok[:-1] + y_batch[i, :wlen] = chunk_tok[1:] + with torch.autocast(device_type="cuda", dtype=torch.bfloat16): + logits = compiled_logits_ttt(x_batch) + nll = F.cross_entropy( + logits.reshape(-1, logits.size(-1)).float(), + y_batch.reshape(-1), reduction="none", + ).reshape(bsz, seq_len) + for i, ws in enumerate(batch_ws): + wlen = wlens[i] + s = 0 if ws == 0 else max(wlen - stride, 0) + scored_nll = nll[i, s:wlen].to(torch.float64) + loss_sum += scored_nll.sum() + token_count += float(wlen - s) + tgt, prev = y_batch[i, s:wlen], x_batch[i, s:wlen] + tb = base_bytes_lut[tgt].to(torch.float64) + tb += (has_leading_space_lut[tgt] & ~is_boundary_token_lut[prev]).to(torch.float64) + byte_count += tb.sum() + is_last_chunk = (ci == num_chunks - 1) + if not is_last_chunk and args.ttt_epochs > 0: + base_model.train() + chunk_seqs = (chunk_end - chunk_start) // seq_len + if chunk_seqs > 0: + cos_lr = args.ttt_lr * 0.5 * (1.0 + math.cos(math.pi * ci / max(num_chunks - 1, 1))) + for pg in optimizer.param_groups: + pg['lr'] = cos_lr + my_seq_s = (chunk_seqs * rank) // world_size + my_seq_e = (chunk_seqs * (rank + 1)) // world_size + my_chunk_seqs = my_seq_e - my_seq_s + for _ep in range(args.ttt_epochs): + for bs in range(0, my_chunk_seqs, args.ttt_batch_seqs): + be = min(bs + args.ttt_batch_seqs, my_chunk_seqs) + actual_bs = my_seq_s + bs + start_tok = chunk_start + actual_bs * seq_len + end_tok = chunk_start + (my_seq_s + be) * seq_len + 1 + if end_tok > val_tokens.numel(): + continue + local = val_tokens[start_tok:end_tok].to(device=device, dtype=torch.int64) + x = local[:-1].reshape(-1, seq_len) + y = local[1:].reshape(-1, seq_len) + optimizer.zero_grad(set_to_none=True) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16): + loss = compiled_forward_ttt(x, y) + loss.backward() + if world_size > 1: + for p in ttt_params: + if p.grad is not None: + dist.all_reduce(p.grad, op=dist.ReduceOp.AVG) + torch.nn.utils.clip_grad_norm_(ttt_params, args.ttt_grad_clip) + optimizer.step() + if rank == 0 and (ci % 10 == 0 or ci == num_chunks - 1): + elapsed = time.perf_counter() - t0 + rl = loss_sum.item() / max(token_count.item(), 1) + rbpb = rl / math.log(2.0) * (token_count.item() / max(byte_count.item(), 1)) if token_count.item() > 0 else 0.0 + log0(f" ttt_chunk [{ci+1}/{num_chunks}] bpb={rbpb:.6f} time={elapsed:.1f}s") + if dist.is_available() and dist.is_initialized(): + dist.all_reduce(loss_sum, op=dist.ReduceOp.SUM) + dist.all_reduce(token_count, op=dist.ReduceOp.SUM) + dist.all_reduce(byte_count, op=dist.ReduceOp.SUM) + val_loss = (loss_sum / token_count).item() + val_bpb = val_loss / math.log(2.0) * (token_count.item() / byte_count.item()) + for p in base_model.parameters(): + p.requires_grad_(True) + base_model.eval() + log0(f"ttt_sliding:done val_loss={val_loss:.6f} val_bpb={val_bpb:.6f} " + f"elapsed={time.perf_counter() - t0:.1f}s") + return val_loss, val_bpb +_INT6_MODE = bool(int(os.environ.get("INT6_QAT", "1"))) +_INT3_MODE = bool(int(os.environ.get("INT3_QUANT", "0"))) # Ternary quantization +_QUANT_OVERRIDE = int(os.environ.get("QUANT_BITS", "0")) # Override: 3,4,5,6,8 +if _QUANT_OVERRIDE > 0: + _QUANT_BITS = _QUANT_OVERRIDE + _QUANT_MAX_VAL = (1 << (_QUANT_BITS - 1)) - 1 # e.g. int5 -> 15, int4 -> 7 +elif _INT3_MODE: + _QUANT_MAX_VAL = 3 # Ternary: -3, -2, -1, 0, 1, 2, 3 + _QUANT_BITS = 3 +elif _INT6_MODE: + _QUANT_MAX_VAL = 31 + _QUANT_BITS = 6 +else: + _QUANT_MAX_VAL = 127 + _QUANT_BITS = 8 +CONTROL_TENSOR_NAME_PATTERNS = tuple( + pattern + for pattern in os.environ.get( + "CONTROL_TENSOR_NAME_PATTERNS", + "attn_scale,attn_scales,mlp_scale,mlp_scales,resid_mix,resid_mixes,q_gain,skip_weight,skip_weights,smear", + ).split(",") + if pattern +) +INT8_KEEP_FLOAT_FP32_NAME_PATTERNS = tuple( + pattern + for pattern in os.environ.get( + "INT8_KEEP_FLOAT_FP32_NAME_PATTERNS", + ",".join(CONTROL_TENSOR_NAME_PATTERNS), + ).split(",") + if pattern +) +INT8_KEEP_FLOAT_MAX_NUMEL = 65_536 +INT8_KEEP_FLOAT_STORE_DTYPE = torch.float16 +INT8_PER_ROW_SCALE_DTYPE = torch.float16 +INT8_CLIP_PERCENTILE = 99.99984 +INT8_CLIP_Q = INT8_CLIP_PERCENTILE / 100.0 +def tensor_nbytes(t: Tensor) -> int: + return int(t.numel()) * int(t.element_size()) +def keep_float_tensor(name: str, t: Tensor, passthrough_orig_dtypes: dict[str, str]) -> Tensor: + if any(pattern in name for pattern in INT8_KEEP_FLOAT_FP32_NAME_PATTERNS): + return t.float().contiguous() + if t.dtype in {torch.float32, torch.bfloat16}: + passthrough_orig_dtypes[name] = str(t.dtype).removeprefix("torch.") + return t.to(dtype=INT8_KEEP_FLOAT_STORE_DTYPE).contiguous() + return t +_GPTQ_LITE_PERCENTILES = [0.9999, 0.99995, 0.99999, 0.999995, 1.0] +def quantize_float_tensor(t: Tensor, gpu_device: torch.device | None = None) -> tuple[Tensor, Tensor]: + # Run GPTQ-lite search on GPU when available (much faster for large 2D weight matrices) + dev = gpu_device if (gpu_device is not None and t.ndim == 2 and t.numel() > 0) else t.device + t32 = t.float().to(dev) + qmax = _QUANT_MAX_VAL + if t32.ndim == 2 and t32.numel() > 0: + abs_vals = t32.abs() + best_q = None + best_scale = None + best_mse = None + for pct in _GPTQ_LITE_PERCENTILES: + if pct >= 1.0: + clip_abs = abs_vals.amax(dim=1) + else: + clip_abs = torch.quantile(abs_vals, pct, dim=1) + clip_abs = clip_abs.clamp_min(1e-8) + scale = (clip_abs / qmax).clamp_min(1.0 / qmax) + clipped = torch.maximum(torch.minimum(t32, clip_abs[:, None]), -clip_abs[:, None]) + q = torch.clamp(torch.round(clipped / scale[:, None]), -qmax, qmax) + recon = q * scale[:, None] + mse = (t32 - recon).square().mean(dim=1) + if best_mse is None: + best_mse = mse + best_q = q + best_scale = scale + else: + improved = mse < best_mse + if improved.any(): + best_mse = torch.where(improved, mse, best_mse) + best_q = torch.where(improved[:, None], q, best_q) + best_scale = torch.where(improved, scale, best_scale) + return best_q.cpu().to(torch.int8).contiguous(), best_scale.cpu().to(dtype=INT8_PER_ROW_SCALE_DTYPE).contiguous() + if t32.ndim == 2: + return torch.empty_like(t32, dtype=torch.int8, device="cpu"), torch.empty((t32.shape[0],), dtype=INT8_PER_ROW_SCALE_DTYPE) + t32 = t32.cpu() + clip_abs = float(torch.quantile(t32.abs().flatten(), INT8_CLIP_Q).item()) if t32.numel() else 0.0 + scale = torch.tensor(clip_abs / qmax if clip_abs > 0 else 1.0, dtype=torch.float32) + q = torch.clamp(torch.round(torch.clamp(t32, -clip_abs, clip_abs) / scale), -qmax, qmax).to(torch.int8).contiguous() + return q, scale +def quantize_state_dict_int8(state_dict: dict[str, Tensor], gpu_device: torch.device | None = None, + rank: int = 0, world_size: int = 1): + """Quantize state dict. When world_size > 1, distributes large tensor quantization across ranks.""" + quantized: dict[str, Tensor] = {} + scales: dict[str, Tensor] = {} + dtypes: dict[str, str] = {} + passthrough: dict[str, Tensor] = {} + passthrough_orig_dtypes: dict[str, str] = {} + qmeta: dict[str, dict[str, object]] = {} + stats = dict.fromkeys( + ("param_count", "num_tensors", "num_float_tensors", "num_nonfloat_tensors", "baseline_tensor_bytes", "int8_payload_bytes"), + 0, + ) + # Separate tensors into passthrough vs needs-quantization + to_quantize: list[tuple[str, Tensor]] = [] + for name, tensor in state_dict.items(): + t = tensor.detach().to("cpu").contiguous() + stats["param_count"] += int(t.numel()) + stats["num_tensors"] += 1 + stats["baseline_tensor_bytes"] += tensor_nbytes(t) + if not t.is_floating_point(): + stats["num_nonfloat_tensors"] += 1 + passthrough[name] = t + stats["int8_payload_bytes"] += tensor_nbytes(t) + continue + if t.numel() <= INT8_KEEP_FLOAT_MAX_NUMEL: + kept = keep_float_tensor(name, t, passthrough_orig_dtypes) + passthrough[name] = kept + stats["int8_payload_bytes"] += tensor_nbytes(kept) + continue + stats["num_float_tensors"] += 1 + to_quantize.append((name, t)) + # Distribute quantization work across ranks (round-robin by tensor index) + for idx, (name, t) in enumerate(to_quantize): + if world_size > 1 and idx % world_size != rank: + continue # Another rank handles this tensor + q, s = quantize_float_tensor(t, gpu_device=gpu_device) + if s.ndim > 0: + qmeta[name] = {"scheme": "per_row", "axis": 0} + quantized[name] = q + scales[name] = s + dtypes[name] = str(t.dtype).removeprefix("torch.") + stats["int8_payload_bytes"] += tensor_nbytes(q) + tensor_nbytes(s) + # Gather results from all ranks + if world_size > 1 and dist.is_available() and dist.is_initialized(): + # Each rank broadcasts its quantized tensors to all others + for idx, (name, t) in enumerate(to_quantize): + owner = idx % world_size + if owner == rank: + # Broadcast shape info then tensor data + q_tensor = quantized[name] + s_tensor = scales[name] + # Send via broadcast + q_gpu = q_tensor.to(gpu_device) if gpu_device is not None else q_tensor.cuda() + s_gpu = s_tensor.to(gpu_device) if gpu_device is not None else s_tensor.cuda() + dist.broadcast(q_gpu, src=owner) + dist.broadcast(s_gpu, src=owner) + quantized[name] = q_gpu.cpu() + scales[name] = s_gpu.cpu() + else: + # Allocate matching tensors and receive + q_shape = (t.shape[0],) + t.shape[1:] if t.ndim == 2 else t.shape + q_gpu = torch.empty(q_shape, dtype=torch.int8, device=gpu_device if gpu_device is not None else "cuda") + if t.ndim == 2 and t.numel() > 0: + s_gpu = torch.empty(t.shape[0], dtype=INT8_PER_ROW_SCALE_DTYPE, device=gpu_device if gpu_device is not None else "cuda") + else: + s_gpu = torch.empty((), dtype=torch.float32, device=gpu_device if gpu_device is not None else "cuda") + dist.broadcast(q_gpu, src=owner) + dist.broadcast(s_gpu, src=owner) + quantized[name] = q_gpu.cpu() + scales[name] = s_gpu.cpu() + dtypes[name] = str(t.dtype).removeprefix("torch.") + if s_gpu.ndim > 0: + qmeta[name] = {"scheme": "per_row", "axis": 0} + stats["int8_payload_bytes"] += tensor_nbytes(quantized[name]) + tensor_nbytes(scales[name]) + obj: dict[str, object] = { + "__quant_format__": f"int{_QUANT_BITS}_clean_per_row_v1", + "quantized": quantized, + "scales": scales, + "dtypes": dtypes, + "passthrough": passthrough, + } + if qmeta: + obj["qmeta"] = qmeta + if passthrough_orig_dtypes: + obj["passthrough_orig_dtypes"] = passthrough_orig_dtypes + return obj, stats +def dequantize_state_dict_int8(obj: dict[str, object]) -> dict[str, Tensor]: + out: dict[str, Tensor] = {} + qmeta = obj.get("qmeta", {}) + passthrough_orig_dtypes = obj.get("passthrough_orig_dtypes", {}) + for name, q in obj["quantized"].items(): + dtype = getattr(torch, obj["dtypes"][name]) + s = obj["scales"][name] + if qmeta.get(name, {}).get("scheme") == "per_row" or s.ndim > 0: + s = s.to(dtype=torch.float32) + out[name] = (q.float() * s.view(q.shape[0], *([1] * (q.ndim - 1)))).to(dtype=dtype).contiguous() + else: + scale = float(s.item()) + out[name] = (q.float() * scale).to(dtype=dtype).contiguous() + for name, t in obj["passthrough"].items(): + out_t = t.detach().to("cpu").contiguous() + orig_dtype = passthrough_orig_dtypes.get(name) + if isinstance(orig_dtype, str): + out_t = out_t.to(dtype=getattr(torch, orig_dtype)).contiguous() + out[name] = out_t + return out +def load_data_shard(file: Path) -> Tensor: + header_bytes = 256 * np.dtype(" None: + self.file_idx = (self.file_idx + 1) % len(self.files) + self.tokens = load_data_shard(self.files[self.file_idx]) + self.pos = 0 + def take(self, n: int) -> Tensor: + chunks: list[Tensor] = [] + remaining = n + while remaining > 0: + avail = self.tokens.numel() - self.pos + if avail <= 0: + self._advance_file() + continue + k = min(remaining, avail) + chunks.append(self.tokens[self.pos : self.pos + k]) + self.pos += k + remaining -= k + return chunks[0] if len(chunks) == 1 else torch.cat(chunks) +class DistributedTokenLoader: + def __init__(self, pattern: str, rank: int, world_size: int, device: torch.device): + self.rank = rank + self.world_size = world_size + self.device = device + self.stream = TokenStream(pattern) + def next_batch(self, global_tokens: int, seq_len: int, grad_accum_steps: int) -> tuple[Tensor, Tensor]: + local_tokens = global_tokens // (self.world_size * grad_accum_steps) + per_rank_span = local_tokens + 1 + chunk = self.stream.take(per_rank_span * self.world_size) + start = self.rank * per_rank_span + local = chunk[start : start + per_rank_span].to(dtype=torch.int64) + x = local[:-1].reshape(-1, seq_len) + y = local[1:].reshape(-1, seq_len) + return x.to(self.device, non_blocking=True), y.to(self.device, non_blocking=True) +class RMSNorm(nn.Module): + def __init__(self, eps: float | None = None): + super().__init__() + self.eps = eps + def forward(self, x: Tensor) -> Tensor: + return F.rms_norm(x, (x.size(-1),), eps=self.eps) +class CastedLinear(nn.Linear): + _int6_qat: bool = False + def forward(self, x: Tensor) -> Tensor: + w = self.weight.to(x.dtype) + if CastedLinear._int6_qat and self.training: + w = int6_ste(w) + bias = self.bias.to(x.dtype) if self.bias is not None else None + return F.linear(x, w, bias) +def restore_low_dim_params_to_fp32(module: nn.Module) -> None: + with torch.no_grad(): + for name, param in module.named_parameters(): + if (param.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS)) and param.dtype != torch.float32: + param.data = param.data.float() +class BigramHash(nn.Module): + def __init__(self, num_buckets: int, dim: int): + super().__init__() + self.num_buckets = num_buckets + self.emb = nn.Embedding(num_buckets, dim) + nn.init.zeros_(self.emb.weight) + def forward(self, input_ids: Tensor) -> Tensor: + prev = F.pad(input_ids[:, :-1], (1, 0), value=0) + bigram_hash = (prev * 31 + input_ids) % self.num_buckets + return self.emb(bigram_hash) +class SmearGate(nn.Module): + def __init__(self, dim: int): + super().__init__() + self.gate = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + def forward(self, x: Tensor) -> Tensor: + g = torch.sigmoid(self.gate.to(dtype=x.dtype))[None, None, :] + positions = torch.arange(1, x.shape[1] + 1, device=x.device, dtype=x.dtype).view(1, -1, 1) + smooth = torch.cumsum(x, dim=1) / positions + return g * x + (1 - g) * smooth +class Rotary(nn.Module): + def __init__(self, dim: int, base: float = 10000.0): + super().__init__() + inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.float32) / dim)) + self.register_buffer("inv_freq", inv_freq, persistent=False) + self._seq_len_cached = 0 + self._cos_cached: Tensor | None = None + self._sin_cached: Tensor | None = None + def forward(self, seq_len: int, device: torch.device, dtype: torch.dtype) -> tuple[Tensor, Tensor]: + if ( + self._cos_cached is None + or self._sin_cached is None + or self._seq_len_cached != seq_len + or self._cos_cached.device != device + ): + t = torch.arange(seq_len, device=device, dtype=self.inv_freq.dtype) + freqs = torch.outer(t, self.inv_freq.to(device)) + self._cos_cached = freqs.cos()[None, None, :, :] + self._sin_cached = freqs.sin()[None, None, :, :] + self._seq_len_cached = seq_len + return self._cos_cached.to(dtype=dtype), self._sin_cached.to(dtype=dtype) +def apply_rotary_emb(x: Tensor, cos: Tensor, sin: Tensor) -> Tensor: + half = x.size(-1) // 2 + x1, x2 = x[..., :half], x[..., half:] + return torch.cat((x1 * cos + x2 * sin, x1 * (-sin) + x2 * cos), dim=-1) +class CausalSelfAttention(nn.Module): + def __init__( + self, + dim: int, + num_heads: int, + num_kv_heads: int, + rope_base: float, + qk_gain_init: float, + rope_partial_dims: int = 0, + use_xsa: bool = False, + ): + super().__init__() + if dim % num_heads != 0: + raise ValueError("model_dim must be divisible by num_heads") + if num_heads % num_kv_heads != 0: + raise ValueError("num_heads must be divisible by num_kv_heads") + self.num_heads = num_heads + self.num_kv_heads = num_kv_heads + self.head_dim = dim // num_heads + if self.head_dim % 2 != 0: + raise ValueError("head_dim must be even for RoPE") + kv_dim = self.num_kv_heads * self.head_dim + self.c_q = CastedLinear(dim, dim, bias=False) + self.c_k = CastedLinear(dim, kv_dim, bias=False) + self.c_v = CastedLinear(dim, kv_dim, bias=False) + self.proj = CastedLinear(dim, dim, bias=False) + self.proj._zero_init = True + self.q_gain = nn.Parameter(torch.full((num_heads,), qk_gain_init, dtype=torch.float32)) + rope_dim = rope_partial_dims if 0 < rope_partial_dims < self.head_dim else self.head_dim + self.rope_dim = rope_dim + self.rotary = Rotary(rope_dim, base=rope_base) + self.use_xsa = use_xsa + def forward(self, x: Tensor) -> Tensor: + bsz, seqlen, dim = x.shape + q = self.c_q(x).reshape(bsz, seqlen, self.num_heads, self.head_dim).transpose(1, 2) + k = self.c_k(x).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim).transpose(1, 2) + v = self.c_v(x).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim).transpose(1, 2) + q = F.rms_norm(q, (q.size(-1),)) + k = F.rms_norm(k, (k.size(-1),)) + cos, sin = self.rotary(seqlen, x.device, q.dtype) + if self.rope_dim < self.head_dim: + q_rope, q_pass = q[..., :self.rope_dim], q[..., self.rope_dim:] + k_rope, k_pass = k[..., :self.rope_dim], k[..., self.rope_dim:] + q_rope = apply_rotary_emb(q_rope, cos, sin) + k_rope = apply_rotary_emb(k_rope, cos, sin) + q = torch.cat([q_rope, q_pass], dim=-1) + k = torch.cat([k_rope, k_pass], dim=-1) + else: + q = apply_rotary_emb(q, cos, sin) + k = apply_rotary_emb(k, cos, sin) + q = q * self.q_gain.to(dtype=q.dtype)[None, :, None, None] + y = F.scaled_dot_product_attention( + q, + k, + v, + attn_mask=None, + is_causal=True, + enable_gqa=(self.num_kv_heads != self.num_heads), + ) + if self.use_xsa: + if self.num_kv_heads != self.num_heads: + repeats = self.num_heads // self.num_kv_heads + v_expanded = v.repeat_interleave(repeats, dim=1) + else: + v_expanded = v + y = y - v_expanded + y = y.transpose(1, 2).contiguous().reshape(bsz, seqlen, dim) + return self.proj(y) +class MLP(nn.Module): + def __init__(self, dim: int, mlp_mult: int): + super().__init__() + hidden = mlp_mult * dim + self.fc = CastedLinear(dim, hidden, bias=False) + self.proj = CastedLinear(hidden, dim, bias=False) + self.proj._zero_init = True + def forward(self, x: Tensor) -> Tensor: + x = F.leaky_relu(self.fc(x), negative_slope=0.5) + return self.proj(x.square()) +class MoEMLP(nn.Module): + def __init__(self, dim: int, mlp_mult: int, num_experts: int): + super().__init__() + self.num_experts = num_experts + expert_hidden = (mlp_mult * dim) // num_experts + self.experts = nn.ModuleList([ + nn.ModuleDict({ + "fc": CastedLinear(dim, expert_hidden, bias=False), + "proj": CastedLinear(expert_hidden, dim, bias=False), + }) + for _ in range(num_experts) + ]) + for e in self.experts: + e["proj"]._zero_init = True + self.router = nn.Linear(dim, num_experts, bias=False) + def forward(self, x: Tensor) -> Tensor: + bsz, seqlen, dim = x.shape + logits = self.router(x.detach()) + weights = torch.softmax(logits, dim=-1) + indices = weights.argmax(dim=-1) + out = torch.zeros_like(x) + for i, expert in enumerate(self.experts): + mask = (indices == i) + if not mask.any(): + continue + tokens = x[mask] + h = F.leaky_relu(expert["fc"](tokens), negative_slope=0.5) + h = expert["proj"](h.square()) + out[mask] = h + return out +class Block(nn.Module): + def __init__( + self, + dim: int, + num_heads: int, + num_kv_heads: int, + mlp_mult: int, + rope_base: float, + qk_gain_init: float, + use_smeargate: bool = False, + rope_partial_dims: int = 0, + use_xsa: bool = False, + ln_scale_factor: float = 1.0, + moe_num_experts: int = 0, + ): + super().__init__() + self.attn_norm = RMSNorm() + self.mlp_norm = RMSNorm() + self.attn = CausalSelfAttention( + dim, num_heads, num_kv_heads, rope_base, qk_gain_init, + rope_partial_dims=rope_partial_dims, + use_xsa=use_xsa, + ) + self.mlp = MoEMLP(dim, mlp_mult, moe_num_experts) if moe_num_experts > 1 else MLP(dim, mlp_mult) + self.attn_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + self.mlp_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + self.resid_mix = nn.Parameter(torch.stack((torch.ones(dim), torch.zeros(dim))).float()) + self.smear = SmearGate(dim) if use_smeargate else None + self.ln_scale_factor = ln_scale_factor + def forward(self, x: Tensor, x0: Tensor) -> Tensor: + mix = self.resid_mix.to(dtype=x.dtype) + x = mix[0][None, None, :] * x + mix[1][None, None, :] * x0 + normed = self.attn_norm(x) + if self.ln_scale_factor != 1.0: + normed = normed * self.ln_scale_factor + attn_out = self.attn(normed) + x = x + self.attn_scale.to(dtype=x.dtype)[None, None, :] * attn_out + mlp_normed = self.mlp_norm(x) + if self.ln_scale_factor != 1.0: + mlp_normed = mlp_normed * self.ln_scale_factor + x = x + self.mlp_scale.to(dtype=x.dtype)[None, None, :] * self.mlp(mlp_normed) + if self.smear is not None: + x = self.smear(x) + return x +class GPT(nn.Module): + def __init__( + self, + vocab_size: int, + num_layers: int, + model_dim: int, + num_heads: int, + num_kv_heads: int, + mlp_mult: int, + tie_embeddings: bool, + tied_embed_init_std: float, + logit_softcap: float, + rope_base: float, + qk_gain_init: float, + bigramhash_buckets: int = 0, + use_smeargate: bool = False, + unet_skips: bool = True, + rope_partial_dims: int = 0, + ln_scale: bool = False, + xsa_layers: int = 0, + recurrence_repeats: int = 0, + recurrence_unique_head: int = 2, + recurrence_unique_tail: int = 2, + moe_num_experts: int = 0, + ): + super().__init__() + if logit_softcap <= 0.0: + raise ValueError(f"logit_softcap must be positive, got {logit_softcap}") + self.tie_embeddings = tie_embeddings + self.tied_embed_init_std = tied_embed_init_std + self.logit_softcap = logit_softcap + self.tok_emb = nn.Embedding(vocab_size, model_dim) + self.bigram_hash = BigramHash(bigramhash_buckets, model_dim) if bigramhash_buckets > 0 else None + self.recurrence_repeats = recurrence_repeats + if recurrence_repeats > 0: + self.num_encoder_layers = 0 + self.num_decoder_layers = 0 + self.num_skip_weights = 0 + self.skip_weights = nn.Parameter(torch.zeros(0, dtype=torch.float32)) + n_head = recurrence_unique_head + n_tail = recurrence_unique_tail + effective_layers = n_head + recurrence_repeats + n_tail + def make_block(layer_idx: int, eff_total: int) -> Block: + xsa_start_idx = eff_total - xsa_layers if xsa_layers > 0 else eff_total + return Block( + model_dim, num_heads, num_kv_heads, mlp_mult, rope_base, qk_gain_init, + use_smeargate=use_smeargate, + rope_partial_dims=rope_partial_dims, + use_xsa=(layer_idx >= xsa_start_idx), + ln_scale_factor=(1.0 / math.sqrt(layer_idx + 1)) if ln_scale else 1.0, + moe_num_experts=moe_num_experts, + ) + self.head_blocks = nn.ModuleList([make_block(i, effective_layers) for i in range(n_head)]) + self.shared_block = make_block(n_head, effective_layers) + self.recurrence_gate = nn.Parameter(torch.full((recurrence_repeats,), 0.5, dtype=torch.float32)) + self.tail_blocks = nn.ModuleList([ + make_block(n_head + recurrence_repeats + i, effective_layers) for i in range(n_tail) + ]) + self.blocks = nn.ModuleList() + else: + if unet_skips: + self.num_encoder_layers = num_layers // 2 + self.num_decoder_layers = num_layers - self.num_encoder_layers + self.num_skip_weights = min(self.num_encoder_layers, self.num_decoder_layers) + else: + self.num_encoder_layers = num_layers + self.num_decoder_layers = 0 + self.num_skip_weights = 0 + self.skip_weights = nn.Parameter(torch.ones(self.num_skip_weights, model_dim, dtype=torch.float32)) + xsa_start = num_layers - xsa_layers if xsa_layers > 0 else num_layers + self.blocks = nn.ModuleList( + [ + Block( + model_dim, num_heads, num_kv_heads, mlp_mult, rope_base, qk_gain_init, + use_smeargate=use_smeargate, + rope_partial_dims=rope_partial_dims, + use_xsa=(i >= xsa_start), + ln_scale_factor=(1.0 / math.sqrt(i + 1)) if ln_scale else 1.0, + moe_num_experts=moe_num_experts, + ) + for i in range(num_layers) + ] + ) + self.head_blocks = nn.ModuleList() + self.shared_block = None + self.recurrence_gate = None + self.tail_blocks = nn.ModuleList() + self.final_norm = RMSNorm() + self.lm_head = None if tie_embeddings else CastedLinear(model_dim, vocab_size, bias=False) + if self.lm_head is not None: + self.lm_head._zero_init = True + self._init_weights() + def _init_weights(self) -> None: + if self.tie_embeddings: + nn.init.normal_(self.tok_emb.weight, mean=0.0, std=self.tied_embed_init_std) + for module in self.modules(): + if isinstance(module, nn.Linear) and getattr(module, "_zero_init", False): + nn.init.zeros_(module.weight) + def forward(self, input_ids: Tensor, target_ids: Tensor) -> Tensor: + x = self.tok_emb(input_ids) + if self.bigram_hash is not None: + x = x + self.bigram_hash(input_ids) + x = F.rms_norm(x, (x.size(-1),)) + x0 = x + if self.recurrence_repeats > 0: + for block in self.head_blocks: + x = block(x, x0) + for r in range(self.recurrence_repeats): + gate = torch.sigmoid(self.recurrence_gate[r]).to(dtype=x.dtype) + residual = x + x = self.shared_block(x, x0) + x = gate * x + (1.0 - gate) * residual + for block in self.tail_blocks: + x = block(x, x0) + else: + skips: list[Tensor] = [] + for i in range(self.num_encoder_layers): + x = self.blocks[i](x, x0) + skips.append(x) + for i in range(self.num_decoder_layers): + if skips: + x = x + self.skip_weights[i].to(dtype=x.dtype)[None, None, :] * skips.pop() + x = self.blocks[self.num_encoder_layers + i](x, x0) + x = self.final_norm(x).reshape(-1, x.size(-1)) + targets = target_ids.reshape(-1) + if self.tie_embeddings: + logits_proj = F.linear(x, self.tok_emb.weight) + else: + if self.lm_head is None: + raise RuntimeError("lm_head is required when tie_embeddings=False") + logits_proj = self.lm_head(x) + logits = self.logit_softcap * torch.tanh(logits_proj / self.logit_softcap) + if getattr(self, '_return_logits', False): + return logits.float().reshape(input_ids.shape[0], input_ids.shape[1], -1) + return F.cross_entropy(logits.float(), targets, reduction="mean") + def forward_logits(self, input_ids: Tensor) -> Tensor: + x = self.tok_emb(input_ids) + if self.bigram_hash is not None: + x = x + self.bigram_hash(input_ids) + x = F.rms_norm(x, (x.size(-1),)) + x0 = x + if self.recurrence_repeats > 0: + for block in self.head_blocks: + x = block(x, x0) + for r in range(self.recurrence_repeats): + gate = torch.sigmoid(self.recurrence_gate[r]).to(dtype=x.dtype) + residual = x + x = self.shared_block(x, x0) + x = gate * x + (1.0 - gate) * residual + for block in self.tail_blocks: + x = block(x, x0) + else: + skips: list[Tensor] = [] + for i in range(self.num_encoder_layers): + x = self.blocks[i](x, x0) + skips.append(x) + for i in range(self.num_decoder_layers): + if skips: + x = x + self.skip_weights[i].to(dtype=x.dtype)[None, None, :] * skips.pop() + x = self.blocks[self.num_encoder_layers + i](x, x0) + x = self.final_norm(x) + if self.tie_embeddings: + logits_proj = F.linear(x, self.tok_emb.weight) + else: + if self.lm_head is None: + raise RuntimeError("lm_head is required when tie_embeddings=False") + logits_proj = self.lm_head(x) + return self.logit_softcap * torch.tanh(logits_proj / self.logit_softcap) +def main() -> None: + global zeropower_via_newtonschulz5 + code = Path(__file__).read_text(encoding="utf-8") + args = Hyperparameters() + if not int(os.environ.get("TORCHDYNAMO_DISABLE", "0")): + zeropower_via_newtonschulz5 = torch.compile(zeropower_via_newtonschulz5) + distributed = "RANK" in os.environ and "WORLD_SIZE" in os.environ + rank = int(os.environ.get("RANK", "0")) + world_size = int(os.environ.get("WORLD_SIZE", "1")) + local_rank = int(os.environ.get("LOCAL_RANK", "0")) + if world_size <= 0: + raise ValueError(f"WORLD_SIZE must be positive, got {world_size}") + if 8 % world_size != 0: + raise ValueError(f"WORLD_SIZE={world_size} must divide 8 so grad_accum_steps stays integral") + grad_accum_steps = 8 // world_size + grad_scale = 1.0 / grad_accum_steps + if not torch.cuda.is_available(): + raise RuntimeError("CUDA is required") + device = torch.device("cuda", local_rank) + torch.cuda.set_device(device) + if distributed: + dist.init_process_group(backend="nccl", device_id=device) + dist.barrier() + master_process = rank == 0 + torch.backends.cuda.matmul.allow_tf32 = True + torch.backends.cudnn.allow_tf32 = True + from torch.backends.cuda import enable_cudnn_sdp, enable_flash_sdp, enable_math_sdp, enable_mem_efficient_sdp + _use_flash = sys.platform != "win32" + enable_cudnn_sdp(False) + enable_flash_sdp(_use_flash) + enable_mem_efficient_sdp(False) + enable_math_sdp(not _use_flash) + # Create per-run output directory so artifacts/logs never collide + run_dir = Path(os.environ.get("RUN_OUTPUT_DIR", f"runs/{args.run_id}")) + if master_process: + run_dir.mkdir(parents=True, exist_ok=True) + # Backward-compat: also keep logs/ symlink/copy + os.makedirs("logs", exist_ok=True) + logfile = None + if master_process: + logfile = str(run_dir / "log.txt") + # Also write to legacy logs/ path for compatibility + legacy_logfile = f"logs/{args.run_id}.txt" + if not Path(legacy_logfile).exists(): + try: + os.symlink(os.path.abspath(logfile), legacy_logfile) + except OSError: + pass # Windows or cross-device; will write to both + print(logfile) + def log0(msg: str, console: bool = True) -> None: + if not master_process: + return + if console: + print(msg) + if logfile is not None: + with open(logfile, "a", encoding="utf-8") as f: + print(msg, file=f) + log0(code, console=False) + log0("=" * 100, console=False) + log0(f"Running Python {sys.version}", console=False) + log0(f"Running PyTorch {torch.__version__}", console=False) + log0( + subprocess.run(["nvidia-smi"], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True, check=False).stdout, + console=False, + ) + log0("=" * 100, console=False) + random.seed(args.seed) + np.random.seed(args.seed) + torch.manual_seed(args.seed) + torch.cuda.manual_seed_all(args.seed) + if not args.tokenizer_path.endswith(".model"): + raise ValueError(f"Script only setup for SentencePiece .model file: {args.tokenizer_path}") + sp = spm.SentencePieceProcessor(model_file=args.tokenizer_path) + if int(sp.vocab_size()) != args.vocab_size: + raise ValueError( + f"VOCAB_SIZE={args.vocab_size} does not match tokenizer vocab_size={int(sp.vocab_size())}" + ) + dataset_dir = Path(args.data_path).resolve() + actual_train_files = len(list(dataset_dir.glob("fineweb_train_*.bin"))) + val_tokens = load_validation_tokens(args.val_files, args.train_seq_len) + base_bytes_lut, has_leading_space_lut, is_boundary_token_lut = build_sentencepiece_luts( + sp, args.vocab_size, device + ) + log0(f"val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path={args.tokenizer_path}") + log0(f"train_loader:dataset:{dataset_dir.name} train_shards:{actual_train_files}") + log0(f"val_loader:shards pattern={args.val_files} tokens:{val_tokens.numel() - 1}") + CastedLinear._int6_qat = args.int6_qat + base_model = GPT( + vocab_size=args.vocab_size, + num_layers=args.num_layers, + model_dim=args.model_dim, + num_heads=args.num_heads, + num_kv_heads=args.num_kv_heads, + mlp_mult=args.mlp_mult, + tie_embeddings=args.tie_embeddings, + tied_embed_init_std=args.tied_embed_init_std, + logit_softcap=args.logit_softcap, + rope_base=args.rope_base, + qk_gain_init=args.qk_gain_init, + bigramhash_buckets=args.bigramhash_buckets, + use_smeargate=args.smeargate, + unet_skips=args.unet_skips, + rope_partial_dims=args.rope_partial_dims, + ln_scale=args.ln_scale, + xsa_layers=args.xsa_layers, + recurrence_repeats=args.recurrence_repeats, + recurrence_unique_head=args.recurrence_unique_head, + recurrence_unique_tail=args.recurrence_unique_tail, + moe_num_experts=args.moe_num_experts, + ).to(device).bfloat16() + for module in base_model.modules(): + if isinstance(module, CastedLinear): + module.float() + restore_low_dim_params_to_fp32(base_model) + if int(os.environ.get("TORCHDYNAMO_DISABLE", "0")): + compiled_model = base_model + else: + compiled_model = torch.compile(base_model, dynamic=False, fullgraph=True) + model: nn.Module = DDP(compiled_model, device_ids=[local_rank], broadcast_buffers=False) if distributed else compiled_model + if args.recurrence_repeats > 0: + block_named_params = ( + list(base_model.head_blocks.named_parameters()) + + list(base_model.shared_block.named_parameters()) + + list(base_model.tail_blocks.named_parameters()) + ) + if base_model.recurrence_gate is not None: + block_named_params.append(("recurrence_gate", base_model.recurrence_gate)) + else: + block_named_params = list(base_model.blocks.named_parameters()) + matrix_params = [ + p + for name, p in block_named_params + if p.ndim == 2 and not any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) + ] + scalar_params = [ + p + for name, p in block_named_params + if p.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) + ] + if base_model.skip_weights.numel() > 0: + scalar_params.append(base_model.skip_weights) + token_lr = args.tied_embed_lr if args.tie_embeddings else args.embed_lr + embed_params: list[nn.Parameter] = [base_model.tok_emb.weight] + if base_model.bigram_hash is not None: + embed_params.append(base_model.bigram_hash.emb.weight) + optimizer_tok = torch.optim.Adam( + [{"params": embed_params, "lr": token_lr, "base_lr": token_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + fused=True, + ) + optimizer_muon = Muon( + matrix_params, + lr=args.matrix_lr, + momentum=args.muon_momentum, + backend_steps=args.muon_backend_steps, + ) + for group in optimizer_muon.param_groups: + group["base_lr"] = args.matrix_lr + optimizer_scalar = torch.optim.Adam( + [{"params": scalar_params, "lr": args.scalar_lr, "base_lr": args.scalar_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + fused=True, + ) + optimizers: list[torch.optim.Optimizer] = [optimizer_tok, optimizer_muon, optimizer_scalar] + if base_model.lm_head is not None: + optimizer_head = torch.optim.Adam( + [{"params": [base_model.lm_head.weight], "lr": args.head_lr, "base_lr": args.head_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + fused=True, + ) + optimizers.insert(1, optimizer_head) + n_params = sum(p.numel() for p in base_model.parameters()) + log0(f"model_params:{n_params}") + log0(f"world_size:{world_size} grad_accum_steps:{grad_accum_steps}") + log0(f"sdp_backends:cudnn=False flash={_use_flash} mem_efficient=False math={not _use_flash}") + log0(f"attention_mode:gqa num_heads:{args.num_heads} num_kv_heads:{args.num_kv_heads}") + log0( + f"tie_embeddings:{args.tie_embeddings} embed_lr:{token_lr} " + f"head_lr:{args.head_lr if base_model.lm_head is not None else 0.0} " + f"matrix_lr:{args.matrix_lr} scalar_lr:{args.scalar_lr}" + ) + log0( + f"train_batch_tokens:{args.train_batch_tokens} train_seq_len:{args.train_seq_len} " + f"iterations:{args.iterations} warmup_steps:{args.warmup_steps} " + f"max_wallclock_seconds:{args.max_wallclock_seconds:.3f}" + ) + log0(f"seed:{args.seed}") + log0( + f"meta_baseline: int6_qat:{args.int6_qat} bigramhash:{args.bigramhash_buckets} " + f"smeargate:{args.smeargate} ema_decay:{args.ema_decay} unet_skips:{args.unet_skips} " + f"warmdown_iters:{args.warmdown_iters} compression:{'zstd-22' if _HAS_ZSTD else 'zlib-9'}" + ) + log0( + f"strong_gains: leaky_relu_sq:True rope_partial_dims:{args.rope_partial_dims} " + f"ln_scale:{args.ln_scale} xsa_layers:{args.xsa_layers}" + ) + if args.recurrence_repeats > 0: + eff = args.recurrence_unique_head + args.recurrence_repeats + args.recurrence_unique_tail + log0( + f"depth_recurrence: head={args.recurrence_unique_head} repeats={args.recurrence_repeats} " + f"tail={args.recurrence_unique_tail} effective_layers={eff}" + ) + train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) + def zero_grad_all() -> None: + for opt in optimizers: + opt.zero_grad(set_to_none=True) + max_wallclock_ms = 1000.0 * args.max_wallclock_seconds if args.max_wallclock_seconds > 0 else None + def lr_mul(step: int, elapsed_ms: float) -> float: + if args.warmdown_iters <= 0: + return 1.0 + if max_wallclock_ms is None: + warmdown_start = max(args.iterations - args.warmdown_iters, 0) + return max((args.iterations - step) / max(args.warmdown_iters, 1), 0.0) if warmdown_start <= step < args.iterations else 1.0 + step_ms = elapsed_ms / max(step, 1) + warmdown_ms = args.warmdown_iters * step_ms + remaining_ms = max(max_wallclock_ms - elapsed_ms, 0.0) + return remaining_ms / max(warmdown_ms, 1e-9) if remaining_ms <= warmdown_ms else 1.0 + if args.warmup_steps > 0: + initial_model_state = {name: tensor.detach().cpu().clone() for name, tensor in base_model.state_dict().items()} + initial_optimizer_states = [copy.deepcopy(opt.state_dict()) for opt in optimizers] + model.train() + for warmup_step in range(args.warmup_steps): + zero_grad_all() + for micro_step in range(grad_accum_steps): + if distributed: + model.require_backward_grad_sync = micro_step == grad_accum_steps - 1 + x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + warmup_loss = model(x, y) + (warmup_loss * grad_scale).backward() + for opt in optimizers: + opt.step() + zero_grad_all() + if args.warmup_steps <= 20 or (warmup_step + 1) % 10 == 0 or warmup_step + 1 == args.warmup_steps: + log0(f"warmup_step:{warmup_step + 1}/{args.warmup_steps}") + base_model.load_state_dict(initial_model_state, strict=True) + for opt, state in zip(optimizers, initial_optimizer_states, strict=True): + opt.load_state_dict(state) + zero_grad_all() + if distributed: + model.require_backward_grad_sync = True + train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) + ema_state: dict[str, Tensor] | None = None + if args.ema_decay > 0: + ema_state = {name: param.data.clone() for name, param in base_model.named_parameters()} + log0(f"ema:enabled decay={args.ema_decay}") + training_time_ms = 0.0 + stop_after_step: int | None = None + torch.cuda.synchronize() + t0 = time.perf_counter() + step = 0 + while True: + last_step = step == args.iterations or (stop_after_step is not None and step >= stop_after_step) + should_validate = last_step or (args.val_loss_every > 0 and step % args.val_loss_every == 0) + if should_validate: + torch.cuda.synchronize() + training_time_ms += 1000.0 * (time.perf_counter() - t0) + ema_backup: dict[str, Tensor] | None = None + if ema_state is not None: + ema_backup = {} + with torch.no_grad(): + for name, p in base_model.named_parameters(): + ema_backup[name] = p.data.clone() + p.data.copy_(ema_state[name]) + val_loss, val_bpb = eval_val( + args, + model, + rank, + world_size, + device, + grad_accum_steps, + val_tokens, + base_bytes_lut, + has_leading_space_lut, + is_boundary_token_lut, + ) + if ema_backup is not None: + with torch.no_grad(): + for name, p in base_model.named_parameters(): + p.data.copy_(ema_backup[name]) + del ema_backup + log0( + f"step:{step}/{args.iterations} val_loss:{val_loss:.4f} val_bpb:{val_bpb:.4f} " + f"train_time:{training_time_ms:.0f}ms step_avg:{training_time_ms / max(step, 1):.2f}ms" + ) + torch.cuda.synchronize() + t0 = time.perf_counter() + if last_step: + if stop_after_step is not None and step < args.iterations: + log0( + f"stopping_early: wallclock_cap train_time:{training_time_ms:.0f}ms " + f"step:{step}/{args.iterations}" + ) + break + elapsed_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) + scale = lr_mul(step, elapsed_ms) + zero_grad_all() + train_loss = torch.zeros((), device=device) + for micro_step in range(grad_accum_steps): + if distributed: + model.require_backward_grad_sync = micro_step == grad_accum_steps - 1 + x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + loss = model(x, y) + train_loss += loss.detach() + (loss * grad_scale).backward() + train_loss /= grad_accum_steps + frac = min(step / args.muon_momentum_warmup_steps, 1.0) if args.muon_momentum_warmup_steps > 0 else 1.0 + muon_momentum = (1 - frac) * args.muon_momentum_warmup_start + frac * args.muon_momentum + for group in optimizer_muon.param_groups: + group["momentum"] = muon_momentum + for opt in optimizers: + for group in opt.param_groups: + group["lr"] = group["base_lr"] * scale + if args.grad_clip_norm > 0: + torch.nn.utils.clip_grad_norm_(base_model.parameters(), args.grad_clip_norm) + for opt in optimizers: + opt.step() + zero_grad_all() + if ema_state is not None: + with torch.no_grad(): + decay = args.ema_decay + for name, p in base_model.named_parameters(): + ema_state[name].mul_(decay).add_(p.data, alpha=1.0 - decay) + step += 1 + if args.checkpoint_every > 0 and step % args.checkpoint_every == 0 and master_process: + ckpt_dir = run_dir / "checkpoints" + ckpt_dir.mkdir(exist_ok=True) + ckpt_path = ckpt_dir / f"ckpt_step{step}.pt" + ckpt_data = {"step": step, "model": base_model.state_dict()} + if ema_state is not None: + ckpt_data["ema"] = {k: v.cpu() for k, v in ema_state.items()} + torch.save(ckpt_data, ckpt_path) + ckpts = sorted(ckpt_dir.glob("ckpt_step*.pt"), key=lambda p: p.stat().st_mtime) + for old in ckpts[:-2]: + old.unlink(missing_ok=True) + log0(f"checkpoint:saved {ckpt_path} (kept {min(len(ckpts), 2)} latest)") + approx_training_time_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) + should_log_train = ( + args.train_log_every > 0 + and (step <= 10 or step % args.train_log_every == 0 or stop_after_step is not None) + ) + if should_log_train: + log0( + f"step:{step}/{args.iterations} train_loss:{train_loss.item():.4f} " + f"train_time:{approx_training_time_ms:.0f}ms step_avg:{approx_training_time_ms / step:.2f}ms" + ) + reached_cap = max_wallclock_ms is not None and approx_training_time_ms >= max_wallclock_ms + if distributed and max_wallclock_ms is not None: + reached_cap_tensor = torch.tensor(int(reached_cap), device=device) + dist.all_reduce(reached_cap_tensor, op=dist.ReduceOp.MAX) + reached_cap = bool(reached_cap_tensor.item()) + if stop_after_step is None and reached_cap: + stop_after_step = step + log0( + f"peak memory allocated: {torch.cuda.max_memory_allocated() // 1024 // 1024} MiB " + f"reserved: {torch.cuda.max_memory_reserved() // 1024 // 1024} MiB" + ) + if ema_state is not None: + with torch.no_grad(): + for name, p in base_model.named_parameters(): + p.data.copy_(ema_state[name]) + log0("ema:applied EMA weights for export") + fp32_path = str(run_dir / "final_model.pt") + if master_process: + torch.save(base_model.state_dict(), fp32_path) + model_bytes = os.path.getsize(fp32_path) + code_bytes = len(code.encode("utf-8")) + log0(f"Serialized model: {model_bytes} bytes → {fp32_path}") + log0(f"Code size: {code_bytes} bytes") + log0(f"Total submission size: {model_bytes + code_bytes} bytes") + quant_label = f"int{_QUANT_BITS}" + compress_label = "zstd" if _HAS_ZSTD else "zlib" + log0(f"quantization:start {quant_label} distributing across {world_size} GPUs") + torch.cuda.synchronize() + t_quant = time.perf_counter() + quant_obj, quant_stats = quantize_state_dict_int8( + base_model.state_dict(), gpu_device=device, rank=rank, world_size=world_size, + ) + torch.cuda.synchronize() + log0(f"quantization:done {quant_label} time={1000.0 * (time.perf_counter() - t_quant):.0f}ms") + quant_buf = io.BytesIO() + torch.save(quant_obj, quant_buf) + quant_raw = quant_buf.getvalue() + if _HAS_ZSTD: + compressor = zstd.ZstdCompressor(level=22, threads=-1) # -1 = use all CPU cores + quant_blob = compressor.compress(quant_raw) + else: + quant_blob = zlib.compress(quant_raw, level=9) + quant_raw_bytes = len(quant_raw) + artifact_name = str(run_dir / f"final_model.{quant_label}.ptz") + if master_process: + with open(artifact_name, "wb") as f: + f.write(quant_blob) + quant_file_bytes = os.path.getsize(artifact_name) + code_bytes = len(code.encode("utf-8")) + ratio = quant_stats["baseline_tensor_bytes"] / max(quant_stats["int8_payload_bytes"], 1) + log0( + f"Serialized model {quant_label}+{compress_label}: {quant_file_bytes} bytes " + f"(payload:{quant_stats['int8_payload_bytes']} raw_torch:{quant_raw_bytes} payload_ratio:{ratio:.2f}x)" + ) + log0(f"Total submission size {quant_label}+{compress_label}: {quant_file_bytes + code_bytes} bytes") + log0(f"artifact_path:{artifact_name}") + if distributed: + dist.barrier() + with open(artifact_name, "rb") as f: + quant_blob_disk = f.read() + if _HAS_ZSTD: + decompressor = zstd.ZstdDecompressor() + quant_state = torch.load(io.BytesIO(decompressor.decompress(quant_blob_disk)), map_location="cpu") + else: + quant_state = torch.load(io.BytesIO(zlib.decompress(quant_blob_disk)), map_location="cpu") + base_model.load_state_dict(dequantize_state_dict_int8(quant_state), strict=True) + torch.cuda.synchronize() + t_qeval = time.perf_counter() + q_val_loss, q_val_bpb = eval_val( + args, + model, + rank, + world_size, + device, + grad_accum_steps, + val_tokens, + base_bytes_lut, + has_leading_space_lut, + is_boundary_token_lut, + ) + torch.cuda.synchronize() + log0( + f"final_{quant_label}_{compress_label}_roundtrip val_loss:{q_val_loss:.4f} val_bpb:{q_val_bpb:.4f} " + f"eval_time:{1000.0 * (time.perf_counter() - t_qeval):.0f}ms" + ) + log0(f"final_{quant_label}_{compress_label}_roundtrip_exact val_loss:{q_val_loss:.8f} val_bpb:{q_val_bpb:.8f}") + if args.eval_stride > 0 and args.eval_stride < args.train_seq_len: + torch.cuda.synchronize() + t_slide = time.perf_counter() + sw_val_loss, sw_val_bpb = eval_val_sliding( + args, base_model, rank, world_size, device, + val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, + stride=args.eval_stride, + ) + torch.cuda.synchronize() + log0( + f"final_sliding_window val_loss:{sw_val_loss:.4f} val_bpb:{sw_val_bpb:.4f} " + f"stride:{args.eval_stride} eval_time:{1000.0 * (time.perf_counter() - t_slide):.0f}ms" + ) + log0(f"final_sliding_window_exact val_loss:{sw_val_loss:.8f} val_bpb:{sw_val_bpb:.8f}") + if args.causal_slot_enabled: + torch.cuda.synchronize() + t_cslot = time.perf_counter() + slot_stride = args.eval_stride if args.eval_stride > 0 else 64 + log0( + f"causal_slot:start stride={slot_stride} steps={args.causal_slot_steps} " + f"lr={args.causal_slot_lr} dim={args.causal_slot_dim if args.causal_slot_dim > 0 else args.model_dim}" + ) + cslot_val_loss, cslot_val_bpb = eval_val_sliding( + args, base_model, rank, world_size, device, + val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, + stride=slot_stride, + ) + torch.cuda.synchronize() + log0( + f"causal_slot val_loss:{cslot_val_loss:.4f} val_bpb:{cslot_val_bpb:.4f} " + f"stride:{slot_stride} eval_time:{1000.0 * (time.perf_counter() - t_cslot):.0f}ms" + ) + log0(f"causal_slot_exact val_loss:{cslot_val_loss:.8f} val_bpb:{cslot_val_bpb:.8f}") + if args.ttt_enabled: + torch.cuda.synchronize() + t_ttt = time.perf_counter() + + # Choose LoRA TTT or full TTT based on TTT_LORA_ENABLED flag + if args.ttt_lora_enabled: + ttt_loss, ttt_bpb = eval_val_sliding_lora_ttt( + args, base_model, rank, world_size, device, + val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, + stride=args.eval_stride if args.eval_stride > 0 else 64, + log0=log0, + ) + ttt_label = "legal_lora_ttt" + else: + ttt_loss, ttt_bpb = eval_val_sliding_ttt( + args, base_model, rank, world_size, device, + val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, + stride=args.eval_stride if args.eval_stride > 0 else 64, + log0=log0, + ) + ttt_label = "legal_ttt" + + torch.cuda.synchronize() + log0( + f"{ttt_label} val_loss:{ttt_loss:.4f} val_bpb:{ttt_bpb:.4f} " + f"eval_time:{1000.0 * (time.perf_counter() - t_ttt):.0f}ms" + ) + log0(f"{ttt_label}_exact val_loss:{ttt_loss:.8f} val_bpb:{ttt_bpb:.8f}") + if int(os.environ.get("NGRAM_EVAL", "0")): + ngram_max_order = int(os.environ.get("NGRAM_MAX_ORDER", "9")) + log0(f"ngram_eval:starting max_order={ngram_max_order}") + torch.cuda.synchronize() + t_ngram = time.perf_counter() + try: + from ngram_eval import eval_val_ngram + ng_val_loss, ng_val_bpb = eval_val_ngram( + args, base_model, model, rank, world_size, device, grad_accum_steps, + val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, + max_order=ngram_max_order, log_fn=log0, + ) + torch.cuda.synchronize() + log0( + f"ngram_eval_result val_loss:{ng_val_loss:.4f} val_bpb:{ng_val_bpb:.4f} " + f"eval_time:{1000.0 * (time.perf_counter() - t_ngram):.0f}ms" + ) + log0(f"ngram_eval_result_exact val_loss:{ng_val_loss:.8f} val_bpb:{ng_val_bpb:.8f}") + except Exception as e: + log0(f"ngram_eval:FAILED {e}") + if distributed: + dist.destroy_process_group() +if __name__ == "__main__": + main()