From 5d3c23069e8ee57a325eb53af0cbf16a931fcd86 Mon Sep 17 00:00:00 2001 From: Billy Endson Date: Sat, 21 Mar 2026 02:34:49 +0700 Subject: [PATCH 01/32] Add Mixture of Softmax (MoS) with low-rank option for softmax bottleneck Fix critical bugs: MoS params now included in optimizer groups, use NLL loss (not cross_entropy) since MoS returns log-probs, skip logit softcap for MoS path, re-normalize after LoRA correction. Low-rank factorization (MOS_RANK=64) keeps artifact under 16MB budget. Enable via: USE_MOS=1 MOS_K=2 MOS_RANK=64 Co-Authored-By: Claude Opus 4.6 --- run_pilot.sh | 42 +++++++++++++++++++ test_mos.py | 98 +++++++++++++++++++++++++++++++++++++++++++++ train_gpt.py | 111 +++++++++++++++++++++++++++++++++++++++++++++++++-- 3 files changed, 248 insertions(+), 3 deletions(-) create mode 100755 run_pilot.sh create mode 100644 test_mos.py diff --git a/run_pilot.sh b/run_pilot.sh new file mode 100755 index 0000000000..cb12b317b3 --- /dev/null +++ b/run_pilot.sh @@ -0,0 +1,42 @@ +#!/bin/bash +# Quick start script for 1x H100 MoS pilot +# Run from the parameter-golf repo root directory + +set -e + +echo "=== Parameter Golf MoS Pilot ===" +echo "Date: $(date)" +echo "GPU: 1x H100 SXM" +echo "" + +# Configuration (all via env vars — train_gpt.py has no argparse) +ITERATIONS=2000 +SEED=42 +MOS_K=2 +MOS_RANK=64 # Low-rank to fit in 16MB budget (~100KB vs ~500KB full-rank) + +echo "Configuration:" +echo " Iterations: $ITERATIONS" +echo " Seed: $SEED" +echo " MoS K: $MOS_K" +echo " MoS Rank: $MOS_RANK (0=full-rank)" +echo "" + +# Baseline run +echo "=== Running Baseline ===" +ITERATIONS=$ITERATIONS SEED=$SEED MAX_WALLCLOCK_SECONDS=99999 \ + python3 train_gpt.py 2>&1 | tee baseline_log.txt + +echo "" +echo "=== Running MoS K=$MOS_K rank=$MOS_RANK ===" +ITERATIONS=$ITERATIONS SEED=$SEED MAX_WALLCLOCK_SECONDS=99999 \ + USE_MOS=1 MOS_K=$MOS_K MOS_RANK=$MOS_RANK \ + python3 train_gpt.py 2>&1 | tee mos_k${MOS_K}_r${MOS_RANK}_log.txt + +echo "" +echo "=== Done ===" +echo "Compare results:" +echo " grep 'val_bpb' baseline_log.txt" +echo " grep 'val_bpb' mos_k${MOS_K}_r${MOS_RANK}_log.txt" +echo " grep 'bytes' baseline_log.txt" +echo " grep 'bytes' mos_k${MOS_K}_r${MOS_RANK}_log.txt" \ No newline at end of file diff --git a/test_mos.py b/test_mos.py new file mode 100644 index 0000000000..6732b343b0 --- /dev/null +++ b/test_mos.py @@ -0,0 +1,98 @@ +#!/usr/bin/env python3 +"""Quick test to verify MoS implementation works correctly.""" + +import torch +import torch.nn.functional as F + +# Mock CastedLinear for testing +class CastedLinear(torch.nn.Linear): + def forward(self, x): + bias = self.bias.to(x.dtype) if self.bias is not None else None + return F.linear(x, self.weight.to(x.dtype), bias) + + +class MixtureOfSoftmax(torch.nn.Module): + """Mixture of Softmax output layer for breaking the softmax bottleneck.""" + + def __init__(self, model_dim: int, vocab_size: int, n_mixtures: int = 2, rank: int = 0): + super().__init__() + self.n_mixtures = n_mixtures + self.model_dim = model_dim + self.vocab_size = vocab_size + self.rank = rank + + if rank > 0: + self.proj_down = CastedLinear(model_dim, rank, bias=False) + self.proj_up = CastedLinear(rank, n_mixtures * model_dim, bias=False) + torch.nn.init.normal_(self.proj_down.weight, mean=0.0, std=0.02) + torch.nn.init.normal_(self.proj_up.weight, mean=0.0, std=0.02) + else: + self.projections = CastedLinear(model_dim, n_mixtures * model_dim, bias=False) + torch.nn.init.normal_(self.projections.weight, mean=0.0, std=0.02) + + self.gate = CastedLinear(model_dim, n_mixtures, bias=False) + torch.nn.init.normal_(self.gate.weight, mean=0.0, std=0.02) + + def forward(self, hidden: torch.Tensor, weight_matrix: torch.Tensor) -> torch.Tensor: + bsz, seq_len, dim = hidden.shape + K = self.n_mixtures + + pi = F.softmax(self.gate(hidden), dim=-1) + + if self.rank > 0: + projected = self.proj_up(self.proj_down(hidden)).view(bsz, seq_len, K, dim) + else: + projected = self.projections(hidden).view(bsz, seq_len, K, dim) + + logits = torch.einsum('bskd,vd->bskv', projected, weight_matrix) + + log_probs = F.log_softmax(logits, dim=-1) + log_pi = torch.log(pi.unsqueeze(-1) + 1e-10) + mixed_log_probs = torch.logsumexp(log_probs + log_pi, dim=2) + + return mixed_log_probs + + +def test_mos(): + """Test MoS forward pass.""" + print("Testing MoS implementation...") + + vocab_size = 1024 + model_dim = 512 + batch_size = 2 + seq_len = 16 + + hidden = torch.randn(batch_size, seq_len, model_dim) + weight_matrix = torch.randn(vocab_size, model_dim) + + for K in [1, 2, 3]: + for rank in [0, 32, 64]: + mos = MixtureOfSoftmax(model_dim, vocab_size, n_mixtures=K, rank=rank) + output = mos(hidden, weight_matrix) + + assert output.shape == (batch_size, seq_len, vocab_size), f"Wrong shape: {output.shape}" + + # Verify output is valid log probabilities + probs = torch.exp(output) + prob_sum = probs.sum(dim=-1) + assert torch.allclose(prob_sum, torch.ones_like(prob_sum), atol=1e-4), \ + f"K={K} rank={rank}: probs don't sum to 1: {prob_sum.mean():.6f}" + + # Count parameters + params = sum(p.numel() for p in mos.parameters()) + size_kb = params / 1024 + print(f" K={K} rank={rank:>3d}: {params:>10,} params ({size_kb:>7.1f} KB at int8)") + + # Verify NLL loss works correctly with MoS output + mos = MixtureOfSoftmax(model_dim, vocab_size, n_mixtures=2, rank=64) + output = mos(hidden, weight_matrix) + targets = torch.randint(0, vocab_size, (batch_size, seq_len)) + loss = F.nll_loss(output.reshape(-1, vocab_size), targets.reshape(-1)) + assert loss.isfinite(), f"NLL loss is not finite: {loss}" + print(f"\n NLL loss test: {loss.item():.4f} (should be ~6.93 for random)") + + print("\nAll tests passed!") + + +if __name__ == "__main__": + test_mos() diff --git a/train_gpt.py b/train_gpt.py index 85e2cc463a..f74f343143 100644 --- a/train_gpt.py +++ b/train_gpt.py @@ -93,10 +93,17 @@ class Hyperparameters: ttt_eval_seq_len = int(os.environ.get("TTT_EVAL_SEQ_LEN", 1024)) ttt_batch_size = int(os.environ.get("TTT_BATCH_SIZE", 64)) + # Mixture of Softmax (MoS) output layer - breaks softmax bottleneck. + # At vocab=1024, dim=512, standard softmax has rank ≤ 513 (binding constraint). + # MoS with K=2 lifts this to rank ≤ 1026, enabling richer output distributions. + use_mos = bool(int(os.environ.get("USE_MOS", "0"))) + mos_k = int(os.environ.get("MOS_K", 2)) + mos_rank = int(os.environ.get("MOS_RANK", 64)) # 0 = full-rank, >0 = low-rank factorization + # ----------------------------- -# MUON OPTIMIZER +# MUON OPTIMIZER # ----------------------------- -# +# # As borrowed from modded-nanogpt # Background on Muon: https://kellerjordan.github.io/posts/muon/ @@ -658,6 +665,73 @@ def forward(self, x: Tensor, x0: Tensor, q_delta_fn=None, v_delta_fn=None) -> Te return x +class MixtureOfSoftmax(nn.Module): + """Mixture of Softmax output layer for breaking the softmax bottleneck. + + At vocab=1024, dim=512, the standard softmax has rank ≤ 513. + MoS with K=2 lifts this to rank ≤ 1026, enabling richer output distributions. + + When mos_rank > 0, uses low-rank factorization to save parameters: + instead of dim -> K*dim projection, uses dim -> rank -> K*dim. + + Paper: Yang et al. (2018), "Breaking the Softmax Bottleneck", ICLR 2018. + """ + + def __init__(self, model_dim: int, vocab_size: int, n_mixtures: int = 2, rank: int = 0): + super().__init__() + self.n_mixtures = n_mixtures + self.model_dim = model_dim + self.vocab_size = vocab_size + self.rank = rank + + if rank > 0: + # Low-rank factorization: dim -> rank -> K*dim + self.proj_down = CastedLinear(model_dim, rank, bias=False) + self.proj_up = CastedLinear(rank, n_mixtures * model_dim, bias=False) + nn.init.normal_(self.proj_down.weight, mean=0.0, std=0.02) + nn.init.normal_(self.proj_up.weight, mean=0.0, std=0.02) + else: + # Full-rank: dim -> K*dim + self.projections = CastedLinear(model_dim, n_mixtures * model_dim, bias=False) + nn.init.normal_(self.projections.weight, mean=0.0, std=0.02) + + # Mixing weight predictor + self.gate = CastedLinear(model_dim, n_mixtures, bias=False) + nn.init.normal_(self.gate.weight, mean=0.0, std=0.02) + + def forward(self, hidden: Tensor, weight_matrix: Tensor) -> Tensor: + """Compute mixed softmax distribution. + + Args: + hidden: (bsz, seq_len, dim) - final hidden states + weight_matrix: (vocab_size, dim) - tied embedding weights + + Returns: + log_probs: (bsz, seq_len, vocab_size) - mixed log probabilities + """ + bsz, seq_len, dim = hidden.shape + K = self.n_mixtures + + # Compute mixing weights: (bsz, seq, K) + pi = F.softmax(self.gate(hidden), dim=-1) + + # Project to K different spaces: (bsz, seq, K * dim) -> (bsz, seq, K, dim) + if self.rank > 0: + projected = self.proj_up(self.proj_down(hidden)).view(bsz, seq_len, K, dim) + else: + projected = self.projections(hidden).view(bsz, seq_len, K, dim) + + # Compute K different logit vectors: (bsz, seq, K, vocab) + logits = torch.einsum('bskd,vd->bskv', projected, weight_matrix) + + # Mix softmax distributions using log-space for numerical stability + log_probs = F.log_softmax(logits, dim=-1) # (bsz, seq, K, vocab) + log_pi = torch.log(pi.unsqueeze(-1) + 1e-10) # (bsz, seq, K, 1) + mixed_log_probs = torch.logsumexp(log_probs + log_pi, dim=2) # (bsz, seq, vocab) + + return mixed_log_probs + + class GPT(nn.Module): def __init__( self, @@ -672,6 +746,9 @@ def __init__( logit_softcap: float, rope_base: float, qk_gain_init: float, + use_mos: bool = False, + mos_k: int = 2, + mos_rank: int = 0, ): super().__init__() if logit_softcap <= 0.0: @@ -679,6 +756,7 @@ def __init__( self.tie_embeddings = tie_embeddings self.tied_embed_init_std = tied_embed_init_std self.logit_softcap = logit_softcap + self.use_mos = use_mos self.tok_emb = nn.Embedding(vocab_size, model_dim) self.num_encoder_layers = num_layers // 2 self.num_decoder_layers = num_layers - self.num_encoder_layers @@ -698,6 +776,11 @@ def __init__( ] ) self.final_norm = RMSNorm() + # MoS output layer (optional) - breaks softmax bottleneck + if use_mos: + self.mos = MixtureOfSoftmax(model_dim, vocab_size, n_mixtures=mos_k, rank=mos_rank) + else: + self.mos = None self.lm_head = None if tie_embeddings else CastedLinear(model_dim, vocab_size, bias=False) if self.lm_head is not None: self.lm_head._zero_init = True @@ -730,7 +813,18 @@ def forward(self, input_ids: Tensor, target_ids: Tensor, lora=None) -> Tensor: vd = lora.v_loras[bi] if lora else None x = self.blocks[bi](x, x0, qd, vd) x = self.final_norm(x) - if self.tie_embeddings: + # Output layer + if self.mos is not None and self.tie_embeddings: + # MoS: returns log-probs (already log-softmaxed), use NLL loss directly + log_probs = self.mos(x, self.tok_emb.weight) + if lora: + # LoRA correction breaks normalization; re-normalize via log_softmax + log_probs = F.log_softmax(log_probs + lora.lm_head_lora(x), dim=-1) + bsz, sl, V = log_probs.shape + return F.nll_loss( + log_probs.float().reshape(-1, V), target_ids.reshape(-1), reduction="none").reshape(bsz, sl) + return F.nll_loss(log_probs.float().reshape(-1, log_probs.size(-1)), target_ids.reshape(-1), reduction="mean") + elif self.tie_embeddings: logits = F.linear(x, self.tok_emb.weight) else: logits = self.lm_head(x) @@ -1065,6 +1159,9 @@ def log0(msg: str, console: bool = True) -> None: logit_softcap=args.logit_softcap, rope_base=args.rope_base, qk_gain_init=args.qk_gain_init, + use_mos=args.use_mos, + mos_k=args.mos_k, + mos_rank=args.mos_rank, ).to(device).bfloat16() for module in base_model.modules(): if isinstance(module, CastedLinear): @@ -1093,6 +1190,14 @@ def log0(msg: str, console: bool = True) -> None: ] if base_model.skip_weights.numel() > 0: scalar_params.append(base_model.skip_weights) + # MoS parameters: 2D projection weights go to Muon, gate goes to scalar optimizer + if base_model.mos is not None: + if base_model.mos.rank > 0: + matrix_params.append(base_model.mos.proj_down.weight) + matrix_params.append(base_model.mos.proj_up.weight) + else: + matrix_params.append(base_model.mos.projections.weight) + scalar_params.append(base_model.mos.gate.weight) token_lr = args.tied_embed_lr if args.tie_embeddings else args.embed_lr optimizer_tok = torch.optim.Adam( [{"params": [base_model.tok_emb.weight], "lr": token_lr, "base_lr": token_lr}], From 9d61c1736b94c07b184ab4539ce6fdaf6a1bda8d Mon Sep 17 00:00:00 2001 From: Billy Endson Date: Sat, 21 Mar 2026 02:36:02 +0700 Subject: [PATCH 02/32] Add one-shot RunPod setup and pilot run script Clones fork, downloads dataset, runs baseline vs MoS K=2 rank=64 A/B comparison (10 min each on 1x H100). Co-Authored-By: Claude Opus 4.6 --- setup_and_run.sh | 85 ++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 85 insertions(+) create mode 100755 setup_and_run.sh diff --git a/setup_and_run.sh b/setup_and_run.sh new file mode 100755 index 0000000000..ab7088568b --- /dev/null +++ b/setup_and_run.sh @@ -0,0 +1,85 @@ +#!/bin/bash +# === Parameter Golf: MoS Pilot on 1x H100 === +# Paste this entire script into your RunPod terminal. +# Uses the pre-built runpod/parameter-golf:latest template. +# Total time: ~18 min download + 10 min baseline + 10 min MoS = ~40 min + +set -e + +echo "=== Step 1: Clone fork and download dataset ===" +cd /workspace +git clone https://github.com/User123331/parameter-golf.git +cd parameter-golf + +# Download full dataset (takes ~18 min, needs all 80 shards for proper training) +python3 data/cached_challenge_fineweb.py --variant sp1024 & +DOWNLOAD_PID=$! + +echo "Dataset downloading in background (PID: $DOWNLOAD_PID)..." +echo "Waiting for download to complete..." +wait $DOWNLOAD_PID +echo "Dataset download complete!" + +# Verify dataset +ls -la data/datasets/fineweb10B_sp1024/ | head -5 +echo "Train shards: $(ls data/datasets/fineweb10B_sp1024/fineweb_train_*.bin 2>/dev/null | wc -l)" +echo "Val shards: $(ls data/datasets/fineweb10B_sp1024/fineweb_val_*.bin 2>/dev/null | wc -l)" + +echo "" +echo "=== Step 2: Run Baseline (10 min, 1x H100) ===" +echo "Start time: $(date)" + +RUN_ID=baseline_pilot \ +DATA_PATH=./data/datasets/fineweb10B_sp1024 \ +TOKENIZER_PATH=./data/tokenizers/fineweb_1024_bpe.model \ +VOCAB_SIZE=1024 \ +SEED=42 \ +MAX_WALLCLOCK_SECONDS=600 \ +VAL_LOSS_EVERY=500 \ +TRAIN_LOG_EVERY=100 \ +torchrun --standalone --nproc_per_node=1 train_gpt.py 2>&1 | tee /workspace/baseline_log.txt + +echo "" +echo "Baseline done at: $(date)" +echo "" + +# Save baseline artifact +cp final_model.int8.ptz /workspace/baseline_model.int8.ptz 2>/dev/null || true + +echo "=== Step 3: Run MoS K=2 rank=64 (10 min, 1x H100) ===" +echo "Start time: $(date)" + +RUN_ID=mos_k2_r64_pilot \ +DATA_PATH=./data/datasets/fineweb10B_sp1024 \ +TOKENIZER_PATH=./data/tokenizers/fineweb_1024_bpe.model \ +VOCAB_SIZE=1024 \ +SEED=42 \ +USE_MOS=1 \ +MOS_K=2 \ +MOS_RANK=64 \ +MAX_WALLCLOCK_SECONDS=600 \ +VAL_LOSS_EVERY=500 \ +TRAIN_LOG_EVERY=100 \ +torchrun --standalone --nproc_per_node=1 train_gpt.py 2>&1 | tee /workspace/mos_log.txt + +echo "" +echo "MoS done at: $(date)" + +# Save MoS artifact +cp final_model.int8.ptz /workspace/mos_model.int8.ptz 2>/dev/null || true + +echo "" +echo "============================================" +echo "=== RESULTS COMPARISON ===" +echo "============================================" +echo "" +echo "--- Baseline ---" +grep -E 'val_bpb|val_loss|bytes|param' /workspace/baseline_log.txt | tail -10 +echo "" +echo "--- MoS K=2 rank=64 ---" +grep -E 'val_bpb|val_loss|bytes|param' /workspace/mos_log.txt | tail -10 +echo "" +echo "--- Artifact Sizes ---" +ls -la /workspace/baseline_model.int8.ptz /workspace/mos_model.int8.ptz 2>/dev/null +echo "" +echo "Done! Copy these results back." From c8de91e169327fb0345b55e17e2e2231dde10389 Mon Sep 17 00:00:00 2001 From: Billy Endson Date: Sat, 21 Mar 2026 02:37:12 +0700 Subject: [PATCH 03/32] Simplify pilot script: MoS only, skip redundant baseline Baseline bpb already known from prior runs (~1.2244). Saves 10 min of GPU time. Co-Authored-By: Claude Opus 4.6 --- setup_and_run.sh | 61 ++++++++---------------------------------------- 1 file changed, 10 insertions(+), 51 deletions(-) diff --git a/setup_and_run.sh b/setup_and_run.sh index ab7088568b..c25241d369 100755 --- a/setup_and_run.sh +++ b/setup_and_run.sh @@ -1,8 +1,8 @@ #!/bin/bash # === Parameter Golf: MoS Pilot on 1x H100 === -# Paste this entire script into your RunPod terminal. -# Uses the pre-built runpod/parameter-golf:latest template. -# Total time: ~18 min download + 10 min baseline + 10 min MoS = ~40 min +# Paste this into your RunPod terminal. +# Total time: ~18 min download + 10 min MoS run = ~28 min +# Baseline already known: ~1.2244 bpb (10min/8xH100) or ~1.2074 (4hr/8xH100) set -e @@ -11,42 +11,15 @@ cd /workspace git clone https://github.com/User123331/parameter-golf.git cd parameter-golf -# Download full dataset (takes ~18 min, needs all 80 shards for proper training) -python3 data/cached_challenge_fineweb.py --variant sp1024 & -DOWNLOAD_PID=$! - -echo "Dataset downloading in background (PID: $DOWNLOAD_PID)..." -echo "Waiting for download to complete..." -wait $DOWNLOAD_PID -echo "Dataset download complete!" +# Download full dataset (~18 min) +python3 data/cached_challenge_fineweb.py --variant sp1024 # Verify dataset -ls -la data/datasets/fineweb10B_sp1024/ | head -5 echo "Train shards: $(ls data/datasets/fineweb10B_sp1024/fineweb_train_*.bin 2>/dev/null | wc -l)" echo "Val shards: $(ls data/datasets/fineweb10B_sp1024/fineweb_val_*.bin 2>/dev/null | wc -l)" echo "" -echo "=== Step 2: Run Baseline (10 min, 1x H100) ===" -echo "Start time: $(date)" - -RUN_ID=baseline_pilot \ -DATA_PATH=./data/datasets/fineweb10B_sp1024 \ -TOKENIZER_PATH=./data/tokenizers/fineweb_1024_bpe.model \ -VOCAB_SIZE=1024 \ -SEED=42 \ -MAX_WALLCLOCK_SECONDS=600 \ -VAL_LOSS_EVERY=500 \ -TRAIN_LOG_EVERY=100 \ -torchrun --standalone --nproc_per_node=1 train_gpt.py 2>&1 | tee /workspace/baseline_log.txt - -echo "" -echo "Baseline done at: $(date)" -echo "" - -# Save baseline artifact -cp final_model.int8.ptz /workspace/baseline_model.int8.ptz 2>/dev/null || true - -echo "=== Step 3: Run MoS K=2 rank=64 (10 min, 1x H100) ===" +echo "=== Step 2: Run MoS K=2 rank=64 (10 min, 1x H100) ===" echo "Start time: $(date)" RUN_ID=mos_k2_r64_pilot \ @@ -63,23 +36,9 @@ TRAIN_LOG_EVERY=100 \ torchrun --standalone --nproc_per_node=1 train_gpt.py 2>&1 | tee /workspace/mos_log.txt echo "" -echo "MoS done at: $(date)" - -# Save MoS artifact -cp final_model.int8.ptz /workspace/mos_model.int8.ptz 2>/dev/null || true - -echo "" -echo "============================================" -echo "=== RESULTS COMPARISON ===" -echo "============================================" -echo "" -echo "--- Baseline ---" -grep -E 'val_bpb|val_loss|bytes|param' /workspace/baseline_log.txt | tail -10 -echo "" -echo "--- MoS K=2 rank=64 ---" -grep -E 'val_bpb|val_loss|bytes|param' /workspace/mos_log.txt | tail -10 +echo "=== RESULTS ===" echo "" -echo "--- Artifact Sizes ---" -ls -la /workspace/baseline_model.int8.ptz /workspace/mos_model.int8.ptz 2>/dev/null +grep -E 'val_bpb|val_loss|bytes|param|model_params' /workspace/mos_log.txt | tail -15 echo "" -echo "Done! Copy these results back." +echo "Known baseline: val_bpb ~1.2244 (10min/8xH100)" +echo "Done at: $(date)" From 68520fe859e80bfd8fe3418e84b08925db6c15da Mon Sep 17 00:00:00 2001 From: Billy Endson Date: Sat, 21 Mar 2026 02:43:54 +0700 Subject: [PATCH 04/32] Fix setup script: remove clone step, assume already in repo Co-Authored-By: Claude Opus 4.6 --- setup_and_run.sh | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/setup_and_run.sh b/setup_and_run.sh index c25241d369..e4b4e824a5 100755 --- a/setup_and_run.sh +++ b/setup_and_run.sh @@ -6,10 +6,9 @@ set -e -echo "=== Step 1: Clone fork and download dataset ===" -cd /workspace -git clone https://github.com/User123331/parameter-golf.git -cd parameter-golf +echo "=== Step 1: Download dataset ===" +# Run from the repo root (already cloned) +cd /workspace/parameter-golf # Download full dataset (~18 min) python3 data/cached_challenge_fineweb.py --variant sp1024 From 01f071aca5ab46afbc47a7b9a8b22ac4ec9259e4 Mon Sep 17 00:00:00 2001 From: Billy Endson Date: Sat, 21 Mar 2026 02:47:02 +0700 Subject: [PATCH 05/32] Add HF_TOKEN to setup script for faster dataset downloads Co-Authored-By: Claude Opus 4.6 --- setup_and_run.sh | 3 +++ 1 file changed, 3 insertions(+) diff --git a/setup_and_run.sh b/setup_and_run.sh index e4b4e824a5..af74eca5ad 100755 --- a/setup_and_run.sh +++ b/setup_and_run.sh @@ -10,6 +10,9 @@ echo "=== Step 1: Download dataset ===" # Run from the repo root (already cloned) cd /workspace/parameter-golf +# HF token for faster downloads (avoids rate limiting) +export HF_TOKEN="${HF_TOKEN:-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}" + # Download full dataset (~18 min) python3 data/cached_challenge_fineweb.py --variant sp1024 From 183dfa3a190801fbf2b9e6ec88dd1862adea6259 Mon Sep 17 00:00:00 2001 From: Billy Endson Date: Sat, 21 Mar 2026 03:08:06 +0700 Subject: [PATCH 06/32] =?UTF-8?q?Record:=20MoS=20K=3D2=20R=3D64=20pilot=20?= =?UTF-8?q?=E2=80=94=20val=5Fbpb=3D1.3932=20(1xH100,=2010min)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit First MoS pilot run. 1113 steps on 1xH100 SXM, 12.8MB artifact. Loss still dropping at wallclock cap. Co-Authored-By: Claude Opus 4.6 --- .../README.md | 46 + .../submission.json | 31 + .../train.log | 93 + .../train_gpt.py | 1567 +++++++++++++++++ 4 files changed, 1737 insertions(+) create mode 100644 records/track_non_record_16mb/2026-03-21_MoS_K2_R64_1xH100_10min/README.md create mode 100644 records/track_non_record_16mb/2026-03-21_MoS_K2_R64_1xH100_10min/submission.json create mode 100644 records/track_non_record_16mb/2026-03-21_MoS_K2_R64_1xH100_10min/train.log create mode 100644 records/track_non_record_16mb/2026-03-21_MoS_K2_R64_1xH100_10min/train_gpt.py diff --git a/records/track_non_record_16mb/2026-03-21_MoS_K2_R64_1xH100_10min/README.md b/records/track_non_record_16mb/2026-03-21_MoS_K2_R64_1xH100_10min/README.md new file mode 100644 index 0000000000..8a62b6780d --- /dev/null +++ b/records/track_non_record_16mb/2026-03-21_MoS_K2_R64_1xH100_10min/README.md @@ -0,0 +1,46 @@ +First pilot run of Mixture of Softmax (MoS) on 1x H100 SXM, 10-minute wallclock. + +Configuration: +- Track: `non-record`, 1x H100 SXM, 10 min wallclock +- Layout: `VOCAB_SIZE=1024 NUM_LAYERS=9 MODEL_DIM=512 NUM_HEADS=8 NUM_KV_HEADS=4 MLP_MULT=2` +- MoS: `USE_MOS=1 MOS_K=2 MOS_RANK=64` (low-rank factorization, ~99K extra params) +- Tied embeddings, seed=42 + +Command: +```bash +RUN_ID=mos_k2_r64_pilot \ +DATA_PATH=./data/datasets/fineweb10B_sp1024 \ +TOKENIZER_PATH=./data/tokenizers/fineweb_1024_bpe.model \ +VOCAB_SIZE=1024 SEED=42 \ +USE_MOS=1 MOS_K=2 MOS_RANK=64 \ +MAX_WALLCLOCK_SECONDS=600 \ +VAL_LOSS_EVERY=500 TRAIN_LOG_EVERY=100 \ +torchrun --standalone --nproc_per_node=1 train_gpt.py +``` + +Key metrics: +- Stopped at step 1113/20000 (wallclock cap) +- Pre-quant: `val_loss:2.3505 val_bpb:1.3921` +- Post-quant (int8+zlib): `val_loss:2.3523 val_bpb:1.3932` +- Quantization degradation: +0.0011 bpb (minimal) +- Model params: 17,159,240 +- Artifact: 12,764,492 bytes int8+zlib (12.8MB, 3.2MB under 16MB cap) +- Code: 63,345 bytes +- Total: 12,827,837 bytes +- Peak memory: 11,012 MiB allocated +- Step avg: 539ms/step on 1x H100 + +Training curve: +| Step | Train Loss | Val BPB | Time | +|------|-----------|---------|------| +| 0 | 6.93 | 4.11 | 0s | +| 100 | 3.27 | — | 54s | +| 500 | 2.58 | 1.52 | 271s | +| 1000 | 2.40 | 1.40 | 542s | +| 1113 | — | 1.39 | 600s | + +Notes: +- Loss still dropping at wallclock stop — model had more to learn +- No TTT/LoRA eval was run (only int8 roundtrip) +- No same-conditions baseline for direct comparison (8xH100 baseline: ~1.2244 bpb at 20K steps) +- 1x H100 = ~1/8 throughput → only 1113 steps vs ~20K on 8xH100 diff --git a/records/track_non_record_16mb/2026-03-21_MoS_K2_R64_1xH100_10min/submission.json b/records/track_non_record_16mb/2026-03-21_MoS_K2_R64_1xH100_10min/submission.json new file mode 100644 index 0000000000..4112a7586f --- /dev/null +++ b/records/track_non_record_16mb/2026-03-21_MoS_K2_R64_1xH100_10min/submission.json @@ -0,0 +1,31 @@ +{ + "author": "billyendson", + "github_id": "User123331", + "name": "MoS K=2 Rank=64 Pilot (1xH100, 10min)", + "blurb": "First pilot of Mixture of Softmax (K=2, low-rank=64) on 1xH100 SXM for 10 minutes. Tests softmax bottleneck breaking with minimal parameter overhead (~99K params, 97KB). Artifact 12.8MB, well under 16MB cap. No TTT eval. Loss still dropping at wallclock stop.", + "date": "2026-03-21T19:48:40Z", + "track": "non-record-unlimited-compute-16mb", + "val_loss": 2.35234121, + "val_bpb": 1.39318897, + "pre_quant_val_loss": 2.3505, + "pre_quant_val_bpb": 1.3921, + "step_stop": 1113, + "wallclock_seconds": 600.423, + "bytes_total": 12827837, + "bytes_model_int8_zlib": 12764492, + "bytes_code": 63345, + "gpu": "1xH100_SXM", + "config": { + "USE_MOS": 1, + "MOS_K": 2, + "MOS_RANK": 64, + "VOCAB_SIZE": 1024, + "NUM_LAYERS": 9, + "MODEL_DIM": 512, + "NUM_HEADS": 8, + "NUM_KV_HEADS": 4, + "MLP_MULT": 2, + "SEED": 42, + "TRAIN_SEQ_LEN": 1024 + } +} diff --git a/records/track_non_record_16mb/2026-03-21_MoS_K2_R64_1xH100_10min/train.log b/records/track_non_record_16mb/2026-03-21_MoS_K2_R64_1xH100_10min/train.log new file mode 100644 index 0000000000..4904cb52fb --- /dev/null +++ b/records/track_non_record_16mb/2026-03-21_MoS_K2_R64_1xH100_10min/train.log @@ -0,0 +1,93 @@ +logs/mos_k2_r64_pilot.txt +val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path=./data/tokenizers/fineweb_1024_bpe.model +train_loader:dataset:fineweb10B_sp1024 train_shards:80 +val_loader:shards pattern=./data/datasets/fineweb10B_sp1024/fineweb_val_*.bin tokens:62021632 +model_params:17159240 +world_size:1 grad_accum_steps:8 +sdp_backends:cudnn=False flash=True mem_efficient=False math=False +attention_mode:gqa num_heads:8 num_kv_heads:4 +tie_embeddings:True embed_lr:0.05 head_lr:0.0 matrix_lr:0.04 scalar_lr:0.04 +train_batch_tokens:524288 train_seq_len:1024 iterations:20000 warmup_steps:20 max_wallclock_seconds:600.000 +seed:42 +/usr/local/lib/python3.12/dist-packages/torch/_inductor/lowering.py:7242: UserWarning: +Online softmax is disabled on the fly since Inductor decides to +split the reduction. Cut an issue to PyTorch if this is an +important use case and you want to speed it up with online +softmax. + + warnings.warn( +/usr/local/lib/python3.12/dist-packages/torch/_inductor/lowering.py:7242: UserWarning: +Online softmax is disabled on the fly since Inductor decides to +split the reduction. Cut an issue to PyTorch if this is an +important use case and you want to speed it up with online +softmax. + + warnings.warn( +warmup_step:1/20 +warmup_step:2/20 +warmup_step:3/20 +warmup_step:4/20 +warmup_step:5/20 +warmup_step:6/20 +warmup_step:7/20 +warmup_step:8/20 +warmup_step:9/20 +warmup_step:10/20 +warmup_step:11/20 +warmup_step:12/20 +warmup_step:13/20 +warmup_step:14/20 +warmup_step:15/20 +warmup_step:16/20 +warmup_step:17/20 +warmup_step:18/20 +warmup_step:19/20 +warmup_step:20/20 +/usr/local/lib/python3.12/dist-packages/torch/_inductor/lowering.py:7242: UserWarning: +Online softmax is disabled on the fly since Inductor decides to +split the reduction. Cut an issue to PyTorch if this is an +important use case and you want to speed it up with online +softmax. + + warnings.warn( +/usr/local/lib/python3.12/dist-packages/torch/_inductor/lowering.py:7242: UserWarning: +Online softmax is disabled on the fly since Inductor decides to +split the reduction. Cut an issue to PyTorch if this is an +important use case and you want to speed it up with online +softmax. + + warnings.warn( +step:0/20000 val_loss:6.9314 val_bpb:4.1052 train_time:0ms step_avg:0.01ms +step:1/20000 train_loss:6.9314 train_time:581ms step_avg:581.11ms +step:2/20000 train_loss:6.8515 train_time:1295ms step_avg:647.47ms +step:3/20000 train_loss:5.8655 train_time:1996ms step_avg:665.19ms +step:4/20000 train_loss:5.4250 train_time:2793ms step_avg:698.33ms +step:5/20000 train_loss:5.0728 train_time:3413ms step_avg:682.51ms +step:6/20000 train_loss:4.9797 train_time:4016ms step_avg:669.27ms +step:7/20000 train_loss:4.8555 train_time:4676ms step_avg:668.03ms +step:8/20000 train_loss:4.7612 train_time:5341ms step_avg:667.67ms +step:9/20000 train_loss:4.6900 train_time:5990ms step_avg:665.54ms +step:10/20000 train_loss:4.7029 train_time:6682ms step_avg:668.23ms +step:100/20000 train_loss:3.2746 train_time:54475ms step_avg:544.75ms +step:200/20000 train_loss:2.8511 train_time:108479ms step_avg:542.40ms +step:300/20000 train_loss:2.7046 train_time:162973ms step_avg:543.24ms +step:400/20000 train_loss:2.4804 train_time:217390ms step_avg:543.47ms +step:500/20000 train_loss:2.5755 train_time:271183ms step_avg:542.37ms +step:500/20000 val_loss:2.5703 val_bpb:1.5223 train_time:271193ms step_avg:542.39ms +step:600/20000 train_loss:2.5630 train_time:324786ms step_avg:541.31ms +step:700/20000 train_loss:2.5112 train_time:378359ms step_avg:540.51ms +step:800/20000 train_loss:2.3957 train_time:432963ms step_avg:541.20ms +step:900/20000 train_loss:2.4135 train_time:487589ms step_avg:541.77ms +step:1000/20000 train_loss:2.4031 train_time:542181ms step_avg:542.18ms +step:1000/20000 val_loss:2.3696 val_bpb:1.4034 train_time:542248ms step_avg:542.25ms +step:1100/20000 train_loss:2.3186 train_time:594115ms step_avg:540.10ms +step:1113/20000 val_loss:2.3505 val_bpb:1.3921 train_time:600423ms step_avg:539.46ms +stopping_early: wallclock_cap train_time:600423ms step:1113/20000 +peak memory allocated: 11012 MiB reserved: 11320 MiB +Serialized model: 67623386 bytes +Code size: 63345 bytes +Total submission size: 67686731 bytes +Serialized model int8+zlib: 12764492 bytes (payload:17377568 raw_torch:17423635 payload_ratio:3.89x) +Total submission size int8+zlib: 12827837 bytes +final_int8_zlib_roundtrip val_loss:2.3523 val_bpb:1.3932 eval_time:11887ms +final_int8_zlib_roundtrip_exact val_loss:2.35234121 val_bpb:1.39318897 diff --git a/records/track_non_record_16mb/2026-03-21_MoS_K2_R64_1xH100_10min/train_gpt.py b/records/track_non_record_16mb/2026-03-21_MoS_K2_R64_1xH100_10min/train_gpt.py new file mode 100644 index 0000000000..40db9e1f6d --- /dev/null +++ b/records/track_non_record_16mb/2026-03-21_MoS_K2_R64_1xH100_10min/train_gpt.py @@ -0,0 +1,1567 @@ +""" +The `train_gpt.py` and `train_gpt_mlx.py` scripts are intended as good launching-off points for new participants, not SOTA configs. We'll accept PRs that tune, improve, or simplify these scripts without significantly increasing complexity, but competitive submissions should stay in the `/records` folder. + +Hard stop: To keep readable for newcomers, let's make sure `train_gpt.py` and `train_gpt_mlx.py` never are longer than 1500 lines. +""" + +from __future__ import annotations + +import copy +import glob +import io +import math +import os +import random +import subprocess +import sys +import time +import uuid +import zlib +from pathlib import Path + +import numpy as np +import sentencepiece as spm +import torch +import torch.distributed as dist +import torch.nn.functional as F +from torch import Tensor, nn +from torch.nn.parallel import DistributedDataParallel as DDP + +# ----------------------------- +# HYPERPARAMETERS +# ----------------------------- +# Default Simple Baseline run: +# - 9 transformer blocks at width 512 +# - 8 attention heads with 4 KV heads (GQA) and 2x MLP expansion +# - vocab size 1024, sequence length 1024, tied embeddings +# - 524,288 train tokens per step for 20,000 iterations with a ~10 minute cap + +class Hyperparameters: + # Data paths are shard globs produced by the existing preprocessing pipeline. + data_path = os.environ.get("DATA_PATH", "./data/datasets/fineweb10B_sp1024") + train_files = os.path.join(data_path, "fineweb_train_*.bin") + val_files = os.path.join(data_path, "fineweb_val_*.bin") + tokenizer_path = os.environ.get("TOKENIZER_PATH", "./data/tokenizers/fineweb_1024_bpe.model") + run_id = os.environ.get("RUN_ID", str(uuid.uuid4())) + seed = int(os.environ.get("SEED", 1337)) + + # Validation cadence and batch size. Validation always uses the full fineweb_val split. + val_batch_size = int(os.environ.get("VAL_BATCH_SIZE", 524_288)) + val_loss_every = int(os.environ.get("VAL_LOSS_EVERY", 1000)) + train_log_every = int(os.environ.get("TRAIN_LOG_EVERY", 200)) + + # Training length. + iterations = int(os.environ.get("ITERATIONS", 20000)) + warmdown_iters = int(os.environ.get("WARMDOWN_ITERS", 1200)) + warmup_steps = int(os.environ.get("WARMUP_STEPS", 20)) + train_batch_tokens = int(os.environ.get("TRAIN_BATCH_TOKENS", 524_288)) + train_seq_len = int(os.environ.get("TRAIN_SEQ_LEN", 1024)) + max_wallclock_seconds = float(os.environ.get("MAX_WALLCLOCK_SECONDS", 600.0)) + qk_gain_init = float(os.environ.get("QK_GAIN_INIT", 1.5)) + + # Model shape. + vocab_size = int(os.environ.get("VOCAB_SIZE", 1024)) + num_layers = int(os.environ.get("NUM_LAYERS", 9)) + num_kv_heads = int(os.environ.get("NUM_KV_HEADS", 4)) + model_dim = int(os.environ.get("MODEL_DIM", 512)) + num_heads = int(os.environ.get("NUM_HEADS", 8)) + mlp_mult = int(os.environ.get("MLP_MULT", 2)) + tie_embeddings = bool(int(os.environ.get("TIE_EMBEDDINGS", "1"))) + rope_base = float(os.environ.get("ROPE_BASE", 10000.0)) + logit_softcap = float(os.environ.get("LOGIT_SOFTCAP", 30.0)) + + # Optimizer hyperparameters. + embed_lr = float(os.environ.get("EMBED_LR", 0.6)) + head_lr = float(os.environ.get("HEAD_LR", 0.008)) + tied_embed_lr = float(os.environ.get("TIED_EMBED_LR", 0.05)) + tied_embed_init_std = float(os.environ.get("TIED_EMBED_INIT_STD", 0.005)) + matrix_lr = float(os.environ.get("MATRIX_LR", 0.04)) + scalar_lr = float(os.environ.get("SCALAR_LR", 0.04)) + muon_momentum = float(os.environ.get("MUON_MOMENTUM", 0.95)) + muon_backend_steps = int(os.environ.get("MUON_BACKEND_STEPS", 5)) + muon_momentum_warmup_start = float(os.environ.get("MUON_MOMENTUM_WARMUP_START", 0.85)) + muon_momentum_warmup_steps = int(os.environ.get("MUON_MOMENTUM_WARMUP_STEPS", 500)) + beta1 = float(os.environ.get("BETA1", 0.9)) + beta2 = float(os.environ.get("BETA2", 0.95)) + adam_eps = float(os.environ.get("ADAM_EPS", 1e-8)) + grad_clip_norm = float(os.environ.get("GRAD_CLIP_NORM", 0.0)) + + # Test-time training (LoRA) hyperparameters. + ttt_lora_rank = int(os.environ.get("TTT_LORA_RANK", 8)) + ttt_lora_lr = float(os.environ.get("TTT_LORA_LR", 0.01)) + ttt_chunk_size = int(os.environ.get("TTT_CHUNK_SIZE", 256)) + ttt_eval_seq_len = int(os.environ.get("TTT_EVAL_SEQ_LEN", 1024)) + ttt_batch_size = int(os.environ.get("TTT_BATCH_SIZE", 64)) + + # Mixture of Softmax (MoS) output layer - breaks softmax bottleneck. + # At vocab=1024, dim=512, standard softmax has rank ≤ 513 (binding constraint). + # MoS with K=2 lifts this to rank ≤ 1026, enabling richer output distributions. + use_mos = bool(int(os.environ.get("USE_MOS", "0"))) + mos_k = int(os.environ.get("MOS_K", 2)) + mos_rank = int(os.environ.get("MOS_RANK", 64)) # 0 = full-rank, >0 = low-rank factorization + +# ----------------------------- +# MUON OPTIMIZER +# ----------------------------- +# +# As borrowed from modded-nanogpt +# Background on Muon: https://kellerjordan.github.io/posts/muon/ + +def zeropower_via_newtonschulz5(G: Tensor, steps: int = 10, eps: float = 1e-7) -> Tensor: + # Orthogonalize a 2D update matrix with a fast Newton-Schulz iteration. + # Muon uses this to normalize matrix-shaped gradients before applying them. + a, b, c = (3.4445, -4.7750, 2.0315) + X = G.bfloat16() + X /= X.norm() + eps + transposed = G.size(0) > G.size(1) + if transposed: + X = X.T + for _ in range(steps): + A = X @ X.T + B = b * A + c * A @ A + X = a * X + B @ X + return X.T if transposed else X + + +class Muon(torch.optim.Optimizer): + def __init__(self, params, lr: float, momentum: float, backend_steps: int, nesterov: bool = True): + super().__init__( + params, + dict(lr=lr, momentum=momentum, backend_steps=backend_steps, nesterov=nesterov), + ) + + @torch.no_grad() + def step(self, closure=None): + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + + distributed = dist.is_available() and dist.is_initialized() + world_size = dist.get_world_size() if distributed else 1 + rank = dist.get_rank() if distributed else 0 + + for group in self.param_groups: + params = group["params"] + if not params: + continue + lr = group["lr"] + momentum = group["momentum"] + backend_steps = group["backend_steps"] + nesterov = group["nesterov"] + + total_params = sum(int(p.numel()) for p in params) + updates_flat = torch.zeros(total_params, device=params[0].device, dtype=torch.bfloat16) + + curr = 0 + for i, p in enumerate(params): + if i % world_size == rank and p.grad is not None: + g = p.grad + state = self.state[p] + if "momentum_buffer" not in state: + state["momentum_buffer"] = torch.zeros_like(g) + buf = state["momentum_buffer"] + buf.mul_(momentum).add_(g) + if nesterov: + g = g.add(buf, alpha=momentum) + g = zeropower_via_newtonschulz5(g, steps=backend_steps) + # Scale correction from Muon reference implementations. + g *= max(1, g.size(0) / g.size(1)) ** 0.5 + updates_flat[curr : curr + p.numel()] = g.reshape(-1) + curr += p.numel() + + if distributed: + dist.all_reduce(updates_flat, op=dist.ReduceOp.SUM) + + curr = 0 + for p in params: + g = updates_flat[curr : curr + p.numel()].view_as(p).to(dtype=p.dtype) + p.add_(g, alpha=-lr) + curr += p.numel() + + return loss + + +# ----------------------------- +# TOKENIZER-AGNOSTIC EVALUATION SETUP +# ----------------------------- +# +# It's common for small models have a large fraction of their parameters be embeddings, since the 2 * d_model * d_vocab vectors can be gigantic. +# Instead of locking the tokenizer, we let you bring your own and calculate our validation metrics on the average compression of the validation set. +# We calculate BPB (bits-per-byte) instead of validation loss, so we need methods to count the number of bits per token in the tokenizer. +# Note: Submissions that edit the tokenizer will be examined more carefully, since screwing this up might unjustly improve your score. + +def build_sentencepiece_luts( + sp: spm.SentencePieceProcessor, vocab_size: int, device: torch.device +) -> tuple[Tensor, Tensor, Tensor]: + sp_vocab_size = int(sp.vocab_size()) + table_size = max(sp_vocab_size, vocab_size) + base_bytes_np = np.zeros((table_size,), dtype=np.int16) + has_leading_space_np = np.zeros((table_size,), dtype=np.bool_) + is_boundary_token_np = np.ones((table_size,), dtype=np.bool_) + for token_id in range(sp_vocab_size): + if sp.is_control(token_id) or sp.is_unknown(token_id) or sp.is_unused(token_id): + continue + is_boundary_token_np[token_id] = False + if sp.is_byte(token_id): + base_bytes_np[token_id] = 1 + continue + piece = sp.id_to_piece(token_id) + if piece.startswith("▁"): + has_leading_space_np[token_id] = True + piece = piece[1:] + base_bytes_np[token_id] = len(piece.encode("utf-8")) + return ( + torch.tensor(base_bytes_np, dtype=torch.int16, device=device), + torch.tensor(has_leading_space_np, dtype=torch.bool, device=device), + torch.tensor(is_boundary_token_np, dtype=torch.bool, device=device), + ) + + +def load_validation_tokens(pattern: str, seq_len: int) -> Tensor: + files = [Path(p) for p in sorted(glob.glob(pattern))] + if not files: + raise FileNotFoundError(f"No files found for pattern: {pattern}") + # The export pipeline writes the fixed first-50k-doc validation set to fineweb_val_*. + tokens = torch.cat([load_data_shard(file) for file in files]).contiguous() + usable = ((tokens.numel() - 1) // seq_len) * seq_len + if usable <= 0: + raise ValueError(f"Validation split is too short for TRAIN_SEQ_LEN={seq_len}") + return tokens[: usable + 1] + + +def eval_val( + args: Hyperparameters, + model: nn.Module, + rank: int, + world_size: int, + device: torch.device, + grad_accum_steps: int, + val_tokens: Tensor, + base_bytes_lut: Tensor, + has_leading_space_lut: Tensor, + is_boundary_token_lut: Tensor, +) -> tuple[float, float]: + # Validation computes two metrics: + # - val_loss: token cross-entropy (natural log) + # - val_bpb: tokenizer-agnostic compression metric used by the challenge + local_batch_tokens = args.val_batch_size // (world_size * grad_accum_steps) + if local_batch_tokens < args.train_seq_len: + raise ValueError( + "VAL_BATCH_SIZE must provide at least one sequence per rank; " + f"got VAL_BATCH_SIZE={args.val_batch_size}, WORLD_SIZE={world_size}, " + f"GRAD_ACCUM_STEPS={grad_accum_steps}, TRAIN_SEQ_LEN={args.train_seq_len}" + ) + local_batch_seqs = local_batch_tokens // args.train_seq_len + total_seqs = (val_tokens.numel() - 1) // args.train_seq_len + seq_start = (total_seqs * rank) // world_size + seq_end = (total_seqs * (rank + 1)) // world_size + val_loss_sum = torch.zeros((), device=device, dtype=torch.float64) + val_token_count = torch.zeros((), device=device, dtype=torch.float64) + val_byte_count = torch.zeros((), device=device, dtype=torch.float64) + + model.eval() + with torch.inference_mode(): + for batch_seq_start in range(seq_start, seq_end, local_batch_seqs): + batch_seq_end = min(batch_seq_start + local_batch_seqs, seq_end) + raw_start = batch_seq_start * args.train_seq_len + raw_end = batch_seq_end * args.train_seq_len + 1 + local = val_tokens[raw_start:raw_end].to(device=device, dtype=torch.int64, non_blocking=True) + x = local[:-1].reshape(-1, args.train_seq_len) + y = local[1:].reshape(-1, args.train_seq_len) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + batch_loss = model(x, y).detach() + batch_token_count = float(y.numel()) + val_loss_sum += batch_loss.to(torch.float64) * batch_token_count + val_token_count += batch_token_count + prev_ids = x.reshape(-1) + tgt_ids = y.reshape(-1) + token_bytes = base_bytes_lut[tgt_ids].to(dtype=torch.int16) + token_bytes += (has_leading_space_lut[tgt_ids] & ~is_boundary_token_lut[prev_ids]).to(dtype=torch.int16) + val_byte_count += token_bytes.to(torch.float64).sum() + + if dist.is_available() and dist.is_initialized(): + dist.all_reduce(val_loss_sum, op=dist.ReduceOp.SUM) + dist.all_reduce(val_token_count, op=dist.ReduceOp.SUM) + dist.all_reduce(val_byte_count, op=dist.ReduceOp.SUM) + + val_loss = val_loss_sum / val_token_count + bits_per_token = val_loss.item() / math.log(2.0) + tokens_per_byte = val_token_count.item() / val_byte_count.item() + model.train() + return float(val_loss.item()), float(bits_per_token * tokens_per_byte) + +# ----------------------------- +# POST-TRAINING QUANTIZATION +# ----------------------------- +# +# It's silly to export our model, which is trained in bf16 and fp32, at that same precision. +# Instead, we get approximately the same model (with a small hit) by quantizing the model to int8 & zlib compressing. +# We can then decompress the model and run in higher precision for evaluation, after closing in under the size limit. + +CONTROL_TENSOR_NAME_PATTERNS = tuple( + pattern + for pattern in os.environ.get( + "CONTROL_TENSOR_NAME_PATTERNS", + "attn_scale,attn_scales,mlp_scale,mlp_scales,resid_mix,resid_mixes,q_gain,skip_weight,skip_weights", + ).split(",") + if pattern +) +INT8_KEEP_FLOAT_FP32_NAME_PATTERNS = tuple( + pattern + for pattern in os.environ.get( + "INT8_KEEP_FLOAT_FP32_NAME_PATTERNS", + ",".join(CONTROL_TENSOR_NAME_PATTERNS), + ).split(",") + if pattern +) +INT8_KEEP_FLOAT_MAX_NUMEL = 65_536 +INT8_KEEP_FLOAT_STORE_DTYPE = torch.float16 +INT8_PER_ROW_SCALE_DTYPE = torch.float16 +INT8_CLIP_PERCENTILE = 99.99984 +INT8_CLIP_Q = INT8_CLIP_PERCENTILE / 100.0 + +def tensor_nbytes(t: Tensor) -> int: + return int(t.numel()) * int(t.element_size()) + +def keep_float_tensor(name: str, t: Tensor, passthrough_orig_dtypes: dict[str, str]) -> Tensor: + if any(pattern in name for pattern in INT8_KEEP_FLOAT_FP32_NAME_PATTERNS): + return t.float().contiguous() + if t.dtype in {torch.float32, torch.bfloat16}: + passthrough_orig_dtypes[name] = str(t.dtype).removeprefix("torch.") + return t.to(dtype=INT8_KEEP_FLOAT_STORE_DTYPE).contiguous() + return t + +def quantize_float_tensor(t: Tensor) -> tuple[Tensor, Tensor]: + t32 = t.float() + if t32.ndim == 2: + # Matrices get one scale per row, which usually tracks output-channel + # ranges much better than a single tensor-wide scale. + clip_abs = ( + torch.quantile(t32.abs(), INT8_CLIP_Q, dim=1) + if t32.numel() + else torch.empty((t32.shape[0],), dtype=torch.float32) + ) + clipped = torch.maximum(torch.minimum(t32, clip_abs[:, None]), -clip_abs[:, None]) + scale = (clip_abs / 127.0).clamp_min(1.0 / 127.0) + q = torch.clamp(torch.round(clipped / scale[:, None]), -127, 127).to(torch.int8).contiguous() + return q, scale.to(dtype=INT8_PER_ROW_SCALE_DTYPE).contiguous() + + # Vectors / scalars use a simpler per-tensor scale. + clip_abs = float(torch.quantile(t32.abs().flatten(), INT8_CLIP_Q).item()) if t32.numel() else 0.0 + scale = torch.tensor(clip_abs / 127.0 if clip_abs > 0 else 1.0, dtype=torch.float32) + q = torch.clamp(torch.round(torch.clamp(t32, -clip_abs, clip_abs) / scale), -127, 127).to(torch.int8).contiguous() + return q, scale + +def quantize_state_dict_int8(state_dict: dict[str, Tensor]): + # Single supported clean-script export format: + # - per-row int8 for 2D float tensors + # - per-tensor int8 for other float tensors + # - exact passthrough for non-floats + # - passthrough for small float tensors, stored as fp16 to save bytes + quantized: dict[str, Tensor] = {} + scales: dict[str, Tensor] = {} + dtypes: dict[str, str] = {} + passthrough: dict[str, Tensor] = {} + passthrough_orig_dtypes: dict[str, str] = {} + qmeta: dict[str, dict[str, object]] = {} + stats = dict.fromkeys( + ("param_count", "num_tensors", "num_float_tensors", "num_nonfloat_tensors", "baseline_tensor_bytes", "int8_payload_bytes"), + 0, + ) + + for name, tensor in state_dict.items(): + t = tensor.detach().to("cpu").contiguous() + stats["param_count"] += int(t.numel()) + stats["num_tensors"] += 1 + stats["baseline_tensor_bytes"] += tensor_nbytes(t) + + if not t.is_floating_point(): + stats["num_nonfloat_tensors"] += 1 + passthrough[name] = t + stats["int8_payload_bytes"] += tensor_nbytes(t) + continue + + # Small float tensors are cheap enough to keep directly. We still downcast + # fp32/bf16 passthrough tensors to fp16 so metadata does not dominate size. + if t.numel() <= INT8_KEEP_FLOAT_MAX_NUMEL: + kept = keep_float_tensor(name, t, passthrough_orig_dtypes) + passthrough[name] = kept + stats["int8_payload_bytes"] += tensor_nbytes(kept) + continue + + stats["num_float_tensors"] += 1 + q, s = quantize_float_tensor(t) + if s.ndim > 0: + qmeta[name] = {"scheme": "per_row", "axis": 0} + quantized[name] = q + scales[name] = s + dtypes[name] = str(t.dtype).removeprefix("torch.") + stats["int8_payload_bytes"] += tensor_nbytes(q) + tensor_nbytes(s) + + obj: dict[str, object] = { + "__quant_format__": "int8_clean_per_row_v1", + "quantized": quantized, + "scales": scales, + "dtypes": dtypes, + "passthrough": passthrough, + } + if qmeta: + obj["qmeta"] = qmeta + if passthrough_orig_dtypes: + obj["passthrough_orig_dtypes"] = passthrough_orig_dtypes + return obj, stats + +def dequantize_state_dict_int8(obj: dict[str, object]) -> dict[str, Tensor]: + out: dict[str, Tensor] = {} + qmeta = obj.get("qmeta", {}) + passthrough_orig_dtypes = obj.get("passthrough_orig_dtypes", {}) + for name, q in obj["quantized"].items(): + dtype = getattr(torch, obj["dtypes"][name]) + s = obj["scales"][name] + if qmeta.get(name, {}).get("scheme") == "per_row" or s.ndim > 0: + s = s.to(dtype=torch.float32) + # Broadcast the saved row scale back across trailing dimensions. + out[name] = (q.float() * s.view(q.shape[0], *([1] * (q.ndim - 1)))).to(dtype=dtype).contiguous() + else: + scale = float(s.item()) + out[name] = (q.float() * scale).to(dtype=dtype).contiguous() + for name, t in obj["passthrough"].items(): + # Restore small tensors, undoing the temporary fp16 storage cast if needed. + out_t = t.detach().to("cpu").contiguous() + orig_dtype = passthrough_orig_dtypes.get(name) + if isinstance(orig_dtype, str): + out_t = out_t.to(dtype=getattr(torch, orig_dtype)).contiguous() + out[name] = out_t + return out + + +# ----------------------------- +# DATA LOADING +# ----------------------------- + +def load_data_shard(file: Path) -> Tensor: + header_bytes = 256 * np.dtype(" None: + self.file_idx = (self.file_idx + 1) % len(self.files) + self.tokens = load_data_shard(self.files[self.file_idx]) + self.pos = 0 + + def take(self, n: int) -> Tensor: + chunks: list[Tensor] = [] + remaining = n + while remaining > 0: + avail = self.tokens.numel() - self.pos + if avail <= 0: + self._advance_file() + continue + k = min(remaining, avail) + chunks.append(self.tokens[self.pos : self.pos + k]) + self.pos += k + remaining -= k + return chunks[0] if len(chunks) == 1 else torch.cat(chunks) + + +class DistributedTokenLoader: + # Each call consumes a contiguous chunk from the shared token stream, then slices out + # one disjoint span per rank. The extra "+1" token lets us build (x, y) by shifting. + def __init__(self, pattern: str, rank: int, world_size: int, device: torch.device): + self.rank = rank + self.world_size = world_size + self.device = device + self.stream = TokenStream(pattern) + + def next_batch(self, global_tokens: int, seq_len: int, grad_accum_steps: int) -> tuple[Tensor, Tensor]: + local_tokens = global_tokens // (self.world_size * grad_accum_steps) + per_rank_span = local_tokens + 1 + chunk = self.stream.take(per_rank_span * self.world_size) + start = self.rank * per_rank_span + local = chunk[start : start + per_rank_span].to(dtype=torch.int64) + x = local[:-1].reshape(-1, seq_len) + y = local[1:].reshape(-1, seq_len) + return x.to(self.device, non_blocking=True), y.to(self.device, non_blocking=True) + +# ----------------------------- +# TRANSFORMER MODULES +# ----------------------------- + +class RMSNorm(nn.Module): + def __init__(self, eps: float | None = None): + super().__init__() + self.eps = eps + + def forward(self, x: Tensor) -> Tensor: + return F.rms_norm(x, (x.size(-1),), eps=self.eps) + + +class CastedLinear(nn.Linear): + # Keep weights in fp32 for optimizer/state quality, cast at matmul time for bf16 compute. + def forward(self, x: Tensor) -> Tensor: + bias = self.bias.to(x.dtype) if self.bias is not None else None + return F.linear(x, self.weight.to(x.dtype), bias) + + +def restore_low_dim_params_to_fp32(module: nn.Module) -> None: + # Keep small/control parameters in fp32 even when the model body runs in bf16. + with torch.no_grad(): + for name, param in module.named_parameters(): + if (param.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS)) and param.dtype != torch.float32: + param.data = param.data.float() + + +class Rotary(nn.Module): + # Caches cos/sin tables per sequence length on the current device. + def __init__(self, dim: int, base: float = 10000.0): + super().__init__() + inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.float32) / dim)) + self.register_buffer("inv_freq", inv_freq, persistent=False) + self._seq_len_cached = 0 + self._cos_cached: Tensor | None = None + self._sin_cached: Tensor | None = None + + def forward(self, seq_len: int, device: torch.device, dtype: torch.dtype) -> tuple[Tensor, Tensor]: + if ( + self._cos_cached is None + or self._sin_cached is None + or self._seq_len_cached != seq_len + or self._cos_cached.device != device + ): + t = torch.arange(seq_len, device=device, dtype=self.inv_freq.dtype) + freqs = torch.outer(t, self.inv_freq.to(device)) + self._cos_cached = freqs.cos()[None, None, :, :] + self._sin_cached = freqs.sin()[None, None, :, :] + self._seq_len_cached = seq_len + return self._cos_cached.to(dtype=dtype), self._sin_cached.to(dtype=dtype) + + +def apply_rotary_emb(x: Tensor, cos: Tensor, sin: Tensor) -> Tensor: + half = x.size(-1) // 2 + x1, x2 = x[..., :half], x[..., half:] + return torch.cat((x1 * cos + x2 * sin, x1 * (-sin) + x2 * cos), dim=-1) + + +class CausalSelfAttention(nn.Module): + def __init__( + self, + dim: int, + num_heads: int, + num_kv_heads: int, + rope_base: float, + qk_gain_init: float, + ): + super().__init__() + if dim % num_heads != 0: + raise ValueError("model_dim must be divisible by num_heads") + if num_heads % num_kv_heads != 0: + raise ValueError("num_heads must be divisible by num_kv_heads") + self.num_heads = num_heads + self.num_kv_heads = num_kv_heads + self.head_dim = dim // num_heads + if self.head_dim % 2 != 0: + raise ValueError("head_dim must be even for RoPE") + kv_dim = self.num_kv_heads * self.head_dim + self.c_q = CastedLinear(dim, dim, bias=False) + self.c_k = CastedLinear(dim, kv_dim, bias=False) + self.c_v = CastedLinear(dim, kv_dim, bias=False) + self.proj = CastedLinear(dim, dim, bias=False) + self.proj._zero_init = True + self.q_gain = nn.Parameter(torch.full((num_heads,), qk_gain_init, dtype=torch.float32)) + self.rotary = Rotary(self.head_dim, base=rope_base) + + def forward(self, x: Tensor, q_delta=None, v_delta=None) -> Tensor: + bsz, seqlen, dim = x.shape + q = self.c_q(x) + (q_delta if q_delta is not None else 0) + k = self.c_k(x) + v = self.c_v(x) + (v_delta if v_delta is not None else 0) + q = q.reshape(bsz, seqlen, self.num_heads, self.head_dim).transpose(1, 2) + k = k.reshape(bsz, seqlen, self.num_kv_heads, self.head_dim).transpose(1, 2) + v = v.reshape(bsz, seqlen, self.num_kv_heads, self.head_dim).transpose(1, 2) + q = F.rms_norm(q, (q.size(-1),)) + k = F.rms_norm(k, (k.size(-1),)) + cos, sin = self.rotary(seqlen, x.device, q.dtype) + q = apply_rotary_emb(q, cos, sin) + k = apply_rotary_emb(k, cos, sin) + q = q * self.q_gain.to(dtype=q.dtype)[None, :, None, None] + y = F.scaled_dot_product_attention( + q, + k, + v, + attn_mask=None, + is_causal=True, + enable_gqa=(self.num_kv_heads != self.num_heads), + ) + y = y.transpose(1, 2).contiguous().reshape(bsz, seqlen, dim) + return self.proj(y) + + +class MLP(nn.Module): + # relu^2 MLP from the original modded-nanogpt setup + def __init__(self, dim: int, mlp_mult: int): + super().__init__() + hidden = mlp_mult * dim + self.fc = CastedLinear(dim, hidden, bias=False) + self.proj = CastedLinear(hidden, dim, bias=False) + self.proj._zero_init = True + + def forward(self, x: Tensor) -> Tensor: + x = torch.relu(self.fc(x)) + return self.proj(x.square()) + + +class Block(nn.Module): + def __init__( + self, + dim: int, + num_heads: int, + num_kv_heads: int, + mlp_mult: int, + rope_base: float, + qk_gain_init: float, + ): + super().__init__() + self.attn_norm = RMSNorm() + self.mlp_norm = RMSNorm() + self.attn = CausalSelfAttention(dim, num_heads, num_kv_heads, rope_base, qk_gain_init) + self.mlp = MLP(dim, mlp_mult) + self.attn_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + self.mlp_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + self.resid_mix = nn.Parameter(torch.stack((torch.ones(dim), torch.zeros(dim))).float()) + + def forward(self, x: Tensor, x0: Tensor, q_delta_fn=None, v_delta_fn=None) -> Tensor: + mix = self.resid_mix.to(dtype=x.dtype) + x = mix[0][None, None, :] * x + mix[1][None, None, :] * x0 + n = self.attn_norm(x) + qd = q_delta_fn(n) if q_delta_fn is not None else None + vd = v_delta_fn(n) if v_delta_fn is not None else None + attn_out = self.attn(n, qd, vd) + x = x + self.attn_scale.to(dtype=x.dtype)[None, None, :] * attn_out + x = x + self.mlp_scale.to(dtype=x.dtype)[None, None, :] * self.mlp(self.mlp_norm(x)) + return x + + +class MixtureOfSoftmax(nn.Module): + """Mixture of Softmax output layer for breaking the softmax bottleneck. + + At vocab=1024, dim=512, the standard softmax has rank ≤ 513. + MoS with K=2 lifts this to rank ≤ 1026, enabling richer output distributions. + + When mos_rank > 0, uses low-rank factorization to save parameters: + instead of dim -> K*dim projection, uses dim -> rank -> K*dim. + + Paper: Yang et al. (2018), "Breaking the Softmax Bottleneck", ICLR 2018. + """ + + def __init__(self, model_dim: int, vocab_size: int, n_mixtures: int = 2, rank: int = 0): + super().__init__() + self.n_mixtures = n_mixtures + self.model_dim = model_dim + self.vocab_size = vocab_size + self.rank = rank + + if rank > 0: + # Low-rank factorization: dim -> rank -> K*dim + self.proj_down = CastedLinear(model_dim, rank, bias=False) + self.proj_up = CastedLinear(rank, n_mixtures * model_dim, bias=False) + nn.init.normal_(self.proj_down.weight, mean=0.0, std=0.02) + nn.init.normal_(self.proj_up.weight, mean=0.0, std=0.02) + else: + # Full-rank: dim -> K*dim + self.projections = CastedLinear(model_dim, n_mixtures * model_dim, bias=False) + nn.init.normal_(self.projections.weight, mean=0.0, std=0.02) + + # Mixing weight predictor + self.gate = CastedLinear(model_dim, n_mixtures, bias=False) + nn.init.normal_(self.gate.weight, mean=0.0, std=0.02) + + def forward(self, hidden: Tensor, weight_matrix: Tensor) -> Tensor: + """Compute mixed softmax distribution. + + Args: + hidden: (bsz, seq_len, dim) - final hidden states + weight_matrix: (vocab_size, dim) - tied embedding weights + + Returns: + log_probs: (bsz, seq_len, vocab_size) - mixed log probabilities + """ + bsz, seq_len, dim = hidden.shape + K = self.n_mixtures + + # Compute mixing weights: (bsz, seq, K) + pi = F.softmax(self.gate(hidden), dim=-1) + + # Project to K different spaces: (bsz, seq, K * dim) -> (bsz, seq, K, dim) + if self.rank > 0: + projected = self.proj_up(self.proj_down(hidden)).view(bsz, seq_len, K, dim) + else: + projected = self.projections(hidden).view(bsz, seq_len, K, dim) + + # Compute K different logit vectors: (bsz, seq, K, vocab) + logits = torch.einsum('bskd,vd->bskv', projected, weight_matrix) + + # Mix softmax distributions using log-space for numerical stability + log_probs = F.log_softmax(logits, dim=-1) # (bsz, seq, K, vocab) + log_pi = torch.log(pi.unsqueeze(-1) + 1e-10) # (bsz, seq, K, 1) + mixed_log_probs = torch.logsumexp(log_probs + log_pi, dim=2) # (bsz, seq, vocab) + + return mixed_log_probs + + +class GPT(nn.Module): + def __init__( + self, + vocab_size: int, + num_layers: int, + model_dim: int, + num_heads: int, + num_kv_heads: int, + mlp_mult: int, + tie_embeddings: bool, + tied_embed_init_std: float, + logit_softcap: float, + rope_base: float, + qk_gain_init: float, + use_mos: bool = False, + mos_k: int = 2, + mos_rank: int = 0, + ): + super().__init__() + if logit_softcap <= 0.0: + raise ValueError(f"logit_softcap must be positive, got {logit_softcap}") + self.tie_embeddings = tie_embeddings + self.tied_embed_init_std = tied_embed_init_std + self.logit_softcap = logit_softcap + self.use_mos = use_mos + self.tok_emb = nn.Embedding(vocab_size, model_dim) + self.num_encoder_layers = num_layers // 2 + self.num_decoder_layers = num_layers - self.num_encoder_layers + self.num_skip_weights = min(self.num_encoder_layers, self.num_decoder_layers) + self.skip_weights = nn.Parameter(torch.ones(self.num_skip_weights, model_dim, dtype=torch.float32)) + self.blocks = nn.ModuleList( + [ + Block( + model_dim, + num_heads, + num_kv_heads, + mlp_mult, + rope_base, + qk_gain_init, + ) + for i in range(num_layers) + ] + ) + self.final_norm = RMSNorm() + # MoS output layer (optional) - breaks softmax bottleneck + if use_mos: + self.mos = MixtureOfSoftmax(model_dim, vocab_size, n_mixtures=mos_k, rank=mos_rank) + else: + self.mos = None + self.lm_head = None if tie_embeddings else CastedLinear(model_dim, vocab_size, bias=False) + if self.lm_head is not None: + self.lm_head._zero_init = True + self._init_weights() + + def _init_weights(self) -> None: + if self.tie_embeddings: + nn.init.normal_(self.tok_emb.weight, mean=0.0, std=self.tied_embed_init_std) + for module in self.modules(): + if isinstance(module, nn.Linear) and getattr(module, "_zero_init", False): + nn.init.zeros_(module.weight) + + def forward(self, input_ids: Tensor, target_ids: Tensor, lora=None) -> Tensor: + x = self.tok_emb(input_ids) + x = F.rms_norm(x, (x.size(-1),)) + x0 = x + skips: list[Tensor] = [] + + # First half stores skips; second half reuses them in reverse order. + for i in range(self.num_encoder_layers): + qd = lora.q_loras[i] if lora else None + vd = lora.v_loras[i] if lora else None + x = self.blocks[i](x, x0, qd, vd) + skips.append(x) + for i in range(self.num_decoder_layers): + bi = self.num_encoder_layers + i + if skips: + x = x + self.skip_weights[i].to(dtype=x.dtype)[None, None, :] * skips.pop() + qd = lora.q_loras[bi] if lora else None + vd = lora.v_loras[bi] if lora else None + x = self.blocks[bi](x, x0, qd, vd) + x = self.final_norm(x) + # Output layer + if self.mos is not None and self.tie_embeddings: + # MoS: returns log-probs (already log-softmaxed), use NLL loss directly + log_probs = self.mos(x, self.tok_emb.weight) + if lora: + # LoRA correction breaks normalization; re-normalize via log_softmax + log_probs = F.log_softmax(log_probs + lora.lm_head_lora(x), dim=-1) + bsz, sl, V = log_probs.shape + return F.nll_loss( + log_probs.float().reshape(-1, V), target_ids.reshape(-1), reduction="none").reshape(bsz, sl) + return F.nll_loss(log_probs.float().reshape(-1, log_probs.size(-1)), target_ids.reshape(-1), reduction="mean") + elif self.tie_embeddings: + logits = F.linear(x, self.tok_emb.weight) + else: + logits = self.lm_head(x) + logits = logits + (lora.lm_head_lora(x) if lora else 0) + logits = self.logit_softcap * torch.tanh(logits / self.logit_softcap) + if lora: + bsz, sl, V = logits.shape + return F.cross_entropy( + logits.float().reshape(-1, V), target_ids.reshape(-1), reduction="none").reshape(bsz, sl) + return F.cross_entropy(logits.float().reshape(-1, logits.size(-1)), target_ids.reshape(-1), reduction="mean") + + +# ----------------------------- +# TEST-TIME TRAINING (LoRA) +# ----------------------------- +# +# At evaluation time, we adapt per-document low-rank adapters on the validation data. +# Each document gets its own adapter, so there is no inter-document dependency. + +BOS_ID = 1 + +class BatchedLinearLoRA(nn.Module): + """LoRA for a linear layer, with independent weights per batch element. + Computes x @ Aᵀ @ Bᵀ = x @ (BA)ᵀ, i.e. the LoRA delta is ΔW = BA.""" + def __init__(self, bsz: int, in_features: int, out_features: int, rank: int): + super().__init__() + self.in_features = in_features + self.A = nn.Parameter(torch.empty(bsz, rank, in_features)) # down-projection + self.B = nn.Parameter(torch.zeros(bsz, out_features, rank)) # up-projection + self.reset() + + def forward(self, x: Tensor) -> Tensor: + return (x @ self.A.transpose(1, 2)) @ self.B.transpose(1, 2) # (bsz, T, out) + + def reset(self) -> None: + bound = 1.0 / math.sqrt(self.in_features) + with torch.no_grad(): + self.A.uniform_(-bound, bound) # kaiming-uniform + self.B.zero_() + +class BatchedTTTLoRA(nn.Module): + """All LoRA adapters for one batch: LM head and Q/V per block.""" + def __init__(self, bsz: int, model: GPT, rank: int): + super().__init__() + dim = model.tok_emb.embedding_dim + vocab = model.tok_emb.num_embeddings + self.lm_head_lora = BatchedLinearLoRA(bsz, dim, vocab, rank) + self.q_loras = nn.ModuleList() + self.v_loras = nn.ModuleList() + for block in model.blocks: + self.q_loras.append(BatchedLinearLoRA(bsz, dim, block.attn.c_q.weight.shape[0], rank)) + self.v_loras.append(BatchedLinearLoRA(bsz, dim, block.attn.c_v.weight.shape[0], rank)) + + def reset(self) -> None: + for m in self.modules(): + if isinstance(m, BatchedLinearLoRA): + m.reset() + +def _reset_ttt_optimizer(opt): + for group in opt.param_groups: + for p in group['params']: + s = opt.state.get(p) + if not s: # Fresh state. + continue + s['exp_avg'].zero_() + s['exp_avg_sq'].zero_() + s['step'].fill_(0) + +def _build_ttt_optimizer(lora, args: Hyperparameters): + return torch.optim.Adam(lora.parameters(), lr=args.ttt_lora_lr, betas=(args.beta1, args.beta2), eps=1e-10) + +def _find_docs(all_tokens: Tensor, include_next_bos: bool = True) -> list[tuple[int, int]]: + """Return (start_offset, length) for each document, identified by BOS boundaries. + + If include_next_bos is True, include next document's BOS (to match continuous-stream + eval token count exactly). + """ + bos_positions = (all_tokens == BOS_ID).nonzero(as_tuple=True)[0].numpy() + docs = [] + for i in range(len(bos_positions)): + start = int(bos_positions[i]) + end = int(bos_positions[i + 1]) if i + 1 < len(bos_positions) else all_tokens.numel() + if include_next_bos and i + 1 < len(bos_positions): + end += 1 + assert end - start >= 2 + docs.append((start, end - start)) + return docs + +def _compute_chunk_window(ci: int, pred_len: int, num_chunks: int, chunk_size: int, eval_seq_len: int): + """Return (win_start, win_len, chunk_offset, chunk_len) for chunk `ci` of a doc.""" + chunk_start = ci * chunk_size + chunk_end = pred_len if ci == num_chunks - 1 else (ci + 1) * chunk_size + win_start = max(0, chunk_end - eval_seq_len) + win_len = chunk_end - win_start + chunk_offset = chunk_start - win_start + chunk_len = chunk_end - chunk_start + return win_start, win_len, chunk_offset, chunk_len + +def _accumulate_bpb( + ptl: Tensor, x: Tensor, y: Tensor, + batch_i: int, chunk_offset: int, chunk_len: int, + base_bytes_lut: Tensor, has_leading_space_lut: Tensor, is_boundary_token_lut: Tensor, + loss_sum: Tensor, byte_sum: Tensor, token_count: Tensor, +): + """Add one doc-chunk's contribution to the running BPB accumulators.""" + lbl = ptl[batch_i, chunk_offset:chunk_offset + chunk_len].to(torch.float64) + prev = x[batch_i, chunk_offset:chunk_offset + chunk_len] + tgt = y[batch_i, chunk_offset:chunk_offset + chunk_len] + tok_bytes = base_bytes_lut[tgt].to(torch.float64) + tok_bytes += has_leading_space_lut[tgt] & ~is_boundary_token_lut[prev] + loss_sum += lbl.sum() + byte_sum += tok_bytes.sum() + token_count += chunk_len + +def eval_val_ttt_lora( + args: Hyperparameters, + base_model: GPT, + rank: int, + world_size: int, + device: torch.device, + base_bytes_lut: Tensor, + has_leading_space_lut: Tensor, + is_boundary_token_lut: Tensor, +) -> tuple[float, float]: + """Evaluate with batched LoRA test-time training. Returns (val_loss, val_bpb).""" + # Load validation tokens and find document boundaries + files = sorted(glob.glob(args.val_files)) + all_tokens = torch.cat([load_data_shard(Path(f)) for f in files]) + docs = _find_docs(all_tokens) + + # Each rank takes a contiguous slice of documents + rank_docs = docs[(len(docs) * rank) // world_size : (len(docs) * (rank + 1)) // world_size] + chunk_size = args.ttt_chunk_size + eval_seq_len = args.ttt_eval_seq_len + batch_size = args.ttt_batch_size + lora_rank = args.ttt_lora_rank + + rank_docs.sort(key=lambda d: (d[1] - 2) // chunk_size) + + base_model.eval() + for p in base_model.parameters(): + p.requires_grad_(False) + + lora = BatchedTTTLoRA(batch_size, base_model, lora_rank).to(device) + opt = _build_ttt_optimizer(lora, args) + + loss_sum = torch.zeros((), device=device, dtype=torch.float64) + byte_sum = torch.zeros((), device=device, dtype=torch.float64) + token_count = torch.zeros((), device=device, dtype=torch.float64) + + for bi in range(0, len(rank_docs), batch_size): + batch = rank_docs[bi:bi + batch_size] + bsz = len(batch) + + if bsz == batch_size: + cur_lora, cur_opt = lora, opt + cur_lora.reset() + _reset_ttt_optimizer(cur_opt) + else: + cur_lora = BatchedTTTLoRA(bsz, base_model, lora_rank).to(device) + cur_opt = _build_ttt_optimizer(cur_lora, args) + + pred_lens = [doc_len - 1 for _, doc_len in batch] + num_chunks = [(pl + chunk_size - 1) // chunk_size for pl in pred_lens] + max_nc = max(num_chunks) + + for ci in range(max_nc): + chunk_stats = _compute_chunk_window(ci, (ci + 1) * chunk_size, ci + 1, chunk_size, eval_seq_len) + context_size, chunk_offset = chunk_stats[1], chunk_stats[2] + + active = [ci < nc for nc in num_chunks] + needs_train = any(ci < nc - 1 for nc in num_chunks) + + x = torch.zeros(bsz, context_size, dtype=torch.int64, device=device) + y = torch.zeros(bsz, context_size, dtype=torch.int64, device=device) + doc_info = [] # (chunk_offset, chunk_len) per doc + for b in range(bsz): + if not active[b]: + doc_info.append((0, 0)) + continue + ds, dl = batch[b] + ws, wl, co, cl = _compute_chunk_window(ci, pred_lens[b], num_chunks[b], chunk_size, eval_seq_len) + chunk = all_tokens[ds + ws: ds + ws + wl + 1] + toks = chunk.to(dtype=torch.int64, device=device) + x[b, :wl] = toks[:-1] + y[b, :wl] = toks[1:] + doc_info.append((co, cl)) + + # Forward pass (keep grad graph alive only when we need to train) + if needs_train: + with torch.autocast(device_type="cuda", dtype=torch.bfloat16): + ptl = base_model(x, y, lora=cur_lora) + else: + with torch.no_grad(), torch.autocast(device_type="cuda", dtype=torch.bfloat16): + ptl = base_model(x, y, lora=cur_lora) + + # Score: accumulate loss and byte counts for BPB (before training on chunk) + with torch.no_grad(): + for b in range(bsz): + if not active[b]: + continue + co, cl = doc_info[b] + _accumulate_bpb( + ptl, x, y, b, co, cl, base_bytes_lut, has_leading_space_lut, + is_boundary_token_lut, loss_sum, byte_sum, token_count) + + # Train: one Adam step on the LoRA params using this chunk's loss + if needs_train: + mask = torch.tensor([float(ci < num_chunks[b] - 1) for b in range(bsz)], device=device) + per_doc = ptl[:, chunk_offset:chunk_offset + chunk_size].mean(dim=-1) + cur_opt.zero_grad() + (per_doc * mask).sum().backward() + cur_opt.step() + + if dist.is_available() and dist.is_initialized(): + dist.all_reduce(loss_sum, op=dist.ReduceOp.SUM) + dist.all_reduce(byte_sum, op=dist.ReduceOp.SUM) + dist.all_reduce(token_count, op=dist.ReduceOp.SUM) + + val_loss = float(loss_sum.item() / token_count.item()) + val_bpb = float((loss_sum.item() / math.log(2.0)) / byte_sum.item()) + return val_loss, val_bpb + +# ----------------------------- +# TRAINING +# ----------------------------- + +def main() -> None: + global zeropower_via_newtonschulz5 + + code = Path(__file__).read_text(encoding="utf-8") + args = Hyperparameters() + zeropower_via_newtonschulz5 = torch.compile(zeropower_via_newtonschulz5) + + # ----------------------------- + # DISTRIBUTED + CUDA SETUP + # ----------------------------- + + distributed = "RANK" in os.environ and "WORLD_SIZE" in os.environ + rank = int(os.environ.get("RANK", "0")) + world_size = int(os.environ.get("WORLD_SIZE", "1")) + local_rank = int(os.environ.get("LOCAL_RANK", "0")) + if world_size <= 0: + raise ValueError(f"WORLD_SIZE must be positive, got {world_size}") + if 8 % world_size != 0: + raise ValueError(f"WORLD_SIZE={world_size} must divide 8 so grad_accum_steps stays integral") + grad_accum_steps = 8 // world_size + grad_scale = 1.0 / grad_accum_steps + if not torch.cuda.is_available(): + raise RuntimeError("CUDA is required") + device = torch.device("cuda", local_rank) + torch.cuda.set_device(device) + if distributed: + dist.init_process_group(backend="nccl", device_id=device) + dist.barrier() + master_process = rank == 0 + + # Fast math knobs + torch.backends.cuda.matmul.allow_tf32 = True + torch.backends.cudnn.allow_tf32 = True + from torch.backends.cuda import enable_cudnn_sdp, enable_flash_sdp, enable_math_sdp, enable_mem_efficient_sdp + + enable_cudnn_sdp(False) + enable_flash_sdp(True) + enable_mem_efficient_sdp(False) + enable_math_sdp(False) + + logfile = None + if master_process: + os.makedirs("logs", exist_ok=True) + logfile = f"logs/{args.run_id}.txt" + print(logfile) + + def log0(msg: str, console: bool = True) -> None: + if not master_process: + return + if console: + print(msg) + if logfile is not None: + with open(logfile, "a", encoding="utf-8") as f: + print(msg, file=f) + + log0(code, console=False) + log0("=" * 100, console=False) + log0(f"Running Python {sys.version}", console=False) + log0(f"Running PyTorch {torch.__version__}", console=False) + log0( + subprocess.run(["nvidia-smi"], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True, check=False).stdout, + console=False, + ) + log0("=" * 100, console=False) + + # ----------------------------- + # TOKENIZER + VALIDATION METRIC SETUP + # ----------------------------- + + random.seed(args.seed) + np.random.seed(args.seed) + torch.manual_seed(args.seed) + torch.cuda.manual_seed_all(args.seed) + + if not args.tokenizer_path.endswith(".model"): + raise ValueError(f"Script only setup for SentencePiece .model file: {args.tokenizer_path}") + sp = spm.SentencePieceProcessor(model_file=args.tokenizer_path) + if int(sp.vocab_size()) != args.vocab_size: + raise ValueError( + f"VOCAB_SIZE={args.vocab_size} does not match tokenizer vocab_size={int(sp.vocab_size())}" + ) + dataset_dir = Path(args.data_path).resolve() + actual_train_files = len(list(dataset_dir.glob("fineweb_train_*.bin"))) + val_tokens = load_validation_tokens(args.val_files, args.train_seq_len) + base_bytes_lut, has_leading_space_lut, is_boundary_token_lut = build_sentencepiece_luts( + sp, args.vocab_size, device + ) + log0(f"val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path={args.tokenizer_path}") + log0(f"train_loader:dataset:{dataset_dir.name} train_shards:{actual_train_files}") + log0(f"val_loader:shards pattern={args.val_files} tokens:{val_tokens.numel() - 1}") + + # ----------------------------- + # MODEL + OPTIMIZER SETUP + # ----------------------------- + + base_model = GPT( + vocab_size=args.vocab_size, + num_layers=args.num_layers, + model_dim=args.model_dim, + num_heads=args.num_heads, + num_kv_heads=args.num_kv_heads, + mlp_mult=args.mlp_mult, + tie_embeddings=args.tie_embeddings, + tied_embed_init_std=args.tied_embed_init_std, + logit_softcap=args.logit_softcap, + rope_base=args.rope_base, + qk_gain_init=args.qk_gain_init, + use_mos=args.use_mos, + mos_k=args.mos_k, + mos_rank=args.mos_rank, + ).to(device).bfloat16() + for module in base_model.modules(): + if isinstance(module, CastedLinear): + module.float() + if isinstance(module, Rotary): + module.inv_freq.data = module.inv_freq.data.float() + restore_low_dim_params_to_fp32(base_model) + compiled_model = torch.compile(base_model, dynamic=False, fullgraph=True) + model: nn.Module = DDP(compiled_model, device_ids=[local_rank], broadcast_buffers=False) if distributed else compiled_model + + # Optimizer split: + # - token embedding (Adam) uses EMBED_LR + # - untied lm_head (Adam) uses HEAD_LR + # - matrix params in transformer blocks use MATRIX_LR via Muon + # - vectors/scalars use SCALAR_LR via Adam + block_named_params = list(base_model.blocks.named_parameters()) + matrix_params = [ + p + for name, p in block_named_params + if p.ndim == 2 and not any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) + ] + scalar_params = [ + p + for name, p in block_named_params + if p.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) + ] + if base_model.skip_weights.numel() > 0: + scalar_params.append(base_model.skip_weights) + # MoS parameters: 2D projection weights go to Muon, gate goes to scalar optimizer + if base_model.mos is not None: + if base_model.mos.rank > 0: + matrix_params.append(base_model.mos.proj_down.weight) + matrix_params.append(base_model.mos.proj_up.weight) + else: + matrix_params.append(base_model.mos.projections.weight) + scalar_params.append(base_model.mos.gate.weight) + token_lr = args.tied_embed_lr if args.tie_embeddings else args.embed_lr + optimizer_tok = torch.optim.Adam( + [{"params": [base_model.tok_emb.weight], "lr": token_lr, "base_lr": token_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + fused=True, + ) + optimizer_muon = Muon( + matrix_params, + lr=args.matrix_lr, + momentum=args.muon_momentum, + backend_steps=args.muon_backend_steps, + ) + for group in optimizer_muon.param_groups: + group["base_lr"] = args.matrix_lr + optimizer_scalar = torch.optim.Adam( + [{"params": scalar_params, "lr": args.scalar_lr, "base_lr": args.scalar_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + fused=True, + ) + optimizers: list[torch.optim.Optimizer] = [optimizer_tok, optimizer_muon, optimizer_scalar] + if base_model.lm_head is not None: + optimizer_head = torch.optim.Adam( + [{"params": [base_model.lm_head.weight], "lr": args.head_lr, "base_lr": args.head_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + fused=True, + ) + optimizers.insert(1, optimizer_head) + + n_params = sum(p.numel() for p in base_model.parameters()) + log0(f"model_params:{n_params}") + log0(f"world_size:{world_size} grad_accum_steps:{grad_accum_steps}") + log0("sdp_backends:cudnn=False flash=True mem_efficient=False math=False") + log0(f"attention_mode:gqa num_heads:{args.num_heads} num_kv_heads:{args.num_kv_heads}") + log0( + f"tie_embeddings:{args.tie_embeddings} embed_lr:{token_lr} " + f"head_lr:{args.head_lr if base_model.lm_head is not None else 0.0} " + f"matrix_lr:{args.matrix_lr} scalar_lr:{args.scalar_lr}" + ) + log0( + f"train_batch_tokens:{args.train_batch_tokens} train_seq_len:{args.train_seq_len} " + f"iterations:{args.iterations} warmup_steps:{args.warmup_steps} " + f"max_wallclock_seconds:{args.max_wallclock_seconds:.3f}" + ) + log0(f"seed:{args.seed}") + + # ----------------------------- + # DATA LOADER & MODEL WARMUP + # ----------------------------- + + train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) + + def zero_grad_all() -> None: + for opt in optimizers: + opt.zero_grad(set_to_none=True) + + max_wallclock_ms = 1000.0 * args.max_wallclock_seconds if args.max_wallclock_seconds > 0 else None + + def lr_mul(step: int, elapsed_ms: float) -> float: + if args.warmdown_iters <= 0: + return 1.0 + if max_wallclock_ms is None: + warmdown_start = max(args.iterations - args.warmdown_iters, 0) + return max((args.iterations - step) / max(args.warmdown_iters, 1), 0.0) if warmdown_start <= step < args.iterations else 1.0 + step_ms = elapsed_ms / max(step, 1) + warmdown_ms = args.warmdown_iters * step_ms + remaining_ms = max(max_wallclock_ms - elapsed_ms, 0.0) + return remaining_ms / max(warmdown_ms, 1e-9) if remaining_ms <= warmdown_ms else 1.0 + + # Warmup primes the compiled forward/backward/optimizer paths, then we restore the + # initial weights/optimizer state so measured training starts from the true init. + if args.warmup_steps > 0: + initial_model_state = {name: tensor.detach().cpu().clone() for name, tensor in base_model.state_dict().items()} + initial_optimizer_states = [copy.deepcopy(opt.state_dict()) for opt in optimizers] + model.train() + for warmup_step in range(args.warmup_steps): + zero_grad_all() + for micro_step in range(grad_accum_steps): + if distributed: + model.require_backward_grad_sync = micro_step == grad_accum_steps - 1 + x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + warmup_loss = model(x, y) + (warmup_loss * grad_scale).backward() + for opt in optimizers: + opt.step() + zero_grad_all() + if args.warmup_steps <= 20 or (warmup_step + 1) % 10 == 0 or warmup_step + 1 == args.warmup_steps: + log0(f"warmup_step:{warmup_step + 1}/{args.warmup_steps}") + base_model.load_state_dict(initial_model_state, strict=True) + for opt, state in zip(optimizers, initial_optimizer_states, strict=True): + opt.load_state_dict(state) + zero_grad_all() + if distributed: + model.require_backward_grad_sync = True + train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) + + # ----------------------------- + # MAIN TRAINING LOOP + # ----------------------------- + + training_time_ms = 0.0 + stop_after_step: int | None = None + torch.cuda.synchronize() + t0 = time.perf_counter() + + step = 0 + while True: + last_step = step == args.iterations or (stop_after_step is not None and step >= stop_after_step) + + should_validate = last_step or (args.val_loss_every > 0 and step % args.val_loss_every == 0) + if should_validate: + torch.cuda.synchronize() + training_time_ms += 1000.0 * (time.perf_counter() - t0) + val_loss, val_bpb = eval_val( + args, + model, + rank, + world_size, + device, + grad_accum_steps, + val_tokens, + base_bytes_lut, + has_leading_space_lut, + is_boundary_token_lut, + ) + log0( + f"step:{step}/{args.iterations} val_loss:{val_loss:.4f} val_bpb:{val_bpb:.4f} " + f"train_time:{training_time_ms:.0f}ms step_avg:{training_time_ms / max(step, 1):.2f}ms" + ) + torch.cuda.synchronize() + t0 = time.perf_counter() + + if last_step: + if stop_after_step is not None and step < args.iterations: + log0( + f"stopping_early: wallclock_cap train_time:{training_time_ms:.0f}ms " + f"step:{step}/{args.iterations}" + ) + break + + elapsed_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) + scale = lr_mul(step, elapsed_ms) + zero_grad_all() + train_loss = torch.zeros((), device=device) + for micro_step in range(grad_accum_steps): + if distributed: + model.require_backward_grad_sync = micro_step == grad_accum_steps - 1 + x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + loss = model(x, y) + train_loss += loss.detach() + (loss * grad_scale).backward() + train_loss /= grad_accum_steps + + frac = min(step / args.muon_momentum_warmup_steps, 1.0) if args.muon_momentum_warmup_steps > 0 else 1.0 + muon_momentum = (1 - frac) * args.muon_momentum_warmup_start + frac * args.muon_momentum + for group in optimizer_muon.param_groups: + group["momentum"] = muon_momentum + + for opt in optimizers: + for group in opt.param_groups: + group["lr"] = group["base_lr"] * scale + + if args.grad_clip_norm > 0: + torch.nn.utils.clip_grad_norm_(base_model.parameters(), args.grad_clip_norm) + for opt in optimizers: + opt.step() + zero_grad_all() + + step += 1 + approx_training_time_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) + should_log_train = ( + args.train_log_every > 0 + and (step <= 10 or step % args.train_log_every == 0 or stop_after_step is not None) + ) + if should_log_train: + log0( + f"step:{step}/{args.iterations} train_loss:{train_loss.item():.4f} " + f"train_time:{approx_training_time_ms:.0f}ms step_avg:{approx_training_time_ms / step:.2f}ms" + ) + + # Needed to sync whether we've reached the wallclock cap. + reached_cap = max_wallclock_ms is not None and approx_training_time_ms >= max_wallclock_ms + if distributed and max_wallclock_ms is not None: + reached_cap_tensor = torch.tensor(int(reached_cap), device=device) + dist.all_reduce(reached_cap_tensor, op=dist.ReduceOp.MAX) + reached_cap = bool(reached_cap_tensor.item()) + if stop_after_step is None and reached_cap: + stop_after_step = step + + log0( + f"peak memory allocated: {torch.cuda.max_memory_allocated() // 1024 // 1024} MiB " + f"reserved: {torch.cuda.max_memory_reserved() // 1024 // 1024} MiB" + ) + + # ----------------------------- + # SERIALIZATION + ROUNDTRIP VALIDATION + # ----------------------------- + # Save the raw state (useful for debugging/loading in PyTorch directly), then always produce + # the compressed int8+zlib artifact and validate the round-tripped weights. + + if master_process: + torch.save(base_model.state_dict(), "final_model.pt") + model_bytes = os.path.getsize("final_model.pt") + code_bytes = len(code.encode("utf-8")) + log0(f"Serialized model: {model_bytes} bytes") + log0(f"Code size: {code_bytes} bytes") + log0(f"Total submission size: {model_bytes + code_bytes} bytes") + + quant_obj, quant_stats = quantize_state_dict_int8(base_model.state_dict()) + quant_buf = io.BytesIO() + torch.save(quant_obj, quant_buf) + quant_raw = quant_buf.getvalue() + quant_blob = zlib.compress(quant_raw, level=9) + quant_raw_bytes = len(quant_raw) + if master_process: + with open("final_model.int8.ptz", "wb") as f: + f.write(quant_blob) + quant_file_bytes = os.path.getsize("final_model.int8.ptz") + code_bytes = len(code.encode("utf-8")) + ratio = quant_stats["baseline_tensor_bytes"] / max(quant_stats["int8_payload_bytes"], 1) + log0( + f"Serialized model int8+zlib: {quant_file_bytes} bytes " + f"(payload:{quant_stats['int8_payload_bytes']} raw_torch:{quant_raw_bytes} payload_ratio:{ratio:.2f}x)" + ) + log0(f"Total submission size int8+zlib: {quant_file_bytes + code_bytes} bytes") + + if distributed: + dist.barrier() + with open("final_model.int8.ptz", "rb") as f: + quant_blob_disk = f.read() + quant_state = torch.load(io.BytesIO(zlib.decompress(quant_blob_disk)), map_location="cpu") + base_model.load_state_dict(dequantize_state_dict_int8(quant_state), strict=True) + torch.cuda.synchronize() + t_qeval = time.perf_counter() + q_val_loss, q_val_bpb = eval_val( + args, + model, + rank, + world_size, + device, + grad_accum_steps, + val_tokens, + base_bytes_lut, + has_leading_space_lut, + is_boundary_token_lut, + ) + torch.cuda.synchronize() + log0( + f"final_int8_zlib_roundtrip val_loss:{q_val_loss:.4f} val_bpb:{q_val_bpb:.4f} " + f"eval_time:{1000.0 * (time.perf_counter() - t_qeval):.0f}ms" + ) + log0(f"final_int8_zlib_roundtrip_exact val_loss:{q_val_loss:.8f} val_bpb:{q_val_bpb:.8f}") + + # LoRA test-time training evaluation (the competition score) + torch._dynamo.reset() + torch.cuda.synchronize() + t_ttt = time.perf_counter() + ttt_val_loss, ttt_val_bpb = eval_val_ttt_lora( + args, base_model, rank, world_size, device, + base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, + ) + torch.cuda.synchronize() + log0( + f"final_int8_ttt_lora val_loss:{ttt_val_loss:.4f} val_bpb:{ttt_val_bpb:.4f} " + f"eval_time:{1000.0 * (time.perf_counter() - t_ttt):.0f}ms" + ) + + if distributed: + dist.destroy_process_group() + + +if __name__ == "__main__": + main() + +==================================================================================================== +Running Python 3.12.3 (main, Nov 6 2025, 13:44:16) [GCC 13.3.0] +Running PyTorch 2.9.1+cu128 +Fri Mar 20 19:48:46 2026 ++-----------------------------------------------------------------------------------------+ +| NVIDIA-SMI 580.126.09 Driver Version: 580.126.09 CUDA Version: 13.0 | ++-----------------------------------------+------------------------+----------------------+ +| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC | +| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. | +| | | MIG M. | +|=========================================+========================+======================| +| 0 NVIDIA H100 80GB HBM3 On | 00000000:DB:00.0 Off | 0 | +| N/A 27C P0 96W / 700W | 1185MiB / 81559MiB | 10% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ + ++-----------------------------------------------------------------------------------------+ +| Processes: | +| GPU GI CI PID Type Process name GPU Memory | +| ID ID Usage | +|=========================================================================================| +| 0 N/A N/A 1166 C /usr/local/bin/python 1176MiB | ++-----------------------------------------------------------------------------------------+ + +==================================================================================================== +val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path=./data/tokenizers/fineweb_1024_bpe.model +train_loader:dataset:fineweb10B_sp1024 train_shards:80 +val_loader:shards pattern=./data/datasets/fineweb10B_sp1024/fineweb_val_*.bin tokens:62021632 +model_params:17159240 +world_size:1 grad_accum_steps:8 +sdp_backends:cudnn=False flash=True mem_efficient=False math=False +attention_mode:gqa num_heads:8 num_kv_heads:4 +tie_embeddings:True embed_lr:0.05 head_lr:0.0 matrix_lr:0.04 scalar_lr:0.04 +train_batch_tokens:524288 train_seq_len:1024 iterations:20000 warmup_steps:20 max_wallclock_seconds:600.000 +seed:42 +warmup_step:1/20 +warmup_step:2/20 +warmup_step:3/20 +warmup_step:4/20 +warmup_step:5/20 +warmup_step:6/20 +warmup_step:7/20 +warmup_step:8/20 +warmup_step:9/20 +warmup_step:10/20 +warmup_step:11/20 +warmup_step:12/20 +warmup_step:13/20 +warmup_step:14/20 +warmup_step:15/20 +warmup_step:16/20 +warmup_step:17/20 +warmup_step:18/20 +warmup_step:19/20 +warmup_step:20/20 +step:0/20000 val_loss:6.9314 val_bpb:4.1052 train_time:0ms step_avg:0.01ms +step:1/20000 train_loss:6.9314 train_time:581ms step_avg:581.11ms +step:2/20000 train_loss:6.8515 train_time:1295ms step_avg:647.47ms +step:3/20000 train_loss:5.8655 train_time:1996ms step_avg:665.19ms +step:4/20000 train_loss:5.4250 train_time:2793ms step_avg:698.33ms +step:5/20000 train_loss:5.0728 train_time:3413ms step_avg:682.51ms +step:6/20000 train_loss:4.9797 train_time:4016ms step_avg:669.27ms +step:7/20000 train_loss:4.8555 train_time:4676ms step_avg:668.03ms +step:8/20000 train_loss:4.7612 train_time:5341ms step_avg:667.67ms +step:9/20000 train_loss:4.6900 train_time:5990ms step_avg:665.54ms +step:10/20000 train_loss:4.7029 train_time:6682ms step_avg:668.23ms +step:100/20000 train_loss:3.2746 train_time:54475ms step_avg:544.75ms +step:200/20000 train_loss:2.8511 train_time:108479ms step_avg:542.40ms +step:300/20000 train_loss:2.7046 train_time:162973ms step_avg:543.24ms +step:400/20000 train_loss:2.4804 train_time:217390ms step_avg:543.47ms +step:500/20000 train_loss:2.5755 train_time:271183ms step_avg:542.37ms +step:500/20000 val_loss:2.5703 val_bpb:1.5223 train_time:271193ms step_avg:542.39ms +step:600/20000 train_loss:2.5630 train_time:324786ms step_avg:541.31ms +step:700/20000 train_loss:2.5112 train_time:378359ms step_avg:540.51ms +step:800/20000 train_loss:2.3957 train_time:432963ms step_avg:541.20ms +step:900/20000 train_loss:2.4135 train_time:487589ms step_avg:541.77ms +step:1000/20000 train_loss:2.4031 train_time:542181ms step_avg:542.18ms +step:1000/20000 val_loss:2.3696 val_bpb:1.4034 train_time:542248ms step_avg:542.25ms +step:1100/20000 train_loss:2.3186 train_time:594115ms step_avg:540.10ms +step:1113/20000 val_loss:2.3505 val_bpb:1.3921 train_time:600423ms step_avg:539.46ms +stopping_early: wallclock_cap train_time:600423ms step:1113/20000 +peak memory allocated: 11012 MiB reserved: 11320 MiB +Serialized model: 67623386 bytes +Code size: 63345 bytes +Total submission size: 67686731 bytes +Serialized model int8+zlib: 12764492 bytes (payload:17377568 raw_torch:17423635 payload_ratio:3.89x) +Total submission size int8+zlib: 12827837 bytes +final_int8_zlib_roundtrip val_loss:2.3523 val_bpb:1.3932 eval_time:11887ms +final_int8_zlib_roundtrip_exact val_loss:2.35234121 val_bpb:1.39318897 From 2440ea856c84fc5f9bd422e968a6d800c887016b Mon Sep 17 00:00:00 2001 From: Billy Endson Date: Sat, 21 Mar 2026 03:19:52 +0700 Subject: [PATCH 07/32] Add 1-hour MoS validation script (targeting PR#111 baseline) 1hr run with MoS K=2 R=64 + WARMDOWN_ITERS=100 on 1xH100. Target: beat vanilla baseline val_bpb=1.2540 from PR#111. Co-Authored-By: Claude Opus 4.6 --- setup_and_run_1h.sh | 57 +++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 57 insertions(+) create mode 100644 setup_and_run_1h.sh diff --git a/setup_and_run_1h.sh b/setup_and_run_1h.sh new file mode 100644 index 0000000000..bf422d8d6a --- /dev/null +++ b/setup_and_run_1h.sh @@ -0,0 +1,57 @@ +#!/bin/bash +# === Parameter Golf: MoS 1-Hour Validation on 1x H100 === +# Paste this into your RunPod terminal. +# Total time: ~18 min download + 60 min MoS run = ~78 min +# Target to beat: val_bpb 1.2540 (PR#111 vanilla baseline, 1hr/1xH100) + +set -e + +echo "=== Step 1: Download dataset ===" +cd /workspace/parameter-golf + +# HF token for faster downloads +export HF_TOKEN="${HF_TOKEN:-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}" + +# Download full dataset (~18 min) +python3 data/cached_challenge_fineweb.py --variant sp1024 + +# Verify dataset +TRAIN_COUNT=$(ls data/datasets/fineweb10B_sp1024/fineweb_train_*.bin 2>/dev/null | wc -l) +VAL_COUNT=$(ls data/datasets/fineweb10B_sp1024/fineweb_val_*.bin 2>/dev/null | wc -l) +echo "Train shards: $TRAIN_COUNT Val shards: $VAL_COUNT" +if [ "$TRAIN_COUNT" -lt 1 ]; then + echo "ERROR: No training shards found. Dataset download failed." + exit 1 +fi + +echo "" +echo "=== Step 2: Run MoS K=2 R=64 (1 HOUR, 1x H100) ===" +echo "Start time: $(date)" +echo "Target: beat PR#111 baseline val_bpb=1.2540" +echo "" + +RUN_ID=mos_k2_r64_1h \ +DATA_PATH=./data/datasets/fineweb10B_sp1024 \ +TOKENIZER_PATH=./data/tokenizers/fineweb_1024_bpe.model \ +VOCAB_SIZE=1024 \ +SEED=42 \ +USE_MOS=1 \ +MOS_K=2 \ +MOS_RANK=64 \ +WARMDOWN_ITERS=100 \ +MAX_WALLCLOCK_SECONDS=3600 \ +VAL_LOSS_EVERY=500 \ +TRAIN_LOG_EVERY=100 \ +torchrun --standalone --nproc_per_node=1 train_gpt.py 2>&1 | tee /workspace/mos_1h_log.txt + +echo "" +echo "=== RESULTS ===" +echo "" +grep -E 'val_bpb|val_loss|bytes|param|model_params|stopping' /workspace/mos_1h_log.txt | tail -20 +echo "" +echo "=== COMPARISON ===" +echo "Target (PR#111 vanilla 1hr): val_bpb=1.2540" +echo "Our 10-min MoS pilot: val_bpb=1.3932" +echo "PR#111 10-min baseline: val_bpb=1.3486" +echo "" +echo "Done at: $(date)" From 0773ee01a583093fcacd3cf027770a1259165f52 Mon Sep 17 00:00:00 2001 From: User123331 Date: Sat, 21 Mar 2026 04:15:51 +0700 Subject: [PATCH 08/32] Make 1h script survive terminal disconnects via nohup MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Training now runs in background — safe to close terminal. Monitor with: tail -f /workspace/mos_1h_log.txt Co-Authored-By: Claude Opus 4.6 --- setup_and_run_1h.sh | 48 +++++++++++++++++++++++++++++---------------- 1 file changed, 31 insertions(+), 17 deletions(-) diff --git a/setup_and_run_1h.sh b/setup_and_run_1h.sh index bf422d8d6a..1efe040661 100644 --- a/setup_and_run_1h.sh +++ b/setup_and_run_1h.sh @@ -1,8 +1,8 @@ #!/bin/bash # === Parameter Golf: MoS 1-Hour Validation on 1x H100 === -# Paste this into your RunPod terminal. -# Total time: ~18 min download + 60 min MoS run = ~78 min -# Target to beat: val_bpb 1.2540 (PR#111 vanilla baseline, 1hr/1xH100) +# Usage: bash setup_and_run_1h.sh +# The script runs training inside nohup so it survives terminal disconnects. +# Log is written to /workspace/mos_1h_log.txt — check with: tail -f /workspace/mos_1h_log.txt set -e @@ -12,7 +12,7 @@ cd /workspace/parameter-golf # HF token for faster downloads export HF_TOKEN="${HF_TOKEN:-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}" -# Download full dataset (~18 min) +# Download full dataset (~18 min, skips if already present) python3 data/cached_challenge_fineweb.py --variant sp1024 # Verify dataset @@ -27,9 +27,14 @@ fi echo "" echo "=== Step 2: Run MoS K=2 R=64 (1 HOUR, 1x H100) ===" echo "Start time: $(date)" -echo "Target: beat PR#111 baseline val_bpb=1.2540" +echo "" +echo "Training will run in the background via nohup." +echo "Monitor with: tail -f /workspace/mos_1h_log.txt" +echo "Check GPU with: nvidia-smi" +echo "Safe to close terminal — training will continue." echo "" +nohup bash -c ' RUN_ID=mos_k2_r64_1h \ DATA_PATH=./data/datasets/fineweb10B_sp1024 \ TOKENIZER_PATH=./data/tokenizers/fineweb_1024_bpe.model \ @@ -42,16 +47,25 @@ WARMDOWN_ITERS=100 \ MAX_WALLCLOCK_SECONDS=3600 \ VAL_LOSS_EVERY=500 \ TRAIN_LOG_EVERY=100 \ -torchrun --standalone --nproc_per_node=1 train_gpt.py 2>&1 | tee /workspace/mos_1h_log.txt +torchrun --standalone --nproc_per_node=1 train_gpt.py +' > /workspace/mos_1h_log.txt 2>&1 & -echo "" -echo "=== RESULTS ===" -echo "" -grep -E 'val_bpb|val_loss|bytes|param|model_params|stopping' /workspace/mos_1h_log.txt | tail -20 -echo "" -echo "=== COMPARISON ===" -echo "Target (PR#111 vanilla 1hr): val_bpb=1.2540" -echo "Our 10-min MoS pilot: val_bpb=1.3932" -echo "PR#111 10-min baseline: val_bpb=1.3486" -echo "" -echo "Done at: $(date)" +TRAIN_PID=$! +echo "Training PID: $TRAIN_PID" +echo "PID saved to /workspace/train.pid" +echo "$TRAIN_PID" > /workspace/train.pid + +# Wait a few seconds and confirm it started +sleep 5 +if kill -0 $TRAIN_PID 2>/dev/null; then + echo "Training is running. You can safely close this terminal." + echo "" + echo "=== Quick commands ===" + echo " Monitor: tail -f /workspace/mos_1h_log.txt" + echo " Status: nvidia-smi" + echo " Kill: kill \$(cat /workspace/train.pid)" +else + echo "ERROR: Training process died. Check /workspace/mos_1h_log.txt" + tail -20 /workspace/mos_1h_log.txt + exit 1 +fi From 69d4a7ddb9803b627343d0e03c499e05e357871c Mon Sep 17 00:00:00 2001 From: User123331 Date: Sat, 21 Mar 2026 04:18:08 +0700 Subject: [PATCH 09/32] Add vanilla baseline 10-min script for 1xH100 comparison Co-Authored-By: Claude Opus 4.6 --- run_baseline_10min.sh | 20 ++++++++++++++++++++ 1 file changed, 20 insertions(+) create mode 100644 run_baseline_10min.sh diff --git a/run_baseline_10min.sh b/run_baseline_10min.sh new file mode 100644 index 0000000000..d81a4526e2 --- /dev/null +++ b/run_baseline_10min.sh @@ -0,0 +1,20 @@ +#!/bin/bash +# Vanilla baseline, 10 min, 1x H100. Survives terminal disconnect. +# Monitor: tail -f /workspace/baseline_10min_log.txt + +cd /workspace/parameter-golf + +nohup bash -c ' +RUN_ID=baseline_10min \ +DATA_PATH=./data/datasets/fineweb10B_sp1024 \ +TOKENIZER_PATH=./data/tokenizers/fineweb_1024_bpe.model \ +VOCAB_SIZE=1024 \ +SEED=42 \ +MAX_WALLCLOCK_SECONDS=600 \ +VAL_LOSS_EVERY=500 \ +TRAIN_LOG_EVERY=100 \ +torchrun --standalone --nproc_per_node=1 train_gpt.py +' > /workspace/baseline_10min_log.txt 2>&1 & + +echo "PID: $!" +echo "Monitor: tail -f /workspace/baseline_10min_log.txt" From 0a32f9d8795362d006a48d6b2e4411ffcd8b2061 Mon Sep 17 00:00:00 2001 From: User123331 Date: Sat, 21 Mar 2026 05:02:12 +0700 Subject: [PATCH 10/32] Add research artifacts: technique encyclopedia, combination matrix, speed optimizations, SOTA plan MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - techniques_encyclopedia.md: 39 techniques catalog with bpb impacts and PR references - combination_matrix.md: Compatibility matrix (++/+/~/−) with stacking recommendations - speed_optimizations.md: Triton/FA3/fused kernels research for throughput gains - PLAN_beat_SOTA.md: Phase-by-phase implementation plan targeting <1.13 bpb MoS rejected after experiments showed +0.057 bpb worse than baseline. Co-Authored-By: Claude Opus 4.6 --- Graphs/PLAN_beat_SOTA.md | 156 ++++++++++++++++++++++++++++++++++ Graphs/speed_optimizations.md | 99 +++++++++++++++++++++ 2 files changed, 255 insertions(+) create mode 100644 Graphs/PLAN_beat_SOTA.md create mode 100644 Graphs/speed_optimizations.md diff --git a/Graphs/PLAN_beat_SOTA.md b/Graphs/PLAN_beat_SOTA.md new file mode 100644 index 0000000000..b6f7bb8d18 --- /dev/null +++ b/Graphs/PLAN_beat_SOTA.md @@ -0,0 +1,156 @@ +# Plan: Beat SOTA (1.1428 bpb) + +**Date**: 2026-03-21 +**Current SOTA**: 1.1428 (thwu1, PR #180) +**Emerging**: 1.1303 (PR #254), 1.1307 (PR #265) — not yet on leaderboard +**Our target**: < 1.13 bpb + +--- + +## Strategy: Combine proven techniques nobody has stacked together yet + +The key insight from analyzing all PRs: **no single submission combines ALL the best techniques**. Each top entry uses a subset. We stack them all. + +--- + +## The Stack + +### Layer 1: Base Architecture (from thwu1 #180) +- 10-11 layers, dim=512, 8 heads, 4 KV heads (GQA) +- MLP 3x (hidden=1536), ReLU-squared +- U-Net skip connections +- Tied embeddings (FP16 passthrough) +- Logit softcap=30 + +### Layer 2: Quantization (from thwu1 #180) +- Int5 for MLP weights (saves ~1.86MB for extra layer/features) +- Int6 for attention weights +- zstd-22 compression +- 3% magnitude pruning post-training (better compression) +- WD=0.04 for quantization robustness + +### Layer 3: Input Augmentation (from thwu1 #180 + #265) +- BigramHash(10240) buckets, dim=128, projected to 512 +- SmearGate (proven compatible, +0.005-0.008) + +### Layer 4: Training Optimization (best of all PRs) +- Muon: lr=0.02, WD=0.04, momentum warmup 0.92→0.99 over 1500 steps (from #265) +- SWA: start_frac=0.4, every=50 steps (from thwu1) +- OrthoInit + muP scaling +- Warmdown=3000, warmup=20, grad_clip=0.3 +- Seq2048, batch=524K tokens (from #236 — more gradient updates) + +### Layer 5: Speed (from #265 + modded-nanogpt) +- FlashAttention 3 (Hopper native) — ~5% faster steps +- Fused Linear+ReLU^2 Triton kernel — ~10% MLP speedup +- torch.compile mode="max-autotune" + +### Layer 6: Eval-Time (from #265 + #267) +- Sliding window eval (stride=64) +- Partial XSA on last 3 layers (from #265, +0.002 bpb, only 2ms/step) +- Causal TTT: SGD on val chunks after scoring (from #267, +0.003 bpb) + +### Layer 7: Free Training Signal +- MTP auxiliary head (predict t+2, t+3) — discarded at save, zero artifact cost +- From PR #88 — provides gradient enrichment during training + +--- + +## Expected Impact Breakdown + +| Technique | bpb gain over baseline | Source | +|-----------|----------------------|--------| +| Int5/6 + MLP3x + 10L | ~0.08 | thwu1 baseline | +| BigramHash(10240) | ~0.01 | thwu1 | +| SmearGate | ~0.006 | PR #162 | +| SWA | ~0.005 | thwu1 | +| OrthoInit + muP | ~0.004 | PR #198 | +| Sliding Window | ~0.03 | All top PRs | +| Seq2048 | ~0.015 | PR #198 | +| Smaller batch (524K) | ~0.003 | PR #236 | +| FA3 + fused kernels (more steps) | ~0.005 | PR #265 | +| Partial XSA (last 3 layers) | ~0.002 | PR #265 | +| Causal TTT | ~0.003 | PR #267 | +| MTP auxiliary | ~0.002 | PR #88 | +| **Total from 1.2244 baseline** | **~0.165** | | +| **Projected bpb** | **~1.06-1.10** | | + +Conservative estimate: **1.10-1.12 bpb** (not everything stacks perfectly). + +--- + +## Implementation Phases + +### Phase 1: Fork SOTA code (~2 hours) +- Take thwu1's train_gpt.py from PR #180 as base +- Verify it reproduces 1.1428 on 8xH100 (10 min run, ~$3) +- This becomes our baseline to improve upon + +### Phase 2: Add proven extras (~3 hours) +- Add SmearGate (if not already in thwu1's code) +- Add Muon momentum warmup (0.92→0.99) +- Switch to batch=524K +- Add FlashAttention 3 +- Test on 1xH100 for quick validation + +### Phase 3: Add novel techniques (~4 hours) +- Implement Partial XSA on last 3 layers (from PR #265) +- Add MTP auxiliary head (from PR #88) +- Add fused Triton kernels (Linear+ReLU^2, softcapped CE) +- Test on 1xH100 + +### Phase 4: Eval-time optimization (~2 hours) +- Implement Causal TTT (SGD, 3 epochs per chunk) +- Tune TTT hyperparameters (lr, momentum, epochs) +- Test on 1xH100 + +### Phase 5: Record attempt (~$20) +- Full run on 8xH100, 10 min +- Submit to record track +- If < 1.13 → PR to openai/parameter-golf + +--- + +## Compute Budget + +| Phase | Hardware | Time | Cost | +|-------|----------|------|------| +| Phase 1 | 8xH100 | 15 min | ~$5 | +| Phase 2 | 1xH100 | 30 min | ~$2 | +| Phase 3 | 1xH100 | 1 hour | ~$4 | +| Phase 4 | 1xH100 | 30 min | ~$2 | +| Phase 5 | 8xH100 | 15 min | ~$5 | +| Buffer | — | — | ~$5 | +| **Total** | | | **~$23** | + +--- + +## What Makes This Novel + +Nobody has combined ALL of these: +1. Int5/Int6 mixed quant + 10-11L (thwu1) +2. + Partial XSA (PR #265, brand new technique) +3. + MTP auxiliary training (PR #88, free signal) +4. + Causal TTT (PR #267) +5. + FA3 + fused Triton kernels (modded-nanogpt) +6. + Optimized batch size (PR #236) + +Each top PR uses 3-4 of these. We use all 6+. + +--- + +## Risk Assessment + +| Risk | Mitigation | +|------|-----------| +| Techniques don't stack as expected | Phase-by-phase testing on 1xH100 | +| XSA + TTT conflict | Test independently first | +| Int5 fragile with new techniques | Fall back to Int6 if quant degrades | +| Compute budget overrun | 1xH100 validation before 8xH100 record | +| FA3 install issues on RunPod | FA3 may already be in the template; fall back to FA2 | + +--- + +## Immediate Next Step + +Pull thwu1's code from PR #180 and start Phase 1. diff --git a/Graphs/speed_optimizations.md b/Graphs/speed_optimizations.md new file mode 100644 index 0000000000..c58770bbec --- /dev/null +++ b/Graphs/speed_optimizations.md @@ -0,0 +1,99 @@ +# Speed Optimizations: Triton Kernels & Libraries + +**Goal**: More training steps in the same wallclock = better bpb + +--- + +## Priority 1: FlashAttention 3 (~5% step time reduction) + +**What**: H100-optimized attention using Hopper async TMA + warp specialization +**Speedup**: 1.5-2x over FA2 in attention forward, ~5% overall step time +**Integration**: Drop-in replacement +```python +from flash_attn_interface import flash_attn_func as flash_attn_3_func +``` +**Status**: Proven — PRs #198 and #164 use this. Only external library in top submissions. +**Install**: `pip install flash-attn --no-build-isolation` (from hopper branch) + +--- + +## Priority 2: Fused Linear+ReLU^2 Triton Kernel (~5-15% MLP speedup) + +**What**: Fuses CastedLinear + relu().square() into one Triton kernel +**Source**: modded-nanogpt `triton_kernels.FusedLinearReLUSquareFunction` +**Why it helps**: Eliminates intermediate tensor materialization in MLP (which is 3x expanded) +**Integration**: Copy Triton kernel, replace MLP forward pass +**Status**: Used in modded-nanogpt speedrun, not yet in any Parameter Golf PR + +--- + +## Priority 3: Fused Softcapped Cross-Entropy (~2-5% loss speedup) + +**What**: Fuses logit_softcap + cross_entropy into one Triton kernel +**Source**: modded-nanogpt `triton_kernels.FusedSoftcappedCrossEntropy` +**Why it helps**: Avoids materializing softcapped logits tensor +**Integration**: Copy Triton kernel, replace loss computation +**Note**: Only applies to non-MoS path (MoS uses nll_loss on log-probs) +**Status**: Used in modded-nanogpt speedrun, not yet in any Parameter Golf PR + +--- + +## Priority 4: torch.compile Tuning (0-5% overall) + +```python +# Current +torch.compile(model, dynamic=False, fullgraph=True) + +# Try +torch.compile(model, dynamic=False, fullgraph=True, mode="max-autotune") +``` + +Also set env var: +```bash +export PYTORCH_ALLOC_CONF="expandable_segments:True" +``` + +--- + +## Priority 5: Gradient Checkpointing (enables larger batch/seq) + +**What**: Recompute activations in backward pass instead of storing them +**Benefit**: 50-70% activation memory reduction, enables seq=2048 or larger batch on 1xH100 +**Cost**: ~20-33% more compute (5-10% wall-clock in practice) +**When to use**: If moving to seq=2048+ on 1xH100 + +--- + +## Priority 6: Custom Triton MoS Kernel (if MoS proves useful) + +**What**: Fuse log_softmax over K components + logsumexp mixture into one kernel +**Expected**: Reduce MoS overhead from ~5ms to ~2-3ms per step +**Effort**: ~50-100 lines of Triton, based on fused softmax tutorial +**Note**: The bigger bottleneck is the K einsum matmuls, not the softmax + +--- + +## NOT Worth It at Our Scale + +| Technique | Why Skip | +|-----------|----------| +| FP8 training (torchao) | dim=512 matrices too small, overhead > benefit | +| Fused RMSNorm | torch.compile already fuses it | +| Apex FusedAdam | Already using fused=True, marginal gain | +| Liger FusedCE | Logit tensor tiny at vocab=1024 | +| bitsandbytes 8-bit optimizer | Model too small to benefit | + +--- + +## Impact Estimate + +| Optimization | Step Time Reduction | Extra Steps in 10min | bpb Impact | +|-------------|--------------------|--------------------|------------| +| FA3 | ~5% | +1000 steps | ~0.005 bpb | +| Fused MLP | ~10% | +2000 steps | ~0.008 bpb | +| Fused CE | ~3% | +600 steps | ~0.002 bpb | +| max-autotune | ~2% | +400 steps | ~0.001 bpb | +| **Combined** | **~20%** | **+4000 steps** | **~0.015 bpb** | + +At current ~500ms/step, 20% reduction = 400ms/step = ~1500 steps in 10min → ~1875 steps. +On 8xH100 at ~27ms/step, 20% = ~22ms/step = ~27,300 steps vs ~22,200. From fd4f106cb35866ffb8f191eb4f249685f7b2a245 Mon Sep 17 00:00:00 2001 From: User123331 Date: Sat, 21 Mar 2026 12:30:27 +0700 Subject: [PATCH 11/32] Integrate SOTA stack (thwu1's 1.1428 bpb) + custom tokenizer pipeline - Replace train_gpt.py with thwu1's #1 implementation: - 10 layers, 3x MLP, BigramHash(10240), SmearGate - Mixed int5/int6 quantization, SWA, sliding eval - zstd-22 compression, magnitude pruning - Add custom tokenizer training pipeline: - run_custom_tokenizer_pipeline.sh: all-in-one script - data/train_tokenizer.py: SentencePiece trainer - Add run scripts: - run_competitive.sh: SOTA stack with default tokenizer - run_competitive_custom_tok.sh: SOTA stack with custom tokenizer Co-Authored-By: Claude Opus 4.6 --- data/train_tokenizer.py | 655 ++++++++++++++++++++ run_competitive.sh | 92 +++ run_competitive_custom_tok.sh | 87 +++ run_custom_tokenizer_pipeline.sh | 149 +++++ train_gpt.py | 986 ++++++++++++------------------- 5 files changed, 1353 insertions(+), 616 deletions(-) create mode 100644 data/train_tokenizer.py create mode 100755 run_competitive.sh create mode 100755 run_competitive_custom_tok.sh create mode 100755 run_custom_tokenizer_pipeline.sh diff --git a/data/train_tokenizer.py b/data/train_tokenizer.py new file mode 100644 index 0000000000..44f4b0b415 --- /dev/null +++ b/data/train_tokenizer.py @@ -0,0 +1,655 @@ +#!/usr/bin/env python3 +""" +Tokenizer Trainer for Parameter Golf Competition + +Trains custom SentencePiece tokenizers on FineWeb data, evaluates quality, +and optionally exports binary shards compatible with train_gpt.py. + +Supports both BPE and Unigram model types. Research suggests Unigram often +outperforms BPE at small vocab sizes (512-4096) due to its top-down pruning +strategy selecting globally-useful tokens vs BPE's greedy bottom-up merges. + +Usage: + python train_tokenizer.py --vocab-size 1024 --model-type bpe + python train_tokenizer.py --vocab-size 1024 --model-type unigram + python train_tokenizer.py --compare # Compare BPE vs Unigram across vocab sizes + python train_tokenizer.py --vocab-size 1024 --export-shards # Train + export .bin shards + python train_tokenizer.py --evaluate path/to/model.model # Evaluate existing tokenizer + +Integration: + The trained .model file plugs directly into train_gpt.py: + VOCAB_SIZE=1024 python train_gpt.py --tokenizer-path ./tokenizers/spm_bpe_1024.model +""" + +import argparse +import json +import math +import os +import sys +import time +from collections import Counter +from pathlib import Path +from typing import Any, Dict, List, Optional + +import numpy as np + +# ============================================================================= +# CONFIGURATION +# ============================================================================= + +FINEWEB_DOCS_PATH = Path(os.environ.get( + "DOCS_JSONL_PATH", + "./data/docs_selected.jsonl" # Default for RunPod; override for local use +)) +OUTPUT_DIR = Path("./tokenizers") + +# Binary shard format (must match train_gpt.py / download_hf_docs_and_tokenize.py) +DATAFILE_MAGIC = 20240520 +DATAFILE_VERSION = 1 +SHARD_SIZE = 10**8 # 100M tokens per shard +NUM_VAL_DOCS = 50_000 +APPEND_EOS = False + +# Training data size recommendations by vocab size +# Larger vocabs need more data to see enough merge candidates +TRAIN_DOCS_COUNT = { + 512: 50_000, + 1024: 100_000, + 2048: 200_000, + 4096: 500_000, +} + +# Sample sentences for manual inspection of tokenization quality +SAMPLE_SENTENCES = [ + "The quick brown fox jumps over the lazy dog.", + "Machine learning models require careful hyperparameter tuning.", + "In 2024, researchers published 3,847 papers on language models.", + "def fibonacci(n): return n if n <= 1 else fibonacci(n-1) + fibonacci(n-2)", + "The café served crème brûlée for €12.50 — absolutely délicieux!", + "HTTP/1.1 200 OK\nContent-Type: application/json\n{\"status\": \"success\"}", + "∀x ∈ ℝ, |sin(x)| ≤ 1", +] + + +# ============================================================================= +# DATA LOADING +# ============================================================================= + +def iter_docs_jsonl(docs_path: Path, max_docs: Optional[int] = None): + """Iterate over texts from FineWeb docs_selected.jsonl (streaming, low memory).""" + if not docs_path.exists(): + raise FileNotFoundError( + f"{docs_path} not found!\n" + "Download FineWeb data first:\n" + " cd ../parameter-golf && python data/cached_challenge_fineweb.py --variant sp1024" + ) + count = 0 + with open(docs_path) as f: + for line in f: + if max_docs is not None and count >= max_docs: + break + line = line.strip() + if line: + yield json.loads(line)["text"] + count += 1 + print(f"Loaded {count:,} documents from {docs_path}") + + +def iter_sentences_for_spm(docs_path: Path, max_docs: Optional[int] = None): + """Yield individual sentences for SentencePiece sentence_iterator. + + SentencePiece's sentence_iterator expects one sentence per yield. + Splitting on newlines gives cleaner training signal than full documents. + """ + for text in iter_docs_jsonl(docs_path, max_docs): + for line in text.split("\n"): + line = line.strip() + if line: + yield line + + +# ============================================================================= +# SENTENCEPIECE TRAINER +# ============================================================================= + +def train_sentencepiece( + docs_path: Path, + vocab_size: int, + model_type: str, + output_dir: Path, + max_docs: Optional[int] = None, + *, + character_coverage: float = 0.995, + max_sentencepiece_length: int = 16, + min_frequency: int = 2, + input_sentence_size: int = 0, +) -> Dict[str, Any]: + """Train a SentencePiece tokenizer (BPE or Unigram). + + Args: + docs_path: Path to docs_selected.jsonl + vocab_size: Target vocabulary size + model_type: 'bpe' or 'unigram' + output_dir: Directory to save model files + max_docs: Max documents to use for training + character_coverage: Unicode character coverage (0.995 recommended for small vocabs) + max_sentencepiece_length: Max token length in chars (prevents overly long tokens) + min_frequency: Minimum frequency for a token to be kept + input_sentence_size: Max sentences to use (0 = all). Set to 5M+ for large corpora. + """ + try: + import sentencepiece as spm + except ImportError: + print("ERROR: sentencepiece not installed. Run: pip install sentencepiece") + return {} + + output_dir.mkdir(parents=True, exist_ok=True) + model_prefix = str(output_dir / f"spm_{model_type}_{vocab_size}") + + print(f"\n{'=' * 60}") + print(f"Training SentencePiece {model_type.upper()} (vocab={vocab_size})") + print(f"{'=' * 60}") + + # Build training kwargs — uses sentence_iterator (streaming) instead of + # writing a temp file, matching the official data pipeline. + kwargs: Dict[str, Any] = { + "sentence_iterator": iter_sentences_for_spm(docs_path, max_docs), + "model_prefix": model_prefix, + "model_type": model_type, + "vocab_size": vocab_size, + # --- Coverage & fallback --- + # 0.995 saves vocab slots vs 0.9995; rare Unicode falls back to bytes. + "character_coverage": character_coverage, + # Critical for small vocab: reserves 256 byte tokens so any byte sequence + # is representable (no output). Costs 256 vocab slots. + "byte_fallback": True, + # --- Splitting rules --- + "split_digits": True, # Each digit 0-9 is its own token + "split_by_unicode_script": True, # Prevent cross-script merges + "split_by_number": True, # Prevent number-letter merges + # --- Normalization --- + # nmt_nfkc collapses Unicode variants (fullwidth chars, ligatures, etc.) + # saving vocab slots. Must match at inference time. + "normalization_rule_name": "nmt_nfkc", + # False matches the official data pipeline. Means "Hello" stays "Hello" + # (no leading ▁), but " Hello" gets ▁. + "add_dummy_prefix": False, + # --- Special token IDs (must match train_gpt.py) --- + "pad_id": 0, + "bos_id": 1, + "eos_id": 2, + "unk_id": 3, + # --- Vocab constraints --- + "hard_vocab_limit": False, + "max_sentencepiece_length": max_sentencepiece_length, + } + + if min_frequency > 0: + # SentencePiece: --min_frequency is only used in BPE mode (ignored for unigram) + kwargs["min_frequency"] = min_frequency + + if input_sentence_size > 0: + kwargs["input_sentence_size"] = input_sentence_size + kwargs["shuffle_input_sentence"] = True + + start_time = time.time() + spm.SentencePieceTrainer.train(**kwargs) + train_time = time.time() - start_time + + print(f"Training completed in {train_time:.1f}s") + + # Load and verify + sp = spm.SentencePieceProcessor(model_file=f"{model_prefix}.model") + + # Count token types + n_byte = sum(1 for i in range(sp.vocab_size()) if sp.is_byte(i)) + n_control = sum(1 for i in range(sp.vocab_size()) if sp.is_control(i)) + n_unknown = sum(1 for i in range(sp.vocab_size()) if sp.is_unknown(i)) + n_learned = sp.vocab_size() - n_byte - n_control - n_unknown + + print(f"Actual vocab size: {sp.vocab_size()}") + print(f" Learned subword tokens: {n_learned}") + print(f" Byte fallback tokens: {n_byte}") + print(f" Control tokens: {n_control}") + print(f" Unknown tokens: {n_unknown}") + + return { + "method": f"sentencepiece_{model_type}", + "model_type": model_type, + "vocab_size": vocab_size, + "actual_vocab": sp.vocab_size(), + "learned_tokens": n_learned, + "byte_tokens": n_byte, + "train_time_sec": train_time, + "model_path": f"{model_prefix}.model", + "vocab_path": f"{model_prefix}.vocab", + } + + +# ============================================================================= +# TOKENIZER EVALUATION +# ============================================================================= + +def evaluate_tokenizer(model_path: str, docs_path: Path, n_eval_docs: int = 5000) -> Dict[str, Any]: + """Comprehensive evaluation of a trained SentencePiece tokenizer. + + Computes: + - Compression ratio (bytes per token) — higher is better + - Fertility (tokens per word) — lower is better + - Token length distribution + - Coverage analysis + - Sample tokenizations for manual inspection + """ + import sentencepiece as spm + + sp = spm.SentencePieceProcessor(model_file=model_path) + + print(f"\n{'=' * 60}") + print(f"Evaluating: {model_path}") + print(f"Vocab size: {sp.vocab_size()}, Eval docs: {n_eval_docs:,}") + print(f"{'=' * 60}") + + total_bytes = 0 + total_tokens = 0 + total_words = 0 + token_lengths: List[int] = [] # UTF-8 byte length of each token's text + token_freq: Counter = Counter() + + start = time.time() + for text in iter_docs_jsonl(docs_path, max_docs=n_eval_docs): + text_bytes = len(text.encode("utf-8")) + ids = sp.encode(text, out_type=int) + pieces = sp.encode(text, out_type=str) + words = text.split() + + total_bytes += text_bytes + total_tokens += len(ids) + total_words += len(words) + + for piece in pieces: + clean = piece.lstrip("▁") + token_lengths.append(len(clean.encode("utf-8"))) + + token_freq.update(ids) + + eval_time = time.time() - start + + # Core metrics + bytes_per_token = total_bytes / total_tokens if total_tokens else 0 + fertility = total_tokens / total_words if total_words else 0 + bits_per_byte_ceiling = math.log2(sp.vocab_size()) # Theoretical worst case + + # Token length distribution + lengths_arr = np.array(token_lengths) + pct_single_byte = np.sum(lengths_arr <= 1) / len(lengths_arr) * 100 + pct_multi_char = np.sum(lengths_arr >= 3) / len(lengths_arr) * 100 + + # Vocab utilization: how many unique tokens actually appear + unique_used = len(token_freq) + vocab_utilization = unique_used / sp.vocab_size() * 100 + + # Top tokens + top_20 = token_freq.most_common(20) + + print(f"\n--- Compression ---") + print(f" Bytes per token: {bytes_per_token:.2f} (higher = better)") + print(f" Tokens per word: {fertility:.2f} (lower = better)") + print(f" Bits/byte ceiling: {bits_per_byte_ceiling:.2f} (log2(vocab_size))") + print(f" Effective BPB: ~{bits_per_byte_ceiling / bytes_per_token:.2f} (ceiling / compression)") + + print(f"\n--- Token Length Distribution ---") + print(f" Mean token length: {lengths_arr.mean():.1f} bytes") + print(f" Median: {np.median(lengths_arr):.0f} bytes") + print(f" Single-byte tokens: {pct_single_byte:.1f}%") + print(f" Multi-char (≥3B): {pct_multi_char:.1f}%") + + print(f"\n--- Vocab Utilization ---") + print(f" Unique tokens used: {unique_used:,} / {sp.vocab_size()} ({vocab_utilization:.1f}%)") + + print(f"\n--- Top 20 Tokens ---") + for token_id, count in top_20: + piece = sp.id_to_piece(token_id) + print(f" {token_id:5d} | {piece:20s} | {count:,}") + + # Sample tokenizations + print(f"\n--- Sample Tokenizations ---") + for sentence in SAMPLE_SENTENCES: + pieces = sp.encode(sentence, out_type=str) + ids = sp.encode(sentence, out_type=int) + print(f"\n Input: {sentence[:80]}") + print(f" Tokens: {len(ids)}") + print(f" Pieces: {' | '.join(pieces[:30])}") + + print(f"\n Eval time: {eval_time:.1f}s") + + return { + "model_path": model_path, + "vocab_size": sp.vocab_size(), + "bytes_per_token": bytes_per_token, + "fertility": fertility, + "bits_per_byte_ceiling": bits_per_byte_ceiling, + "mean_token_length_bytes": float(lengths_arr.mean()), + "pct_single_byte_tokens": pct_single_byte, + "pct_multi_char_tokens": pct_multi_char, + "vocab_utilization_pct": vocab_utilization, + "eval_docs": n_eval_docs, + "total_tokens": total_tokens, + "total_bytes": total_bytes, + } + + +# ============================================================================= +# BINARY SHARD EXPORT (compatible with train_gpt.py) +# ============================================================================= + +def write_datafile(path: Path, toks: np.ndarray) -> None: + """Write a binary shard file matching train_gpt.py format.""" + if len(toks) >= 2**31: + raise ValueError("token count too large") + header = np.zeros(256, dtype=" Dict[str, int]: + """Export tokenized binary shards for train_gpt.py. + + Format: [256-int32 header][uint16 token stream] + Each document: [BOS_ID] [encoded tokens] (no EOS by default) + First `num_val_docs` go to val shards, rest to train shards. + """ + import sentencepiece as spm + + sp = spm.SentencePieceProcessor(model_file=model_path) + vocab_size = sp.vocab_size() + + if vocab_size > 2**16: + raise ValueError(f"vocab_size={vocab_size} too large for uint16 shard storage") + + output_dir.mkdir(parents=True, exist_ok=True) + + # Clean stale shards + for pattern in ("fineweb_train_*.bin", "fineweb_val_*.bin"): + for stale in output_dir.glob(pattern): + stale.unlink() + + stats = {k: 0 for k in [ + "docs_total", "docs_val", "docs_train", + "files_total", "files_val", "files_train", + "tokens_total", "tokens_val", "tokens_train", + ]} + + buf = np.empty((shard_size,), dtype=np.uint16) + fill = 0 + split = "val" + shards = {"val": 0, "train": 0} + + def flush(): + nonlocal fill + if fill == 0: + return + path = output_dir / f"fineweb_{split}_{shards[split]:06d}.bin" + write_datafile(path, buf[:fill]) + stats["files_total"] += 1 + stats[f"files_{split}"] += 1 + shards[split] += 1 + fill = 0 + + bos_id = sp.bos_id() + + print(f"\nExporting binary shards to {output_dir}/") + print(f" Tokenizer: {model_path} (vocab={vocab_size})") + print(f" Val docs: {num_val_docs:,}, Shard size: {shard_size:,} tokens") + + for text in iter_docs_jsonl(docs_path): + doc_split = "val" if stats["docs_total"] < num_val_docs else "train" + if doc_split != split: + flush() + split = doc_split + + encoded = np.asarray(sp.encode(text, out_type=int), dtype=np.int32) + toks = np.empty((encoded.size + 1 + int(APPEND_EOS),), dtype=np.int32) + toks[0] = bos_id + toks[1: 1 + encoded.size] = encoded + if APPEND_EOS: + toks[-1] = sp.eos_id() + + if not ((0 <= toks).all() and (toks < vocab_size).all()): + bad = int(toks[(toks < 0) | (toks >= vocab_size)][0]) + raise ValueError(f"token id {bad} outside vocab_size={vocab_size}") + toks = toks.astype(" Dict[str, Any]: + """Calculate embedding size impact on the 16MB artifact budget.""" + bytes_fp16 = vocab_size * model_dim * 2 + artifact_budget = 16_000_000 + baseline_vocab = 1024 + baseline_bytes = baseline_vocab * model_dim * 2 + + return { + "vocab_size": vocab_size, + "model_dim": model_dim, + "embedding_bytes_fp16": bytes_fp16, + "embedding_mb_fp16": bytes_fp16 / 1_000_000, + "budget_remaining": artifact_budget - bytes_fp16, + "vs_baseline_1024": bytes_fp16 - baseline_bytes, + "pct_of_budget": bytes_fp16 / artifact_budget * 100, + } + + +# ============================================================================= +# COMPARISON +# ============================================================================= + +def run_comparison(docs_path: Path, vocab_sizes: List[int], model_types: List[str], + max_docs: Optional[int] = None, n_eval_docs: int = 5000): + """Train and evaluate all combinations, print comparison table.""" + results = [] + + for vocab_size in vocab_sizes: + n_docs = max_docs or TRAIN_DOCS_COUNT.get(vocab_size, 100_000) + for model_type in model_types: + info = train_sentencepiece(docs_path, vocab_size, model_type, OUTPUT_DIR, max_docs=n_docs) + if not info: + continue + eval_result = evaluate_tokenizer(info["model_path"], docs_path, n_eval_docs=n_eval_docs) + budget = calc_embedding_budget(vocab_size) + results.append({**info, **eval_result, "budget": budget}) + + # Print comparison table + print(f"\n{'=' * 100}") + print("COMPARISON TABLE") + print(f"{'=' * 100}") + print(f"{'Method':<20} {'Vocab':>6} {'Actual':>6} {'Learned':>7} {'B/Tok':>6} " + f"{'Fert':>6} {'Emb MB':>7} {'Budget%':>8} {'Train(s)':>9}") + print("-" * 100) + + for r in results: + print(f"{r['method']:<20} {r['vocab_size']:>6} {r['actual_vocab']:>6} " + f"{r['learned_tokens']:>7} {r['bytes_per_token']:>6.2f} " + f"{r['fertility']:>6.2f} {r['budget']['embedding_mb_fp16']:>7.2f} " + f"{r['budget']['pct_of_budget']:>7.1f}% {r['train_time_sec']:>8.1f}") + + # Save results + results_path = OUTPUT_DIR / "comparison_results.json" + with open(results_path, "w") as f: + json.dump(results, f, indent=2, default=str) + print(f"\nResults saved to {results_path}") + + # Recommendation + if results: + best = max(results, key=lambda r: r["bytes_per_token"]) + print(f"\nBest compression: {best['method']} vocab={best['vocab_size']} " + f"({best['bytes_per_token']:.2f} bytes/token)") + + return results + + +# ============================================================================= +# MAIN +# ============================================================================= + +def main(): + parser = argparse.ArgumentParser( + description="Train custom SentencePiece tokenizers for Parameter Golf", + formatter_class=argparse.RawDescriptionHelpFormatter, + epilog=""" +Examples: + # Train BPE tokenizer with vocab 1024 + python train_tokenizer.py --vocab-size 1024 --model-type bpe + + # Train Unigram tokenizer (often better for small vocabs) + python train_tokenizer.py --vocab-size 1024 --model-type unigram + + # Compare BPE vs Unigram across vocab sizes + python train_tokenizer.py --compare + + # Train + export binary shards for train_gpt.py + python train_tokenizer.py --vocab-size 1024 --model-type bpe --export-shards + + # Evaluate an existing tokenizer + python train_tokenizer.py --evaluate ./tokenizers/spm_bpe_1024.model + """, + ) + parser.add_argument("--vocab-size", type=int, default=1024, help="Target vocabulary size (default: 1024)") + parser.add_argument("--model-type", type=str, default="bpe", choices=["bpe", "unigram"], + help="SentencePiece model type (default: bpe)") + parser.add_argument("--compare", action="store_true", + help="Compare BPE vs Unigram across vocab sizes [512, 1024, 2048, 4096]") + parser.add_argument("--evaluate", type=str, default=None, metavar="MODEL_PATH", + help="Evaluate an existing .model file instead of training") + parser.add_argument("--export-shards", action="store_true", + help="Export binary shards after training (for train_gpt.py)") + parser.add_argument("--shard-output-dir", type=str, default=None, + help="Output directory for binary shards (default: alongside tokenizer)") + parser.add_argument("--docs-path", type=str, default=None, help="Path to docs_selected.jsonl") + parser.add_argument("--max-docs", type=int, default=None, help="Max docs for tokenizer training") + parser.add_argument("--eval-docs", type=int, default=5000, help="Docs for evaluation (default: 5000)") + parser.add_argument("--character-coverage", type=float, default=0.995, + help="Unicode character coverage (default: 0.995)") + parser.add_argument("--max-token-length", type=int, default=16, + help="Max SentencePiece token length in chars (default: 16)") + parser.add_argument("--input-sentence-size", type=int, default=0, + help="Max sentences for SPM training (0=all, set >0 for large corpora)") + args = parser.parse_args() + + docs_path = Path(args.docs_path) if args.docs_path else FINEWEB_DOCS_PATH + OUTPUT_DIR.mkdir(parents=True, exist_ok=True) + + # --- Evaluate existing model --- + if args.evaluate: + if not Path(args.evaluate).exists(): + print(f"ERROR: {args.evaluate} not found!") + sys.exit(1) + evaluate_tokenizer(args.evaluate, docs_path, n_eval_docs=args.eval_docs) + return + + # --- Compare mode --- + if args.compare: + run_comparison( + docs_path, + vocab_sizes=[512, 1024, 2048, 4096], + model_types=["bpe", "unigram"], + max_docs=args.max_docs, + n_eval_docs=args.eval_docs, + ) + return + + # --- Train single tokenizer --- + max_docs = args.max_docs or TRAIN_DOCS_COUNT.get(args.vocab_size, 100_000) + + info = train_sentencepiece( + docs_path, + args.vocab_size, + args.model_type, + OUTPUT_DIR, + max_docs=max_docs, + character_coverage=args.character_coverage, + max_sentencepiece_length=args.max_token_length, + input_sentence_size=args.input_sentence_size, + ) + + if not info: + print("ERROR: Training failed!") + sys.exit(1) + + # Evaluate + eval_result = evaluate_tokenizer(info["model_path"], docs_path, n_eval_docs=args.eval_docs) + budget = calc_embedding_budget(args.vocab_size) + + # Print budget info + print(f"\n--- Embedding Budget Impact (d_model=512, FP16) ---") + print(f" Embedding size: {budget['embedding_mb_fp16']:.2f} MB") + print(f" % of 16MB budget: {budget['pct_of_budget']:.1f}%") + print(f" vs baseline (1024): {budget['vs_baseline_1024']:+,} bytes") + + # Save results + all_results = {**info, **eval_result, "budget": budget} + results_path = OUTPUT_DIR / f"result_{args.model_type}_{args.vocab_size}.json" + with open(results_path, "w") as f: + json.dump(all_results, f, indent=2, default=str) + print(f"\nResults saved to {results_path}") + + # Export shards if requested + if args.export_shards: + shard_dir = Path(args.shard_output_dir) if args.shard_output_dir else ( + OUTPUT_DIR / f"shards_{args.model_type}_{args.vocab_size}" + ) + export_binary_shards(info["model_path"], docs_path, shard_dir) + + # Print integration instructions + print(f"\n--- Integration ---") + print(f" To use with train_gpt.py:") + print(f" VOCAB_SIZE={info['actual_vocab']} python train_gpt.py \\") + print(f" --tokenizer-path {info['model_path']}") + if not args.export_shards: + print(f"\n To export binary shards:") + print(f" python train_tokenizer.py --vocab-size {args.vocab_size} " + f"--model-type {args.model_type} --export-shards") + + +if __name__ == "__main__": + main() diff --git a/run_competitive.sh b/run_competitive.sh new file mode 100755 index 0000000000..b36d9ba575 --- /dev/null +++ b/run_competitive.sh @@ -0,0 +1,92 @@ +#!/bin/bash +# === Parameter Golf: Competitive Entry (SOTA stack) === +# Based on thwu1's #1 submission (1.1428 bpb) +# Techniques: 10L + BigramHash(10240) + SmearGate + Int5/Int6 mixed quant +# + 3x MLP + OrthoInit + SWA(0.4) + WD=0.04 + sliding eval +# + zstd-22 + magnitude pruning +# +# Requirements: 8x H100 SXM (or adjust WORLD_SIZE) +# Expected: ~1.14 bpb in 10 minutes +# +# Usage: +# # 8x H100 (full competitive run) +# bash run_competitive.sh +# +# # 1x H100 (pilot test) +# bash run_competitive.sh --pilot + +set -e + +PILOT=0 +if [[ "$1" == "--pilot" ]]; then + PILOT=1 +fi + +cd /workspace/parameter-golf + +# Install zstandard for better compression +pip install zstandard 2>/dev/null || true + +# Download dataset if not present +if [ ! -f data/datasets/fineweb10B_sp1024/fineweb_train_000000.bin ]; then + echo "=== Downloading dataset ===" + export HF_TOKEN="${HF_TOKEN:-}" + python3 data/cached_challenge_fineweb.py --variant sp1024 +fi + +echo "Train shards: $(ls data/datasets/fineweb10B_sp1024/fineweb_train_*.bin 2>/dev/null | wc -l)" +echo "Val shards: $(ls data/datasets/fineweb10B_sp1024/fineweb_val_*.bin 2>/dev/null | wc -l)" + +if [ "$PILOT" -eq 1 ]; then + echo "" + echo "=== PILOT RUN: 1x H100, 10 min ===" + echo "Start: $(date)" + NPROC=1 +else + echo "" + echo "=== COMPETITIVE RUN: 8x H100, 10 min ===" + echo "Start: $(date)" + NPROC=$(nvidia-smi --list-gpus | wc -l) + echo "Detected GPUs: $NPROC" +fi + +RUN_ID="competitive_$(date +%Y%m%d_%H%M%S)" \ +DATA_PATH=./data/datasets/fineweb10B_sp1024 \ +TOKENIZER_PATH=./data/tokenizers/fineweb_1024_bpe.model \ +VOCAB_SIZE=1024 \ +SEED=42 \ +NUM_LAYERS=10 \ +MODEL_DIM=512 \ +NUM_HEADS=8 \ +NUM_KV_HEADS=4 \ +MLP_MULT=3.0 \ +TIE_EMBEDDINGS=1 \ +TRAIN_SEQ_LEN=2048 \ +TRAIN_BATCH_TOKENS=786432 \ +WARMDOWN_ITERS=3000 \ +MAX_WALLCLOCK_SECONDS=600 \ +VAL_LOSS_EVERY=500 \ +TRAIN_LOG_EVERY=100 \ +WEIGHT_DECAY=0.04 \ +MATRIX_LR=0.02 \ +SCALAR_LR=0.02 \ +TIED_EMBED_LR=0.03 \ +MUON_MOMENTUM=0.99 \ +MUON_MOMENTUM_WARMUP_START=0.92 \ +MUON_MOMENTUM_WARMUP_STEPS=1500 \ +GRAD_CLIP_NORM=0.3 \ +BIGRAM_VOCAB_SIZE=10240 \ +BIGRAM_DIM=128 \ +SWA_ENABLED=1 \ +SWA_START_FRAC=0.4 \ +SWA_EVERY=50 \ +EVAL_STRIDE=64 \ +EVAL_BATCH_SEQS=32 \ +torchrun --standalone --nproc_per_node=$NPROC train_gpt.py 2>&1 | tee /workspace/competitive_log.txt + +echo "" +echo "=== RESULTS ===" +grep -E 'val_bpb|final_int8|submission|model_params|swa:' /workspace/competitive_log.txt | tail -20 +echo "" +echo "Target: val_bpb < 1.1428 (current SOTA)" +echo "Done: $(date)" diff --git a/run_competitive_custom_tok.sh b/run_competitive_custom_tok.sh new file mode 100755 index 0000000000..1e8b6a4dfa --- /dev/null +++ b/run_competitive_custom_tok.sh @@ -0,0 +1,87 @@ +#!/bin/bash +# === Parameter Golf: Competitive Entry with Custom Tokenizer === +# Same SOTA stack as run_competitive.sh but uses a custom-trained tokenizer. +# +# WORKFLOW: +# 1. Train tokenizer: python3 ../trainer-tokenizer/train_tokenizer.py --vocab-size 1024 --model-type unigram +# 2. Export shards: python3 ../trainer-tokenizer/train_tokenizer.py --vocab-size 1024 --model-type unigram --export-shards +# 3. Run this script with the paths +# +# Usage: +# CUSTOM_TOKENIZER=./tokenizers/spm_unigram_1024.model \ +# CUSTOM_DATA=./tokenizers/shards_unigram_1024 \ +# CUSTOM_VOCAB=1024 \ +# bash run_competitive_custom_tok.sh [--pilot] + +set -e + +PILOT=0 +if [[ "$1" == "--pilot" ]]; then + PILOT=1 +fi + +# Custom tokenizer paths (must be set) +CUSTOM_TOKENIZER="${CUSTOM_TOKENIZER:?Set CUSTOM_TOKENIZER to your .model file}" +CUSTOM_DATA="${CUSTOM_DATA:?Set CUSTOM_DATA to your shard directory}" +CUSTOM_VOCAB="${CUSTOM_VOCAB:-1024}" + +cd /workspace/parameter-golf + +pip install zstandard 2>/dev/null || true + +echo "=== Custom Tokenizer Run ===" +echo "Tokenizer: $CUSTOM_TOKENIZER" +echo "Data: $CUSTOM_DATA" +echo "Vocab: $CUSTOM_VOCAB" +echo "Train shards: $(ls ${CUSTOM_DATA}/fineweb_train_*.bin 2>/dev/null | wc -l)" +echo "Val shards: $(ls ${CUSTOM_DATA}/fineweb_val_*.bin 2>/dev/null | wc -l)" + +if [ "$PILOT" -eq 1 ]; then + NPROC=1 + echo "Mode: PILOT (1x GPU)" +else + NPROC=$(nvidia-smi --list-gpus | wc -l) + echo "Mode: COMPETITIVE ($NPROC GPUs)" +fi + +echo "Start: $(date)" + +RUN_ID="custom_tok_$(date +%Y%m%d_%H%M%S)" \ +DATA_PATH="$CUSTOM_DATA" \ +TOKENIZER_PATH="$CUSTOM_TOKENIZER" \ +VOCAB_SIZE="$CUSTOM_VOCAB" \ +SEED=42 \ +NUM_LAYERS=10 \ +MODEL_DIM=512 \ +NUM_HEADS=8 \ +NUM_KV_HEADS=4 \ +MLP_MULT=3.0 \ +TIE_EMBEDDINGS=1 \ +TRAIN_SEQ_LEN=2048 \ +TRAIN_BATCH_TOKENS=786432 \ +WARMDOWN_ITERS=3000 \ +MAX_WALLCLOCK_SECONDS=600 \ +VAL_LOSS_EVERY=500 \ +TRAIN_LOG_EVERY=100 \ +WEIGHT_DECAY=0.04 \ +MATRIX_LR=0.02 \ +SCALAR_LR=0.02 \ +TIED_EMBED_LR=0.03 \ +MUON_MOMENTUM=0.99 \ +MUON_MOMENTUM_WARMUP_START=0.92 \ +MUON_MOMENTUM_WARMUP_STEPS=1500 \ +GRAD_CLIP_NORM=0.3 \ +BIGRAM_VOCAB_SIZE=10240 \ +BIGRAM_DIM=128 \ +SWA_ENABLED=1 \ +SWA_START_FRAC=0.4 \ +SWA_EVERY=50 \ +EVAL_STRIDE=64 \ +EVAL_BATCH_SEQS=32 \ +torchrun --standalone --nproc_per_node=$NPROC train_gpt.py 2>&1 | tee /workspace/custom_tok_log.txt + +echo "" +echo "=== RESULTS ===" +grep -E 'val_bpb|final_int8|submission|model_params|swa:' /workspace/custom_tok_log.txt | tail -20 +echo "" +echo "Done: $(date)" diff --git a/run_custom_tokenizer_pipeline.sh b/run_custom_tokenizer_pipeline.sh new file mode 100755 index 0000000000..f3da04c45b --- /dev/null +++ b/run_custom_tokenizer_pipeline.sh @@ -0,0 +1,149 @@ +#!/bin/bash +# ============================================================================= +# Parameter Golf: Custom Tokenizer + Competitive Run (All-in-One) +# ============================================================================= +# +# Steps: +# 1. Download docs_selected.jsonl (~45GB, 10-30 min) +# 2. Train unigram tokenizer (5-10 min) +# 3. Export binary shards (30-60 min) +# 4. Run SOTA training with custom tokenizer (10 min) +# +# Usage (paste into RunPod terminal): +# +# git clone https://github.com/User123331/parameter-golf.git +# cd parameter-golf +# git pull +# bash run_custom_tokenizer_pipeline.sh +# +# ============================================================================= + +set -e + +VOCAB_SIZE=1024 +MODEL_TYPE=unigram +MAX_TRAIN_DOCS=200000 +EVAL_DOCS=10000 +HF_TOKEN="${HF_TOKEN:-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}" + +DATA_DIR="./data/datasets" +TOKENIZER_DIR="./data/tokenizers_custom" +DOCS_JSONL="${DATA_DIR}/docs_selected.jsonl" +CUSTOM_SHARDS="${DATA_DIR}/fineweb10B_custom_${MODEL_TYPE}${VOCAB_SIZE}" +CUSTOM_MODEL="${TOKENIZER_DIR}/spm_${MODEL_TYPE}_${VOCAB_SIZE}.model" + +GREEN='\033[0;32m' +NC='\033[0m' +log() { echo -e "${GREEN}[$(date +%H:%M:%S)]${NC} $*"; } + +# ============================================================================= +# Step 1: Download docs_selected.jsonl +# ============================================================================= +log "Step 1: Downloading docs_selected.jsonl (~45GB)..." + +mkdir -p "${DATA_DIR}" + +if [ ! -f "${DOCS_JSONL}" ]; then + pip install --quiet huggingface_hub + + python3 -c " +from huggingface_hub import hf_hub_download +import shutil, os +cached = hf_hub_download( + repo_id='willdepueoai/parameter-golf', + filename='docs_selected.jsonl', + subfolder='datasets', + repo_type='dataset', +) +src = os.path.realpath(cached) +dst = '${DOCS_JSONL}' +print(f'Copying to {dst}') +try: + os.link(src, dst) +except OSError: + shutil.copy2(src, dst) +" +fi + +log "Docs ready: $(du -h "${DOCS_JSONL}" 2>/dev/null | cut -f1)" + +# ============================================================================= +# Step 2: Train custom tokenizer +# ============================================================================= +log "Step 2: Training ${MODEL_TYPE} tokenizer..." + +mkdir -p "${TOKENIZER_DIR}" +pip install --quiet sentencepiece numpy + +python3 data/train_tokenizer.py \ + --vocab-size ${VOCAB_SIZE} \ + --model-type ${MODEL_TYPE} \ + --docs-path "${DOCS_JSONL}" \ + --max-docs ${MAX_TRAIN_DOCS} \ + --eval-docs ${EVAL_DOCS} \ + --character-coverage 0.995 + +log "Tokenizer ready: ${CUSTOM_MODEL}" + +# ============================================================================= +# Step 3: Export binary shards +# ============================================================================= +log "Step 3: Exporting binary shards (30-60 min)..." + +python3 data/train_tokenizer.py \ + --vocab-size ${VOCAB_SIZE} \ + --model-type ${MODEL_TYPE} \ + --docs-path "${DOCS_JSONL}" \ + --export-shards \ + --shard-output-dir "${CUSTOM_SHARDS}" + +log "Shards ready: ${CUSTOM_SHARDS}" +log "Train shards: $(ls ${CUSTOM_SHARDS}/fineweb_train_*.bin 2>/dev/null | wc -l | tr -d ' ')" +log "Val shards: $(ls ${CUSTOM_SHARDS}/fineweb_val_*.bin 2>/dev/null | wc -l | tr -d ' ')" + +# ============================================================================= +# Step 4: Run training with custom tokenizer +# ============================================================================= +log "Step 4: Running SOTA training..." + +pip install --quiet zstandard + +NPROC=$(nvidia-smi --list-gpus 2>/dev/null | wc -l | tr -d ' ') +[ -z "$NPROC" ] || [ "$NPROC" -lt 1 ] && NPROC=1 + +RUN_ID="custom_${MODEL_TYPE}${VOCAB_SIZE}_$(date +%Y%m%d_%H%M%S)" \ +DATA_PATH="${CUSTOM_SHARDS}" \ +TOKENIZER_PATH="${CUSTOM_MODEL}" \ +VOCAB_SIZE=${VOCAB_SIZE} \ +SEED=42 \ +NUM_LAYERS=10 \ +MODEL_DIM=512 \ +NUM_HEADS=8 \ +NUM_KV_HEADS=4 \ +MLP_MULT=3.0 \ +TIE_EMBEDDINGS=1 \ +TRAIN_SEQ_LEN=2048 \ +TRAIN_BATCH_TOKENS=786432 \ +WARMDOWN_ITERS=3000 \ +MAX_WALLCLOCK_SECONDS=600 \ +VAL_LOSS_EVERY=500 \ +TRAIN_LOG_EVERY=100 \ +WEIGHT_DECAY=0.04 \ +MATRIX_LR=0.02 \ +SCALAR_LR=0.02 \ +TIED_EMBED_LR=0.03 \ +MUON_MOMENTUM=0.99 \ +MUON_MOMENTUM_WARMUP_START=0.92 \ +MUON_MOMENTUM_WARMUP_STEPS=1500 \ +GRAD_CLIP_NORM=0.3 \ +BIGRAM_VOCAB_SIZE=10240 \ +BIGRAM_DIM=128 \ +SWA_ENABLED=1 \ +SWA_START_FRAC=0.4 \ +SWA_EVERY=50 \ +EVAL_STRIDE=64 \ +EVAL_BATCH_SEQS=32 \ +torchrun --standalone --nproc_per_node=$NPROC train_gpt.py 2>&1 | tee /workspace/custom_tok_train.log + +log "Done!" +grep -E 'val_bpb|final_int8' /workspace/custom_tok_train.log | tail -5 \ No newline at end of file diff --git a/train_gpt.py b/train_gpt.py index f74f343143..bbe5ab2943 100644 --- a/train_gpt.py +++ b/train_gpt.py @@ -19,6 +19,12 @@ import zlib from pathlib import Path +try: + import zstandard + _COMPRESSOR = "zstd" +except ImportError: + _COMPRESSOR = "zlib" + import numpy as np import sentencepiece as spm import torch @@ -30,86 +36,68 @@ # ----------------------------- # HYPERPARAMETERS # ----------------------------- -# Default Simple Baseline run: -# - 9 transformer blocks at width 512 -# - 8 attention heads with 4 KV heads (GQA) and 2x MLP expansion -# - vocab size 1024, sequence length 1024, tied embeddings -# - 524,288 train tokens per step for 20,000 iterations with a ~10 minute cap class Hyperparameters: - # Data paths are shard globs produced by the existing preprocessing pipeline. data_path = os.environ.get("DATA_PATH", "./data/datasets/fineweb10B_sp1024") train_files = os.path.join(data_path, "fineweb_train_*.bin") val_files = os.path.join(data_path, "fineweb_val_*.bin") tokenizer_path = os.environ.get("TOKENIZER_PATH", "./data/tokenizers/fineweb_1024_bpe.model") run_id = os.environ.get("RUN_ID", str(uuid.uuid4())) - seed = int(os.environ.get("SEED", 1337)) + seed = int(os.environ.get("SEED", 42)) - # Validation cadence and batch size. Validation always uses the full fineweb_val split. val_batch_size = int(os.environ.get("VAL_BATCH_SIZE", 524_288)) - val_loss_every = int(os.environ.get("VAL_LOSS_EVERY", 1000)) - train_log_every = int(os.environ.get("TRAIN_LOG_EVERY", 200)) + val_loss_every = int(os.environ.get("VAL_LOSS_EVERY", 500)) + train_log_every = int(os.environ.get("TRAIN_LOG_EVERY", 100)) - # Training length. iterations = int(os.environ.get("ITERATIONS", 20000)) - warmdown_iters = int(os.environ.get("WARMDOWN_ITERS", 1200)) + warmdown_iters = int(os.environ.get("WARMDOWN_ITERS", 3000)) warmup_steps = int(os.environ.get("WARMUP_STEPS", 20)) - train_batch_tokens = int(os.environ.get("TRAIN_BATCH_TOKENS", 524_288)) - train_seq_len = int(os.environ.get("TRAIN_SEQ_LEN", 1024)) + train_batch_tokens = int(os.environ.get("TRAIN_BATCH_TOKENS", 786_432)) + train_seq_len = int(os.environ.get("TRAIN_SEQ_LEN", 2048)) max_wallclock_seconds = float(os.environ.get("MAX_WALLCLOCK_SECONDS", 600.0)) qk_gain_init = float(os.environ.get("QK_GAIN_INIT", 1.5)) - # Model shape. vocab_size = int(os.environ.get("VOCAB_SIZE", 1024)) - num_layers = int(os.environ.get("NUM_LAYERS", 9)) + num_layers = int(os.environ.get("NUM_LAYERS", 10)) num_kv_heads = int(os.environ.get("NUM_KV_HEADS", 4)) model_dim = int(os.environ.get("MODEL_DIM", 512)) num_heads = int(os.environ.get("NUM_HEADS", 8)) - mlp_mult = int(os.environ.get("MLP_MULT", 2)) + mlp_mult = float(os.environ.get("MLP_MULT", 3.0)) tie_embeddings = bool(int(os.environ.get("TIE_EMBEDDINGS", "1"))) rope_base = float(os.environ.get("ROPE_BASE", 10000.0)) logit_softcap = float(os.environ.get("LOGIT_SOFTCAP", 30.0)) - # Optimizer hyperparameters. embed_lr = float(os.environ.get("EMBED_LR", 0.6)) head_lr = float(os.environ.get("HEAD_LR", 0.008)) - tied_embed_lr = float(os.environ.get("TIED_EMBED_LR", 0.05)) + tied_embed_lr = float(os.environ.get("TIED_EMBED_LR", 0.03)) tied_embed_init_std = float(os.environ.get("TIED_EMBED_INIT_STD", 0.005)) - matrix_lr = float(os.environ.get("MATRIX_LR", 0.04)) - scalar_lr = float(os.environ.get("SCALAR_LR", 0.04)) - muon_momentum = float(os.environ.get("MUON_MOMENTUM", 0.95)) + matrix_lr = float(os.environ.get("MATRIX_LR", 0.02)) + scalar_lr = float(os.environ.get("SCALAR_LR", 0.02)) + muon_momentum = float(os.environ.get("MUON_MOMENTUM", 0.99)) muon_backend_steps = int(os.environ.get("MUON_BACKEND_STEPS", 5)) - muon_momentum_warmup_start = float(os.environ.get("MUON_MOMENTUM_WARMUP_START", 0.85)) - muon_momentum_warmup_steps = int(os.environ.get("MUON_MOMENTUM_WARMUP_STEPS", 500)) + muon_momentum_warmup_start = float(os.environ.get("MUON_MOMENTUM_WARMUP_START", 0.92)) + muon_momentum_warmup_steps = int(os.environ.get("MUON_MOMENTUM_WARMUP_STEPS", 1500)) beta1 = float(os.environ.get("BETA1", 0.9)) beta2 = float(os.environ.get("BETA2", 0.95)) adam_eps = float(os.environ.get("ADAM_EPS", 1e-8)) - grad_clip_norm = float(os.environ.get("GRAD_CLIP_NORM", 0.0)) - - # Test-time training (LoRA) hyperparameters. - ttt_lora_rank = int(os.environ.get("TTT_LORA_RANK", 8)) - ttt_lora_lr = float(os.environ.get("TTT_LORA_LR", 0.01)) - ttt_chunk_size = int(os.environ.get("TTT_CHUNK_SIZE", 256)) - ttt_eval_seq_len = int(os.environ.get("TTT_EVAL_SEQ_LEN", 1024)) - ttt_batch_size = int(os.environ.get("TTT_BATCH_SIZE", 64)) - - # Mixture of Softmax (MoS) output layer - breaks softmax bottleneck. - # At vocab=1024, dim=512, standard softmax has rank ≤ 513 (binding constraint). - # MoS with K=2 lifts this to rank ≤ 1026, enabling richer output distributions. - use_mos = bool(int(os.environ.get("USE_MOS", "0"))) - mos_k = int(os.environ.get("MOS_K", 2)) - mos_rank = int(os.environ.get("MOS_RANK", 64)) # 0 = full-rank, >0 = low-rank factorization + grad_clip_norm = float(os.environ.get("GRAD_CLIP_NORM", 0.3)) + weight_decay = float(os.environ.get("WEIGHT_DECAY", 0.04)) + + eval_stride = int(os.environ.get("EVAL_STRIDE", 64)) + eval_batch_seqs = int(os.environ.get("EVAL_BATCH_SEQS", 32)) + + bigram_vocab_size = int(os.environ.get("BIGRAM_VOCAB_SIZE", 10240)) + bigram_dim = int(os.environ.get("BIGRAM_DIM", 128)) + + swa_enabled = bool(int(os.environ.get("SWA_ENABLED", "1"))) + swa_start_frac = float(os.environ.get("SWA_START_FRAC", 0.4)) + swa_every = int(os.environ.get("SWA_EVERY", 50)) # ----------------------------- # MUON OPTIMIZER # ----------------------------- -# -# As borrowed from modded-nanogpt -# Background on Muon: https://kellerjordan.github.io/posts/muon/ def zeropower_via_newtonschulz5(G: Tensor, steps: int = 10, eps: float = 1e-7) -> Tensor: - # Orthogonalize a 2D update matrix with a fast Newton-Schulz iteration. - # Muon uses this to normalize matrix-shaped gradients before applying them. a, b, c = (3.4445, -4.7750, 2.0315) X = G.bfloat16() X /= X.norm() + eps @@ -124,10 +112,10 @@ def zeropower_via_newtonschulz5(G: Tensor, steps: int = 10, eps: float = 1e-7) - class Muon(torch.optim.Optimizer): - def __init__(self, params, lr: float, momentum: float, backend_steps: int, nesterov: bool = True): + def __init__(self, params, lr: float, momentum: float, backend_steps: int, nesterov: bool = True, weight_decay: float = 0.0): super().__init__( params, - dict(lr=lr, momentum=momentum, backend_steps=backend_steps, nesterov=nesterov), + dict(lr=lr, momentum=momentum, backend_steps=backend_steps, nesterov=nesterov, weight_decay=weight_decay), ) @torch.no_grad() @@ -136,7 +124,6 @@ def step(self, closure=None): if closure is not None: with torch.enable_grad(): loss = closure() - distributed = dist.is_available() and dist.is_initialized() world_size = dist.get_world_size() if distributed else 1 rank = dist.get_rank() if distributed else 0 @@ -165,7 +152,6 @@ def step(self, closure=None): if nesterov: g = g.add(buf, alpha=momentum) g = zeropower_via_newtonschulz5(g, steps=backend_steps) - # Scale correction from Muon reference implementations. g *= max(1, g.size(0) / g.size(1)) ** 0.5 updates_flat[curr : curr + p.numel()] = g.reshape(-1) curr += p.numel() @@ -173,23 +159,20 @@ def step(self, closure=None): if distributed: dist.all_reduce(updates_flat, op=dist.ReduceOp.SUM) + wd = group.get("weight_decay", 0.0) curr = 0 for p in params: g = updates_flat[curr : curr + p.numel()].view_as(p).to(dtype=p.dtype) + if wd > 0: + p.data.mul_(1.0 - lr * wd) p.add_(g, alpha=-lr) curr += p.numel() - return loss # ----------------------------- -# TOKENIZER-AGNOSTIC EVALUATION SETUP +# TOKENIZER-AGNOSTIC EVALUATION # ----------------------------- -# -# It's common for small models have a large fraction of their parameters be embeddings, since the 2 * d_model * d_vocab vectors can be gigantic. -# Instead of locking the tokenizer, we let you bring your own and calculate our validation metrics on the average compression of the validation set. -# We calculate BPB (bits-per-byte) instead of validation loss, so we need methods to count the number of bits per token in the tokenizer. -# Note: Submissions that edit the tokenizer will be examined more carefully, since screwing this up might unjustly improve your score. def build_sentencepiece_luts( sp: spm.SentencePieceProcessor, vocab_size: int, device: torch.device @@ -207,7 +190,7 @@ def build_sentencepiece_luts( base_bytes_np[token_id] = 1 continue piece = sp.id_to_piece(token_id) - if piece.startswith("▁"): + if piece.startswith("\u2581"): has_leading_space_np[token_id] = True piece = piece[1:] base_bytes_np[token_id] = len(piece.encode("utf-8")) @@ -222,7 +205,6 @@ def load_validation_tokens(pattern: str, seq_len: int) -> Tensor: files = [Path(p) for p in sorted(glob.glob(pattern))] if not files: raise FileNotFoundError(f"No files found for pattern: {pattern}") - # The export pipeline writes the fixed first-50k-doc validation set to fineweb_val_*. tokens = torch.cat([load_data_shard(file) for file in files]).contiguous() usable = ((tokens.numel() - 1) // seq_len) * seq_len if usable <= 0: @@ -242,9 +224,6 @@ def eval_val( has_leading_space_lut: Tensor, is_boundary_token_lut: Tensor, ) -> tuple[float, float]: - # Validation computes two metrics: - # - val_loss: token cross-entropy (natural log) - # - val_bpb: tokenizer-agnostic compression metric used by the challenge local_batch_tokens = args.val_batch_size // (world_size * grad_accum_steps) if local_batch_tokens < args.train_seq_len: raise ValueError( @@ -259,7 +238,6 @@ def eval_val( val_loss_sum = torch.zeros((), device=device, dtype=torch.float64) val_token_count = torch.zeros((), device=device, dtype=torch.float64) val_byte_count = torch.zeros((), device=device, dtype=torch.float64) - model.eval() with torch.inference_mode(): for batch_seq_start in range(seq_start, seq_end, local_batch_seqs): @@ -279,34 +257,34 @@ def eval_val( token_bytes = base_bytes_lut[tgt_ids].to(dtype=torch.int16) token_bytes += (has_leading_space_lut[tgt_ids] & ~is_boundary_token_lut[prev_ids]).to(dtype=torch.int16) val_byte_count += token_bytes.to(torch.float64).sum() - if dist.is_available() and dist.is_initialized(): dist.all_reduce(val_loss_sum, op=dist.ReduceOp.SUM) dist.all_reduce(val_token_count, op=dist.ReduceOp.SUM) dist.all_reduce(val_byte_count, op=dist.ReduceOp.SUM) - val_loss = val_loss_sum / val_token_count bits_per_token = val_loss.item() / math.log(2.0) tokens_per_byte = val_token_count.item() / val_byte_count.item() model.train() return float(val_loss.item()), float(bits_per_token * tokens_per_byte) + # ----------------------------- -# POST-TRAINING QUANTIZATION +# POST-TRAINING QUANTIZATION (INT8 legacy + INT6 mixed) # ----------------------------- -# -# It's silly to export our model, which is trained in bf16 and fp32, at that same precision. -# Instead, we get approximately the same model (with a small hit) by quantizing the model to int8 & zlib compressing. -# We can then decompress the model and run in higher precision for evaluation, after closing in under the size limit. CONTROL_TENSOR_NAME_PATTERNS = tuple( pattern for pattern in os.environ.get( "CONTROL_TENSOR_NAME_PATTERNS", - "attn_scale,attn_scales,mlp_scale,mlp_scales,resid_mix,resid_mixes,q_gain,skip_weight,skip_weights", + "attn_scale,attn_scales,mlp_scale,mlp_scales,resid_mix,resid_mixes,q_gain,skip_weight,skip_weights,smear,bigram.scale", ).split(",") if pattern ) +FP16_KEEP_NAME_PATTERNS = tuple( + pattern + for pattern in os.environ.get("FP16_KEEP_NAME_PATTERNS", "tok_emb,blocks.8.attn.c_k").split(",") + if pattern +) INT8_KEEP_FLOAT_FP32_NAME_PATTERNS = tuple( pattern for pattern in os.environ.get( @@ -324,19 +302,9 @@ def eval_val( def tensor_nbytes(t: Tensor) -> int: return int(t.numel()) * int(t.element_size()) -def keep_float_tensor(name: str, t: Tensor, passthrough_orig_dtypes: dict[str, str]) -> Tensor: - if any(pattern in name for pattern in INT8_KEEP_FLOAT_FP32_NAME_PATTERNS): - return t.float().contiguous() - if t.dtype in {torch.float32, torch.bfloat16}: - passthrough_orig_dtypes[name] = str(t.dtype).removeprefix("torch.") - return t.to(dtype=INT8_KEEP_FLOAT_STORE_DTYPE).contiguous() - return t - def quantize_float_tensor(t: Tensor) -> tuple[Tensor, Tensor]: t32 = t.float() if t32.ndim == 2: - # Matrices get one scale per row, which usually tracks output-channel - # ranges much better than a single tensor-wide scale. clip_abs = ( torch.quantile(t32.abs(), INT8_CLIP_Q, dim=1) if t32.numel() @@ -346,105 +314,95 @@ def quantize_float_tensor(t: Tensor) -> tuple[Tensor, Tensor]: scale = (clip_abs / 127.0).clamp_min(1.0 / 127.0) q = torch.clamp(torch.round(clipped / scale[:, None]), -127, 127).to(torch.int8).contiguous() return q, scale.to(dtype=INT8_PER_ROW_SCALE_DTYPE).contiguous() - - # Vectors / scalars use a simpler per-tensor scale. clip_abs = float(torch.quantile(t32.abs().flatten(), INT8_CLIP_Q).item()) if t32.numel() else 0.0 scale = torch.tensor(clip_abs / 127.0 if clip_abs > 0 else 1.0, dtype=torch.float32) q = torch.clamp(torch.round(torch.clamp(t32, -clip_abs, clip_abs) / scale), -127, 127).to(torch.int8).contiguous() return q, scale -def quantize_state_dict_int8(state_dict: dict[str, Tensor]): - # Single supported clean-script export format: - # - per-row int8 for 2D float tensors - # - per-tensor int8 for other float tensors - # - exact passthrough for non-floats - # - passthrough for small float tensors, stored as fp16 to save bytes - quantized: dict[str, Tensor] = {} - scales: dict[str, Tensor] = {} - dtypes: dict[str, str] = {} - passthrough: dict[str, Tensor] = {} - passthrough_orig_dtypes: dict[str, str] = {} - qmeta: dict[str, dict[str, object]] = {} - stats = dict.fromkeys( - ("param_count", "num_tensors", "num_float_tensors", "num_nonfloat_tensors", "baseline_tensor_bytes", "int8_payload_bytes"), - 0, - ) +def _classify_param(name: str) -> str: + if "tok_emb" in name or "lm_head" in name: + return "embed" + if ".mlp." in name: + return "mlp" + if "bigram" in name: + return "bigram" + if ".attn." in name or (".proj." in name and ".mlp." not in name): + return "attn" + return "other" + +def quantize_intN_per_row(t: Tensor, clip_range: int = 31) -> tuple[Tensor, Tensor]: + t32 = t.float() + if t32.ndim == 2: + row_max = t32.abs().amax(dim=1) + scale = (row_max / clip_range).clamp_min(1e-12).to(torch.float16) + scale = scale.clamp_min(torch.finfo(torch.float16).tiny) + q = torch.clamp(torch.round(t32 / scale.float()[:, None]), -(clip_range+1), clip_range).to(torch.int8) + return q, scale + amax = t32.abs().max().item() + scale = torch.tensor(max(amax / clip_range, 1e-12), dtype=torch.float16) + q = torch.clamp(torch.round(t32 / scale.float()), -(clip_range+1), clip_range).to(torch.int8) + return q, scale + +def mixed_quantize_int6(state_dict: dict[str, Tensor], int6_cats: set[str]): + result: dict[str, Tensor] = {} + meta: dict[str, object] = {} for name, tensor in state_dict.items(): - t = tensor.detach().to("cpu").contiguous() - stats["param_count"] += int(t.numel()) - stats["num_tensors"] += 1 - stats["baseline_tensor_bytes"] += tensor_nbytes(t) - - if not t.is_floating_point(): - stats["num_nonfloat_tensors"] += 1 - passthrough[name] = t - stats["int8_payload_bytes"] += tensor_nbytes(t) + t = tensor.detach().cpu().contiguous() + cat = _classify_param(name) + if not t.is_floating_point() or t.numel() <= 8192: + result[name] = t.to(torch.float16) if t.is_floating_point() else t + meta[name] = "passthrough" continue - - # Small float tensors are cheap enough to keep directly. We still downcast - # fp32/bf16 passthrough tensors to fp16 so metadata does not dominate size. - if t.numel() <= INT8_KEEP_FLOAT_MAX_NUMEL: - kept = keep_float_tensor(name, t, passthrough_orig_dtypes) - passthrough[name] = kept - stats["int8_payload_bytes"] += tensor_nbytes(kept) + if any(p in name for p in CONTROL_TENSOR_NAME_PATTERNS): + result[name] = t.float() + meta[name] = "passthrough_ctrl" continue - - stats["num_float_tensors"] += 1 - q, s = quantize_float_tensor(t) - if s.ndim > 0: - qmeta[name] = {"scheme": "per_row", "axis": 0} - quantized[name] = q - scales[name] = s - dtypes[name] = str(t.dtype).removeprefix("torch.") - stats["int8_payload_bytes"] += tensor_nbytes(q) + tensor_nbytes(s) - - obj: dict[str, object] = { - "__quant_format__": "int8_clean_per_row_v1", - "quantized": quantized, - "scales": scales, - "dtypes": dtypes, - "passthrough": passthrough, - } - if qmeta: - obj["qmeta"] = qmeta - if passthrough_orig_dtypes: - obj["passthrough_orig_dtypes"] = passthrough_orig_dtypes - return obj, stats - -def dequantize_state_dict_int8(obj: dict[str, object]) -> dict[str, Tensor]: + if any(pattern in name for pattern in FP16_KEEP_NAME_PATTERNS): + result[name] = t.to(dtype=torch.float16).contiguous() + meta[name] = "passthrough_fp16" + continue + if cat in int6_cats and t.ndim >= 1: + clip = 15 if cat == "mlp" else 31 # int5 for MLP, int6 for attention + q, s = quantize_intN_per_row(t, clip_range=clip) + result[name + ".q"] = q + result[name + ".scale"] = s + meta[name] = {"type": f"int{5 if cat == 'mlp' else 6}"} + else: + q, s = quantize_float_tensor(t) + result[name + ".q"] = q + result[name + ".scale"] = s + meta[name] = {"type": "int8"} + return result, meta + +def dequantize_mixed_int6(result: dict[str, Tensor], meta: dict[str, object], + template_sd: dict[str, Tensor]) -> dict[str, Tensor]: out: dict[str, Tensor] = {} - qmeta = obj.get("qmeta", {}) - passthrough_orig_dtypes = obj.get("passthrough_orig_dtypes", {}) - for name, q in obj["quantized"].items(): - dtype = getattr(torch, obj["dtypes"][name]) - s = obj["scales"][name] - if qmeta.get(name, {}).get("scheme") == "per_row" or s.ndim > 0: - s = s.to(dtype=torch.float32) - # Broadcast the saved row scale back across trailing dimensions. - out[name] = (q.float() * s.view(q.shape[0], *([1] * (q.ndim - 1)))).to(dtype=dtype).contiguous() + for name, orig in template_sd.items(): + info = meta[name] + orig_dtype = orig.dtype + if info in ("passthrough", "passthrough_ctrl", "passthrough_fp16"): + t = result[name] + if t.dtype == torch.float16 and orig_dtype in (torch.float32, torch.bfloat16): + t = t.to(orig_dtype) + out[name] = t + continue + q, s = result[name + ".q"], result[name + ".scale"] + if s.ndim > 0: + out[name] = (q.float() * s.float().view(q.shape[0], *([1] * (q.ndim - 1)))).to(orig_dtype) else: - scale = float(s.item()) - out[name] = (q.float() * scale).to(dtype=dtype).contiguous() - for name, t in obj["passthrough"].items(): - # Restore small tensors, undoing the temporary fp16 storage cast if needed. - out_t = t.detach().to("cpu").contiguous() - orig_dtype = passthrough_orig_dtypes.get(name) - if isinstance(orig_dtype, str): - out_t = out_t.to(dtype=getattr(torch, orig_dtype)).contiguous() - out[name] = out_t + out[name] = (q.float() * float(s.item())).to(orig_dtype) return out # ----------------------------- -# DATA LOADING +# DATA LOADING # ----------------------------- def load_data_shard(file: Path) -> Tensor: header_bytes = 256 * np.dtype(" Tensor: class TokenStream: - # Reads shards sequentially and wraps around forever. The training loop therefore - # has deterministic, simple streaming behavior with no sampling or workers. def __init__(self, pattern: str): self.files = [Path(p) for p in sorted(glob.glob(pattern))] if not self.files: @@ -489,8 +445,6 @@ def take(self, n: int) -> Tensor: class DistributedTokenLoader: - # Each call consumes a contiguous chunk from the shared token stream, then slices out - # one disjoint span per rank. The extra "+1" token lets us build (x, y) by shifting. def __init__(self, pattern: str, rank: int, world_size: int, device: torch.device): self.rank = rank self.world_size = world_size @@ -507,6 +461,7 @@ def next_batch(self, global_tokens: int, seq_len: int, grad_accum_steps: int) -> y = local[1:].reshape(-1, seq_len) return x.to(self.device, non_blocking=True), y.to(self.device, non_blocking=True) + # ----------------------------- # TRANSFORMER MODULES # ----------------------------- @@ -521,14 +476,13 @@ def forward(self, x: Tensor) -> Tensor: class CastedLinear(nn.Linear): - # Keep weights in fp32 for optimizer/state quality, cast at matmul time for bf16 compute. def forward(self, x: Tensor) -> Tensor: + w = self.weight.to(x.dtype) bias = self.bias.to(x.dtype) if self.bias is not None else None - return F.linear(x, self.weight.to(x.dtype), bias) + return F.linear(x, w, bias) def restore_low_dim_params_to_fp32(module: nn.Module) -> None: - # Keep small/control parameters in fp32 even when the model body runs in bf16. with torch.no_grad(): for name, param in module.named_parameters(): if (param.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS)) and param.dtype != torch.float32: @@ -536,7 +490,6 @@ def restore_low_dim_params_to_fp32(module: nn.Module) -> None: class Rotary(nn.Module): - # Caches cos/sin tables per sequence length on the current device. def __init__(self, dim: int, base: float = 10000.0): super().__init__() inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.float32) / dim)) @@ -567,14 +520,7 @@ def apply_rotary_emb(x: Tensor, cos: Tensor, sin: Tensor) -> Tensor: class CausalSelfAttention(nn.Module): - def __init__( - self, - dim: int, - num_heads: int, - num_kv_heads: int, - rope_base: float, - qk_gain_init: float, - ): + def __init__(self, dim: int, num_heads: int, num_kv_heads: int, rope_base: float, qk_gain_init: float): super().__init__() if dim % num_heads != 0: raise ValueError("model_dim must be divisible by num_heads") @@ -594,14 +540,11 @@ def __init__( self.q_gain = nn.Parameter(torch.full((num_heads,), qk_gain_init, dtype=torch.float32)) self.rotary = Rotary(self.head_dim, base=rope_base) - def forward(self, x: Tensor, q_delta=None, v_delta=None) -> Tensor: + def forward(self, x: Tensor) -> Tensor: bsz, seqlen, dim = x.shape - q = self.c_q(x) + (q_delta if q_delta is not None else 0) - k = self.c_k(x) - v = self.c_v(x) + (v_delta if v_delta is not None else 0) - q = q.reshape(bsz, seqlen, self.num_heads, self.head_dim).transpose(1, 2) - k = k.reshape(bsz, seqlen, self.num_kv_heads, self.head_dim).transpose(1, 2) - v = v.reshape(bsz, seqlen, self.num_kv_heads, self.head_dim).transpose(1, 2) + q = self.c_q(x).reshape(bsz, seqlen, self.num_heads, self.head_dim).transpose(1, 2) + k = self.c_k(x).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim).transpose(1, 2) + v = self.c_v(x).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim).transpose(1, 2) q = F.rms_norm(q, (q.size(-1),)) k = F.rms_norm(k, (k.size(-1),)) cos, sin = self.rotary(seqlen, x.device, q.dtype) @@ -609,11 +552,7 @@ def forward(self, x: Tensor, q_delta=None, v_delta=None) -> Tensor: k = apply_rotary_emb(k, cos, sin) q = q * self.q_gain.to(dtype=q.dtype)[None, :, None, None] y = F.scaled_dot_product_attention( - q, - k, - v, - attn_mask=None, - is_causal=True, + q, k, v, attn_mask=None, is_causal=True, enable_gqa=(self.num_kv_heads != self.num_heads), ) y = y.transpose(1, 2).contiguous().reshape(bsz, seqlen, dim) @@ -621,10 +560,9 @@ def forward(self, x: Tensor, q_delta=None, v_delta=None) -> Tensor: class MLP(nn.Module): - # relu^2 MLP from the original modded-nanogpt setup - def __init__(self, dim: int, mlp_mult: int): + def __init__(self, dim: int, mlp_mult: float): super().__init__() - hidden = mlp_mult * dim + hidden = int(mlp_mult * dim) self.fc = CastedLinear(dim, hidden, bias=False) self.proj = CastedLinear(hidden, dim, bias=False) self.proj._zero_init = True @@ -634,16 +572,47 @@ def forward(self, x: Tensor) -> Tensor: return self.proj(x.square()) +class SmearGate(nn.Module): + """Blend each token's embedding with the previous token's embedding.""" + def __init__(self, dim: int): + super().__init__() + self.gate = nn.Parameter(torch.zeros(dim, dtype=torch.float32)) + + def forward(self, x: Tensor) -> Tensor: + g = torch.sigmoid(self.gate.to(dtype=x.dtype))[None, None, :] + x_prev = torch.cat([torch.zeros_like(x[:, :1]), x[:, :-1]], dim=1) + return (1 - g) * x + g * x_prev + + +class BigramHashEmbedding(nn.Module): + """Hash consecutive token pairs into a learned embedding table.""" + def __init__(self, bigram_vocab_size: int, bigram_dim: int, model_dim: int): + super().__init__() + self.bigram_vocab_size = bigram_vocab_size + self.embed = nn.Embedding(bigram_vocab_size, bigram_dim) + nn.init.zeros_(self.embed.weight) + self.proj = CastedLinear(bigram_dim, model_dim, bias=False) if bigram_dim != model_dim else None + if self.proj is not None: + nn.init.zeros_(self.proj.weight) + self.scale = nn.Parameter(torch.tensor(0.05, dtype=torch.float32)) + + def bigram_hash(self, tokens: Tensor) -> Tensor: + t = tokens.to(torch.int32) + mod = self.bigram_vocab_size - 1 + out = torch.empty_like(t) + out[..., 0] = mod + out[..., 1:] = torch.bitwise_xor(36313 * t[..., 1:], 27191 * t[..., :-1]) % mod + return out.long() + + def forward(self, token_ids: Tensor) -> Tensor: + h = self.embed(self.bigram_hash(token_ids)) + if self.proj is not None: + h = self.proj(h) + return h * self.scale.to(dtype=h.dtype) + + class Block(nn.Module): - def __init__( - self, - dim: int, - num_heads: int, - num_kv_heads: int, - mlp_mult: int, - rope_base: float, - qk_gain_init: float, - ): + def __init__(self, dim: int, num_heads: int, num_kv_heads: int, mlp_mult: float, rope_base: float, qk_gain_init: float): super().__init__() self.attn_norm = RMSNorm() self.mlp_norm = RMSNorm() @@ -653,85 +622,15 @@ def __init__( self.mlp_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) self.resid_mix = nn.Parameter(torch.stack((torch.ones(dim), torch.zeros(dim))).float()) - def forward(self, x: Tensor, x0: Tensor, q_delta_fn=None, v_delta_fn=None) -> Tensor: + def forward(self, x: Tensor, x0: Tensor) -> Tensor: mix = self.resid_mix.to(dtype=x.dtype) x = mix[0][None, None, :] * x + mix[1][None, None, :] * x0 - n = self.attn_norm(x) - qd = q_delta_fn(n) if q_delta_fn is not None else None - vd = v_delta_fn(n) if v_delta_fn is not None else None - attn_out = self.attn(n, qd, vd) + attn_out = self.attn(self.attn_norm(x)) x = x + self.attn_scale.to(dtype=x.dtype)[None, None, :] * attn_out x = x + self.mlp_scale.to(dtype=x.dtype)[None, None, :] * self.mlp(self.mlp_norm(x)) return x -class MixtureOfSoftmax(nn.Module): - """Mixture of Softmax output layer for breaking the softmax bottleneck. - - At vocab=1024, dim=512, the standard softmax has rank ≤ 513. - MoS with K=2 lifts this to rank ≤ 1026, enabling richer output distributions. - - When mos_rank > 0, uses low-rank factorization to save parameters: - instead of dim -> K*dim projection, uses dim -> rank -> K*dim. - - Paper: Yang et al. (2018), "Breaking the Softmax Bottleneck", ICLR 2018. - """ - - def __init__(self, model_dim: int, vocab_size: int, n_mixtures: int = 2, rank: int = 0): - super().__init__() - self.n_mixtures = n_mixtures - self.model_dim = model_dim - self.vocab_size = vocab_size - self.rank = rank - - if rank > 0: - # Low-rank factorization: dim -> rank -> K*dim - self.proj_down = CastedLinear(model_dim, rank, bias=False) - self.proj_up = CastedLinear(rank, n_mixtures * model_dim, bias=False) - nn.init.normal_(self.proj_down.weight, mean=0.0, std=0.02) - nn.init.normal_(self.proj_up.weight, mean=0.0, std=0.02) - else: - # Full-rank: dim -> K*dim - self.projections = CastedLinear(model_dim, n_mixtures * model_dim, bias=False) - nn.init.normal_(self.projections.weight, mean=0.0, std=0.02) - - # Mixing weight predictor - self.gate = CastedLinear(model_dim, n_mixtures, bias=False) - nn.init.normal_(self.gate.weight, mean=0.0, std=0.02) - - def forward(self, hidden: Tensor, weight_matrix: Tensor) -> Tensor: - """Compute mixed softmax distribution. - - Args: - hidden: (bsz, seq_len, dim) - final hidden states - weight_matrix: (vocab_size, dim) - tied embedding weights - - Returns: - log_probs: (bsz, seq_len, vocab_size) - mixed log probabilities - """ - bsz, seq_len, dim = hidden.shape - K = self.n_mixtures - - # Compute mixing weights: (bsz, seq, K) - pi = F.softmax(self.gate(hidden), dim=-1) - - # Project to K different spaces: (bsz, seq, K * dim) -> (bsz, seq, K, dim) - if self.rank > 0: - projected = self.proj_up(self.proj_down(hidden)).view(bsz, seq_len, K, dim) - else: - projected = self.projections(hidden).view(bsz, seq_len, K, dim) - - # Compute K different logit vectors: (bsz, seq, K, vocab) - logits = torch.einsum('bskd,vd->bskv', projected, weight_matrix) - - # Mix softmax distributions using log-space for numerical stability - log_probs = F.log_softmax(logits, dim=-1) # (bsz, seq, K, vocab) - log_pi = torch.log(pi.unsqueeze(-1) + 1e-10) # (bsz, seq, K, 1) - mixed_log_probs = torch.logsumexp(log_probs + log_pi, dim=2) # (bsz, seq, vocab) - - return mixed_log_probs - - class GPT(nn.Module): def __init__( self, @@ -740,15 +639,14 @@ def __init__( model_dim: int, num_heads: int, num_kv_heads: int, - mlp_mult: int, + mlp_mult: float, tie_embeddings: bool, tied_embed_init_std: float, logit_softcap: float, rope_base: float, qk_gain_init: float, - use_mos: bool = False, - mos_k: int = 2, - mos_rank: int = 0, + bigram_vocab_size: int = 0, + bigram_dim: int = 128, ): super().__init__() if logit_softcap <= 0.0: @@ -756,31 +654,20 @@ def __init__( self.tie_embeddings = tie_embeddings self.tied_embed_init_std = tied_embed_init_std self.logit_softcap = logit_softcap - self.use_mos = use_mos self.tok_emb = nn.Embedding(vocab_size, model_dim) + self.bigram = BigramHashEmbedding(bigram_vocab_size, bigram_dim, model_dim) if bigram_vocab_size > 0 else None self.num_encoder_layers = num_layers // 2 self.num_decoder_layers = num_layers - self.num_encoder_layers self.num_skip_weights = min(self.num_encoder_layers, self.num_decoder_layers) self.skip_weights = nn.Parameter(torch.ones(self.num_skip_weights, model_dim, dtype=torch.float32)) + self.smear = SmearGate(model_dim) self.blocks = nn.ModuleList( [ - Block( - model_dim, - num_heads, - num_kv_heads, - mlp_mult, - rope_base, - qk_gain_init, - ) - for i in range(num_layers) + Block(model_dim, num_heads, num_kv_heads, mlp_mult, rope_base, qk_gain_init) + for _ in range(num_layers) ] ) self.final_norm = RMSNorm() - # MoS output layer (optional) - breaks softmax bottleneck - if use_mos: - self.mos = MixtureOfSoftmax(model_dim, vocab_size, n_mixtures=mos_k, rank=mos_rank) - else: - self.mos = None self.lm_head = None if tie_embeddings else CastedLinear(model_dim, vocab_size, bias=False) if self.lm_head is not None: self.lm_head._zero_init = True @@ -789,264 +676,145 @@ def __init__( def _init_weights(self) -> None: if self.tie_embeddings: nn.init.normal_(self.tok_emb.weight, mean=0.0, std=self.tied_embed_init_std) - for module in self.modules(): - if isinstance(module, nn.Linear) and getattr(module, "_zero_init", False): - nn.init.zeros_(module.weight) - - def forward(self, input_ids: Tensor, target_ids: Tensor, lora=None) -> Tensor: + num_layers = len(self.blocks) + for name, module in self.named_modules(): + if isinstance(module, nn.Linear): + if getattr(module, "_zero_init", False): + nn.init.zeros_(module.weight) + elif module.weight.ndim == 2 and module.weight.shape[0] >= 64 and module.weight.shape[1] >= 64: + nn.init.orthogonal_(module.weight, gain=1.0) + if ".proj." in name or name.endswith(".proj"): + with torch.no_grad(): + module.weight.mul_(1.0 / math.sqrt(2 * num_layers)) + + def forward(self, input_ids: Tensor, target_ids: Tensor) -> Tensor: x = self.tok_emb(input_ids) + if self.bigram is not None: + x = x + self.bigram(input_ids) x = F.rms_norm(x, (x.size(-1),)) + x = self.smear(x) x0 = x skips: list[Tensor] = [] + for i in range(self.num_encoder_layers): + x = self.blocks[i](x, x0) + skips.append(x) + for i in range(self.num_decoder_layers): + if skips: + x = x + self.skip_weights[i].to(dtype=x.dtype)[None, None, :] * skips.pop() + x = self.blocks[self.num_encoder_layers + i](x, x0) + x = self.final_norm(x).reshape(-1, x.size(-1)) + targets = target_ids.reshape(-1) + if self.tie_embeddings: + logits_proj = F.linear(x, self.tok_emb.weight) + else: + if self.lm_head is None: + raise RuntimeError("lm_head is required when tie_embeddings=False") + logits_proj = self.lm_head(x) + logits = self.logit_softcap * torch.tanh(logits_proj / self.logit_softcap) + return F.cross_entropy(logits.float(), targets, reduction="mean") - # First half stores skips; second half reuses them in reverse order. + def forward_logits(self, input_ids: Tensor) -> Tensor: + x = self.tok_emb(input_ids) + if self.bigram is not None: + x = x + self.bigram(input_ids) + x = F.rms_norm(x, (x.size(-1),)) + x = self.smear(x) + x0 = x + skips: list[Tensor] = [] for i in range(self.num_encoder_layers): - qd = lora.q_loras[i] if lora else None - vd = lora.v_loras[i] if lora else None - x = self.blocks[i](x, x0, qd, vd) + x = self.blocks[i](x, x0) skips.append(x) for i in range(self.num_decoder_layers): - bi = self.num_encoder_layers + i if skips: x = x + self.skip_weights[i].to(dtype=x.dtype)[None, None, :] * skips.pop() - qd = lora.q_loras[bi] if lora else None - vd = lora.v_loras[bi] if lora else None - x = self.blocks[bi](x, x0, qd, vd) + x = self.blocks[self.num_encoder_layers + i](x, x0) x = self.final_norm(x) - # Output layer - if self.mos is not None and self.tie_embeddings: - # MoS: returns log-probs (already log-softmaxed), use NLL loss directly - log_probs = self.mos(x, self.tok_emb.weight) - if lora: - # LoRA correction breaks normalization; re-normalize via log_softmax - log_probs = F.log_softmax(log_probs + lora.lm_head_lora(x), dim=-1) - bsz, sl, V = log_probs.shape - return F.nll_loss( - log_probs.float().reshape(-1, V), target_ids.reshape(-1), reduction="none").reshape(bsz, sl) - return F.nll_loss(log_probs.float().reshape(-1, log_probs.size(-1)), target_ids.reshape(-1), reduction="mean") - elif self.tie_embeddings: - logits = F.linear(x, self.tok_emb.weight) + if self.tie_embeddings: + logits_proj = F.linear(x, self.tok_emb.weight) else: - logits = self.lm_head(x) - logits = logits + (lora.lm_head_lora(x) if lora else 0) - logits = self.logit_softcap * torch.tanh(logits / self.logit_softcap) - if lora: - bsz, sl, V = logits.shape - return F.cross_entropy( - logits.float().reshape(-1, V), target_ids.reshape(-1), reduction="none").reshape(bsz, sl) - return F.cross_entropy(logits.float().reshape(-1, logits.size(-1)), target_ids.reshape(-1), reduction="mean") + logits_proj = self.lm_head(x) + return self.logit_softcap * torch.tanh(logits_proj / self.logit_softcap) -# ----------------------------- -# TEST-TIME TRAINING (LoRA) -# ----------------------------- -# -# At evaluation time, we adapt per-document low-rank adapters on the validation data. -# Each document gets its own adapter, so there is no inter-document dependency. - -BOS_ID = 1 - -class BatchedLinearLoRA(nn.Module): - """LoRA for a linear layer, with independent weights per batch element. - Computes x @ Aᵀ @ Bᵀ = x @ (BA)ᵀ, i.e. the LoRA delta is ΔW = BA.""" - def __init__(self, bsz: int, in_features: int, out_features: int, rank: int): - super().__init__() - self.in_features = in_features - self.A = nn.Parameter(torch.empty(bsz, rank, in_features)) # down-projection - self.B = nn.Parameter(torch.zeros(bsz, out_features, rank)) # up-projection - self.reset() - - def forward(self, x: Tensor) -> Tensor: - return (x @ self.A.transpose(1, 2)) @ self.B.transpose(1, 2) # (bsz, T, out) - - def reset(self) -> None: - bound = 1.0 / math.sqrt(self.in_features) - with torch.no_grad(): - self.A.uniform_(-bound, bound) # kaiming-uniform - self.B.zero_() - -class BatchedTTTLoRA(nn.Module): - """All LoRA adapters for one batch: LM head and Q/V per block.""" - def __init__(self, bsz: int, model: GPT, rank: int): - super().__init__() - dim = model.tok_emb.embedding_dim - vocab = model.tok_emb.num_embeddings - self.lm_head_lora = BatchedLinearLoRA(bsz, dim, vocab, rank) - self.q_loras = nn.ModuleList() - self.v_loras = nn.ModuleList() - for block in model.blocks: - self.q_loras.append(BatchedLinearLoRA(bsz, dim, block.attn.c_q.weight.shape[0], rank)) - self.v_loras.append(BatchedLinearLoRA(bsz, dim, block.attn.c_v.weight.shape[0], rank)) - - def reset(self) -> None: - for m in self.modules(): - if isinstance(m, BatchedLinearLoRA): - m.reset() - -def _reset_ttt_optimizer(opt): - for group in opt.param_groups: - for p in group['params']: - s = opt.state.get(p) - if not s: # Fresh state. - continue - s['exp_avg'].zero_() - s['exp_avg_sq'].zero_() - s['step'].fill_(0) - -def _build_ttt_optimizer(lora, args: Hyperparameters): - return torch.optim.Adam(lora.parameters(), lr=args.ttt_lora_lr, betas=(args.beta1, args.beta2), eps=1e-10) - -def _find_docs(all_tokens: Tensor, include_next_bos: bool = True) -> list[tuple[int, int]]: - """Return (start_offset, length) for each document, identified by BOS boundaries. - - If include_next_bos is True, include next document's BOS (to match continuous-stream - eval token count exactly). - """ - bos_positions = (all_tokens == BOS_ID).nonzero(as_tuple=True)[0].numpy() - docs = [] - for i in range(len(bos_positions)): - start = int(bos_positions[i]) - end = int(bos_positions[i + 1]) if i + 1 < len(bos_positions) else all_tokens.numel() - if include_next_bos and i + 1 < len(bos_positions): - end += 1 - assert end - start >= 2 - docs.append((start, end - start)) - return docs - -def _compute_chunk_window(ci: int, pred_len: int, num_chunks: int, chunk_size: int, eval_seq_len: int): - """Return (win_start, win_len, chunk_offset, chunk_len) for chunk `ci` of a doc.""" - chunk_start = ci * chunk_size - chunk_end = pred_len if ci == num_chunks - 1 else (ci + 1) * chunk_size - win_start = max(0, chunk_end - eval_seq_len) - win_len = chunk_end - win_start - chunk_offset = chunk_start - win_start - chunk_len = chunk_end - chunk_start - return win_start, win_len, chunk_offset, chunk_len - -def _accumulate_bpb( - ptl: Tensor, x: Tensor, y: Tensor, - batch_i: int, chunk_offset: int, chunk_len: int, - base_bytes_lut: Tensor, has_leading_space_lut: Tensor, is_boundary_token_lut: Tensor, - loss_sum: Tensor, byte_sum: Tensor, token_count: Tensor, -): - """Add one doc-chunk's contribution to the running BPB accumulators.""" - lbl = ptl[batch_i, chunk_offset:chunk_offset + chunk_len].to(torch.float64) - prev = x[batch_i, chunk_offset:chunk_offset + chunk_len] - tgt = y[batch_i, chunk_offset:chunk_offset + chunk_len] - tok_bytes = base_bytes_lut[tgt].to(torch.float64) - tok_bytes += has_leading_space_lut[tgt] & ~is_boundary_token_lut[prev] - loss_sum += lbl.sum() - byte_sum += tok_bytes.sum() - token_count += chunk_len - -def eval_val_ttt_lora( +def eval_val_sliding( args: Hyperparameters, - base_model: GPT, + base_model: nn.Module, rank: int, world_size: int, device: torch.device, + val_tokens: Tensor, base_bytes_lut: Tensor, has_leading_space_lut: Tensor, is_boundary_token_lut: Tensor, + stride: int, + batch_seqs: int = 32, ) -> tuple[float, float]: - """Evaluate with batched LoRA test-time training. Returns (val_loss, val_bpb).""" - # Load validation tokens and find document boundaries - files = sorted(glob.glob(args.val_files)) - all_tokens = torch.cat([load_data_shard(Path(f)) for f in files]) - docs = _find_docs(all_tokens) - - # Each rank takes a contiguous slice of documents - rank_docs = docs[(len(docs) * rank) // world_size : (len(docs) * (rank + 1)) // world_size] - chunk_size = args.ttt_chunk_size - eval_seq_len = args.ttt_eval_seq_len - batch_size = args.ttt_batch_size - lora_rank = args.ttt_lora_rank - - rank_docs.sort(key=lambda d: (d[1] - 2) // chunk_size) - - base_model.eval() - for p in base_model.parameters(): - p.requires_grad_(False) - - lora = BatchedTTTLoRA(batch_size, base_model, lora_rank).to(device) - opt = _build_ttt_optimizer(lora, args) + seq_len = args.train_seq_len + total_tokens = val_tokens.numel() - 1 + window_starts = [ws for ws in range(0, total_tokens, stride) + if min(ws + seq_len, total_tokens) - ws >= stride or ws == 0] + total_windows = len(window_starts) + my_s = (total_windows * rank) // world_size + my_e = (total_windows * (rank + 1)) // world_size + my_windows = window_starts[my_s:my_e] loss_sum = torch.zeros((), device=device, dtype=torch.float64) - byte_sum = torch.zeros((), device=device, dtype=torch.float64) token_count = torch.zeros((), device=device, dtype=torch.float64) + byte_count = torch.zeros((), device=device, dtype=torch.float64) - for bi in range(0, len(rank_docs), batch_size): - batch = rank_docs[bi:bi + batch_size] - bsz = len(batch) - - if bsz == batch_size: - cur_lora, cur_opt = lora, opt - cur_lora.reset() - _reset_ttt_optimizer(cur_opt) - else: - cur_lora = BatchedTTTLoRA(bsz, base_model, lora_rank).to(device) - cur_opt = _build_ttt_optimizer(cur_lora, args) - - pred_lens = [doc_len - 1 for _, doc_len in batch] - num_chunks = [(pl + chunk_size - 1) // chunk_size for pl in pred_lens] - max_nc = max(num_chunks) - - for ci in range(max_nc): - chunk_stats = _compute_chunk_window(ci, (ci + 1) * chunk_size, ci + 1, chunk_size, eval_seq_len) - context_size, chunk_offset = chunk_stats[1], chunk_stats[2] - - active = [ci < nc for nc in num_chunks] - needs_train = any(ci < nc - 1 for nc in num_chunks) - - x = torch.zeros(bsz, context_size, dtype=torch.int64, device=device) - y = torch.zeros(bsz, context_size, dtype=torch.int64, device=device) - doc_info = [] # (chunk_offset, chunk_len) per doc - for b in range(bsz): - if not active[b]: - doc_info.append((0, 0)) - continue - ds, dl = batch[b] - ws, wl, co, cl = _compute_chunk_window(ci, pred_lens[b], num_chunks[b], chunk_size, eval_seq_len) - chunk = all_tokens[ds + ws: ds + ws + wl + 1] - toks = chunk.to(dtype=torch.int64, device=device) - x[b, :wl] = toks[:-1] - y[b, :wl] = toks[1:] - doc_info.append((co, cl)) - - # Forward pass (keep grad graph alive only when we need to train) - if needs_train: - with torch.autocast(device_type="cuda", dtype=torch.bfloat16): - ptl = base_model(x, y, lora=cur_lora) - else: - with torch.no_grad(), torch.autocast(device_type="cuda", dtype=torch.bfloat16): - ptl = base_model(x, y, lora=cur_lora) - - # Score: accumulate loss and byte counts for BPB (before training on chunk) - with torch.no_grad(): - for b in range(bsz): - if not active[b]: - continue - co, cl = doc_info[b] - _accumulate_bpb( - ptl, x, y, b, co, cl, base_bytes_lut, has_leading_space_lut, - is_boundary_token_lut, loss_sum, byte_sum, token_count) - - # Train: one Adam step on the LoRA params using this chunk's loss - if needs_train: - mask = torch.tensor([float(ci < num_chunks[b] - 1) for b in range(bsz)], device=device) - per_doc = ptl[:, chunk_offset:chunk_offset + chunk_size].mean(dim=-1) - cur_opt.zero_grad() - (per_doc * mask).sum().backward() - cur_opt.step() + base_model.eval() + with torch.inference_mode(): + for bi in range(0, len(my_windows), batch_seqs): + batch_ws = my_windows[bi:bi + batch_seqs] + bsz = len(batch_ws) + x_batch = torch.zeros(bsz, seq_len, dtype=torch.int64, device=device) + y_batch = torch.zeros(bsz, seq_len, dtype=torch.int64, device=device) + wlens: list[int] = [] + for i, ws in enumerate(batch_ws): + end = min(ws + seq_len, total_tokens) + wlen = end - ws + wlens.append(wlen) + chunk = val_tokens[ws:end + 1].to(dtype=torch.int64, device=device) + x_batch[i, :wlen] = chunk[:-1] + y_batch[i, :wlen] = chunk[1:] + with torch.autocast(device_type="cuda", dtype=torch.bfloat16): + logits = base_model.forward_logits(x_batch) + nll = F.cross_entropy( + logits.reshape(-1, logits.size(-1)).float(), + y_batch.reshape(-1), + reduction="none", + ).reshape(bsz, seq_len) + for i, ws in enumerate(batch_ws): + wlen = wlens[i] + s = 0 if ws == 0 else max(wlen - stride, 0) + scored_nll = nll[i, s:wlen].to(torch.float64) + loss_sum += scored_nll.sum() + token_count += float(wlen - s) + tgt = y_batch[i, s:wlen] + prev = x_batch[i, s:wlen] + tb = base_bytes_lut[tgt].to(torch.float64) + tb += (has_leading_space_lut[tgt] & ~is_boundary_token_lut[prev]).to(torch.float64) + byte_count += tb.sum() + if rank == 0 and (bi // batch_seqs) % 50 == 0: + done = min(bi + batch_seqs, len(my_windows)) + pct = done / len(my_windows) * 100 + running_bpb = 0.0 + if token_count.item() > 0: + rl = (loss_sum / token_count).item() + running_bpb = rl / math.log(2.0) * (token_count.item() / byte_count.item()) + print(f" sliding_eval [{pct:5.1f}%] {done}/{len(my_windows)} windows running_bpb={running_bpb:.6f}", flush=True) if dist.is_available() and dist.is_initialized(): dist.all_reduce(loss_sum, op=dist.ReduceOp.SUM) - dist.all_reduce(byte_sum, op=dist.ReduceOp.SUM) dist.all_reduce(token_count, op=dist.ReduceOp.SUM) + dist.all_reduce(byte_count, op=dist.ReduceOp.SUM) + + val_loss = (loss_sum / token_count).item() + bits_per_token = val_loss / math.log(2.0) + tokens_per_byte = token_count.item() / byte_count.item() + base_model.train() + return val_loss, bits_per_token * tokens_per_byte - val_loss = float(loss_sum.item() / token_count.item()) - val_bpb = float((loss_sum.item() / math.log(2.0)) / byte_sum.item()) - return val_loss, val_bpb # ----------------------------- # TRAINING @@ -1059,10 +827,6 @@ def main() -> None: args = Hyperparameters() zeropower_via_newtonschulz5 = torch.compile(zeropower_via_newtonschulz5) - # ----------------------------- - # DISTRIBUTED + CUDA SETUP - # ----------------------------- - distributed = "RANK" in os.environ and "WORLD_SIZE" in os.environ rank = int(os.environ.get("RANK", "0")) world_size = int(os.environ.get("WORLD_SIZE", "1")) @@ -1082,11 +846,9 @@ def main() -> None: dist.barrier() master_process = rank == 0 - # Fast math knobs torch.backends.cuda.matmul.allow_tf32 = True torch.backends.cudnn.allow_tf32 = True from torch.backends.cuda import enable_cudnn_sdp, enable_flash_sdp, enable_math_sdp, enable_mem_efficient_sdp - enable_cudnn_sdp(False) enable_flash_sdp(True) enable_mem_efficient_sdp(False) @@ -1117,10 +879,6 @@ def log0(msg: str, console: bool = True) -> None: ) log0("=" * 100, console=False) - # ----------------------------- - # TOKENIZER + VALIDATION METRIC SETUP - # ----------------------------- - random.seed(args.seed) np.random.seed(args.seed) torch.manual_seed(args.seed) @@ -1143,10 +901,7 @@ def log0(msg: str, console: bool = True) -> None: log0(f"train_loader:dataset:{dataset_dir.name} train_shards:{actual_train_files}") log0(f"val_loader:shards pattern={args.val_files} tokens:{val_tokens.numel() - 1}") - # ----------------------------- # MODEL + OPTIMIZER SETUP - # ----------------------------- - base_model = GPT( vocab_size=args.vocab_size, num_layers=args.num_layers, @@ -1159,50 +914,43 @@ def log0(msg: str, console: bool = True) -> None: logit_softcap=args.logit_softcap, rope_base=args.rope_base, qk_gain_init=args.qk_gain_init, - use_mos=args.use_mos, - mos_k=args.mos_k, - mos_rank=args.mos_rank, + bigram_vocab_size=args.bigram_vocab_size, + bigram_dim=args.bigram_dim, ).to(device).bfloat16() for module in base_model.modules(): if isinstance(module, CastedLinear): module.float() - if isinstance(module, Rotary): - module.inv_freq.data = module.inv_freq.data.float() restore_low_dim_params_to_fp32(base_model) compiled_model = torch.compile(base_model, dynamic=False, fullgraph=True) model: nn.Module = DDP(compiled_model, device_ids=[local_rank], broadcast_buffers=False) if distributed else compiled_model - # Optimizer split: - # - token embedding (Adam) uses EMBED_LR - # - untied lm_head (Adam) uses HEAD_LR - # - matrix params in transformer blocks use MATRIX_LR via Muon - # - vectors/scalars use SCALAR_LR via Adam block_named_params = list(base_model.blocks.named_parameters()) matrix_params = [ - p - for name, p in block_named_params + p for name, p in block_named_params if p.ndim == 2 and not any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) ] scalar_params = [ - p - for name, p in block_named_params + p for name, p in block_named_params if p.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) ] if base_model.skip_weights.numel() > 0: scalar_params.append(base_model.skip_weights) - # MoS parameters: 2D projection weights go to Muon, gate goes to scalar optimizer - if base_model.mos is not None: - if base_model.mos.rank > 0: - matrix_params.append(base_model.mos.proj_down.weight) - matrix_params.append(base_model.mos.proj_up.weight) - else: - matrix_params.append(base_model.mos.projections.weight) - scalar_params.append(base_model.mos.gate.weight) + scalar_params.append(base_model.smear.gate) + if base_model.bigram is not None: + scalar_params.append(base_model.bigram.scale) + token_lr = args.tied_embed_lr if args.tie_embeddings else args.embed_lr - optimizer_tok = torch.optim.Adam( - [{"params": [base_model.tok_emb.weight], "lr": token_lr, "base_lr": token_lr}], + tok_params = [{"params": [base_model.tok_emb.weight], "lr": token_lr, "base_lr": token_lr}] + if base_model.bigram is not None: + tok_params.append({"params": [base_model.bigram.embed.weight], "lr": token_lr, "base_lr": token_lr}) + if base_model.bigram.proj is not None: + matrix_params.append(base_model.bigram.proj.weight) + + optimizer_tok = torch.optim.AdamW( + tok_params, betas=(args.beta1, args.beta2), eps=args.adam_eps, + weight_decay=args.weight_decay, fused=True, ) optimizer_muon = Muon( @@ -1210,13 +958,15 @@ def log0(msg: str, console: bool = True) -> None: lr=args.matrix_lr, momentum=args.muon_momentum, backend_steps=args.muon_backend_steps, + weight_decay=0.04, ) for group in optimizer_muon.param_groups: group["base_lr"] = args.matrix_lr - optimizer_scalar = torch.optim.Adam( + optimizer_scalar = torch.optim.AdamW( [{"params": scalar_params, "lr": args.scalar_lr, "base_lr": args.scalar_lr}], betas=(args.beta1, args.beta2), eps=args.adam_eps, + weight_decay=args.weight_decay, fused=True, ) optimizers: list[torch.optim.Optimizer] = [optimizer_tok, optimizer_muon, optimizer_scalar] @@ -1232,11 +982,9 @@ def log0(msg: str, console: bool = True) -> None: n_params = sum(p.numel() for p in base_model.parameters()) log0(f"model_params:{n_params}") log0(f"world_size:{world_size} grad_accum_steps:{grad_accum_steps}") - log0("sdp_backends:cudnn=False flash=True mem_efficient=False math=False") log0(f"attention_mode:gqa num_heads:{args.num_heads} num_kv_heads:{args.num_kv_heads}") log0( f"tie_embeddings:{args.tie_embeddings} embed_lr:{token_lr} " - f"head_lr:{args.head_lr if base_model.lm_head is not None else 0.0} " f"matrix_lr:{args.matrix_lr} scalar_lr:{args.scalar_lr}" ) log0( @@ -1246,10 +994,7 @@ def log0(msg: str, console: bool = True) -> None: ) log0(f"seed:{args.seed}") - # ----------------------------- # DATA LOADER & MODEL WARMUP - # ----------------------------- - train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) def zero_grad_all() -> None: @@ -1269,8 +1014,6 @@ def lr_mul(step: int, elapsed_ms: float) -> float: remaining_ms = max(max_wallclock_ms - elapsed_ms, 0.0) return remaining_ms / max(warmdown_ms, 1e-9) if remaining_ms <= warmdown_ms else 1.0 - # Warmup primes the compiled forward/backward/optimizer paths, then we restore the - # initial weights/optimizer state so measured training starts from the true init. if args.warmup_steps > 0: initial_model_state = {name: tensor.detach().cpu().clone() for name, tensor in base_model.state_dict().items()} initial_optimizer_states = [copy.deepcopy(opt.state_dict()) for opt in optimizers] @@ -1297,12 +1040,11 @@ def lr_mul(step: int, elapsed_ms: float) -> float: model.require_backward_grad_sync = True train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) - # ----------------------------- # MAIN TRAINING LOOP - # ----------------------------- - training_time_ms = 0.0 stop_after_step: int | None = None + swa_state: dict[str, Tensor] | None = None + swa_count = 0 torch.cuda.synchronize() t0 = time.perf_counter() @@ -1315,16 +1057,8 @@ def lr_mul(step: int, elapsed_ms: float) -> float: torch.cuda.synchronize() training_time_ms += 1000.0 * (time.perf_counter() - t0) val_loss, val_bpb = eval_val( - args, - model, - rank, - world_size, - device, - grad_accum_steps, - val_tokens, - base_bytes_lut, - has_leading_space_lut, - is_boundary_token_lut, + args, model, rank, world_size, device, grad_accum_steps, + val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, ) log0( f"step:{step}/{args.iterations} val_loss:{val_loss:.4f} val_bpb:{val_bpb:.4f} " @@ -1372,6 +1106,18 @@ def lr_mul(step: int, elapsed_ms: float) -> float: step += 1 approx_training_time_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) + + # SWA: collect checkpoints during warmdown + if args.swa_enabled and scale < args.swa_start_frac and step % args.swa_every == 0: + if swa_state is None: + swa_state = {name: t.detach().cpu().clone() for name, t in base_model.state_dict().items()} + swa_count = 1 + log0(f"swa:start step:{step}") + else: + for name, t in base_model.state_dict().items(): + swa_state[name] += t.detach().cpu() + swa_count += 1 + should_log_train = ( args.train_log_every > 0 and (step <= 10 or step % args.train_log_every == 0 or stop_after_step is not None) @@ -1382,7 +1128,6 @@ def lr_mul(step: int, elapsed_ms: float) -> float: f"train_time:{approx_training_time_ms:.0f}ms step_avg:{approx_training_time_ms / step:.2f}ms" ) - # Needed to sync whether we've reached the wallclock cap. reached_cap = max_wallclock_ms is not None and approx_training_time_ms >= max_wallclock_ms if distributed and max_wallclock_ms is not None: reached_cap_tensor = torch.tensor(int(reached_cap), device=device) @@ -1396,12 +1141,17 @@ def lr_mul(step: int, elapsed_ms: float) -> float: f"reserved: {torch.cuda.max_memory_reserved() // 1024 // 1024} MiB" ) - # ----------------------------- - # SERIALIZATION + ROUNDTRIP VALIDATION - # ----------------------------- - # Save the raw state (useful for debugging/loading in PyTorch directly), then always produce - # the compressed int8+zlib artifact and validate the round-tripped weights. + # Apply SWA if collected + if args.swa_enabled and swa_state is not None and swa_count > 1: + log0(f"swa:applying averaged {swa_count} checkpoints") + current_state = base_model.state_dict() + avg_state = { + name: (tensor / swa_count).to(dtype=current_state[name].dtype) + for name, tensor in swa_state.items() + } + base_model.load_state_dict(avg_state, strict=True) + # SERIALIZATION + ROUNDTRIP VALIDATION if master_process: torch.save(base_model.state_dict(), "final_model.pt") model_bytes = os.path.getsize("final_model.pt") @@ -1410,44 +1160,60 @@ def lr_mul(step: int, elapsed_ms: float) -> float: log0(f"Code size: {code_bytes} bytes") log0(f"Total submission size: {model_bytes + code_bytes} bytes") - quant_obj, quant_stats = quantize_state_dict_int8(base_model.state_dict()) + # Magnitude pruning: zero out smallest weights to improve compression + with torch.no_grad(): + for name, param in base_model.named_parameters(): + if param.ndim == 2 and param.numel() > 65536: + threshold = torch.quantile(param.abs().float().flatten(), 0.03) + mask = param.abs() < threshold + param.masked_fill_(mask, 0.0) + + # INT6 mixed quantization + zstd/zlib export + sd_cpu = {k: v.detach().cpu() for k, v in base_model.state_dict().items()} + quant_result, quant_meta = mixed_quantize_int6(sd_cpu, {"mlp", "attn", "bigram"}) quant_buf = io.BytesIO() - torch.save(quant_obj, quant_buf) + torch.save({"w": quant_result, "m": quant_meta}, quant_buf) quant_raw = quant_buf.getvalue() - quant_blob = zlib.compress(quant_raw, level=9) - quant_raw_bytes = len(quant_raw) + if _COMPRESSOR == "zstd": + quant_blob = zstandard.ZstdCompressor(level=22).compress(quant_raw) + else: + quant_blob = zlib.compress(quant_raw, 9) if master_process: with open("final_model.int8.ptz", "wb") as f: f.write(quant_blob) quant_file_bytes = os.path.getsize("final_model.int8.ptz") code_bytes = len(code.encode("utf-8")) - ratio = quant_stats["baseline_tensor_bytes"] / max(quant_stats["int8_payload_bytes"], 1) - log0( - f"Serialized model int8+zlib: {quant_file_bytes} bytes " - f"(payload:{quant_stats['int8_payload_bytes']} raw_torch:{quant_raw_bytes} payload_ratio:{ratio:.2f}x)" - ) + log0(f"Serialized model int6+{_COMPRESSOR}: {quant_file_bytes} bytes") log0(f"Total submission size int8+zlib: {quant_file_bytes + code_bytes} bytes") if distributed: dist.barrier() with open("final_model.int8.ptz", "rb") as f: quant_blob_disk = f.read() - quant_state = torch.load(io.BytesIO(zlib.decompress(quant_blob_disk)), map_location="cpu") - base_model.load_state_dict(dequantize_state_dict_int8(quant_state), strict=True) + if _COMPRESSOR == "zstd": + decompressed = zstandard.ZstdDecompressor().decompress(quant_blob_disk) + else: + decompressed = zlib.decompress(quant_blob_disk) + quant_state = torch.load(io.BytesIO(decompressed), map_location="cpu") + deq_state = dequantize_mixed_int6(quant_state["w"], quant_state["m"], sd_cpu) + base_model.load_state_dict(deq_state, strict=True) + + # Sliding window eval on int6-roundtripped weights torch.cuda.synchronize() t_qeval = time.perf_counter() - q_val_loss, q_val_bpb = eval_val( - args, - model, - rank, - world_size, - device, - grad_accum_steps, - val_tokens, - base_bytes_lut, - has_leading_space_lut, - is_boundary_token_lut, - ) + if args.eval_stride > 0 and args.eval_stride < args.train_seq_len: + log0(f"final_eval_mode:sliding_window stride:{args.eval_stride} batch_seqs:{args.eval_batch_seqs}") + q_val_loss, q_val_bpb = eval_val_sliding( + args, base_model, rank, world_size, device, + val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, + stride=args.eval_stride, batch_seqs=args.eval_batch_seqs, + ) + else: + log0("final_eval_mode:standard") + q_val_loss, q_val_bpb = eval_val( + args, model, rank, world_size, device, grad_accum_steps, + val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, + ) torch.cuda.synchronize() log0( f"final_int8_zlib_roundtrip val_loss:{q_val_loss:.4f} val_bpb:{q_val_bpb:.4f} " @@ -1455,23 +1221,11 @@ def lr_mul(step: int, elapsed_ms: float) -> float: ) log0(f"final_int8_zlib_roundtrip_exact val_loss:{q_val_loss:.8f} val_bpb:{q_val_bpb:.8f}") - # LoRA test-time training evaluation (the competition score) - torch._dynamo.reset() - torch.cuda.synchronize() - t_ttt = time.perf_counter() - ttt_val_loss, ttt_val_bpb = eval_val_ttt_lora( - args, base_model, rank, world_size, device, - base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, - ) - torch.cuda.synchronize() - log0( - f"final_int8_ttt_lora val_loss:{ttt_val_loss:.4f} val_bpb:{ttt_val_bpb:.4f} " - f"eval_time:{1000.0 * (time.perf_counter() - t_ttt):.0f}ms" - ) - if distributed: dist.destroy_process_group() if __name__ == "__main__": main() +# fixes applied +# tuned From 34ff363168968dcb3f872cfe7154cb5e20505eb5 Mon Sep 17 00:00:00 2001 From: User123331 Date: Sat, 21 Mar 2026 12:31:45 +0700 Subject: [PATCH 12/32] Add train_tokenizer_only.sh for focused tokenizer training Co-Authored-By: Claude Opus 4.6 --- train_tokenizer_only.sh | 122 ++++++++++++++++++++++++++++++++++++++++ 1 file changed, 122 insertions(+) create mode 100755 train_tokenizer_only.sh diff --git a/train_tokenizer_only.sh b/train_tokenizer_only.sh new file mode 100755 index 0000000000..26971f6eee --- /dev/null +++ b/train_tokenizer_only.sh @@ -0,0 +1,122 @@ +#!/bin/bash +# ============================================================================= +# Parameter Golf: Train Custom Tokenizer Only +# ============================================================================= +# +# This script: +# 1. Downloads docs_selected.jsonl (~45GB) +# 2. Trains unigram tokenizer +# 3. Evaluates against baseline +# +# Output: data/tokenizers_custom/spm_unigram_1024.model (~1MB) +# +# Usage on RunPod: +# git clone https://github.com/User123331/parameter-golf.git +# cd parameter-golf +# bash train_tokenizer_only.sh +# ============================================================================= + +set -e + +VOCAB_SIZE=1024 +MODEL_TYPE=unigram +MAX_TRAIN_DOCS=200000 +EVAL_DOCS=10000 + +DATA_DIR="./data/datasets" +TOKENIZER_DIR="./data/tokenizers_custom" +DOCS_JSONL="${DATA_DIR}/docs_selected.jsonl" + +GREEN='\033[0;32m' +CYAN='\033[0;36m' +NC='\033[0m' +log() { echo -e "${GREEN}[$(date +%H:%M:%S)]${NC} $*"; } + +echo "" +echo "============================================================" +echo " Custom Tokenizer Training" +echo " Vocab: ${VOCAB_SIZE}, Type: ${MODEL_TYPE}" +echo "============================================================" +echo "" + +# Check disk +AVAIL_GB=$(df -BG . 2>/dev/null | tail -1 | awk '{print $4}' | tr -d 'G' || echo "?") +log "Available disk: ${AVAIL_GB} GB (need ~50GB)" + +# ============================================================================= +# Step 1: Download docs_selected.jsonl +# ============================================================================= +if [ -f "${DOCS_JSONL}" ]; then + log "Docs already exist: $(du -h "${DOCS_JSONL}" | cut -f1)" +else + log "Downloading docs_selected.jsonl (~45GB, 10-30 min)..." + + mkdir -p "${DATA_DIR}" + pip install --quiet huggingface_hub + + python3 -c " +from huggingface_hub import hf_hub_download +import shutil, os + +cached = hf_hub_download( + repo_id='willdepueoai/parameter-golf', + filename='docs_selected.jsonl', + subfolder='datasets', + repo_type='dataset', +) +src = os.path.realpath(cached) +dst = '${DOCS_JSONL}' +print(f'Copying {src} -> {dst}') +try: + os.link(src, dst) + print('Hard-linked (no extra disk)') +except OSError: + shutil.copy2(src, dst) + print('Copied') +" + log "Download complete: $(du -h "${DOCS_JSONL}" | cut -f1)" +fi + +# ============================================================================= +# Step 2: Train unigram tokenizer +# ============================================================================= +log "Training ${MODEL_TYPE} tokenizer (vocab=${VOCAB_SIZE})..." + +mkdir -p "${TOKENIZER_DIR}" +pip install --quiet sentencepiece numpy + +python3 data/train_tokenizer.py \ + --vocab-size ${VOCAB_SIZE} \ + --model-type ${MODEL_TYPE} \ + --docs-path "${DOCS_JSONL}" \ + --max-docs ${MAX_TRAIN_DOCS} \ + --eval-docs ${EVAL_DOCS} \ + --character-coverage 0.995 + +# ============================================================================= +# Step 3: Evaluate baseline for comparison +# ============================================================================= +log "" +log "Evaluating baseline tokenizer for comparison..." + +BASELINE_MODEL="./data/tokenizers/fineweb_1024_bpe.model" +if [ -f "${BASELINE_MODEL}" ]; then + python3 data/train_tokenizer.py \ + --evaluate "${BASELINE_MODEL}" \ + --docs-path "${DOCS_JSONL}" \ + --eval-docs ${EVAL_DOCS} +else + log "Baseline not found at ${BASELINE_MODEL}" + log "Download it with: python data/cached_challenge_fineweb.py --variant sp1024" +fi + +# ============================================================================= +# Summary +# ============================================================================= +echo "" +echo "============================================================" +log "TOKENIZER TRAINED: ${TOKENIZER_DIR}/spm_${MODEL_TYPE}_${VOCAB_SIZE}.model" +log "" +log "To download from RunPod:" +log " scp root@:/workspace/parameter-golf/${TOKENIZER_DIR}/spm_${MODEL_TYPE}_${VOCAB_SIZE}.model ." +echo "============================================================" \ No newline at end of file From db62d709347b4ebfcf4b0c84404c7a20e2f018f0 Mon Sep 17 00:00:00 2001 From: Billy Endson Date: Sun, 22 Mar 2026 10:08:47 +0700 Subject: [PATCH 13/32] Add RunPod SOTA launcher Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- setup_and_run_sota_1xh100.sh | 262 +++++++++++++++++++++++++++++++++++ 1 file changed, 262 insertions(+) create mode 100755 setup_and_run_sota_1xh100.sh diff --git a/setup_and_run_sota_1xh100.sh b/setup_and_run_sota_1xh100.sh new file mode 100755 index 0000000000..3fe4b71b25 --- /dev/null +++ b/setup_and_run_sota_1xh100.sh @@ -0,0 +1,262 @@ +#!/usr/bin/env bash +# === Parameter Golf: SOTA comparison on 1x H100 (RunPod) === +# Default target: PR #198 (1.1318 bpb on 8xH100, current best open result in local notes) +# +# Usage on RunPod: +# git clone https://github.com/User123331/runpod-testing.git +# cd runpod-testing +# HF_TOKEN=hf_xxx bash setup_and_run_sota_1xh100.sh +# +# Optional: +# TARGET_PR=180 bash setup_and_run_sota_1xh100.sh # thwu1 merged record +# SEED=42 bash setup_and_run_sota_1xh100.sh +# TRAIN_SHARDS=1 bash setup_and_run_sota_1xh100.sh # smoke download only + +set -euo pipefail + +log() { printf '[%s] %s\n' "$(date '+%Y-%m-%d %H:%M:%S')" "$*"; } +warn() { printf '[%s] WARNING: %s\n' "$(date '+%Y-%m-%d %H:%M:%S')" "$*" >&2; } +die() { printf '[%s] ERROR: %s\n' "$(date '+%Y-%m-%d %H:%M:%S')" "$*" >&2; exit 1; } + +require_cmd() { + command -v "$1" >/dev/null 2>&1 || die "Required command not found: $1" +} + +require_clean_checkout() { + if ! git diff --quiet || ! git diff --cached --quiet; then + die "Existing checkout at ${SRC_DIR} has uncommitted changes. Use a fresh SRC_DIR." + fi +} + +discover_legacy_hf_token() { + python3 - "$@" <<'PY' +from pathlib import Path +import re +import sys + +pattern = re.compile(r'export\s+HF_TOKEN="\$\{HF_TOKEN:-([^"}]+)\}"') + +for raw_path in sys.argv[1:]: + path = Path(raw_path) + if not path.is_file(): + continue + text = path.read_text(encoding="utf-8", errors="ignore") + match = pattern.search(text) + if match: + print(f"{match.group(1)}\t{path}") + sys.exit(0) + +sys.exit(1) +PY +} + +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" +TARGET_PR="${TARGET_PR:-198}" +WORKSPACE_DIR="${WORKSPACE_DIR:-/workspace}" +SRC_DIR="${SRC_DIR:-${WORKSPACE_DIR}/parameter-golf-pr${TARGET_PR}}" +LOG_DIR="${LOG_DIR:-${WORKSPACE_DIR}/logs}" +TRAIN_SHARDS="${TRAIN_SHARDS:-80}" +MAX_WALLCLOCK_SECONDS="${MAX_WALLCLOCK_SECONDS:-600}" +TRAIN_BATCH_TOKENS="${TRAIN_BATCH_TOKENS:-786432}" +TRAIN_SEQ_LEN="${TRAIN_SEQ_LEN:-2048}" +HF_TOKEN="${HF_TOKEN:-${HUGGINGFACE_TOKEN:-}}" +HF_TOKEN_SOURCE="" + +case "${TARGET_PR}" in + 198) + TARGET_SHA="${TARGET_SHA:-372bddea57f465c7217c5e26af2252a803221518}" + TRAIN_SCRIPT_REL="records/track_10min_16mb/2026-03-20_11L_Int6_MLP3x_WD04_SmearBigram2k_1.1318/train_gpt.py" + NUM_LAYERS="${NUM_LAYERS:-11}" + BIGRAM_VOCAB_SIZE="${BIGRAM_VOCAB_SIZE:-2048}" + MATRIX_LR="${MATRIX_LR:-0.025}" + SCALAR_LR="${SCALAR_LR:-0.025}" + TIED_EMBED_LR="${TIED_EMBED_LR:-0.035}" + MUON_MOMENTUM="${MUON_MOMENTUM:-0.99}" + MUON_MOMENTUM_WARMUP_START="${MUON_MOMENTUM_WARMUP_START:-0.92}" + MUON_MOMENTUM_WARMUP_STEPS="${MUON_MOMENTUM_WARMUP_STEPS:-1500}" + WARMDOWN_ITERS="${WARMDOWN_ITERS:-3000}" + ITERATIONS="${ITERATIONS:-9000}" + EVAL_STRIDE="${EVAL_STRIDE:-64}" + VAL_LOSS_EVERY="${VAL_LOSS_EVERY:-1000}" + TRAIN_LOG_EVERY="${TRAIN_LOG_EVERY:-200}" + SWA_EVERY="${SWA_EVERY:-200}" + MUON_WD="${MUON_WD:-0.04}" + ADAM_WD="${ADAM_WD:-0.04}" + SEED="${SEED:-1337}" + REQUIRED_PY_MODULES="flash_attn_interface" + ;; + 180) + TARGET_SHA="${TARGET_SHA:-1a8be36c17e20b1fb53dbf4975e1d67f5b8a63e9}" + TRAIN_SCRIPT_REL="records/track_10min_16mb/2026-03-20_10L_Int5MLP_MuonWD04_SWA50/train_gpt.py" + NUM_LAYERS="${NUM_LAYERS:-10}" + BIGRAM_VOCAB_SIZE="${BIGRAM_VOCAB_SIZE:-10240}" + MATRIX_LR="${MATRIX_LR:-0.02}" + SCALAR_LR="${SCALAR_LR:-0.02}" + TIED_EMBED_LR="${TIED_EMBED_LR:-0.03}" + MUON_MOMENTUM="${MUON_MOMENTUM:-0.99}" + MUON_MOMENTUM_WARMUP_START="${MUON_MOMENTUM_WARMUP_START:-0.92}" + MUON_MOMENTUM_WARMUP_STEPS="${MUON_MOMENTUM_WARMUP_STEPS:-1500}" + WARMDOWN_ITERS="${WARMDOWN_ITERS:-3000}" + ITERATIONS="${ITERATIONS:-9000}" + EVAL_STRIDE="${EVAL_STRIDE:-64}" + VAL_LOSS_EVERY="${VAL_LOSS_EVERY:-500}" + TRAIN_LOG_EVERY="${TRAIN_LOG_EVERY:-100}" + SWA_START_FRAC="${SWA_START_FRAC:-0.4}" + SWA_EVERY="${SWA_EVERY:-50}" + WEIGHT_DECAY="${WEIGHT_DECAY:-0.04}" + SEED="${SEED:-42}" + REQUIRED_PY_MODULES="" + ;; + *) + die "Unsupported TARGET_PR=${TARGET_PR}. Use 198 (default) or 180." + ;; +esac + +RUN_ID="${RUN_ID:-pr${TARGET_PR}_1xh100_seed${SEED}_$(date +%Y%m%d_%H%M%S)}" +LOG_PATH="${LOG_DIR}/${RUN_ID}.log" + +require_cmd git +require_cmd python3 +require_cmd torchrun +require_cmd nvidia-smi + +GPU_COUNT="$(nvidia-smi --list-gpus | wc -l | tr -d ' ')" +[ "${GPU_COUNT}" -ge 1 ] || die "No GPUs detected." + +log "Detected ${GPU_COUNT} GPU(s). This script uses exactly 1 GPU." +if [ -z "${HF_TOKEN}" ]; then + if TOKEN_RECORD="$( + discover_legacy_hf_token \ + "${SCRIPT_DIR}/setup_and_run.sh" \ + "${SCRIPT_DIR}/setup_and_run_1h.sh" \ + "${SCRIPT_DIR}/run_custom_tokenizer_pipeline.sh" \ + "${SCRIPT_DIR}/../parameter-golf/setup_and_run.sh" \ + "${SCRIPT_DIR}/../parameter-golf/setup_and_run_1h.sh" \ + "${SCRIPT_DIR}/../trainer-tokenizer/setup_runpod.sh" + )"; then + IFS=$'\t' read -r HF_TOKEN HF_TOKEN_SOURCE <<< "${TOKEN_RECORD}" + fi +fi + +if [ -n "${HF_TOKEN}" ]; then + export HF_TOKEN + if [ -n "${HF_TOKEN_SOURCE}" ]; then + warn "Using HF token found in existing local script: ${HF_TOKEN_SOURCE}" + else + log "HF token detected in environment; authenticated downloads enabled." + fi +else + warn "HF_TOKEN/HUGGINGFACE_TOKEN not set. Public downloads may still work, but auth is recommended." +fi + +python3 - <<'PY' +import importlib.util +import subprocess +import sys + +missing = [pkg for pkg in ("huggingface_hub", "zstandard") if importlib.util.find_spec(pkg) is None] +if missing: + subprocess.check_call([sys.executable, "-m", "pip", "install", "--quiet", *missing]) +PY + +mkdir -p "${LOG_DIR}" + +if [ ! -d "${SRC_DIR}/.git" ]; then + log "Cloning openai/parameter-golf into ${SRC_DIR}" + git clone https://github.com/openai/parameter-golf.git "${SRC_DIR}" +fi + +cd "${SRC_DIR}" +require_clean_checkout + +log "Fetching PR #${TARGET_PR}" +git fetch origin "pull/${TARGET_PR}/head:runpod-pr-${TARGET_PR}" --force +git checkout --detach "${TARGET_SHA}" + +CURRENT_SHA="$(git rev-parse HEAD)" +[ "${CURRENT_SHA}" = "${TARGET_SHA}" ] || die "Checked out ${CURRENT_SHA}, expected ${TARGET_SHA}" +log "Checked out PR #${TARGET_PR} commit ${CURRENT_SHA}" +[ -f "${TRAIN_SCRIPT_REL}" ] || die "Target training script not found: ${TRAIN_SCRIPT_REL}" + +python3 - "${TRAIN_SCRIPT_REL}" "${REQUIRED_PY_MODULES}" <<'PY' +from pathlib import Path +import importlib.util +import sys + +train_script = Path(sys.argv[1]) +required_modules = [m for m in sys.argv[2].split(",") if m] + +source = train_script.read_text(encoding="utf-8") +compile(source, str(train_script), "exec") + +missing = [m for m in required_modules if importlib.util.find_spec(m) is None] +if missing: + raise SystemExit(f"Missing required Python modules for {train_script}: {', '.join(missing)}") +PY +log "Preflight compile check passed for ${TRAIN_SCRIPT_REL}" + +DATASET_DIR="data/datasets/fineweb10B_sp1024" +TOKENIZER_PATH="data/tokenizers/fineweb_1024_bpe.model" + +if [ ! -f "${DATASET_DIR}/fineweb_train_000000.bin" ] || [ ! -f "${TOKENIZER_PATH}" ]; then + log "Downloading FineWeb cached dataset/tokenizer (train_shards=${TRAIN_SHARDS})" + python3 data/cached_challenge_fineweb.py --variant sp1024 --train-shards "${TRAIN_SHARDS}" +else + log "Dataset/tokenizer already present; skipping download." +fi + +TRAIN_COUNT="$(find "${DATASET_DIR}" -maxdepth 1 -name 'fineweb_train_*.bin' 2>/dev/null | wc -l | tr -d ' ')" +VAL_COUNT="$(find "${DATASET_DIR}" -maxdepth 1 -name 'fineweb_val_*.bin' 2>/dev/null | wc -l | tr -d ' ')" +log "Dataset ready: train_shards=${TRAIN_COUNT} val_shards=${VAL_COUNT}" + +export PYTHONUNBUFFERED=1 +export RUN_ID +export DATA_PATH="./${DATASET_DIR}" +export TOKENIZER_PATH="./${TOKENIZER_PATH}" +export VOCAB_SIZE="${VOCAB_SIZE:-1024}" +export NUM_LAYERS +export MODEL_DIM="${MODEL_DIM:-512}" +export NUM_HEADS="${NUM_HEADS:-8}" +export NUM_KV_HEADS="${NUM_KV_HEADS:-4}" +export MLP_MULT="${MLP_MULT:-3.0}" +export TIE_EMBEDDINGS="${TIE_EMBEDDINGS:-1}" +export TRAIN_BATCH_TOKENS +export TRAIN_SEQ_LEN +export BIGRAM_VOCAB_SIZE +export BIGRAM_DIM="${BIGRAM_DIM:-128}" +export MATRIX_LR +export SCALAR_LR +export TIED_EMBED_LR +export MUON_MOMENTUM +export MUON_MOMENTUM_WARMUP_START +export MUON_MOMENTUM_WARMUP_STEPS +export SWA_ENABLED="${SWA_ENABLED:-1}" +export EVAL_STRIDE +export EVAL_BATCH_SEQS="${EVAL_BATCH_SEQS:-32}" +export ITERATIONS +export WARMDOWN_ITERS +export MAX_WALLCLOCK_SECONDS +export VAL_LOSS_EVERY +export TRAIN_LOG_EVERY +export SEED + +if [ "${TARGET_PR}" = "198" ]; then + export MUON_WD + export ADAM_WD + export SWA_EVERY +else + export WEIGHT_DECAY + export SWA_START_FRAC + export SWA_EVERY +fi + +log "Starting 1xH100 run for PR #${TARGET_PR}" +log "Run ID: ${RUN_ID}" +log "Log file: ${LOG_PATH}" + +torchrun --standalone --nproc_per_node=1 "${TRAIN_SCRIPT_REL}" 2>&1 | tee "${LOG_PATH}" + +log "Run completed. Final metrics:" +grep -E 'val_bpb|val_loss|artifact|bytes|final_int|submission|model_params|swa:' "${LOG_PATH}" | tail -20 || true + +log "Done." From 32c694d1456e533b67fcd5eebcc96425de05adc9 Mon Sep 17 00:00:00 2001 From: Billy Endson Date: Sun, 22 Mar 2026 11:54:03 +0700 Subject: [PATCH 14/32] Add MoS + SOTA technique stack for competitive testing Mixture of Softmax (K=2) output layer integrated with full SOTA technique stack: 11L Int6 + XSA4 + Partial RoPE + LN Scale + Tight SWA + VE128 + U-Net skips + Late QAT + SmearGate + BigramHash + FA3. - train_gpt_mos_sota.py: MoS class, FA3 soft fallback, nll_loss branch - run_mos_sota.sh: MODE=baseline|mos|smoke, auto FA3 selective build Co-Authored-By: Claude Opus 4.6 --- run_mos_sota.sh | 158 ++++ train_gpt_mos_sota.py | 1750 +++++++++++++++++++++++++++++++++++++++++ 2 files changed, 1908 insertions(+) create mode 100755 run_mos_sota.sh create mode 100644 train_gpt_mos_sota.py diff --git a/run_mos_sota.sh b/run_mos_sota.sh new file mode 100755 index 0000000000..9a20cefc80 --- /dev/null +++ b/run_mos_sota.sh @@ -0,0 +1,158 @@ +#!/usr/bin/env bash +# === Parameter Golf: MoS + SOTA Techniques on 1x/8x H100 (RunPod) === +# Tests Mixture of Softmax (K=2) with full SOTA technique stack. +# +# Usage on RunPod: +# git clone https://github.com/User123331/runpod-testing.git +# cd runpod-testing +# bash run_mos_sota.sh +# +# Modes: +# MODE=baseline bash run_mos_sota.sh # SOTA stack without MoS (control) +# MODE=mos bash run_mos_sota.sh # SOTA stack + MoS K=2 (experiment) +# MODE=smoke bash run_mos_sota.sh # Quick 300s smoke test with MoS + +set -euo pipefail + +log() { printf '[%s] %s\n' "$(date '+%Y-%m-%d %H:%M:%S')" "$*"; } + +MODE="${MODE:-mos}" +SEED="${SEED:-1337}" +MAX_WALLCLOCK_SECONDS="${MAX_WALLCLOCK_SECONDS:-600}" +HF_TOKEN="${HF_TOKEN:-${HUGGINGFACE_TOKEN:-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}}" + +case "${MODE}" in + baseline) + USE_MOS=0 + MOS_K=2 + BIGRAM_VOCAB_SIZE="${BIGRAM_VOCAB_SIZE:-2048}" + RUN_TAG="sota_baseline" + ;; + mos) + USE_MOS=1 + MOS_K="${MOS_K:-2}" + BIGRAM_VOCAB_SIZE="${BIGRAM_VOCAB_SIZE:-1024}" # reduced to fit MoS in 16MB + RUN_TAG="sota_mos_k${MOS_K}" + ;; + smoke) + USE_MOS=1 + MOS_K="${MOS_K:-2}" + BIGRAM_VOCAB_SIZE="${BIGRAM_VOCAB_SIZE:-1024}" + MAX_WALLCLOCK_SECONDS=300 + RUN_TAG="sota_mos_smoke" + ;; + *) + echo "Unknown MODE=${MODE}. Use: baseline, mos, smoke" >&2 + exit 1 + ;; +esac + +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" +TRAIN_SCRIPT="${SCRIPT_DIR}/train_gpt_mos_sota.py" +RUN_ID="${RUN_TAG}_seed${SEED}_$(date +%Y%m%d_%H%M%S)" +LOG_DIR="${LOG_DIR:-${SCRIPT_DIR}/logs}" +mkdir -p "${LOG_DIR}" +LOG_PATH="${LOG_DIR}/${RUN_ID}.log" + +[ -f "${TRAIN_SCRIPT}" ] || { echo "ERROR: ${TRAIN_SCRIPT} not found"; exit 1; } + +# Ensure deps +python3 -c "import huggingface_hub, zstandard" 2>/dev/null || \ + pip install --quiet huggingface_hub zstandard + +# Build FA3 (selective, ~5 min) if not already installed +if ! python3 -c "from flash_attn_interface import flash_attn_func" 2>/dev/null; then + log "FA3 not found. Building selectively (~5 min)..." + if [ -d "/workspace/flash-attention" ]; then + FA3_DIR="/workspace/flash-attention" + else + git clone https://github.com/Dao-AILab/flash-attention.git /workspace/flash-attention + FA3_DIR="/workspace/flash-attention" + fi + cd "${FA3_DIR}/hopper" + # Only build bf16 hdim64 SM90 causal — skip everything else + export FLASH_ATTENTION_DISABLE_FP16=TRUE + export FLASH_ATTENTION_DISABLE_FP8=TRUE + export FLASH_ATTENTION_DISABLE_HDIM96=TRUE + export FLASH_ATTENTION_DISABLE_HDIM128=TRUE + export FLASH_ATTENTION_DISABLE_HDIM192=TRUE + export FLASH_ATTENTION_DISABLE_HDIM256=TRUE + export FLASH_ATTENTION_DISABLE_SM80=TRUE + export FLASH_ATTENTION_DISABLE_PAGEDKV=TRUE + export FLASH_ATTENTION_DISABLE_APPENDKV=TRUE + export FLASH_ATTENTION_DISABLE_SOFTCAP=TRUE + export FLASH_ATTENTION_DISABLE_PACKGQA=TRUE + export FLASH_ATTENTION_DISABLE_VARLEN=TRUE + export FLASH_ATTENTION_DISABLE_SPLIT=TRUE + export FLASH_ATTENTION_DISABLE_LOCAL=TRUE + export FLASH_ATTENTION_DISABLE_CLUSTER=TRUE + export FLASH_ATTENTION_DISABLE_HDIMDIFF64=TRUE + export FLASH_ATTENTION_DISABLE_HDIMDIFF192=TRUE + pip install -e . + cd "${SCRIPT_DIR}" + log "FA3 build complete." +else + log "FA3 already installed." +fi + +# Download dataset if needed +DATA_DIR="data/datasets/fineweb10B_sp1024" +TOK_PATH="data/tokenizers/fineweb_1024_bpe.model" +if [ ! -f "${DATA_DIR}/fineweb_train_000000.bin" ] || [ ! -f "${TOK_PATH}" ]; then + log "Downloading FineWeb dataset..." + if [ -n "${HF_TOKEN}" ]; then export HF_TOKEN; fi + python3 data/cached_challenge_fineweb.py --variant sp1024 --train-shards 80 +fi + +GPU_COUNT="$(nvidia-smi --list-gpus 2>/dev/null | wc -l | tr -d ' ')" +log "Detected ${GPU_COUNT} GPU(s). Mode: ${MODE}" +log "MoS: USE_MOS=${USE_MOS} MOS_K=${MOS_K} BIGRAM_VOCAB_SIZE=${BIGRAM_VOCAB_SIZE}" +log "Run ID: ${RUN_ID}" + +export PYTHONUNBUFFERED=1 +export RUN_ID +export DATA_PATH="./${DATA_DIR}" +export TOKENIZER_PATH="./${TOK_PATH}" +export VOCAB_SIZE=1024 +export NUM_LAYERS=11 +export MODEL_DIM=512 +export NUM_HEADS=8 +export NUM_KV_HEADS=4 +export MLP_MULT=3.0 +export TIE_EMBEDDINGS=1 +export TRAIN_BATCH_TOKENS=786432 +export TRAIN_SEQ_LEN=2048 +export BIGRAM_VOCAB_SIZE +export BIGRAM_DIM=128 +export MATRIX_LR=0.025 +export SCALAR_LR=0.025 +export TIED_EMBED_LR=0.035 +export MUON_MOMENTUM=0.99 +export MUON_MOMENTUM_WARMUP_START=0.92 +export MUON_MOMENTUM_WARMUP_STEPS=1500 +export WARMDOWN_ITERS=3000 +export ITERATIONS=9000 +export MAX_WALLCLOCK_SECONDS +export EVAL_STRIDE=64 +export SWA_ENABLED=1 +export SWA_EVERY=50 +export MUON_WD=0.04 +export ADAM_WD=0.04 +export XSA_LAST_N=4 +export ROPE_DIMS=16 +export LN_SCALE=1 +export LATE_QAT_THRESHOLD=0.1 +export VE_ENABLED=1 +export VE_DIM=128 +export VE_LAYERS="9,10" +export USE_MOS +export MOS_K +export SEED + +log "Starting training..." +torchrun --standalone --nproc_per_node="${GPU_COUNT}" "${TRAIN_SCRIPT}" 2>&1 | tee "${LOG_PATH}" + +log "Run completed. Key metrics:" +grep -E 'val_bpb|model_params|mos_params|final_int|submission|Serialized|artifact|swa:' "${LOG_PATH}" | tail -20 || true + +log "Done. Log: ${LOG_PATH}" diff --git a/train_gpt_mos_sota.py b/train_gpt_mos_sota.py new file mode 100644 index 0000000000..d229422102 --- /dev/null +++ b/train_gpt_mos_sota.py @@ -0,0 +1,1750 @@ +""" +MoS + SOTA technique stack for Parameter Golf. +Mixture of Softmax (K=2) output layer with 11L Int6 + XSA4 + Partial RoPE + LN Scale + +Tight SWA + Shared VE128 + U-Net skips + Late QAT + SmearGate + BigramHash + FA3. +""" + +from __future__ import annotations + +import copy +import glob +import io +import math +import os +import random +import subprocess +import sys +import time +import uuid +import zlib +from pathlib import Path + +try: + import zstandard + _COMPRESSOR = "zstd" +except ImportError: + _COMPRESSOR = "zlib" + +import numpy as np +import sentencepiece as spm +import torch +import torch._dynamo +torch._dynamo.config.optimize_ddp = False # Required for FA3 + torch.compile backward pass +import torch.distributed as dist +import torch.nn.functional as F +from torch import Tensor, nn +from torch.nn.parallel import DistributedDataParallel as DDP + +try: + from flash_attn_interface import flash_attn_func as flash_attn_3_func + _FA3_AVAILABLE = True +except ImportError: + _FA3_AVAILABLE = False + +# ----------------------------- +# HYPERPARAMETERS +# ----------------------------- +# Default Simple Baseline run: +# - 9 transformer blocks at width 512 +# - 8 attention heads with 4 KV heads (GQA) and 2x MLP expansion +# - vocab size 1024, sequence length 1024, tied embeddings +# - 524,288 train tokens per step for 20,000 iterations with a ~10 minute cap + +class Hyperparameters: + # Data paths are shard globs produced by the existing preprocessing pipeline. + data_path = os.environ.get("DATA_PATH", "./data/datasets/fineweb10B_sp1024") + train_files = os.path.join(data_path, "fineweb_train_*.bin") + val_files = os.path.join(data_path, "fineweb_val_*.bin") + tokenizer_path = os.environ.get("TOKENIZER_PATH", "./data/tokenizers/fineweb_1024_bpe.model") + run_id = os.environ.get("RUN_ID", str(uuid.uuid4())) + seed = int(os.environ.get("SEED", 1337)) + + # Validation cadence and batch size. Validation always uses the full fineweb_val split. + val_batch_size = int(os.environ.get("VAL_BATCH_SIZE", 524_288)) + val_loss_every = int(os.environ.get("VAL_LOSS_EVERY", 4000)) + train_log_every = int(os.environ.get("TRAIN_LOG_EVERY", 500)) + + # Training length. + iterations = int(os.environ.get("ITERATIONS", 20000)) + warmdown_iters = int(os.environ.get("WARMDOWN_ITERS", 3000)) + warmup_steps = int(os.environ.get("WARMUP_STEPS", 20)) + train_batch_tokens = int(os.environ.get("TRAIN_BATCH_TOKENS", 786_432)) + train_seq_len = int(os.environ.get("TRAIN_SEQ_LEN", 2048)) + eval_seq_len = int(os.environ.get("EVAL_SEQ_LEN", 2048)) + max_wallclock_seconds = float(os.environ.get("MAX_WALLCLOCK_SECONDS", 600.0)) + qk_gain_init = float(os.environ.get("QK_GAIN_INIT", 1.5)) + + # Model shape. + vocab_size = int(os.environ.get("VOCAB_SIZE", 1024)) + num_layers = int(os.environ.get("NUM_LAYERS", 11)) + num_kv_heads = int(os.environ.get("NUM_KV_HEADS", 4)) + model_dim = int(os.environ.get("MODEL_DIM", 512)) + num_heads = int(os.environ.get("NUM_HEADS", 8)) + mlp_mult = float(os.environ.get("MLP_MULT", 3.0)) + tie_embeddings = bool(int(os.environ.get("TIE_EMBEDDINGS", "1"))) + rope_base = float(os.environ.get("ROPE_BASE", 10000.0)) + logit_softcap = float(os.environ.get("LOGIT_SOFTCAP", 30.0)) + + # Optimizer hyperparameters. + embed_lr = float(os.environ.get("EMBED_LR", 0.6)) + head_lr = float(os.environ.get("HEAD_LR", 0.008)) + tied_embed_lr = float(os.environ.get("TIED_EMBED_LR", 0.035)) + tied_embed_init_std = float(os.environ.get("TIED_EMBED_INIT_STD", 0.005)) + matrix_lr = float(os.environ.get("MATRIX_LR", 0.025)) + scalar_lr = float(os.environ.get("SCALAR_LR", 0.025)) + muon_momentum = float(os.environ.get("MUON_MOMENTUM", 0.99)) + muon_backend_steps = int(os.environ.get("MUON_BACKEND_STEPS", 5)) + muon_momentum_warmup_start = float(os.environ.get("MUON_MOMENTUM_WARMUP_START", 0.92)) + muon_momentum_warmup_steps = int(os.environ.get("MUON_MOMENTUM_WARMUP_STEPS", 1500)) + beta1 = float(os.environ.get("BETA1", 0.9)) + beta2 = float(os.environ.get("BETA2", 0.95)) + adam_eps = float(os.environ.get("ADAM_EPS", 1e-8)) + grad_clip_norm = float(os.environ.get("GRAD_CLIP_NORM", 0.3)) + eval_stride = int(os.environ.get("EVAL_STRIDE", 64)) + mtp_num_heads = int(os.environ.get("MTP_NUM_HEADS", 0)) + mtp_loss_weight = float(os.environ.get("MTP_LOSS_WEIGHT", 0.2)) + muon_beta2 = float(os.environ.get("MUON_BETA2", 0.95)) + swa_enabled = bool(int(os.environ.get("SWA_ENABLED", "1"))) + swa_every = int(os.environ.get("SWA_EVERY", 50)) # tighter: collect more recent checkpoints + muon_wd = float(os.environ.get("MUON_WD", 0.04)) + adam_wd = float(os.environ.get("ADAM_WD", 0.04)) + qat_enabled = bool(int(os.environ.get("QAT_ENABLED", "0"))) + bigram_vocab_size = int(os.environ.get("BIGRAM_VOCAB_SIZE", 2048)) + bigram_dim = int(os.environ.get("BIGRAM_DIM", 128)) + + # Efficient partial XSA: apply to last N layers only (deep layers have highest self-attention bias) + xsa_last_n = int(os.environ.get("XSA_LAST_N", 4)) # XSA on last 4 layers (0 = disabled) + rope_dims = int(os.environ.get("ROPE_DIMS", 16)) + ln_scale = bool(int(os.environ.get("LN_SCALE", "1"))) + dtg_enabled = bool(int(os.environ.get("DTG_ENABLED", "0"))) + late_qat_threshold = float(os.environ.get("LATE_QAT_THRESHOLD", 0.1)) + + # Value Embeddings: 1 shared table, per-layer scales (saves 50% VE params) + ve_enabled = bool(int(os.environ.get("VE_ENABLED", "1"))) + ve_dim = int(os.environ.get("VE_DIM", 128)) + ve_layers = os.environ.get("VE_LAYERS", "9,10") + + # Mixture of Softmax output layer + use_mos = bool(int(os.environ.get("USE_MOS", "0"))) + mos_k = int(os.environ.get("MOS_K", 2)) + +# ----------------------------- +# MUON OPTIMIZER +# ----------------------------- +# +# As borrowed from modded-nanogpt +# Background on Muon: https://kellerjordan.github.io/posts/muon/ + +def zeropower_via_newtonschulz5(G: Tensor, steps: int = 10, eps: float = 1e-7) -> Tensor: + a, b, c = (3.4445, -4.7750, 2.0315) + X = G.bfloat16() + X /= X.norm() + eps + transposed = G.size(0) > G.size(1) + if transposed: + X = X.T + for _ in range(steps): + A = X @ X.T + B = b * A + c * A @ A + X = a * X + B @ X + return X.T if transposed else X + + +class Muon(torch.optim.Optimizer): + def __init__(self, params, lr: float, momentum: float, backend_steps: int, + nesterov: bool = True, weight_decay: float = 0.0): + super().__init__( + params, + dict(lr=lr, momentum=momentum, backend_steps=backend_steps, + nesterov=nesterov, weight_decay=weight_decay), + ) + + @torch.no_grad() + def step(self, closure=None): + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + + distributed = dist.is_available() and dist.is_initialized() + world_size = dist.get_world_size() if distributed else 1 + rank = dist.get_rank() if distributed else 0 + + for group in self.param_groups: + params = group["params"] + if not params: + continue + lr = group["lr"] + momentum = group["momentum"] + backend_steps = group["backend_steps"] + nesterov = group["nesterov"] + + total_params = sum(int(p.numel()) for p in params) + updates_flat = torch.zeros(total_params, device=params[0].device, dtype=torch.bfloat16) + + curr = 0 + for i, p in enumerate(params): + if i % world_size == rank and p.grad is not None: + g = p.grad + state = self.state[p] + if "momentum_buffer" not in state: + state["momentum_buffer"] = torch.zeros_like(g) + buf = state["momentum_buffer"] + buf.mul_(momentum).add_(g) + if nesterov: + g = g.add(buf, alpha=momentum) + g = zeropower_via_newtonschulz5(g, steps=backend_steps) + g *= max(1, g.size(0) / g.size(1)) ** 0.5 + updates_flat[curr : curr + p.numel()] = g.reshape(-1) + curr += p.numel() + + if distributed: + dist.all_reduce(updates_flat, op=dist.ReduceOp.SUM) + + wd = group.get("weight_decay", 0.0) + curr = 0 + for p in params: + if wd > 0.0: + p.data.mul_(1.0 - lr * wd) + g = updates_flat[curr : curr + p.numel()].view_as(p).to(dtype=p.dtype) + p.add_(g, alpha=-lr) + curr += p.numel() + + return loss + + +# ----------------------------- +# TOKENIZER-AGNOSTIC EVALUATION SETUP +# ----------------------------- +# +# It's common for small models have a large fraction of their parameters be embeddings, since the 2 * d_model * d_vocab vectors can be gigantic. +# Instead of locking the tokenizer, we let you bring your own and calculate our validation metrics on the average compression of the validation set. +# We calculate BPB (bits-per-byte) instead of validation loss, so we need methods to count the number of bits per token in the tokenizer. +# Note: Submissions that edit the tokenizer will be examined more carefully, since screwing this up might unjustly improve your score. + +def build_sentencepiece_luts( + sp: spm.SentencePieceProcessor, vocab_size: int, device: torch.device +) -> tuple[Tensor, Tensor, Tensor]: + sp_vocab_size = int(sp.vocab_size()) + table_size = max(sp_vocab_size, vocab_size) + base_bytes_np = np.zeros((table_size,), dtype=np.int16) + has_leading_space_np = np.zeros((table_size,), dtype=np.bool_) + is_boundary_token_np = np.ones((table_size,), dtype=np.bool_) + for token_id in range(sp_vocab_size): + if sp.is_control(token_id) or sp.is_unknown(token_id) or sp.is_unused(token_id): + continue + is_boundary_token_np[token_id] = False + if sp.is_byte(token_id): + base_bytes_np[token_id] = 1 + continue + piece = sp.id_to_piece(token_id) + if piece.startswith("▁"): + has_leading_space_np[token_id] = True + piece = piece[1:] + base_bytes_np[token_id] = len(piece.encode("utf-8")) + return ( + torch.tensor(base_bytes_np, dtype=torch.int16, device=device), + torch.tensor(has_leading_space_np, dtype=torch.bool, device=device), + torch.tensor(is_boundary_token_np, dtype=torch.bool, device=device), + ) + + +def load_validation_tokens(pattern: str, seq_len: int) -> Tensor: + files = [Path(p) for p in sorted(glob.glob(pattern))] + if not files: + raise FileNotFoundError(f"No files found for pattern: {pattern}") + # The export pipeline writes the fixed first-50k-doc validation set to fineweb_val_*. + tokens = torch.cat([load_data_shard(file) for file in files]).contiguous() + usable = ((tokens.numel() - 1) // seq_len) * seq_len + if usable <= 0: + raise ValueError(f"Validation split is too short for TRAIN_SEQ_LEN={seq_len}") + return tokens[: usable + 1] + + +def eval_val( + args: Hyperparameters, + model: nn.Module, + rank: int, + world_size: int, + device: torch.device, + grad_accum_steps: int, + val_tokens: Tensor, + base_bytes_lut: Tensor, + has_leading_space_lut: Tensor, + is_boundary_token_lut: Tensor, + eval_seq_len: int | None = None, +) -> tuple[float, float]: + seq_len = eval_seq_len or args.train_seq_len + local_batch_tokens = args.val_batch_size // (world_size * grad_accum_steps) + if local_batch_tokens < seq_len: + raise ValueError( + "VAL_BATCH_SIZE must provide at least one sequence per rank; " + f"got VAL_BATCH_SIZE={args.val_batch_size}, WORLD_SIZE={world_size}, " + f"GRAD_ACCUM_STEPS={grad_accum_steps}, seq_len={seq_len}" + ) + local_batch_seqs = local_batch_tokens // seq_len + total_seqs = (val_tokens.numel() - 1) // seq_len + seq_start = (total_seqs * rank) // world_size + seq_end = (total_seqs * (rank + 1)) // world_size + val_loss_sum = torch.zeros((), device=device, dtype=torch.float64) + val_token_count = torch.zeros((), device=device, dtype=torch.float64) + val_byte_count = torch.zeros((), device=device, dtype=torch.float64) + + model.eval() + with torch.inference_mode(): + for batch_seq_start in range(seq_start, seq_end, local_batch_seqs): + batch_seq_end = min(batch_seq_start + local_batch_seqs, seq_end) + raw_start = batch_seq_start * seq_len + raw_end = batch_seq_end * seq_len + 1 + local = val_tokens[raw_start:raw_end].to(device=device, dtype=torch.int64, non_blocking=True) + x = local[:-1].reshape(-1, seq_len) + y = local[1:].reshape(-1, seq_len) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + batch_loss = model(x, y).detach() + batch_token_count = float(y.numel()) + val_loss_sum += batch_loss.to(torch.float64) * batch_token_count + val_token_count += batch_token_count + prev_ids = x.reshape(-1) + tgt_ids = y.reshape(-1) + token_bytes = base_bytes_lut[tgt_ids].to(dtype=torch.int16) + token_bytes += (has_leading_space_lut[tgt_ids] & ~is_boundary_token_lut[prev_ids]).to(dtype=torch.int16) + val_byte_count += token_bytes.to(torch.float64).sum() + + if dist.is_available() and dist.is_initialized(): + dist.all_reduce(val_loss_sum, op=dist.ReduceOp.SUM) + dist.all_reduce(val_token_count, op=dist.ReduceOp.SUM) + dist.all_reduce(val_byte_count, op=dist.ReduceOp.SUM) + + val_loss = val_loss_sum / val_token_count + bits_per_token = val_loss.item() / math.log(2.0) + tokens_per_byte = val_token_count.item() / val_byte_count.item() + model.train() + return float(val_loss.item()), float(bits_per_token * tokens_per_byte) + +# ----------------------------- +# POST-TRAINING QUANTIZATION +# ----------------------------- +# +# It's silly to export our model, which is trained in bf16 and fp32, at that same precision. +# Instead, we get approximately the same model (with a small hit) by quantizing the model to int8 & zlib compressing. +# We can then decompress the model and run in higher precision for evaluation, after closing in under the size limit. + +CONTROL_TENSOR_NAME_PATTERNS = tuple( + pattern + for pattern in os.environ.get( + "CONTROL_TENSOR_NAME_PATTERNS", + "attn_scale,attn_scales,mlp_scale,mlp_scales,resid_mix,resid_mixes,q_gain,skip_weight,skip_weights,smear,dtg_gate,ve_layer_scales,ve_shared.scale", + ).split(",") + if pattern +) +INT8_KEEP_FLOAT_FP32_NAME_PATTERNS = tuple( + pattern + for pattern in os.environ.get( + "INT8_KEEP_FLOAT_FP32_NAME_PATTERNS", + ",".join(CONTROL_TENSOR_NAME_PATTERNS), + ).split(",") + if pattern +) +INT8_KEEP_FLOAT_MAX_NUMEL = 65_536 +INT8_KEEP_FLOAT_STORE_DTYPE = torch.float16 +INT8_PER_ROW_SCALE_DTYPE = torch.float16 +INT8_CLIP_PERCENTILE = 99.99984 +INT8_CLIP_Q = INT8_CLIP_PERCENTILE / 100.0 + +def tensor_nbytes(t: Tensor) -> int: + return int(t.numel()) * int(t.element_size()) + +def keep_float_tensor(name: str, t: Tensor, passthrough_orig_dtypes: dict[str, str]) -> Tensor: + if any(pattern in name for pattern in INT8_KEEP_FLOAT_FP32_NAME_PATTERNS): + return t.float().contiguous() + if t.dtype in {torch.float32, torch.bfloat16}: + passthrough_orig_dtypes[name] = str(t.dtype).removeprefix("torch.") + return t.to(dtype=INT8_KEEP_FLOAT_STORE_DTYPE).contiguous() + return t + +def quantize_float_tensor(t: Tensor) -> tuple[Tensor, Tensor]: + t32 = t.float() + if t32.ndim == 2: + # Matrices get one scale per row, which usually tracks output-channel + # ranges much better than a single tensor-wide scale. + clip_abs = ( + torch.quantile(t32.abs(), INT8_CLIP_Q, dim=1) + if t32.numel() + else torch.empty((t32.shape[0],), dtype=torch.float32) + ) + clipped = torch.maximum(torch.minimum(t32, clip_abs[:, None]), -clip_abs[:, None]) + scale = (clip_abs / 127.0).clamp_min(1.0 / 127.0) + q = torch.clamp(torch.round(clipped / scale[:, None]), -127, 127).to(torch.int8).contiguous() + return q, scale.to(dtype=INT8_PER_ROW_SCALE_DTYPE).contiguous() + + # Vectors / scalars use a simpler per-tensor scale. + clip_abs = float(torch.quantile(t32.abs().flatten(), INT8_CLIP_Q).item()) if t32.numel() else 0.0 + scale = torch.tensor(clip_abs / 127.0 if clip_abs > 0 else 1.0, dtype=torch.float32) + q = torch.clamp(torch.round(torch.clamp(t32, -clip_abs, clip_abs) / scale), -127, 127).to(torch.int8).contiguous() + return q, scale + +def quantize_state_dict_int8(state_dict: dict[str, Tensor]): + # Single supported clean-script export format: + # - per-row int8 for 2D float tensors + # - per-tensor int8 for other float tensors + # - exact passthrough for non-floats + # - passthrough for small float tensors, stored as fp16 to save bytes + quantized: dict[str, Tensor] = {} + scales: dict[str, Tensor] = {} + dtypes: dict[str, str] = {} + passthrough: dict[str, Tensor] = {} + passthrough_orig_dtypes: dict[str, str] = {} + qmeta: dict[str, dict[str, object]] = {} + stats = dict.fromkeys( + ("param_count", "num_tensors", "num_float_tensors", "num_nonfloat_tensors", "baseline_tensor_bytes", "int8_payload_bytes"), + 0, + ) + + for name, tensor in state_dict.items(): + t = tensor.detach().to("cpu").contiguous() + stats["param_count"] += int(t.numel()) + stats["num_tensors"] += 1 + stats["baseline_tensor_bytes"] += tensor_nbytes(t) + + if not t.is_floating_point(): + stats["num_nonfloat_tensors"] += 1 + passthrough[name] = t + stats["int8_payload_bytes"] += tensor_nbytes(t) + continue + + # Small float tensors are cheap enough to keep directly. We still downcast + # fp32/bf16 passthrough tensors to fp16 so metadata does not dominate size. + if t.numel() <= INT8_KEEP_FLOAT_MAX_NUMEL: + kept = keep_float_tensor(name, t, passthrough_orig_dtypes) + passthrough[name] = kept + stats["int8_payload_bytes"] += tensor_nbytes(kept) + continue + + stats["num_float_tensors"] += 1 + q, s = quantize_float_tensor(t) + if s.ndim > 0: + qmeta[name] = {"scheme": "per_row", "axis": 0} + quantized[name] = q + scales[name] = s + dtypes[name] = str(t.dtype).removeprefix("torch.") + stats["int8_payload_bytes"] += tensor_nbytes(q) + tensor_nbytes(s) + + obj: dict[str, object] = { + "__quant_format__": "int8_clean_per_row_v1", + "quantized": quantized, + "scales": scales, + "dtypes": dtypes, + "passthrough": passthrough, + } + if qmeta: + obj["qmeta"] = qmeta + if passthrough_orig_dtypes: + obj["passthrough_orig_dtypes"] = passthrough_orig_dtypes + return obj, stats + +def dequantize_state_dict_int8(obj: dict[str, object]) -> dict[str, Tensor]: + out: dict[str, Tensor] = {} + qmeta = obj.get("qmeta", {}) + passthrough_orig_dtypes = obj.get("passthrough_orig_dtypes", {}) + for name, q in obj["quantized"].items(): + dtype = getattr(torch, obj["dtypes"][name]) + s = obj["scales"][name] + if qmeta.get(name, {}).get("scheme") == "per_row" or s.ndim > 0: + s = s.to(dtype=torch.float32) + # Broadcast the saved row scale back across trailing dimensions. + out[name] = (q.float() * s.view(q.shape[0], *([1] * (q.ndim - 1)))).to(dtype=dtype).contiguous() + else: + scale = float(s.item()) + out[name] = (q.float() * scale).to(dtype=dtype).contiguous() + for name, t in obj["passthrough"].items(): + # Restore small tensors, undoing the temporary fp16 storage cast if needed. + out_t = t.detach().to("cpu").contiguous() + orig_dtype = passthrough_orig_dtypes.get(name) + if isinstance(orig_dtype, str): + out_t = out_t.to(dtype=getattr(torch, orig_dtype)).contiguous() + out[name] = out_t + return out + + +# ----------------------------- +# DATA LOADING +# ----------------------------- + +def load_data_shard(file: Path) -> Tensor: + header_bytes = 256 * np.dtype(" None: + self.file_idx = (self.file_idx + 1) % len(self.files) + self.tokens = load_data_shard(self.files[self.file_idx]) + self.pos = 0 + + def take(self, n: int) -> Tensor: + chunks: list[Tensor] = [] + remaining = n + while remaining > 0: + avail = self.tokens.numel() - self.pos + if avail <= 0: + self._advance_file() + continue + k = min(remaining, avail) + chunks.append(self.tokens[self.pos : self.pos + k]) + self.pos += k + remaining -= k + return chunks[0] if len(chunks) == 1 else torch.cat(chunks) + + +class DistributedTokenLoader: + # Each call consumes a contiguous chunk from the shared token stream, then slices out + # one disjoint span per rank. The extra "+1" token lets us build (x, y) by shifting. + def __init__(self, pattern: str, rank: int, world_size: int, device: torch.device): + self.rank = rank + self.world_size = world_size + self.device = device + self.stream = TokenStream(pattern) + + def next_batch(self, global_tokens: int, seq_len: int, grad_accum_steps: int) -> tuple[Tensor, Tensor]: + local_tokens = global_tokens // (self.world_size * grad_accum_steps) + per_rank_span = local_tokens + 1 + chunk = self.stream.take(per_rank_span * self.world_size) + start = self.rank * per_rank_span + local = chunk[start : start + per_rank_span].to(dtype=torch.int64) + x = local[:-1].reshape(-1, seq_len) + y = local[1:].reshape(-1, seq_len) + return x.to(self.device, non_blocking=True), y.to(self.device, non_blocking=True) + +# ----------------------------- +# TRANSFORMER MODULES +# ----------------------------- + +class RMSNorm(nn.Module): + def __init__(self, eps: float | None = None): + super().__init__() + self.eps = eps + + def forward(self, x: Tensor) -> Tensor: + return F.rms_norm(x, (x.size(-1),), eps=self.eps) + + +class CastedLinear(nn.Linear): + _qat_enabled: bool = False + + def forward(self, x: Tensor) -> Tensor: + w = self.weight.to(x.dtype) + if CastedLinear._qat_enabled and self.training and w.ndim == 2: + with torch.no_grad(): + w32 = self.weight.float() + row_max = w32.abs().amax(dim=1) + scale = (row_max / 31.0).clamp_min(1.0 / 31.0) + w_q = (torch.clamp(torch.round(w32 / scale[:, None]), -32, 31) * scale[:, None]).to(x.dtype) + w = w + (w_q - w).detach() + bias = self.bias.to(x.dtype) if self.bias is not None else None + return F.linear(x, w, bias) + + +def restore_low_dim_params_to_fp32(module: nn.Module) -> None: + # Keep small/control parameters in fp32 even when the model body runs in bf16. + with torch.no_grad(): + for name, param in module.named_parameters(): + if (param.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS)) and param.dtype != torch.float32: + param.data = param.data.float() + + +class Rotary(nn.Module): + # NTK-aware RoPE: auto-scales base frequency when seq_len exceeds train_seq_len. + def __init__(self, dim: int, base: float = 10000.0, train_seq_len: int = 1024, rope_dims: int = 0): + super().__init__() + self.dim = dim + self.base = base + self.train_seq_len = train_seq_len + self.rope_dims = rope_dims if rope_dims > 0 else dim + inv_freq = 1.0 / (base ** (torch.arange(0, self.rope_dims, 2, dtype=torch.float32) / self.rope_dims)) + self.register_buffer("inv_freq", inv_freq, persistent=False) + self._seq_len_cached = 0 + self._cos_cached: Tensor | None = None + self._sin_cached: Tensor | None = None + + def forward(self, seq_len: int, device: torch.device, dtype: torch.dtype) -> tuple[Tensor, Tensor]: + if ( + self._cos_cached is None + or self._sin_cached is None + or self._seq_len_cached != seq_len + or self._cos_cached.device != device + ): + rd = self.rope_dims + if seq_len > self.train_seq_len: + scale = seq_len / self.train_seq_len + new_base = self.base * (scale ** (rd / (rd - 2))) + inv_freq = 1.0 / (new_base ** (torch.arange(0, rd, 2, dtype=torch.float32, device=device) / rd)) + else: + inv_freq = self.inv_freq.to(device) + t = torch.arange(seq_len, device=device, dtype=inv_freq.dtype) + freqs = torch.outer(t, inv_freq) + self._cos_cached = freqs.cos()[None, :, None, :] + self._sin_cached = freqs.sin()[None, :, None, :] + self._seq_len_cached = seq_len + return self._cos_cached.to(dtype=dtype), self._sin_cached.to(dtype=dtype) + + +def apply_rotary_emb(x: Tensor, cos: Tensor, sin: Tensor, rope_dims: int = 0) -> Tensor: + if rope_dims > 0 and rope_dims < x.size(-1): + x_rope, x_pass = x[..., :rope_dims], x[..., rope_dims:] + half = rope_dims // 2 + x1, x2 = x_rope[..., :half], x_rope[..., half:] + x_rope = torch.cat((x1 * cos + x2 * sin, x1 * (-sin) + x2 * cos), dim=-1) + return torch.cat((x_rope, x_pass), dim=-1) + half = x.size(-1) // 2 + x1, x2 = x[..., :half], x[..., half:] + return torch.cat((x1 * cos + x2 * sin, x1 * (-sin) + x2 * cos), dim=-1) + + +class CausalSelfAttention(nn.Module): + def __init__( + self, + dim: int, + num_heads: int, + num_kv_heads: int, + rope_base: float, + qk_gain_init: float, + ): + super().__init__() + if dim % num_heads != 0: + raise ValueError("model_dim must be divisible by num_heads") + if num_heads % num_kv_heads != 0: + raise ValueError("num_heads must be divisible by num_kv_heads") + self.num_heads = num_heads + self.num_kv_heads = num_kv_heads + self.head_dim = dim // num_heads + if self.head_dim % 2 != 0: + raise ValueError("head_dim must be even for RoPE") + kv_dim = self.num_kv_heads * self.head_dim + self.c_q = CastedLinear(dim, dim, bias=False) + self.c_k = CastedLinear(dim, kv_dim, bias=False) + self.c_v = CastedLinear(dim, kv_dim, bias=False) + self.proj = CastedLinear(dim, dim, bias=False) + self.proj._zero_init = True + self.q_gain = nn.Parameter(torch.full((num_heads,), qk_gain_init, dtype=torch.float32)) + self.rope_dims = 0 # set by GPT.__init__ for partial RoPE + self.rotary = Rotary(self.head_dim, base=rope_base, train_seq_len=1024) + self.use_xsa = False # set by GPT.__init__ for deep layers only + + def _xsa_efficient(self, y: Tensor, v: Tensor) -> Tensor: + """Efficient XSA: subtract self-value projection via GQA-aware reshape (no repeat_interleave). + y: [B, T, H, D], v: [B, T, Hkv, D]. H must be divisible by Hkv.""" + B, T, H, D = y.shape + Hkv = v.size(-2) + group = H // Hkv + # Reshape y into KV head groups — free view, no memory alloc + y_g = y.reshape(B, T, Hkv, group, D) # [B, T, Hkv, group, D] + vn = F.normalize(v, dim=-1).unsqueeze(-2) # [B, T, Hkv, 1, D] — broadcast ready + # Project out self-value component per KV head group + proj = (y_g * vn).sum(dim=-1, keepdim=True) * vn + return (y_g - proj).reshape(B, T, H, D) + + def forward(self, x: Tensor, v_embed: Tensor | None = None) -> Tensor: + bsz, seqlen, dim = x.shape + q = self.c_q(x).reshape(bsz, seqlen, self.num_heads, self.head_dim) + k = self.c_k(x).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim) + v = self.c_v(x) + # Value embedding: add token identity directly to values before reshape + if v_embed is not None: + v = v + v_embed + v = v.reshape(bsz, seqlen, self.num_kv_heads, self.head_dim) + q = F.rms_norm(q, (q.size(-1),)) + k = F.rms_norm(k, (k.size(-1),)) + cos, sin = self.rotary(seqlen, x.device, q.dtype) + q = apply_rotary_emb(q, cos, sin, self.rope_dims) + k = apply_rotary_emb(k, cos, sin, self.rope_dims) + q = q * self.q_gain.to(dtype=q.dtype)[None, None, :, None] + if _FA3_AVAILABLE: + y = flash_attn_3_func(q, k, v, causal=True) + else: + # SDPA fallback: FA3 uses [B, T, H, D], SDPA uses [B, H, T, D] + q_t = q.transpose(1, 2) + k_t = k.transpose(1, 2) + v_t = v.transpose(1, 2) + y = F.scaled_dot_product_attention( + q_t, k_t, v_t, attn_mask=None, is_causal=True, + enable_gqa=(self.num_kv_heads != self.num_heads), + ).transpose(1, 2).contiguous() + if self.use_xsa: + y = self._xsa_efficient(y, v) + y = y.reshape(bsz, seqlen, dim) + return self.proj(y) + + +class SmearGate(nn.Module): + def __init__(self, dim: int): + super().__init__() + self.gate = nn.Parameter(torch.zeros(dim, dtype=torch.float32)) + + def forward(self, x: Tensor) -> Tensor: + g = torch.sigmoid(self.gate.to(dtype=x.dtype))[None, None, :] + x_prev = torch.cat([torch.zeros_like(x[:, :1]), x[:, :-1]], dim=1) + return (1 - g) * x + g * x_prev + + +class BigramHashEmbedding(nn.Module): + def __init__(self, bigram_vocab_size: int, bigram_dim: int, model_dim: int): + super().__init__() + self.bigram_vocab_size = bigram_vocab_size + self.embed = nn.Embedding(bigram_vocab_size, bigram_dim) + nn.init.zeros_(self.embed.weight) + self.proj = CastedLinear(bigram_dim, model_dim, bias=False) if bigram_dim != model_dim else None + if self.proj is not None: + nn.init.zeros_(self.proj.weight) + self.scale = nn.Parameter(torch.tensor(0.05, dtype=torch.float32)) + + def bigram_hash(self, tokens: Tensor) -> Tensor: + t = tokens.to(torch.int32) + mod = self.bigram_vocab_size - 1 + out = torch.empty_like(t) + out[..., 0] = mod + out[..., 1:] = torch.bitwise_xor(36313 * t[..., 1:], 27191 * t[..., :-1]) % mod + return out.long() + + def forward(self, token_ids: Tensor) -> Tensor: + h = self.embed(self.bigram_hash(token_ids)) + if self.proj is not None: + h = self.proj(h) + return h * self.scale.to(dtype=h.dtype) + + +class ValueEmbedding(nn.Module): + """Reinject token identity into attention values at specific layers. + Each table maps vocab tokens to a low-dim embedding, projected to model_dim.""" + def __init__(self, vocab_size: int, ve_dim: int, model_dim: int): + super().__init__() + self.embed = nn.Embedding(vocab_size, ve_dim) + nn.init.normal_(self.embed.weight, std=0.01) + self.proj = CastedLinear(ve_dim, model_dim, bias=False) if ve_dim != model_dim else None + if self.proj is not None: + nn.init.zeros_(self.proj.weight) + self.scale = nn.Parameter(torch.tensor(0.1, dtype=torch.float32)) + + def forward(self, token_ids: Tensor) -> Tensor: + h = self.embed(token_ids) + if self.proj is not None: + h = self.proj(h) + return h * self.scale.to(dtype=h.dtype) + + +class MixtureOfSoftmax(nn.Module): + """MoS output layer: K separate softmax distributions mixed by learned gate. + Returns log-probabilities, not logits. Use F.nll_loss, not F.cross_entropy.""" + + def __init__(self, model_dim: int, n_mixtures: int = 2): + super().__init__() + self.n_mixtures = n_mixtures + self.model_dim = model_dim + self.projections = CastedLinear(model_dim, n_mixtures * model_dim, bias=False) + self.gate = CastedLinear(model_dim, n_mixtures, bias=False) + + def forward(self, hidden: Tensor, weight_matrix: Tensor) -> Tensor: + # hidden: [B*T, D], weight_matrix: [V, D] + K = self.n_mixtures + D = self.model_dim + pi = F.softmax(self.gate(hidden), dim=-1) # [B*T, K] + projected = self.projections(hidden).view(-1, K, D) # [B*T, K, D] + logits = projected @ weight_matrix.T # [B*T, K, V] + log_probs = F.log_softmax(logits, dim=-1) # [B*T, K, V] + log_pi = torch.log(pi.unsqueeze(-1) + 1e-10) # [B*T, K, 1] + return torch.logsumexp(log_probs + log_pi, dim=1) # [B*T, V] + + +class MLP(nn.Module): + def __init__(self, dim: int, mlp_mult: int): + super().__init__() + hidden = int(mlp_mult * dim) + self.fc = CastedLinear(dim, hidden, bias=False) + self.proj = CastedLinear(hidden, dim, bias=False) + self.proj._zero_init = True + + def forward(self, x: Tensor) -> Tensor: + x = torch.relu(self.fc(x)) + return self.proj(x.square()) + + +class Block(nn.Module): + def __init__( + self, + dim: int, + num_heads: int, + num_kv_heads: int, + mlp_mult: int, + rope_base: float, + qk_gain_init: float, + layer_idx: int = 0, + ln_scale: bool = False, + dtg: bool = False, + ): + super().__init__() + self.attn_norm = RMSNorm() + self.mlp_norm = RMSNorm() + self.attn = CausalSelfAttention(dim, num_heads, num_kv_heads, rope_base, qk_gain_init) + self.mlp = MLP(dim, mlp_mult) + self.attn_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + self.mlp_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + self.resid_mix = nn.Parameter(torch.stack((torch.ones(dim), torch.zeros(dim))).float()) + self.ln_scale_factor = 1.0 / math.sqrt(layer_idx + 1) if ln_scale else 1.0 + if dtg: + self.dtg_gate = nn.Linear(dim, 1, bias=True) + nn.init.zeros_(self.dtg_gate.weight) + nn.init.constant_(self.dtg_gate.bias, 2.0) + else: + self.dtg_gate = None + + def forward(self, x: Tensor, x0: Tensor, v_embed: Tensor | None = None) -> Tensor: + mix = self.resid_mix.to(dtype=x.dtype) + x_in = mix[0][None, None, :] * x + mix[1][None, None, :] * x0 + attn_out = self.attn(self.attn_norm(x_in) * self.ln_scale_factor, v_embed=v_embed) + x_out = x_in + self.attn_scale.to(dtype=x_in.dtype)[None, None, :] * attn_out + x_out = x_out + self.mlp_scale.to(dtype=x_out.dtype)[None, None, :] * self.mlp(self.mlp_norm(x_out) * self.ln_scale_factor) + if self.dtg_gate is not None: + gate = torch.sigmoid(self.dtg_gate(x_in.detach())) + x_out = x_in + gate * (x_out - x_in) + return x_out + + +class GPT(nn.Module): + def __init__( + self, + vocab_size: int, + num_layers: int, + model_dim: int, + num_heads: int, + num_kv_heads: int, + mlp_mult: int, + tie_embeddings: bool, + tied_embed_init_std: float, + logit_softcap: float, + rope_base: float, + qk_gain_init: float, + mtp_num_heads: int = 0, + mtp_loss_weight: float = 0.1, + bigram_vocab_size: int = 0, + bigram_dim: int = 128, + xsa_last_n: int = 0, + rope_dims: int = 0, + ln_scale: bool = False, + dtg: bool = False, + ve_enabled: bool = False, + ve_dim: int = 128, + ve_layers: str = "9,10", + use_mos: bool = False, + mos_k: int = 2, + ): + super().__init__() + self._ve_target_dim = num_kv_heads * (model_dim // num_heads) # kv_dim for value projection + if logit_softcap <= 0.0: + raise ValueError(f"logit_softcap must be positive, got {logit_softcap}") + self.tie_embeddings = tie_embeddings + self.tied_embed_init_std = tied_embed_init_std + self.logit_softcap = logit_softcap + self.mtp_num_heads = mtp_num_heads + self.mtp_loss_weight = mtp_loss_weight + self.tok_emb = nn.Embedding(vocab_size, model_dim) + self.bigram = BigramHashEmbedding(bigram_vocab_size, bigram_dim, model_dim) if bigram_vocab_size > 0 else None + self.smear = SmearGate(model_dim) + self.num_encoder_layers = num_layers // 2 + self.num_decoder_layers = num_layers - self.num_encoder_layers + self.num_skip_weights = min(self.num_encoder_layers, self.num_decoder_layers) + self.skip_weights = nn.Parameter(torch.ones(self.num_skip_weights, model_dim, dtype=torch.float32)) + self.blocks = nn.ModuleList( + [ + Block( + model_dim, + num_heads, + num_kv_heads, + mlp_mult, + rope_base, + qk_gain_init, + layer_idx=i, + ln_scale=ln_scale, + dtg=dtg, + ) + for i in range(num_layers) + ] + ) + # Set partial RoPE on all attention blocks + if rope_dims > 0: + head_dim = model_dim // num_heads + for block in self.blocks: + block.attn.rope_dims = rope_dims + block.attn.rotary = Rotary(head_dim, base=rope_base, train_seq_len=1024, rope_dims=rope_dims) + # Value embeddings: 1 SHARED table, per-layer learned scales + self.ve_layer_indices = [int(x) for x in ve_layers.split(",") if x.strip()] if ve_enabled else [] + kv_dim = self._ve_target_dim + if self.ve_layer_indices: + self.ve_shared = ValueEmbedding(vocab_size, ve_dim, kv_dim) + self.ve_layer_scales = nn.ParameterList( + [nn.Parameter(torch.ones(1, dtype=torch.float32)) for _ in self.ve_layer_indices] + ) + else: + self.ve_shared = None + self.ve_layer_scales = nn.ParameterList() + self.value_embeds = nn.ModuleList() # keep empty for compat + self.final_norm = RMSNorm() + self.lm_head = None if tie_embeddings else CastedLinear(model_dim, vocab_size, bias=False) + if self.lm_head is not None: + self.lm_head._zero_init = True + self.mtp_heads = nn.ModuleList( + [CastedLinear(model_dim, vocab_size, bias=False) for _ in range(mtp_num_heads)] + ) + for head in self.mtp_heads: + head._zero_init = True + # Enable efficient XSA on the deepest layers (highest self-attention bias) + if xsa_last_n > 0: + for i in range(max(0, num_layers - xsa_last_n), num_layers): + self.blocks[i].attn.use_xsa = True + # Mixture of Softmax output layer + self.use_mos = use_mos + self.returns_log_probs = use_mos + self.mos = MixtureOfSoftmax(model_dim, n_mixtures=mos_k) if use_mos else None + self._init_weights() + + def _init_weights(self) -> None: + if self.tie_embeddings: + nn.init.normal_(self.tok_emb.weight, mean=0.0, std=self.tied_embed_init_std) + num_layers = len(self.blocks) + for name, module in self.named_modules(): + if isinstance(module, nn.Linear): + if getattr(module, "_zero_init", False): + nn.init.zeros_(module.weight) + elif module.weight.ndim == 2 and module.weight.shape[0] >= 64 and module.weight.shape[1] >= 64: + nn.init.orthogonal_(module.weight, gain=1.0) + if ".proj." in name or name.endswith(".proj"): + with torch.no_grad(): + module.weight.mul_(1.0 / math.sqrt(2 * num_layers)) + + def _get_ve(self, layer_idx: int, input_ids: Tensor, ve_cache: dict | None = None) -> Tensor | None: + """Get value embedding for a specific layer using shared table + per-layer scale.""" + if self.ve_shared is None or layer_idx not in self.ve_layer_indices: + return None + # Cache the shared VE computation (same for all layers, different scale) + if ve_cache is not None and 've' not in ve_cache: + ve_cache['ve'] = self.ve_shared(input_ids) + ve_base = ve_cache['ve'] if ve_cache is not None else self.ve_shared(input_ids) + ve_idx = self.ve_layer_indices.index(layer_idx) + return ve_base * self.ve_layer_scales[ve_idx].to(dtype=ve_base.dtype) + + def forward(self, input_ids: Tensor, target_ids: Tensor) -> Tensor: + x = self.tok_emb(input_ids) + if self.bigram is not None: + x = x + self.bigram(input_ids) + x = F.rms_norm(x, (x.size(-1),)) + x = self.smear(x) + x0 = x + skips: list[Tensor] = [] + ve_cache: dict = {} + + for i in range(self.num_encoder_layers): + ve = self._get_ve(i, input_ids, ve_cache) + x = self.blocks[i](x, x0, v_embed=ve) + skips.append(x) + for i in range(self.num_decoder_layers): + bi = self.num_encoder_layers + i + if skips: + x = x + self.skip_weights[i].to(dtype=x.dtype)[None, None, :] * skips.pop() + ve = self._get_ve(bi, input_ids, ve_cache) + x = self.blocks[bi](x, x0, v_embed=ve) + + x = self.final_norm(x) + x_flat = x.reshape(-1, x.size(-1)) + targets = target_ids.reshape(-1) + if self.mos is not None: + log_probs = self.mos(x_flat, self.tok_emb.weight) + main_loss = F.nll_loss(log_probs.float(), targets, reduction="mean") + else: + if self.tie_embeddings: + logits_proj = F.linear(x_flat, self.tok_emb.weight) + else: + if self.lm_head is None: + raise RuntimeError("lm_head is required when tie_embeddings=False") + logits_proj = self.lm_head(x_flat) + logits = self.logit_softcap * torch.tanh(logits_proj / self.logit_softcap) + main_loss = F.cross_entropy(logits.float(), targets, reduction="mean") + + if self.training and self.mtp_num_heads > 0 and self.mtp_loss_weight > 0.0: + _, seqlen, dim = x.shape + mtp_loss_sum = x.new_zeros(()) + mtp_loss_count = 0 + for k, mtp_head in enumerate(self.mtp_heads): + valid_t = seqlen - (k + 1) + if valid_t <= 0: + continue + mtp_hidden = x[:, :valid_t, :].reshape(-1, dim) + mtp_targets = target_ids[:, k + 1 :].reshape(-1) + mtp_logits_proj = mtp_head(mtp_hidden) + mtp_logits = self.logit_softcap * torch.tanh(mtp_logits_proj / self.logit_softcap) + mtp_loss_sum = mtp_loss_sum + F.cross_entropy(mtp_logits.float(), mtp_targets, reduction="mean") + mtp_loss_count += 1 + if mtp_loss_count > 0: + main_loss = main_loss + self.mtp_loss_weight * (mtp_loss_sum / mtp_loss_count) + + return main_loss + + def forward_logits(self, input_ids: Tensor) -> Tensor: + """Return logits or log-probs (bsz, seq_len, vocab) without computing loss.""" + x = self.tok_emb(input_ids) + if self.bigram is not None: + x = x + self.bigram(input_ids) + x = F.rms_norm(x, (x.size(-1),)) + x = self.smear(x) + x0 = x + skips: list[Tensor] = [] + ve_cache: dict = {} + for i in range(self.num_encoder_layers): + ve = self._get_ve(i, input_ids, ve_cache) + x = self.blocks[i](x, x0, v_embed=ve) + skips.append(x) + for i in range(self.num_decoder_layers): + bi = self.num_encoder_layers + i + if skips: + x = x + self.skip_weights[i].to(dtype=x.dtype)[None, None, :] * skips.pop() + ve = self._get_ve(bi, input_ids, ve_cache) + x = self.blocks[bi](x, x0, v_embed=ve) + x = self.final_norm(x) + if self.mos is not None: + bsz, seqlen, dim = x.shape + x_flat = x.reshape(-1, dim) + log_probs = self.mos(x_flat, self.tok_emb.weight) + return log_probs.reshape(bsz, seqlen, -1) + if self.tie_embeddings: + logits_proj = F.linear(x, self.tok_emb.weight) + else: + logits_proj = self.lm_head(x) + return self.logit_softcap * torch.tanh(logits_proj / self.logit_softcap) + + +# ----------------------------- +# SLIDING WINDOW EVALUATION +# ----------------------------- + +def eval_val_sliding( + args: Hyperparameters, + base_model: nn.Module, + rank: int, + world_size: int, + device: torch.device, + val_tokens: Tensor, + base_bytes_lut: Tensor, + has_leading_space_lut: Tensor, + is_boundary_token_lut: Tensor, + stride: int, + batch_seqs: int = 32, + eval_seq_len: int | None = None, +) -> tuple[float, float]: + """Sliding window evaluation: each token scored with maximum context.""" + seq_len = eval_seq_len or args.train_seq_len + total_tokens = val_tokens.numel() - 1 + + window_starts = [ws for ws in range(0, total_tokens, stride) + if min(ws + seq_len, total_tokens) - ws >= 1] + total_windows = len(window_starts) + + my_s = (total_windows * rank) // world_size + my_e = (total_windows * (rank + 1)) // world_size + my_windows = window_starts[my_s:my_e] + + loss_sum = torch.zeros((), device=device, dtype=torch.float64) + token_count = torch.zeros((), device=device, dtype=torch.float64) + byte_count = torch.zeros((), device=device, dtype=torch.float64) + + base_model.eval() + compiled_logits = torch.compile(base_model.forward_logits, dynamic=False, fullgraph=True) + _use_nll = getattr(base_model, 'returns_log_probs', False) + + with torch.inference_mode(): + for bi in range(0, len(my_windows), batch_seqs): + batch_ws = my_windows[bi:bi + batch_seqs] + bsz = len(batch_ws) + + x_batch = torch.zeros(bsz, seq_len, dtype=torch.int64, device=device) + y_batch = torch.zeros(bsz, seq_len, dtype=torch.int64, device=device) + wlens: list[int] = [] + + for i, ws in enumerate(batch_ws): + end = min(ws + seq_len, total_tokens) + wlen = end - ws + wlens.append(wlen) + chunk = val_tokens[ws:end + 1].to(dtype=torch.int64, device=device) + x_batch[i, :wlen] = chunk[:-1] + y_batch[i, :wlen] = chunk[1:] + + with torch.autocast(device_type="cuda", dtype=torch.bfloat16): + logits = compiled_logits(x_batch) + + output_flat = logits.reshape(-1, logits.size(-1)).float() + targets_flat = y_batch.reshape(-1) + if _use_nll: + nll = F.nll_loss(output_flat, targets_flat, reduction="none").reshape(bsz, seq_len) + else: + nll = F.cross_entropy(output_flat, targets_flat, reduction="none").reshape(bsz, seq_len) + + for i, ws in enumerate(batch_ws): + wlen = wlens[i] + s = 0 if ws == 0 else max(wlen - stride, 0) + scored_nll = nll[i, s:wlen].to(torch.float64) + loss_sum += scored_nll.sum() + token_count += float(wlen - s) + tgt = y_batch[i, s:wlen] + prev = x_batch[i, s:wlen] + tb = base_bytes_lut[tgt].to(torch.float64) + tb += (has_leading_space_lut[tgt] & ~is_boundary_token_lut[prev]).to(torch.float64) + byte_count += tb.sum() + + if dist.is_available() and dist.is_initialized(): + dist.all_reduce(loss_sum, op=dist.ReduceOp.SUM) + dist.all_reduce(token_count, op=dist.ReduceOp.SUM) + dist.all_reduce(byte_count, op=dist.ReduceOp.SUM) + + val_loss = (loss_sum / token_count).item() + bits_per_token = val_loss / math.log(2.0) + tokens_per_byte = token_count.item() / byte_count.item() + base_model.train() + return val_loss, bits_per_token * tokens_per_byte + + +# ----------------------------- +# INT6 MIXED QUANTIZATION (transplanted from working diagnostic scripts) +# ----------------------------- + +def _classify_param(name: str) -> str: + if "tok_emb" in name or "lm_head" in name: + return "embed" + if ".mlp." in name: + return "mlp" + if "mos." in name: + return "mos" + if ".attn." in name or (".proj." in name and ".mlp." not in name): + return "attn" + return "other" + +def quantize_int6_per_row(t: Tensor) -> tuple[Tensor, Tensor]: + t32 = t.float() + if t32.ndim == 2: + row_max = t32.abs().amax(dim=1) + scale = (row_max / 31.0).clamp_min(1.0 / 31.0).to(torch.float16) + q = torch.clamp(torch.round(t32 / scale.float()[:, None]), -32, 31).to(torch.int8) + return q, scale + amax = t32.abs().max().item() + scale = torch.tensor(amax / 31.0 if amax > 0 else 1.0, dtype=torch.float16) + q = torch.clamp(torch.round(t32 / scale.float()), -32, 31).to(torch.int8) + return q, scale + +def mixed_quantize_int6(state_dict: dict[str, Tensor], int6_cats: set[str]): + num_layers_total = max( + (int(k.split(".")[1]) for k in state_dict if k.startswith("blocks.")), + default=0, + ) + 1 + late_k_layers = set(range(num_layers_total - 2, num_layers_total)) + + result: dict[str, Tensor] = {} + meta: dict[str, object] = {} + for name, tensor in state_dict.items(): + t = tensor.detach().cpu().contiguous() + cat = _classify_param(name) + if not t.is_floating_point() or t.numel() <= 65536: + result[name] = t.to(torch.float16) if t.is_floating_point() else t + meta[name] = "passthrough" + continue + if any(p in name for p in CONTROL_TENSOR_NAME_PATTERNS): + result[name] = t.float() + meta[name] = "passthrough_ctrl" + continue + # tok_emb.weight falls through to int8 via "embed" category + if cat in int6_cats and t.ndim >= 1: + q, s = quantize_int6_per_row(t) + result[name + ".q"] = q + result[name + ".scale"] = s + meta[name] = {"type": "int6"} + else: + q, s = quantize_float_tensor(t) + result[name + ".q"] = q + result[name + ".scale"] = s + meta[name] = {"type": "int8"} + return result, meta + +def dequantize_mixed_int6(result: dict[str, Tensor], meta: dict[str, object], + template_sd: dict[str, Tensor]) -> dict[str, Tensor]: + out: dict[str, Tensor] = {} + for name, orig in template_sd.items(): + info = meta.get(name) + if info is None: + continue + orig_dtype = orig.dtype + if info in ("passthrough", "passthrough_ctrl", "passthrough_fp16"): + t = result[name] + if t.dtype == torch.float16 and orig_dtype in (torch.float32, torch.bfloat16): + t = t.to(orig_dtype) + out[name] = t + continue + q, s = result[name + ".q"], result[name + ".scale"] + if s.ndim > 0: + out[name] = (q.float() * s.float().view(q.shape[0], *([1] * (q.ndim - 1)))).to(orig_dtype) + else: + out[name] = (q.float() * float(s.item())).to(orig_dtype) + return out + + +# ----------------------------- +# TRAINING +# ----------------------------- + +def main() -> None: + global zeropower_via_newtonschulz5 + + code = Path(__file__).read_text(encoding="utf-8") + args = Hyperparameters() + zeropower_via_newtonschulz5 = torch.compile(zeropower_via_newtonschulz5) + + # ----------------------------- + # DISTRIBUTED + CUDA SETUP + # ----------------------------- + + distributed = "RANK" in os.environ and "WORLD_SIZE" in os.environ + rank = int(os.environ.get("RANK", "0")) + world_size = int(os.environ.get("WORLD_SIZE", "1")) + local_rank = int(os.environ.get("LOCAL_RANK", "0")) + if world_size <= 0: + raise ValueError(f"WORLD_SIZE must be positive, got {world_size}") + if 8 % world_size != 0: + raise ValueError(f"WORLD_SIZE={world_size} must divide 8 so grad_accum_steps stays integral") + grad_accum_steps = 8 // world_size + grad_scale = 1.0 / grad_accum_steps + if not torch.cuda.is_available(): + raise RuntimeError("CUDA is required") + device = torch.device("cuda", local_rank) + torch.cuda.set_device(device) + if distributed: + dist.init_process_group(backend="nccl", device_id=device) + dist.barrier() + master_process = rank == 0 + + # Fast math knobs + torch.backends.cuda.matmul.allow_tf32 = True + torch.backends.cudnn.allow_tf32 = True + from torch.backends.cuda import enable_cudnn_sdp, enable_flash_sdp, enable_math_sdp, enable_mem_efficient_sdp + + enable_cudnn_sdp(False) + enable_flash_sdp(True) + enable_mem_efficient_sdp(False) + enable_math_sdp(False) + + logfile = None + if master_process: + os.makedirs("logs", exist_ok=True) + logfile = f"logs/{args.run_id}.txt" + print(logfile) + + def log0(msg: str, console: bool = True) -> None: + if not master_process: + return + if console: + print(msg) + if logfile is not None: + with open(logfile, "a", encoding="utf-8") as f: + print(msg, file=f) + + log0(code, console=False) + log0("=" * 100, console=False) + log0(f"Running Python {sys.version}", console=False) + log0(f"Running PyTorch {torch.__version__}", console=False) + log0( + subprocess.run(["nvidia-smi"], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True, check=False).stdout, + console=False, + ) + log0("=" * 100, console=False) + + # ----------------------------- + # TOKENIZER + VALIDATION METRIC SETUP + # ----------------------------- + + random.seed(args.seed) + np.random.seed(args.seed) + torch.manual_seed(args.seed) + torch.cuda.manual_seed_all(args.seed) + + if not args.tokenizer_path.endswith(".model"): + raise ValueError(f"Script only setup for SentencePiece .model file: {args.tokenizer_path}") + sp = spm.SentencePieceProcessor(model_file=args.tokenizer_path) + if int(sp.vocab_size()) != args.vocab_size: + raise ValueError( + f"VOCAB_SIZE={args.vocab_size} does not match tokenizer vocab_size={int(sp.vocab_size())}" + ) + dataset_dir = Path(args.data_path).resolve() + actual_train_files = len(list(dataset_dir.glob("fineweb_train_*.bin"))) + effective_eval_seq_len = args.eval_seq_len if args.eval_seq_len > 0 else args.train_seq_len + val_seq_len = max(args.train_seq_len, effective_eval_seq_len) + val_tokens = load_validation_tokens(args.val_files, val_seq_len) + base_bytes_lut, has_leading_space_lut, is_boundary_token_lut = build_sentencepiece_luts( + sp, args.vocab_size, device + ) + log0(f"val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path={args.tokenizer_path}") + log0(f"train_loader:dataset:{dataset_dir.name} train_shards:{actual_train_files}") + log0(f"val_loader:shards pattern={args.val_files} tokens:{val_tokens.numel() - 1}") + + # ----------------------------- + # MODEL + OPTIMIZER SETUP + # ----------------------------- + + CastedLinear._qat_enabled = args.qat_enabled + + base_model = GPT( + vocab_size=args.vocab_size, + num_layers=args.num_layers, + model_dim=args.model_dim, + num_heads=args.num_heads, + num_kv_heads=args.num_kv_heads, + mlp_mult=args.mlp_mult, + tie_embeddings=args.tie_embeddings, + tied_embed_init_std=args.tied_embed_init_std, + logit_softcap=args.logit_softcap, + rope_base=args.rope_base, + qk_gain_init=args.qk_gain_init, + mtp_num_heads=args.mtp_num_heads, + mtp_loss_weight=args.mtp_loss_weight, + bigram_vocab_size=args.bigram_vocab_size, + bigram_dim=args.bigram_dim, + xsa_last_n=args.xsa_last_n, + rope_dims=args.rope_dims, + ln_scale=args.ln_scale, + dtg=args.dtg_enabled, + ve_enabled=args.ve_enabled, + ve_dim=args.ve_dim, + ve_layers=args.ve_layers, + use_mos=args.use_mos, + mos_k=args.mos_k, + ).to(device).bfloat16() + for module in base_model.modules(): + if isinstance(module, CastedLinear): + module.float() + restore_low_dim_params_to_fp32(base_model) + compiled_model = torch.compile(base_model, dynamic=False, fullgraph=True) + model: nn.Module = DDP(compiled_model, device_ids=[local_rank], broadcast_buffers=False) if distributed else compiled_model + + # Optimizer split: + # - token embedding (Adam) uses EMBED_LR + # - untied lm_head (Adam) uses HEAD_LR + # - matrix params in transformer blocks use MATRIX_LR via Muon + # - vectors/scalars use SCALAR_LR via Adam + block_named_params = list(base_model.blocks.named_parameters()) + matrix_params = [ + p + for name, p in block_named_params + if p.ndim == 2 and not any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) + ] + if base_model.mtp_num_heads > 0: + matrix_params.extend([p for p in base_model.mtp_heads.parameters() if p.ndim == 2]) + scalar_params = [ + p + for name, p in block_named_params + if p.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) + ] + if base_model.skip_weights.numel() > 0: + scalar_params.append(base_model.skip_weights) + scalar_params.append(base_model.smear.gate) + # gnp_scale removed in v22 + if base_model.bigram is not None: + scalar_params.append(base_model.bigram.scale) + token_lr = args.tied_embed_lr if args.tie_embeddings else args.embed_lr + tok_params = [{"params": [base_model.tok_emb.weight], "lr": token_lr, "base_lr": token_lr}] + if base_model.bigram is not None: + tok_params.append({"params": [base_model.bigram.embed.weight], "lr": token_lr, "base_lr": token_lr}) + if base_model.bigram.proj is not None: + matrix_params.append(base_model.bigram.proj.weight) + if base_model.ve_shared is not None: + tok_params.append({"params": [base_model.ve_shared.embed.weight], "lr": token_lr, "base_lr": token_lr}) + if base_model.ve_shared.proj is not None: + matrix_params.append(base_model.ve_shared.proj.weight) + scalar_params.append(base_model.ve_shared.scale) + for s in base_model.ve_layer_scales: + scalar_params.append(s) + # MoS optimizer assignment + if base_model.mos is not None: + matrix_params.append(base_model.mos.projections.weight) # 512x1024, 2D → Muon + scalar_params.append(base_model.mos.gate.weight) # 512x2, too narrow for NS5 → AdamW + optimizer_tok = torch.optim.AdamW( + tok_params, + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + weight_decay=args.adam_wd, + fused=True, + ) + optimizer_muon = Muon( + matrix_params, + lr=args.matrix_lr, + momentum=args.muon_momentum, + backend_steps=args.muon_backend_steps, + weight_decay=args.muon_wd, + ) + for group in optimizer_muon.param_groups: + group["base_lr"] = args.matrix_lr + optimizer_scalar = torch.optim.AdamW( + [{"params": scalar_params, "lr": args.scalar_lr, "base_lr": args.scalar_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + weight_decay=args.adam_wd, + fused=True, + ) + optimizers: list[torch.optim.Optimizer] = [optimizer_tok, optimizer_muon, optimizer_scalar] + if base_model.lm_head is not None: + optimizer_head = torch.optim.Adam( + [{"params": [base_model.lm_head.weight], "lr": args.head_lr, "base_lr": args.head_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + fused=True, + ) + optimizers.insert(1, optimizer_head) + + n_params = sum(p.numel() for p in base_model.parameters()) + mtp_params = sum(p.numel() for p in base_model.mtp_heads.parameters()) + mos_params = sum(p.numel() for p in base_model.mos.parameters()) if base_model.mos is not None else 0 + log0(f"model_params:{n_params}") + if mos_params > 0: + log0(f"mos_params:{mos_params} mos_k:{args.mos_k}") + log0(f"mtp_num_heads:{args.mtp_num_heads} mtp_loss_weight:{args.mtp_loss_weight} mtp_params:{mtp_params}") + xsa_layers = [i for i, b in enumerate(base_model.blocks) if b.attn.use_xsa] + log0(f"XSA:last_{args.xsa_last_n} active_layers:{xsa_layers}") + log0(f"world_size:{world_size} grad_accum_steps:{grad_accum_steps}") + log0("sdp_backends:cudnn=False flash=True mem_efficient=False math=False") + log0(f"attention_mode:gqa num_heads:{args.num_heads} num_kv_heads:{args.num_kv_heads}") + log0( + f"tie_embeddings:{args.tie_embeddings} embed_lr:{token_lr} " + f"head_lr:{args.head_lr if base_model.lm_head is not None else 0.0} " + f"matrix_lr:{args.matrix_lr} scalar_lr:{args.scalar_lr}" + ) + log0( + f"train_batch_tokens:{args.train_batch_tokens} train_seq_len:{args.train_seq_len} " + f"iterations:{args.iterations} warmup_steps:{args.warmup_steps} " + f"max_wallclock_seconds:{args.max_wallclock_seconds:.3f}" + ) + log0(f"seed:{args.seed}") + + # ----------------------------- + # DATA LOADER & MODEL WARMUP + # ----------------------------- + + train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) + + def zero_grad_all() -> None: + for opt in optimizers: + opt.zero_grad(set_to_none=True) + + max_wallclock_ms = 1000.0 * args.max_wallclock_seconds if args.max_wallclock_seconds > 0 else None + + def lr_mul(step: int, elapsed_ms: float) -> float: + if args.warmdown_iters <= 0: + return 1.0 + if max_wallclock_ms is None: + warmdown_start = max(args.iterations - args.warmdown_iters, 0) + return max((args.iterations - step) / max(args.warmdown_iters, 1), 0.0) if warmdown_start <= step < args.iterations else 1.0 + step_ms = elapsed_ms / max(step, 1) + warmdown_ms = args.warmdown_iters * step_ms + remaining_ms = max(max_wallclock_ms - elapsed_ms, 0.0) + return remaining_ms / max(warmdown_ms, 1e-9) if remaining_ms <= warmdown_ms else 1.0 + + # Warmup primes the compiled forward/backward/optimizer paths, then we restore the + # initial weights/optimizer state so measured training starts from the true init. + if args.warmup_steps > 0: + initial_model_state = {name: tensor.detach().cpu().clone() for name, tensor in base_model.state_dict().items()} + initial_optimizer_states = [copy.deepcopy(opt.state_dict()) for opt in optimizers] + model.train() + for warmup_step in range(args.warmup_steps): + zero_grad_all() + for micro_step in range(grad_accum_steps): + if distributed: + model.require_backward_grad_sync = micro_step == grad_accum_steps - 1 + x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + warmup_loss = model(x, y) + (warmup_loss * grad_scale).backward() + for opt in optimizers: + opt.step() + zero_grad_all() + if args.warmup_steps <= 20 or (warmup_step + 1) % 10 == 0 or warmup_step + 1 == args.warmup_steps: + log0(f"warmup_step:{warmup_step + 1}/{args.warmup_steps}") + base_model.load_state_dict(initial_model_state, strict=True) + for opt, state in zip(optimizers, initial_optimizer_states, strict=True): + opt.load_state_dict(state) + zero_grad_all() + if distributed: + model.require_backward_grad_sync = True + train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) + + # ----------------------------- + # MAIN TRAINING LOOP + # ----------------------------- + + swa_state: dict[str, Tensor] | None = None + swa_count = 0 + + training_time_ms = 0.0 + stop_after_step: int | None = None + torch.cuda.synchronize() + t0 = time.perf_counter() + + step = 0 + while True: + last_step = step == args.iterations or (stop_after_step is not None and step >= stop_after_step) + + should_validate = last_step or (args.val_loss_every > 0 and step % args.val_loss_every == 0) + if should_validate: + torch.cuda.synchronize() + training_time_ms += 1000.0 * (time.perf_counter() - t0) + val_loss, val_bpb = eval_val( + args, + model, + rank, + world_size, + device, + grad_accum_steps, + val_tokens, + base_bytes_lut, + has_leading_space_lut, + is_boundary_token_lut, + ) + log0( + f"step:{step}/{args.iterations} val_loss:{val_loss:.4f} val_bpb:{val_bpb:.4f} " + f"train_time:{training_time_ms:.0f}ms step_avg:{training_time_ms / max(step, 1):.2f}ms" + ) + torch.cuda.synchronize() + t0 = time.perf_counter() + + if last_step: + if stop_after_step is not None and step < args.iterations: + log0( + f"stopping_early: wallclock_cap train_time:{training_time_ms:.0f}ms " + f"step:{step}/{args.iterations}" + ) + break + + elapsed_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) + scale = lr_mul(step, elapsed_ms) + if args.late_qat_threshold > 0 and scale < args.late_qat_threshold and not CastedLinear._qat_enabled: + CastedLinear._qat_enabled = True + log0(f"late_qat:enabled step:{step} scale:{scale:.4f}") + zero_grad_all() + train_loss = torch.zeros((), device=device) + for micro_step in range(grad_accum_steps): + if distributed: + model.require_backward_grad_sync = micro_step == grad_accum_steps - 1 + x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + loss = model(x, y) + train_loss += loss.detach() + (loss * grad_scale).backward() + train_loss /= grad_accum_steps + + frac = min(step / args.muon_momentum_warmup_steps, 1.0) if args.muon_momentum_warmup_steps > 0 else 1.0 + muon_momentum = (1 - frac) * args.muon_momentum_warmup_start + frac * args.muon_momentum + for group in optimizer_muon.param_groups: + group["momentum"] = muon_momentum + + for opt in optimizers: + for group in opt.param_groups: + group["lr"] = group["base_lr"] * scale + + if args.grad_clip_norm > 0: + torch.nn.utils.clip_grad_norm_(base_model.parameters(), args.grad_clip_norm) + for opt in optimizers: + opt.step() + zero_grad_all() + + step += 1 + approx_training_time_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) + + if args.swa_enabled and scale < 0.2 and step % args.swa_every == 0: + if swa_state is None: + swa_state = {name: t.detach().cpu().clone() for name, t in base_model.state_dict().items()} + swa_count = 1 + log0(f"swa:start step:{step}") + else: + for name, t in base_model.state_dict().items(): + swa_state[name] += t.detach().cpu() + swa_count += 1 + + should_log_train = ( + args.train_log_every > 0 + and (step <= 10 or step % args.train_log_every == 0 or stop_after_step is not None) + ) + if should_log_train: + log0( + f"step:{step}/{args.iterations} train_loss:{train_loss.item():.4f} " + f"train_time:{approx_training_time_ms:.0f}ms step_avg:{approx_training_time_ms / step:.2f}ms" + ) + + # Needed to sync whether we've reached the wallclock cap. + reached_cap = max_wallclock_ms is not None and approx_training_time_ms >= max_wallclock_ms + if distributed and max_wallclock_ms is not None: + reached_cap_tensor = torch.tensor(int(reached_cap), device=device) + dist.all_reduce(reached_cap_tensor, op=dist.ReduceOp.MAX) + reached_cap = bool(reached_cap_tensor.item()) + if stop_after_step is None and reached_cap: + stop_after_step = step + + log0( + f"peak memory allocated: {torch.cuda.max_memory_allocated() // 1024 // 1024} MiB " + f"reserved: {torch.cuda.max_memory_reserved() // 1024 // 1024} MiB" + ) + + if args.swa_enabled and swa_state is not None and swa_count > 1: + log0(f"swa:applying averaged {swa_count} checkpoints") + avg_state = {name: (t / swa_count).to(dtype=base_model.state_dict()[name].dtype) + for name, t in swa_state.items()} + base_model.load_state_dict(avg_state, strict=True) + + # Diagnostic eval: measure quality after SWA, before quantization + torch.cuda.synchronize() + t_diag = time.perf_counter() + diag_val_loss, diag_val_bpb = eval_val( + args, compiled_model, rank, world_size, device, grad_accum_steps, + val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, + ) + torch.cuda.synchronize() + log0( + f"DIAGNOSTIC post_swa val_loss:{diag_val_loss:.4f} val_bpb:{diag_val_bpb:.4f} " + f"eval_time:{1000.0 * (time.perf_counter() - t_diag):.0f}ms" + ) + + # ----------------------------- + # SERIALIZATION + ROUNDTRIP VALIDATION + # ----------------------------- + + full_state_dict = base_model.state_dict() + export_sd = {k: v for k, v in full_state_dict.items() if "mtp_heads" not in k} + excluded_mtp = sum(int(t.numel()) for k, t in full_state_dict.items() if "mtp_heads" in k) + if excluded_mtp > 0: + log0(f"export_excluding_mtp_params:{excluded_mtp}") + + if master_process: + torch.save(export_sd, "final_model.pt") + model_bytes = os.path.getsize("final_model.pt") + code_bytes = len(code.encode("utf-8")) + log0(f"Serialized model: {model_bytes} bytes") + log0(f"Code size: {code_bytes} bytes") + + sd_cpu = {k: v.detach().cpu() for k, v in export_sd.items()} + quant_result, quant_meta = mixed_quantize_int6(sd_cpu, {"mlp", "attn", "mos"}) + quant_buf = io.BytesIO() + torch.save({"w": quant_result, "m": quant_meta}, quant_buf) + quant_raw = quant_buf.getvalue() + quant_blob = zstandard.ZstdCompressor(level=22).compress(quant_raw) if _COMPRESSOR == "zstd" else zlib.compress(quant_raw, 9) + if master_process: + with open("final_model.int6.ptz", "wb") as f: + f.write(quant_blob) + quant_file_bytes = len(quant_blob) + code_bytes = len(code.encode("utf-8")) + log0(f"Serialized model int6+{_COMPRESSOR}: {quant_file_bytes} bytes") + log0(f"Total submission size int6+{_COMPRESSOR}: {quant_file_bytes + code_bytes} bytes") + + # Roundtrip: decompress + dequantize into fresh model + eval + if distributed: + dist.barrier() + with open("final_model.int6.ptz", "rb") as f: + quant_blob_disk = f.read() + quant_state = torch.load( + io.BytesIO(zstandard.ZstdDecompressor().decompress(quant_blob_disk) if _COMPRESSOR == "zstd" else zlib.decompress(quant_blob_disk)), + map_location="cpu", + ) + deq_state = dequantize_mixed_int6(quant_state["w"], quant_state["m"], sd_cpu) + + eval_model = GPT( + vocab_size=args.vocab_size, num_layers=args.num_layers, model_dim=args.model_dim, + num_heads=args.num_heads, num_kv_heads=args.num_kv_heads, mlp_mult=args.mlp_mult, + tie_embeddings=args.tie_embeddings, tied_embed_init_std=args.tied_embed_init_std, + logit_softcap=args.logit_softcap, rope_base=args.rope_base, qk_gain_init=args.qk_gain_init, + mtp_num_heads=0, mtp_loss_weight=0.0, + bigram_vocab_size=args.bigram_vocab_size, bigram_dim=args.bigram_dim, + xsa_last_n=args.xsa_last_n, # must match training model + rope_dims=args.rope_dims, ln_scale=args.ln_scale, dtg=args.dtg_enabled, + ve_enabled=args.ve_enabled, ve_dim=args.ve_dim, ve_layers=args.ve_layers, + use_mos=args.use_mos, mos_k=args.mos_k, + ).to(device).bfloat16() + for m in eval_model.modules(): + if isinstance(m, CastedLinear): + m.float() + restore_low_dim_params_to_fp32(eval_model) + eval_model.load_state_dict(deq_state, strict=True) + compiled_eval = torch.compile(eval_model, dynamic=False, fullgraph=True) + + # Standard non-overlapping eval (sanity check) + torch.cuda.synchronize() + t_qeval = time.perf_counter() + q_val_loss, q_val_bpb = eval_val( + args, compiled_eval, rank, world_size, device, grad_accum_steps, + val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, + eval_seq_len=effective_eval_seq_len, + ) + torch.cuda.synchronize() + log0( + f"final_int6_roundtrip val_loss:{q_val_loss:.4f} val_bpb:{q_val_bpb:.4f} " + f"eval_time:{1000.0 * (time.perf_counter() - t_qeval):.0f}ms" + ) + log0(f"final_int6_roundtrip_exact val_loss:{q_val_loss:.8f} val_bpb:{q_val_bpb:.8f}") + + # Sliding window eval (submission score) + sw_seq_len = effective_eval_seq_len + if args.eval_stride > 0 and args.eval_stride < sw_seq_len: + torch.cuda.synchronize() + t_slide = time.perf_counter() + sw_val_loss, sw_val_bpb = eval_val_sliding( + args, eval_model, rank, world_size, device, + val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, + stride=args.eval_stride, + eval_seq_len=sw_seq_len, + ) + torch.cuda.synchronize() + log0( + f"final_int6_sliding_window val_loss:{sw_val_loss:.4f} val_bpb:{sw_val_bpb:.4f} " + f"stride:{args.eval_stride} eval_time:{1000.0 * (time.perf_counter() - t_slide):.0f}ms" + ) + log0(f"final_int6_sliding_window_exact val_loss:{sw_val_loss:.8f} val_bpb:{sw_val_bpb:.8f}") + + # Second sliding window eval at stride=64 for submission comparison + if args.eval_stride != 64 and 64 < sw_seq_len: + torch.cuda.synchronize() + t_slide64 = time.perf_counter() + sw64_val_loss, sw64_val_bpb = eval_val_sliding( + args, eval_model, rank, world_size, device, + val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, + stride=64, + eval_seq_len=sw_seq_len, + ) + torch.cuda.synchronize() + log0( + f"final_int6_sliding_window_s64 val_loss:{sw64_val_loss:.4f} val_bpb:{sw64_val_bpb:.4f} " + f"stride:64 eval_time:{1000.0 * (time.perf_counter() - t_slide64):.0f}ms" + ) + log0(f"final_int6_sliding_window_s64_exact val_loss:{sw64_val_loss:.8f} val_bpb:{sw64_val_bpb:.8f}") + + if distributed: + dist.destroy_process_group() + + +if __name__ == "__main__": + main() From 5a0fa0be65f5e5980ffdca43ef2d3368542f0fbd Mon Sep 17 00:00:00 2001 From: Billy Endson Date: Sun, 22 Mar 2026 11:55:19 +0700 Subject: [PATCH 15/32] Run training with nohup to survive terminal disconnects Co-Authored-By: Claude Opus 4.6 --- run_mos_sota.sh | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/run_mos_sota.sh b/run_mos_sota.sh index 9a20cefc80..e5e5cba799 100755 --- a/run_mos_sota.sh +++ b/run_mos_sota.sh @@ -150,9 +150,15 @@ export MOS_K export SEED log "Starting training..." -torchrun --standalone --nproc_per_node="${GPU_COUNT}" "${TRAIN_SCRIPT}" 2>&1 | tee "${LOG_PATH}" +log "Log file: ${LOG_PATH}" +nohup torchrun --standalone --nproc_per_node="${GPU_COUNT}" "${TRAIN_SCRIPT}" > "${LOG_PATH}" 2>&1 & +TRAIN_PID=$! +log "Training launched in background (PID: ${TRAIN_PID}). Safe to disconnect." +log "Monitor with: tail -f ${LOG_PATH}" +wait ${TRAIN_PID} +TRAIN_EXIT=$? -log "Run completed. Key metrics:" +log "Training finished (exit code: ${TRAIN_EXIT}). Key metrics:" grep -E 'val_bpb|model_params|mos_params|final_int|submission|Serialized|artifact|swa:' "${LOG_PATH}" | tail -20 || true log "Done. Log: ${LOG_PATH}" From a5a63919154aa34e8172a681a318df8685183d5c Mon Sep 17 00:00:00 2001 From: Billy Endson Date: Sun, 22 Mar 2026 11:59:03 +0700 Subject: [PATCH 16/32] Fix FA3 build: use --no-build-isolation so setup.py can find torch Co-Authored-By: Claude Opus 4.6 --- run_mos_sota.sh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/run_mos_sota.sh b/run_mos_sota.sh index e5e5cba799..1de12782cc 100755 --- a/run_mos_sota.sh +++ b/run_mos_sota.sh @@ -88,7 +88,7 @@ if ! python3 -c "from flash_attn_interface import flash_attn_func" 2>/dev/null; export FLASH_ATTENTION_DISABLE_CLUSTER=TRUE export FLASH_ATTENTION_DISABLE_HDIMDIFF64=TRUE export FLASH_ATTENTION_DISABLE_HDIMDIFF192=TRUE - pip install -e . + pip install --no-build-isolation -e . cd "${SCRIPT_DIR}" log "FA3 build complete." else From 73ba4f71734fdf0cc6899a8a10056a8a0ef17c55 Mon Sep 17 00:00:00 2001 From: Billy Endson Date: Sun, 22 Mar 2026 12:09:42 +0700 Subject: [PATCH 17/32] Add keep-alive heartbeat to prevent RunPod pod termination Pings nvidia-smi every 60s in background to keep pod active during FA3 build and other CPU-only phases. Co-Authored-By: Claude Opus 4.6 --- run_mos_sota.sh | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/run_mos_sota.sh b/run_mos_sota.sh index 1de12782cc..f2d3ab1f5b 100755 --- a/run_mos_sota.sh +++ b/run_mos_sota.sh @@ -16,6 +16,11 @@ set -euo pipefail log() { printf '[%s] %s\n' "$(date '+%Y-%m-%d %H:%M:%S')" "$*"; } +# Keep-alive heartbeat: prevents RunPod from killing pod during long builds +(while true; do sleep 60; nvidia-smi > /dev/null 2>&1; done) & +KEEPALIVE_PID=$! +trap "kill ${KEEPALIVE_PID} 2>/dev/null" EXIT + MODE="${MODE:-mos}" SEED="${SEED:-1337}" MAX_WALLCLOCK_SECONDS="${MAX_WALLCLOCK_SECONDS:-600}" From 34519d8ab538801739991e55d3e07e036225edb0 Mon Sep 17 00:00:00 2001 From: Billy Endson Date: Sun, 22 Mar 2026 12:31:01 +0700 Subject: [PATCH 18/32] Fix FA3 build: clear stale build dir, fix variable scoping Co-Authored-By: Claude Opus 4.6 --- run_mos_sota.sh | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/run_mos_sota.sh b/run_mos_sota.sh index f2d3ab1f5b..ab08c3e781 100755 --- a/run_mos_sota.sh +++ b/run_mos_sota.sh @@ -68,13 +68,12 @@ python3 -c "import huggingface_hub, zstandard" 2>/dev/null || \ # Build FA3 (selective, ~5 min) if not already installed if ! python3 -c "from flash_attn_interface import flash_attn_func" 2>/dev/null; then log "FA3 not found. Building selectively (~5 min)..." - if [ -d "/workspace/flash-attention" ]; then - FA3_DIR="/workspace/flash-attention" - else - git clone https://github.com/Dao-AILab/flash-attention.git /workspace/flash-attention - FA3_DIR="/workspace/flash-attention" + FA3_DIR="/workspace/flash-attention" + if [ ! -d "${FA3_DIR}" ]; then + git clone https://github.com/Dao-AILab/flash-attention.git "${FA3_DIR}" fi cd "${FA3_DIR}/hopper" + rm -rf build/ # clear any stale full-build artifacts # Only build bf16 hdim64 SM90 causal — skip everything else export FLASH_ATTENTION_DISABLE_FP16=TRUE export FLASH_ATTENTION_DISABLE_FP8=TRUE From b2e9c10288bb266a2ce0e4332f8fad0ffea975bd Mon Sep 17 00:00:00 2001 From: Billy Endson Date: Sun, 22 Mar 2026 12:35:37 +0700 Subject: [PATCH 19/32] Add sentencepiece and numpy to deps check train_gpt_mos_sota.py imports sentencepiece as spm at the top level; without it the script exits immediately on import. numpy is also used directly. Both are now checked and installed before training starts. Co-Authored-By: Claude Opus 4.6 --- run_mos_sota.sh | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/run_mos_sota.sh b/run_mos_sota.sh index ab08c3e781..e9b3a76dbd 100755 --- a/run_mos_sota.sh +++ b/run_mos_sota.sh @@ -62,8 +62,8 @@ LOG_PATH="${LOG_DIR}/${RUN_ID}.log" [ -f "${TRAIN_SCRIPT}" ] || { echo "ERROR: ${TRAIN_SCRIPT} not found"; exit 1; } # Ensure deps -python3 -c "import huggingface_hub, zstandard" 2>/dev/null || \ - pip install --quiet huggingface_hub zstandard +python3 -c "import huggingface_hub, zstandard, sentencepiece, numpy" 2>/dev/null || \ + pip install --quiet huggingface_hub zstandard sentencepiece numpy # Build FA3 (selective, ~5 min) if not already installed if ! python3 -c "from flash_attn_interface import flash_attn_func" 2>/dev/null; then From d7aa8c4dc5455bb50ef7a754a156f63f5daa4154 Mon Sep 17 00:00:00 2001 From: Billy Endson Date: Sun, 22 Mar 2026 12:43:37 +0700 Subject: [PATCH 20/32] Fix FA3 install: mkdir flash_attn_3 before pip editable install pip copies the compiled .so into flash_attn_3/ relative to the hopper dir, but that subdir doesn't exist after a fresh clone. All kernels compiled successfully; only the final copy step was failing. Co-Authored-By: Claude Opus 4.6 --- run_mos_sota.sh | 1 + 1 file changed, 1 insertion(+) diff --git a/run_mos_sota.sh b/run_mos_sota.sh index e9b3a76dbd..b5ebc8bdf4 100755 --- a/run_mos_sota.sh +++ b/run_mos_sota.sh @@ -74,6 +74,7 @@ if ! python3 -c "from flash_attn_interface import flash_attn_func" 2>/dev/null; fi cd "${FA3_DIR}/hopper" rm -rf build/ # clear any stale full-build artifacts + mkdir -p flash_attn_3 # pip copies .so here; dir must exist # Only build bf16 hdim64 SM90 causal — skip everything else export FLASH_ATTENTION_DISABLE_FP16=TRUE export FLASH_ATTENTION_DISABLE_FP8=TRUE From 6003c5550c4e2641c90f300593cbb84b56fba834 Mon Sep 17 00:00:00 2001 From: Billy Endson Date: Mon, 23 Mar 2026 19:40:12 +0700 Subject: [PATCH 21/32] Add hyperbolic.ai setup scripts for 8x H100 Co-Authored-By: Claude Opus 4.6 --- quickstart_hyperbolic.sh | 41 ++++++++++++++++ setup_hyperbolic.sh | 100 +++++++++++++++++++++++++++++++++++++++ 2 files changed, 141 insertions(+) create mode 100644 quickstart_hyperbolic.sh create mode 100644 setup_hyperbolic.sh diff --git a/quickstart_hyperbolic.sh b/quickstart_hyperbolic.sh new file mode 100644 index 0000000000..d48d2763fc --- /dev/null +++ b/quickstart_hyperbolic.sh @@ -0,0 +1,41 @@ +#!/usr/bin/env bash +# === Hyperbolic.ai Quick Start (Paste Into SSH) === +# Paste this entire block after SSHing into your 8x H100 instance + +set -euo pipefail +log() { printf '[%s] %s\n' "$(date '+%Y-%m-%d %H:%M:%S')" "$*"; } +(while true; do sleep 60; nvidia-smi > /dev/null 2>&1; done) & +trap "kill $! 2>/dev/null" EXIT + +GPU_COUNT=$(nvidia-smi --list-gpus 2>/dev/null | wc -l | tr -d ' ') +log "Detected ${GPU_COUNT} GPUs" + +# Clone repos +cd /workspace +[ ! -d "parameter-golf" ] && git clone https://github.com/openai/parameter-golf.git +[ ! -d "runpod-testing" ] && git clone https://github.com/User123331/runpod-testing.git + +# Build FA3 selectively +if ! python3 -c "from flash_attn_interface import flash_attn_func" 2>/dev/null; then + log "Building FA3 (~5 min)..." + [ ! -d "flash-attention" ] && git clone https://github.com/Dao-AILab/flash-attention.git + cd /workspace/flash-attention/hopper + rm -rf build/ && mkdir -p flash_attn_3 + export FLASH_ATTENTION_DISABLE_FP16=TRUE FLASH_ATTENTION_DISABLE_FP8=TRUE + export FLASH_ATTENTION_DISABLE_HDIM96=TRUE FLASH_ATTENTION_DISABLE_HDIM128=TRUE + export FLASH_ATTENTION_DISABLE_HDIM192=TRUE FLASH_ATTENTION_DISABLE_HDIM256=TRUE + export FLASH_ATTENTION_DISABLE_SM80=TRUE FLASH_ATTENTION_DISABLE_PAGEDKV=TRUE + export FLASH_ATTENTION_DISABLE_APPENDKV=TRUE FLASH_ATTENTION_DISABLE_SOFTCAP=TRUE + export FLASH_ATTENTION_DISABLE_PACKGQA=TRUE FLASH_ATTENTION_DISABLE_VARLEN=TRUE + export FLASH_ATTENTION_DISABLE_SPLIT=TRUE FLASH_ATTENTION_DISABLE_LOCAL=TRUE + export FLASH_ATTENTION_DISABLE_CLUSTER=TRUE FLASH_ATTENTION_DISABLE_HDIMDIFF64=TRUE + export FLASH_ATTENTION_DISABLE_HDIMDIFF192=TRUE + pip install --no-build-isolation -e . +fi + +# Download dataset +cd /workspace/parameter-golf +log "Downloading dataset (8B tokens)..." +python3 data/cached_challenge_fineweb.py --variant sp1024 --train-shards 80 + +log "Setup complete! Ready for MoS experiments." \ No newline at end of file diff --git a/setup_hyperbolic.sh b/setup_hyperbolic.sh new file mode 100644 index 0000000000..532e277ff1 --- /dev/null +++ b/setup_hyperbolic.sh @@ -0,0 +1,100 @@ +#!/usr/bin/env bash +# === Hyperbolic.ai 8x H100 Setup Script === +# Run this after SSHing into your instance +# +# Usage: +# wget https://raw.githubusercontent.com/User123331/runpod-testing/main/setup_hyperbolic.sh +# chmod +x setup_hyperbolic.sh +# ./setup_hyperbolic.sh +# +# Or paste directly from clipboard + +set -euo pipefail + +log() { printf '[%s] %s\n' "$(date '+%Y-%m-%d %H:%M:%S')" "$*"; } + +# Check GPU count +GPU_COUNT=$(nvidia-smi --list-gpus 2>/dev/null | wc -l | tr -d ' ') +log "Detected ${GPU_COUNT} GPU(s)" + +if [ "${GPU_COUNT}" -lt 8 ]; then + log "WARNING: Expected 8 GPUs, found ${GPU_COUNT}" +fi + +# Keep-alive to prevent timeout during long builds +(while true; do sleep 60; nvidia-smi > /dev/null 2>&1; done) & +KEEPALIVE_PID=$! +trap "kill ${KEEPALIVE_PID} 2>/dev/null" EXIT + +# 1. Clone the competition repo (already in image, but verify) +if [ ! -d "/workspace/parameter-golf" ]; then + log "Cloning parameter-golf repo..." + cd /workspace + git clone https://github.com/openai/parameter-golf.git + cd parameter-golf +else + log "parameter-golf repo already exists" + cd /workspace/parameter-golf +fi + +# 2. Clone our MoS-enhanced training scripts +log "Cloning runpod-testing repo with MoS implementation..." +if [ ! -d "/workspace/runpod-testing" ]; then + cd /workspace + git clone https://github.com/User123331/runpod-testing.git +else + cd /workspace/runpod-testing + git pull || true +fi + +# 3. Build Flash Attention 3 (selective, ~5 min) +log "Building Flash Attention 3 (selective kernels only)..." +if python3 -c "from flash_attn_interface import flash_attn_func" 2>/dev/null; then + log "FA3 already installed" +else + FA3_DIR="/workspace/flash-attention" + if [ ! -d "${FA3_DIR}" ]; then + git clone https://github.com/Dao-AILab/flash-attention.git "${FA3_DIR}" + fi + cd "${FA3_DIR}/hopper" + rm -rf build/ + mkdir -p flash_attn_3 + + # Only build bf16 hdim64 SM90 causal — skip everything else + export FLASH_ATTENTION_DISABLE_FP16=TRUE + export FLASH_ATTENTION_DISABLE_FP8=TRUE + export FLASH_ATTENTION_DISABLE_HDIM96=TRUE + export FLASH_ATTENTION_DISABLE_HDIM128=TRUE + export FLASH_ATTENTION_DISABLE_HDIM192=TRUE + export FLASH_ATTENTION_DISABLE_HDIM256=TRUE + export FLASH_ATTENTION_DISABLE_SM80=TRUE + export FLASH_ATTENTION_DISABLE_PAGEDKV=TRUE + export FLASH_ATTENTION_DISABLE_APPENDKV=TRUE + export FLASH_ATTENTION_DISABLE_SOFTCAP=TRUE + export FLASH_ATTENTION_DISABLE_PACKGQA=TRUE + export FLASH_ATTENTION_DISABLE_VARLEN=TRUE + export FLASH_ATTENTION_DISABLE_SPLIT=TRUE + export FLASH_ATTENTION_DISABLE_LOCAL=TRUE + export FLASH_ATTENTION_DISABLE_CLUSTER=TRUE + export FLASH_ATTENTION_DISABLE_HDIMDIFF64=TRUE + export FLASH_ATTENTION_DISABLE_HDIMDIFF192=TRUE + + log "Starting FA3 selective build (~5 min)..." + pip install --no-build-isolation -e . + log "FA3 build complete" +fi + +# 4. Download dataset (80 train shards = 8B tokens) +cd /workspace/parameter-golf +log "Downloading FineWeb dataset (8B tokens)..." +HF_TOKEN="${HF_TOKEN:-${HUGGINGFACE_TOKEN:-}}" python3 data/cached_challenge_fineweb.py --variant sp1024 --train-shards 80 + +# 5. Quick sanity check +log "" +log "=== Setup Complete ===" +log "GPU Count: ${GPU_COUNT}" +log "FA3 Status: $(python3 -c 'from flash_attn_interface import flash_attn_func; print("OK")' 2>/dev/null || echo 'FAILED')" +log "Dataset: $(ls -1 data/datasets/fineweb10B_sp1024/fineweb_train_*.bin 2>/dev/null | wc -l) train shards" +log "" +log "Ready to run experiments!" +log "See: /workspace/runpod-testing/run_mos_sota.sh" \ No newline at end of file From b2f4d3e46c2ffb5ab3f9f524a39ad95f0f067fdc Mon Sep 17 00:00:00 2001 From: Billy Endson Date: Mon, 23 Mar 2026 19:42:34 +0700 Subject: [PATCH 22/32] Update quickstart to use pre-compiled FA3 .so Co-Authored-By: Claude Opus 4.6 --- quickstart_hyperbolic.sh | 52 +++++++++++++++++++++++++++++----------- 1 file changed, 38 insertions(+), 14 deletions(-) diff --git a/quickstart_hyperbolic.sh b/quickstart_hyperbolic.sh index d48d2763fc..c4ae265476 100644 --- a/quickstart_hyperbolic.sh +++ b/quickstart_hyperbolic.sh @@ -1,6 +1,7 @@ #!/usr/bin/env bash # === Hyperbolic.ai Quick Start (Paste Into SSH) === # Paste this entire block after SSHing into your 8x H100 instance +# Uses PRE-COMPILED FA3 .so to skip the 5-min kernel compilation! set -euo pipefail log() { printf '[%s] %s\n' "$(date '+%Y-%m-%d %H:%M:%S')" "$*"; } @@ -15,22 +16,38 @@ cd /workspace [ ! -d "parameter-golf" ] && git clone https://github.com/openai/parameter-golf.git [ ! -d "runpod-testing" ] && git clone https://github.com/User123331/runpod-testing.git -# Build FA3 selectively +# Install FA3 using pre-compiled .so + cloned Python interface if ! python3 -c "from flash_attn_interface import flash_attn_func" 2>/dev/null; then - log "Building FA3 (~5 min)..." + log "Installing FA3 (pre-compiled .so + Python interface)..." + + # Clone flash-attention repo for Python interface files [ ! -d "flash-attention" ] && git clone https://github.com/Dao-AILab/flash-attention.git + + # Copy pre-compiled .so into place + cd /workspace/runpod-testing/"compiled FA3" + SITE_PACKAGES=$(python3 -c "import site; print(site.getsitepackages()[0])") + mkdir -p "${SITE_PACKAGES}/flash_attn_3" + cp _C.abi3.so "${SITE_PACKAGES}/flash_attn_3/" + cp flash_attn_config.py "${SITE_PACKAGES}/flash_attn_3/" + + # Copy Python interface from flash-attention/hopper/flash_attn_3 cd /workspace/flash-attention/hopper - rm -rf build/ && mkdir -p flash_attn_3 - export FLASH_ATTENTION_DISABLE_FP16=TRUE FLASH_ATTENTION_DISABLE_FP8=TRUE - export FLASH_ATTENTION_DISABLE_HDIM96=TRUE FLASH_ATTENTION_DISABLE_HDIM128=TRUE - export FLASH_ATTENTION_DISABLE_HDIM192=TRUE FLASH_ATTENTION_DISABLE_HDIM256=TRUE - export FLASH_ATTENTION_DISABLE_SM80=TRUE FLASH_ATTENTION_DISABLE_PAGEDKV=TRUE - export FLASH_ATTENTION_DISABLE_APPENDKV=TRUE FLASH_ATTENTION_DISABLE_SOFTCAP=TRUE - export FLASH_ATTENTION_DISABLE_PACKGQA=TRUE FLASH_ATTENTION_DISABLE_VARLEN=TRUE - export FLASH_ATTENTION_DISABLE_SPLIT=TRUE FLASH_ATTENTION_DISABLE_LOCAL=TRUE - export FLASH_ATTENTION_DISABLE_CLUSTER=TRUE FLASH_ATTENTION_DISABLE_HDIMDIFF64=TRUE - export FLASH_ATTENTION_DISABLE_HDIMDIFF192=TRUE - pip install --no-build-isolation -e . + cp -r flash_attn_3/*.py "${SITE_PACKAGES}/flash_attn_3/" 2>/dev/null || true + + # Install the interface package + pip install -e . --no-build-isolation 2>/dev/null || { + # If that fails, just copy the interface files manually + cp flash_attn_interface.py "${SITE_PACKAGES}/" 2>/dev/null || true + } + + # Symlink flash_attn_config.py to torch path (fixes torch.compile backward crash) + TORCH_PATH=$(python3 -c "import torch; import os; print(os.path.dirname(torch.__file__))") + ln -sf "${SITE_PACKAGES}/flash_attn_3/flash_attn_config.py" "${TORCH_PATH}/flash_attn_config.py" 2>/dev/null || true + + log "FA3 installed" + python3 -c "from flash_attn_interface import flash_attn_func; print('FA3: OK')" || { + log "WARNING: FA3 interface check failed, may need to build from source" + } fi # Download dataset @@ -38,4 +55,11 @@ cd /workspace/parameter-golf log "Downloading dataset (8B tokens)..." python3 data/cached_challenge_fineweb.py --variant sp1024 --train-shards 80 -log "Setup complete! Ready for MoS experiments." \ No newline at end of file +log "" +log "=== Setup Complete ===" +log "GPUs: ${GPU_COUNT}" +log "FA3: $(python3 -c 'from flash_attn_interface import flash_attn_func; print("OK")' 2>/dev/null || echo 'FAILED - will need to build')" +log "Dataset: $(ls -1 data/datasets/fineweb10B_sp1024/fineweb_train_*.bin 2>/dev/null | wc -l) train shards" +log "" +log "Ready! Run experiments with:" +log " cd /workspace/runpod-testing && MODE=mos bash run_mos_sota.sh" \ No newline at end of file From b34ba1e66e2ef6a47308394d54682c573b2ff858 Mon Sep 17 00:00:00 2001 From: Billy Endson Date: Mon, 23 Mar 2026 19:43:55 +0700 Subject: [PATCH 23/32] Fix data paths for hyperbolic setup Co-Authored-By: Claude Opus 4.6 --- quickstart_hyperbolic.sh | 22 ++++++++++++++++------ 1 file changed, 16 insertions(+), 6 deletions(-) diff --git a/quickstart_hyperbolic.sh b/quickstart_hyperbolic.sh index c4ae265476..8671d9b45f 100644 --- a/quickstart_hyperbolic.sh +++ b/quickstart_hyperbolic.sh @@ -36,7 +36,6 @@ if ! python3 -c "from flash_attn_interface import flash_attn_func" 2>/dev/null; # Install the interface package pip install -e . --no-build-isolation 2>/dev/null || { - # If that fails, just copy the interface files manually cp flash_attn_interface.py "${SITE_PACKAGES}/" 2>/dev/null || true } @@ -46,20 +45,31 @@ if ! python3 -c "from flash_attn_interface import flash_attn_func" 2>/dev/null; log "FA3 installed" python3 -c "from flash_attn_interface import flash_attn_func; print('FA3: OK')" || { - log "WARNING: FA3 interface check failed, may need to build from source" + log "WARNING: FA3 interface check failed, will need selective build" } fi -# Download dataset +# Download dataset to runpod-testing/data (where run_mos_sota.sh expects it) +cd /workspace/runpod-testing +mkdir -p data/datasets data/tokenizers + +log "Downloading FineWeb dataset (8B tokens)..." cd /workspace/parameter-golf -log "Downloading dataset (8B tokens)..." python3 data/cached_challenge_fineweb.py --variant sp1024 --train-shards 80 +# Symlink data from parameter-golf to runpod-testing +cd /workspace/runpod-testing +[ ! -L "data/datasets/fineweb10B_sp1024" ] && \ + ln -s /workspace/parameter-golf/data/datasets/fineweb10B_sp1024 data/datasets/ +[ ! -L "data/tokenizers/fineweb_1024_bpe.model" ] && \ + ln -s /workspace/parameter-golf/data/tokenizers/fineweb_1024_bpe.model data/tokenizers/ + log "" log "=== Setup Complete ===" log "GPUs: ${GPU_COUNT}" -log "FA3: $(python3 -c 'from flash_attn_interface import flash_attn_func; print("OK")' 2>/dev/null || echo 'FAILED - will need to build')" +log "FA3: $(python3 -c 'from flash_attn_interface import flash_attn_func; print("OK")' 2>/dev/null || echo 'FAILED')" log "Dataset: $(ls -1 data/datasets/fineweb10B_sp1024/fineweb_train_*.bin 2>/dev/null | wc -l) train shards" log "" log "Ready! Run experiments with:" -log " cd /workspace/runpod-testing && MODE=mos bash run_mos_sota.sh" \ No newline at end of file +log " cd /workspace/runpod-testing" +log " MODE=mos bash run_mos_sota.sh" \ No newline at end of file From 9ba2ec41f2a74f67ab9a1c52ab827ce17d5fb0ef Mon Sep 17 00:00:00 2001 From: Billy Endson Date: Mon, 23 Mar 2026 20:13:37 +0700 Subject: [PATCH 24/32] Fix: use ~/golf instead of /workspace for hyperbolic Co-Authored-By: Claude Opus 4.6 --- quickstart_hyperbolic.sh | 31 ++++++++++++++++--------------- 1 file changed, 16 insertions(+), 15 deletions(-) diff --git a/quickstart_hyperbolic.sh b/quickstart_hyperbolic.sh index 8671d9b45f..bfd461a551 100644 --- a/quickstart_hyperbolic.sh +++ b/quickstart_hyperbolic.sh @@ -1,7 +1,6 @@ #!/usr/bin/env bash -# === Hyperbolic.ai Quick Start (Paste Into SSH) === -# Paste this entire block after SSHing into your 8x H100 instance -# Uses PRE-COMPILED FA3 .so to skip the 5-min kernel compilation! +# === Hyperbolic.ai Quick Start === +# Run from home directory set -euo pipefail log() { printf '[%s] %s\n' "$(date '+%Y-%m-%d %H:%M:%S')" "$*"; } @@ -11,8 +10,11 @@ trap "kill $! 2>/dev/null" EXIT GPU_COUNT=$(nvidia-smi --list-gpus 2>/dev/null | wc -l | tr -d ' ') log "Detected ${GPU_COUNT} GPUs" +WORK_DIR="$HOME/golf" +mkdir -p "${WORK_DIR}" +cd "${WORK_DIR}" + # Clone repos -cd /workspace [ ! -d "parameter-golf" ] && git clone https://github.com/openai/parameter-golf.git [ ! -d "runpod-testing" ] && git clone https://github.com/User123331/runpod-testing.git @@ -24,14 +26,14 @@ if ! python3 -c "from flash_attn_interface import flash_attn_func" 2>/dev/null; [ ! -d "flash-attention" ] && git clone https://github.com/Dao-AILab/flash-attention.git # Copy pre-compiled .so into place - cd /workspace/runpod-testing/"compiled FA3" + cd "${WORK_DIR}/runpod-testing/compiled FA3" SITE_PACKAGES=$(python3 -c "import site; print(site.getsitepackages()[0])") mkdir -p "${SITE_PACKAGES}/flash_attn_3" cp _C.abi3.so "${SITE_PACKAGES}/flash_attn_3/" cp flash_attn_config.py "${SITE_PACKAGES}/flash_attn_3/" # Copy Python interface from flash-attention/hopper/flash_attn_3 - cd /workspace/flash-attention/hopper + cd "${WORK_DIR}/flash-attention/hopper" cp -r flash_attn_3/*.py "${SITE_PACKAGES}/flash_attn_3/" 2>/dev/null || true # Install the interface package @@ -49,27 +51,26 @@ if ! python3 -c "from flash_attn_interface import flash_attn_func" 2>/dev/null; } fi -# Download dataset to runpod-testing/data (where run_mos_sota.sh expects it) -cd /workspace/runpod-testing -mkdir -p data/datasets data/tokenizers - +# Download dataset +cd "${WORK_DIR}/parameter-golf" log "Downloading FineWeb dataset (8B tokens)..." -cd /workspace/parameter-golf python3 data/cached_challenge_fineweb.py --variant sp1024 --train-shards 80 # Symlink data from parameter-golf to runpod-testing -cd /workspace/runpod-testing +cd "${WORK_DIR}/runpod-testing" +mkdir -p data/datasets data/tokenizers [ ! -L "data/datasets/fineweb10B_sp1024" ] && \ - ln -s /workspace/parameter-golf/data/datasets/fineweb10B_sp1024 data/datasets/ + ln -s "${WORK_DIR}/parameter-golf/data/datasets/fineweb10B_sp1024" data/datasets/ [ ! -L "data/tokenizers/fineweb_1024_bpe.model" ] && \ - ln -s /workspace/parameter-golf/data/tokenizers/fineweb_1024_bpe.model data/tokenizers/ + ln -s "${WORK_DIR}/parameter-golf/data/tokenizers/fineweb_1024_bpe.model" data/tokenizers/ log "" log "=== Setup Complete ===" +log "Work dir: ${WORK_DIR}" log "GPUs: ${GPU_COUNT}" log "FA3: $(python3 -c 'from flash_attn_interface import flash_attn_func; print("OK")' 2>/dev/null || echo 'FAILED')" log "Dataset: $(ls -1 data/datasets/fineweb10B_sp1024/fineweb_train_*.bin 2>/dev/null | wc -l) train shards" log "" log "Ready! Run experiments with:" -log " cd /workspace/runpod-testing" +log " cd ${WORK_DIR}/runpod-testing" log " MODE=mos bash run_mos_sota.sh" \ No newline at end of file From a77f8a7183203a1f00f5585ceb8bff87e956048d Mon Sep 17 00:00:00 2001 From: Billy Endson Date: Mon, 23 Mar 2026 20:14:32 +0700 Subject: [PATCH 25/32] Fix: use $HOME instead of /workspace for FA3 build Co-Authored-By: Claude Opus 4.6 --- run_mos_sota.sh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/run_mos_sota.sh b/run_mos_sota.sh index b5ebc8bdf4..60c045cbdc 100755 --- a/run_mos_sota.sh +++ b/run_mos_sota.sh @@ -68,7 +68,7 @@ python3 -c "import huggingface_hub, zstandard, sentencepiece, numpy" 2>/dev/null # Build FA3 (selective, ~5 min) if not already installed if ! python3 -c "from flash_attn_interface import flash_attn_func" 2>/dev/null; then log "FA3 not found. Building selectively (~5 min)..." - FA3_DIR="/workspace/flash-attention" + FA3_DIR="${HOME}/flash-attention" if [ ! -d "${FA3_DIR}" ]; then git clone https://github.com/Dao-AILab/flash-attention.git "${FA3_DIR}" fi From 057a8444177499788d2269d513db2dc289678bb3 Mon Sep 17 00:00:00 2001 From: Billy Endson Date: Mon, 23 Mar 2026 20:17:32 +0700 Subject: [PATCH 26/32] Add --break-system-packages for externally-managed environments Co-Authored-By: Claude Opus 4.6 --- quickstart_hyperbolic.sh | 2 +- run_mos_sota.sh | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/quickstart_hyperbolic.sh b/quickstart_hyperbolic.sh index bfd461a551..712a67aba2 100644 --- a/quickstart_hyperbolic.sh +++ b/quickstart_hyperbolic.sh @@ -37,7 +37,7 @@ if ! python3 -c "from flash_attn_interface import flash_attn_func" 2>/dev/null; cp -r flash_attn_3/*.py "${SITE_PACKAGES}/flash_attn_3/" 2>/dev/null || true # Install the interface package - pip install -e . --no-build-isolation 2>/dev/null || { + pip install -e . --no-build-isolation --break-system-packages 2>/dev/null || { cp flash_attn_interface.py "${SITE_PACKAGES}/" 2>/dev/null || true } diff --git a/run_mos_sota.sh b/run_mos_sota.sh index 60c045cbdc..7b9aa54a8c 100755 --- a/run_mos_sota.sh +++ b/run_mos_sota.sh @@ -63,7 +63,7 @@ LOG_PATH="${LOG_DIR}/${RUN_ID}.log" # Ensure deps python3 -c "import huggingface_hub, zstandard, sentencepiece, numpy" 2>/dev/null || \ - pip install --quiet huggingface_hub zstandard sentencepiece numpy + pip install --quiet huggingface_hub zstandard sentencepiece numpy --break-system-packages # Build FA3 (selective, ~5 min) if not already installed if ! python3 -c "from flash_attn_interface import flash_attn_func" 2>/dev/null; then @@ -93,7 +93,7 @@ if ! python3 -c "from flash_attn_interface import flash_attn_func" 2>/dev/null; export FLASH_ATTENTION_DISABLE_CLUSTER=TRUE export FLASH_ATTENTION_DISABLE_HDIMDIFF64=TRUE export FLASH_ATTENTION_DISABLE_HDIMDIFF192=TRUE - pip install --no-build-isolation -e . + pip install --no-build-isolation --break-system-packages -e . cd "${SCRIPT_DIR}" log "FA3 build complete." else From c88f9a2ca61c98da013b580ac10a1ac737c97426 Mon Sep 17 00:00:00 2001 From: Billy Endson Date: Mon, 23 Mar 2026 20:19:30 +0700 Subject: [PATCH 27/32] Fix: use SCRIPT_DIR instead of hardcoded golf path Co-Authored-By: Claude Opus 4.6 --- quickstart_hyperbolic.sh | 64 +++++++++++++++++++--------------------- 1 file changed, 31 insertions(+), 33 deletions(-) diff --git a/quickstart_hyperbolic.sh b/quickstart_hyperbolic.sh index 712a67aba2..2ab3c9a036 100644 --- a/quickstart_hyperbolic.sh +++ b/quickstart_hyperbolic.sh @@ -1,6 +1,6 @@ #!/usr/bin/env bash # === Hyperbolic.ai Quick Start === -# Run from home directory +# Run from ~/runpod-testing directory set -euo pipefail log() { printf '[%s] %s\n' "$(date '+%Y-%m-%d %H:%M:%S')" "$*"; } @@ -10,67 +10,65 @@ trap "kill $! 2>/dev/null" EXIT GPU_COUNT=$(nvidia-smi --list-gpus 2>/dev/null | wc -l | tr -d ' ') log "Detected ${GPU_COUNT} GPUs" -WORK_DIR="$HOME/golf" -mkdir -p "${WORK_DIR}" -cd "${WORK_DIR}" +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" +cd "${SCRIPT_DIR}" -# Clone repos -[ ! -d "parameter-golf" ] && git clone https://github.com/openai/parameter-golf.git -[ ! -d "runpod-testing" ] && git clone https://github.com/User123331/runpod-testing.git +# Clone parameter-golf if needed +if [ ! -d "$HOME/parameter-golf" ]; then + log "Cloning parameter-golf..." + git clone https://github.com/openai/parameter-golf.git "$HOME/parameter-golf" +fi -# Install FA3 using pre-compiled .so + cloned Python interface +# Install FA3 using pre-compiled .so if ! python3 -c "from flash_attn_interface import flash_attn_func" 2>/dev/null; then - log "Installing FA3 (pre-compiled .so + Python interface)..." + log "Installing FA3 from pre-compiled .so..." - # Clone flash-attention repo for Python interface files - [ ! -d "flash-attention" ] && git clone https://github.com/Dao-AILab/flash-attention.git + # Clone flash-attention for Python interface + if [ ! -d "$HOME/flash-attention" ]; then + git clone https://github.com/Dao-AILab/flash-attention.git "$HOME/flash-attention" + fi - # Copy pre-compiled .so into place - cd "${WORK_DIR}/runpod-testing/compiled FA3" SITE_PACKAGES=$(python3 -c "import site; print(site.getsitepackages()[0])") + + # Copy pre-compiled .so mkdir -p "${SITE_PACKAGES}/flash_attn_3" - cp _C.abi3.so "${SITE_PACKAGES}/flash_attn_3/" - cp flash_attn_config.py "${SITE_PACKAGES}/flash_attn_3/" + cp "${SCRIPT_DIR}/compiled FA3/_C.abi3.so" "${SITE_PACKAGES}/flash_attn_3/" + cp "${SCRIPT_DIR}/compiled FA3/flash_attn_config.py" "${SITE_PACKAGES}/flash_attn_3/" - # Copy Python interface from flash-attention/hopper/flash_attn_3 - cd "${WORK_DIR}/flash-attention/hopper" - cp -r flash_attn_3/*.py "${SITE_PACKAGES}/flash_attn_3/" 2>/dev/null || true + # Copy Python interface + cp "$HOME/flash-attention/hopper/flash_attn_3/"*.py "${SITE_PACKAGES}/flash_attn_3/" 2>/dev/null || true - # Install the interface package - pip install -e . --no-build-isolation --break-system-packages 2>/dev/null || { - cp flash_attn_interface.py "${SITE_PACKAGES}/" 2>/dev/null || true - } + # Install interface + cd "$HOME/flash-attention/hopper" + pip install -e . --no-build-isolation --break-system-packages 2>/dev/null || true - # Symlink flash_attn_config.py to torch path (fixes torch.compile backward crash) + # Symlink config to torch (fixes torch.compile backward crash) TORCH_PATH=$(python3 -c "import torch; import os; print(os.path.dirname(torch.__file__))") ln -sf "${SITE_PACKAGES}/flash_attn_3/flash_attn_config.py" "${TORCH_PATH}/flash_attn_config.py" 2>/dev/null || true - log "FA3 installed" python3 -c "from flash_attn_interface import flash_attn_func; print('FA3: OK')" || { - log "WARNING: FA3 interface check failed, will need selective build" + log "WARNING: FA3 check failed" } fi # Download dataset -cd "${WORK_DIR}/parameter-golf" +cd "$HOME/parameter-golf" log "Downloading FineWeb dataset (8B tokens)..." python3 data/cached_challenge_fineweb.py --variant sp1024 --train-shards 80 -# Symlink data from parameter-golf to runpod-testing -cd "${WORK_DIR}/runpod-testing" +# Symlink data to runpod-testing +cd "${SCRIPT_DIR}" mkdir -p data/datasets data/tokenizers [ ! -L "data/datasets/fineweb10B_sp1024" ] && \ - ln -s "${WORK_DIR}/parameter-golf/data/datasets/fineweb10B_sp1024" data/datasets/ + ln -s "$HOME/parameter-golf/data/datasets/fineweb10B_sp1024" data/datasets/ [ ! -L "data/tokenizers/fineweb_1024_bpe.model" ] && \ - ln -s "${WORK_DIR}/parameter-golf/data/tokenizers/fineweb_1024_bpe.model" data/tokenizers/ + ln -s "$HOME/parameter-golf/data/tokenizers/fineweb_1024_bpe.model" data/tokenizers/ log "" log "=== Setup Complete ===" -log "Work dir: ${WORK_DIR}" log "GPUs: ${GPU_COUNT}" log "FA3: $(python3 -c 'from flash_attn_interface import flash_attn_func; print("OK")' 2>/dev/null || echo 'FAILED')" log "Dataset: $(ls -1 data/datasets/fineweb10B_sp1024/fineweb_train_*.bin 2>/dev/null | wc -l) train shards" log "" -log "Ready! Run experiments with:" -log " cd ${WORK_DIR}/runpod-testing" +log "Ready! Run:" log " MODE=mos bash run_mos_sota.sh" \ No newline at end of file From bd347d9b8369e7cdb620d0dabc7e32b4592e2f2d Mon Sep 17 00:00:00 2001 From: Billy Endson Date: Mon, 23 Mar 2026 20:22:19 +0700 Subject: [PATCH 28/32] Fix: use user site-packages instead of system Co-Authored-By: Claude Opus 4.6 --- quickstart_hyperbolic.sh | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/quickstart_hyperbolic.sh b/quickstart_hyperbolic.sh index 2ab3c9a036..7e11de7ca1 100644 --- a/quickstart_hyperbolic.sh +++ b/quickstart_hyperbolic.sh @@ -28,7 +28,9 @@ if ! python3 -c "from flash_attn_interface import flash_attn_func" 2>/dev/null; git clone https://github.com/Dao-AILab/flash-attention.git "$HOME/flash-attention" fi - SITE_PACKAGES=$(python3 -c "import site; print(site.getsitepackages()[0])") + # Use user site-packages (writable without sudo) + SITE_PACKAGES=$(python3 -c "import site; print(site.getusersitepackages())") + mkdir -p "${SITE_PACKAGES}" # Copy pre-compiled .so mkdir -p "${SITE_PACKAGES}/flash_attn_3" From 61a9b21d86ecf0ca55a09d9aa7507d3ff8b0a2f5 Mon Sep 17 00:00:00 2001 From: Billy Endson Date: Mon, 23 Mar 2026 20:23:32 +0700 Subject: [PATCH 29/32] Build FA3 from source (pre-compiled .so not in repo) Co-Authored-By: Claude Opus 4.6 --- quickstart_hyperbolic.sh | 45 +++++++++++++++++++++++----------------- 1 file changed, 26 insertions(+), 19 deletions(-) diff --git a/quickstart_hyperbolic.sh b/quickstart_hyperbolic.sh index 7e11de7ca1..941f7cf6f7 100644 --- a/quickstart_hyperbolic.sh +++ b/quickstart_hyperbolic.sh @@ -19,38 +19,45 @@ if [ ! -d "$HOME/parameter-golf" ]; then git clone https://github.com/openai/parameter-golf.git "$HOME/parameter-golf" fi -# Install FA3 using pre-compiled .so +# Build FA3 selectively (~5 min on H100) if ! python3 -c "from flash_attn_interface import flash_attn_func" 2>/dev/null; then - log "Installing FA3 from pre-compiled .so..." + log "Building Flash Attention 3 (selective, ~5 min)..." - # Clone flash-attention for Python interface if [ ! -d "$HOME/flash-attention" ]; then git clone https://github.com/Dao-AILab/flash-attention.git "$HOME/flash-attention" fi - # Use user site-packages (writable without sudo) - SITE_PACKAGES=$(python3 -c "import site; print(site.getusersitepackages())") - mkdir -p "${SITE_PACKAGES}" + cd "$HOME/flash-attention/hopper" + rm -rf build/ + mkdir -p flash_attn_3 - # Copy pre-compiled .so - mkdir -p "${SITE_PACKAGES}/flash_attn_3" - cp "${SCRIPT_DIR}/compiled FA3/_C.abi3.so" "${SITE_PACKAGES}/flash_attn_3/" - cp "${SCRIPT_DIR}/compiled FA3/flash_attn_config.py" "${SITE_PACKAGES}/flash_attn_3/" + # Only build bf16 hdim64 SM90 - skip everything else + export FLASH_ATTENTION_DISABLE_FP16=TRUE + export FLASH_ATTENTION_DISABLE_FP8=TRUE + export FLASH_ATTENTION_DISABLE_HDIM96=TRUE + export FLASH_ATTENTION_DISABLE_HDIM128=TRUE + export FLASH_ATTENTION_DISABLE_HDIM192=TRUE + export FLASH_ATTENTION_DISABLE_HDIM256=TRUE + export FLASH_ATTENTION_DISABLE_SM80=TRUE + export FLASH_ATTENTION_DISABLE_PAGEDKV=TRUE + export FLASH_ATTENTION_DISABLE_APPENDKV=TRUE + export FLASH_ATTENTION_DISABLE_SOFTCAP=TRUE + export FLASH_ATTENTION_DISABLE_PACKGQA=TRUE + export FLASH_ATTENTION_DISABLE_VARLEN=TRUE + export FLASH_ATTENTION_DISABLE_SPLIT=TRUE + export FLASH_ATTENTION_DISABLE_LOCAL=TRUE + export FLASH_ATTENTION_DISABLE_CLUSTER=TRUE + export FLASH_ATTENTION_DISABLE_HDIMDIFF64=TRUE + export FLASH_ATTENTION_DISABLE_HDIMDIFF192=TRUE - # Copy Python interface - cp "$HOME/flash-attention/hopper/flash_attn_3/"*.py "${SITE_PACKAGES}/flash_attn_3/" 2>/dev/null || true - - # Install interface - cd "$HOME/flash-attention/hopper" - pip install -e . --no-build-isolation --break-system-packages 2>/dev/null || true + pip install --no-build-isolation --break-system-packages -e . # Symlink config to torch (fixes torch.compile backward crash) + SITE_PACKAGES=$(python3 -c "import site; print(site.getusersitepackages())") TORCH_PATH=$(python3 -c "import torch; import os; print(os.path.dirname(torch.__file__))") ln -sf "${SITE_PACKAGES}/flash_attn_3/flash_attn_config.py" "${TORCH_PATH}/flash_attn_config.py" 2>/dev/null || true - python3 -c "from flash_attn_interface import flash_attn_func; print('FA3: OK')" || { - log "WARNING: FA3 check failed" - } + python3 -c "from flash_attn_interface import flash_attn_func; print('FA3: OK')" fi # Download dataset From 5ebbc36ae30e1632c55257a1b805ac46ea590c91 Mon Sep 17 00:00:00 2001 From: Billy Endson Date: Mon, 23 Mar 2026 20:41:03 +0700 Subject: [PATCH 30/32] Add DISABLE_COMPILE option to fix torch.compile/inductor issues Default to disabled for stability on fresh environments Co-Authored-By: Claude Opus 4.6 --- run_mos_sota.sh | 1 + train_gpt_mos_sota.py | 13 +++++++++---- 2 files changed, 10 insertions(+), 4 deletions(-) diff --git a/run_mos_sota.sh b/run_mos_sota.sh index 7b9aa54a8c..dd904f72b3 100755 --- a/run_mos_sota.sh +++ b/run_mos_sota.sh @@ -153,6 +153,7 @@ export VE_LAYERS="9,10" export USE_MOS export MOS_K export SEED +export DISABLE_COMPILE="${DISABLE_COMPILE:-1}" # Disable torch.compile by default (fixes inductor issues) log "Starting training..." log "Log file: ${LOG_PATH}" diff --git a/train_gpt_mos_sota.py b/train_gpt_mos_sota.py index d229422102..c3df5b8ab2 100644 --- a/train_gpt_mos_sota.py +++ b/train_gpt_mos_sota.py @@ -30,6 +30,10 @@ import torch import torch._dynamo torch._dynamo.config.optimize_ddp = False # Required for FA3 + torch.compile backward pass +# Disable torch.compile if environment variable is set (fixes inductor issues on some systems) +_DISABLE_COMPILE = bool(int(os.environ.get("DISABLE_COMPILE", "0"))) +if _DISABLE_COMPILE: + torch._dynamo.config.suppress_errors = True import torch.distributed as dist import torch.nn.functional as F from torch import Tensor, nn @@ -1075,7 +1079,7 @@ def eval_val_sliding( byte_count = torch.zeros((), device=device, dtype=torch.float64) base_model.eval() - compiled_logits = torch.compile(base_model.forward_logits, dynamic=False, fullgraph=True) + compiled_logits = base_model.forward_logits if _DISABLE_COMPILE else torch.compile(base_model.forward_logits, dynamic=False, fullgraph=True) _use_nll = getattr(base_model, 'returns_log_probs', False) with torch.inference_mode(): @@ -1220,7 +1224,8 @@ def main() -> None: code = Path(__file__).read_text(encoding="utf-8") args = Hyperparameters() - zeropower_via_newtonschulz5 = torch.compile(zeropower_via_newtonschulz5) + if not _DISABLE_COMPILE: + zeropower_via_newtonschulz5 = torch.compile(zeropower_via_newtonschulz5) # ----------------------------- # DISTRIBUTED + CUDA SETUP @@ -1344,7 +1349,7 @@ def log0(msg: str, console: bool = True) -> None: if isinstance(module, CastedLinear): module.float() restore_low_dim_params_to_fp32(base_model) - compiled_model = torch.compile(base_model, dynamic=False, fullgraph=True) + compiled_model = base_model if _DISABLE_COMPILE else torch.compile(base_model, dynamic=False, fullgraph=True) model: nn.Module = DDP(compiled_model, device_ids=[local_rank], broadcast_buffers=False) if distributed else compiled_model # Optimizer split: @@ -1690,7 +1695,7 @@ def lr_mul(step: int, elapsed_ms: float) -> float: m.float() restore_low_dim_params_to_fp32(eval_model) eval_model.load_state_dict(deq_state, strict=True) - compiled_eval = torch.compile(eval_model, dynamic=False, fullgraph=True) + compiled_eval = eval_model if _DISABLE_COMPILE else torch.compile(eval_model, dynamic=False, fullgraph=True) # Standard non-overlapping eval (sanity check) torch.cuda.synchronize() From 43c0e5a02ce590bae84f7aa85a3c12d5996841e2 Mon Sep 17 00:00:00 2001 From: Billy Endson Date: Mon, 23 Mar 2026 20:50:00 +0700 Subject: [PATCH 31/32] Remove nohup wait - use tmux for persistence instead Co-Authored-By: Claude Opus 4.6 --- run_mos_sota.sh | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/run_mos_sota.sh b/run_mos_sota.sh index dd904f72b3..3c04d53c27 100755 --- a/run_mos_sota.sh +++ b/run_mos_sota.sh @@ -157,12 +157,8 @@ export DISABLE_COMPILE="${DISABLE_COMPILE:-1}" # Disable torch.compile by defau log "Starting training..." log "Log file: ${LOG_PATH}" -nohup torchrun --standalone --nproc_per_node="${GPU_COUNT}" "${TRAIN_SCRIPT}" > "${LOG_PATH}" 2>&1 & -TRAIN_PID=$! -log "Training launched in background (PID: ${TRAIN_PID}). Safe to disconnect." -log "Monitor with: tail -f ${LOG_PATH}" -wait ${TRAIN_PID} -TRAIN_EXIT=$? +torchrun --standalone --nproc_per_node="${GPU_COUNT}" "${TRAIN_SCRIPT}" 2>&1 | tee "${LOG_PATH}" +TRAIN_EXIT=${PIPESTATUS[0]} log "Training finished (exit code: ${TRAIN_EXIT}). Key metrics:" grep -E 'val_bpb|model_params|mos_params|final_int|submission|Serialized|artifact|swa:' "${LOG_PATH}" | tail -20 || true From 2e9f535550cff61f9877d35b06c1f73c11cda9dd Mon Sep 17 00:00:00 2001 From: User123331 Date: Fri, 1 May 2026 00:30:57 +0700 Subject: [PATCH 32/32] Add draft parcae px43 embed7 clip1300 run --- .../GATE_REPORT.md | 27 +++ .../PR_DESCRIPTION.md | 50 ++++++ .../parcae-px43-embed7-clip1300/README.md | 57 ++++++ .../requirements.txt | 6 + .../submission.json | 53 ++++++ .../parcae-px43-embed7-clip1300/train_gpt.py | 2 + .../train_seed1337.log | 166 ++++++++++++++++++ .../train_seed2024.log | 166 ++++++++++++++++++ .../train_seed42.log | 166 ++++++++++++++++++ 9 files changed, 693 insertions(+) create mode 100644 records/track_10min_16mb/parcae-px43-embed7-clip1300/GATE_REPORT.md create mode 100644 records/track_10min_16mb/parcae-px43-embed7-clip1300/PR_DESCRIPTION.md create mode 100644 records/track_10min_16mb/parcae-px43-embed7-clip1300/README.md create mode 100644 records/track_10min_16mb/parcae-px43-embed7-clip1300/requirements.txt create mode 100644 records/track_10min_16mb/parcae-px43-embed7-clip1300/submission.json create mode 100644 records/track_10min_16mb/parcae-px43-embed7-clip1300/train_gpt.py create mode 100644 records/track_10min_16mb/parcae-px43-embed7-clip1300/train_seed1337.log create mode 100644 records/track_10min_16mb/parcae-px43-embed7-clip1300/train_seed2024.log create mode 100644 records/track_10min_16mb/parcae-px43-embed7-clip1300/train_seed42.log diff --git a/records/track_10min_16mb/parcae-px43-embed7-clip1300/GATE_REPORT.md b/records/track_10min_16mb/parcae-px43-embed7-clip1300/GATE_REPORT.md new file mode 100644 index 0000000000..9def917ae1 --- /dev/null +++ b/records/track_10min_16mb/parcae-px43-embed7-clip1300/GATE_REPORT.md @@ -0,0 +1,27 @@ +# Gate Report + +Formal local pre-submit check: + +```text +runs +seed train_ms train_step prequant_bpb sliding_bpb sliding_eval_ms artifact_bytes code_size_bytes sentinel +42 600024 4702 1.09071600 1.08802944 89275 15633824 19689 1 +1337 600117 4699 1.09073547 1.08783878 89174 15630505 19689 1 +2024 600093 4702 1.09049411 1.08760994 89318 15630862 19689 1 + +checks +artifact<=16000000 1 +train<=600000ms 0 +train_step>=4000 1 +final_sliding_eval<=600000ms 1 +sentinel_present_all 1 +required_seeds_present 1 +mean_sliding_bpb 1.08782605 +record_track_gate 0 + +verdict FAIL +reason training wallclock gate failed +reason record-track mean gate failed +``` + +The logs should be rerun with a slightly lower training cap if this candidate is intended for a strict 10-minute record-track package. diff --git a/records/track_10min_16mb/parcae-px43-embed7-clip1300/PR_DESCRIPTION.md b/records/track_10min_16mb/parcae-px43-embed7-clip1300/PR_DESCRIPTION.md new file mode 100644 index 0000000000..9942d77c25 --- /dev/null +++ b/records/track_10min_16mb/parcae-px43-embed7-clip1300/PR_DESCRIPTION.md @@ -0,0 +1,50 @@ +# Draft: parcae-px43-embed7-clip1300 + +This is a draft/non-record research submission based on the Parcae loop-injection direction from @mikeapedia's PR #1674: [Non-record: Parcae Loop Injection + Gemma-style Attention + Gram NS](https://github.com/openai/parameter-golf/pull/1674). + +## What This Architecture Is + +The main idea I wanted to test was whether the Parcae-style loop boundary can improve a small recurrent-depth transformer under the 8xH100 / 16MB setting. PR #1674 describes Parcae constrained loop injection as an SSM-inspired boundary condition at loop re-entry points: instead of passing the recurrent hidden state through unchanged, the loop boundary learns a stable decay term and a residual re-injection term from the original stream. In my run, this is combined with the px43/embed7/clip1300 compression setup and evaluated with the legal sliding-window path. + +The submitted package uses: + +- recurrent-depth transformer loop structure over the middle blocks +- QK-gain attention initialization +- skip gates and tied embedding/head path +- EMA post-training weights +- Hessian-aware mixed GPTQ +- 6-bit matrix quantization and 7-bit embedding quantization +- Brotli compression +- final sliding-window evaluation + +## Tokenizer / Data + +This run uses the Mikeapedia SP8192 tokenizer and pretokenized data from: + +- Hugging Face dataset: [Mikeapedia/parameter-golf-sp8192](https://huggingface.co/datasets/Mikeapedia/parameter-golf-sp8192/tree/main/datasets) +- Tokenizer file: `datasets/tokenizers/fineweb_8192_bpe.model` + +The tokenizer SHA256 used by the runner is: + +```text +a24fd9326f81c9456e24484aae2a05b209898738a0082f37b085ef2fe873cec7 +``` + +## Results + +Three completed 8xH100 seeds are included: + +| Seed | Sliding BPB | Train Time | Eval Time | Artifact Bytes | +|------|-------------|------------|-----------|----------------| +| 42 | 1.08802944 | 600.024s | 89.275s | 15,633,824 | +| 1337 | 1.08783878 | 600.117s | 89.174s | 15,630,505 | +| 2024 | 1.08760994 | 600.093s | 89.318s | 15,630,862 | +| Mean | 1.08782605 | 600.078s | 89.256s | 15,631,730 | + +The run is not being represented as a valid record. The local gate report is included because the logs exceed the strict 600s training budget by 24-117 ms and the score does not beat the current record threshold. + +## Credits + +Thanks to @mikeapedia for PR #1674 and the Parcae loop-injection research direction, plus the public Mikeapedia SP8192 tokenizer/data bundle used here. PR #1674 also points to its upstream inspirations, including xIELU/per-layer QK-gain work and the Parcae paper lineage; this experiment is an attempt to test that family of ideas under a 3-seed 8xH100 run. + +Thanks also to the Parameter Golf community for the prior work on depth recurrence, QK gain, GPTQ, SP8192 tokenization, and compression/eval tooling that this run builds on. diff --git a/records/track_10min_16mb/parcae-px43-embed7-clip1300/README.md b/records/track_10min_16mb/parcae-px43-embed7-clip1300/README.md new file mode 100644 index 0000000000..16d25f6c29 --- /dev/null +++ b/records/track_10min_16mb/parcae-px43-embed7-clip1300/README.md @@ -0,0 +1,57 @@ +# Draft: parcae-px43-embed7-clip1300 + +**Status:** draft only. Do not submit this as a record PR without rerunning. + +**3-seed mean val_bpb:** 1.08782605 + +| Seed | Sliding BPB | Train Time | Eval Time | Artifact Bytes | +|------|-------------|------------|-----------|----------------| +| 42 | 1.08802944 | 600.024s | 89.275s | 15,633,824 | +| 1337 | 1.08783878 | 600.117s | 89.174s | 15,630,505 | +| 2024 | 1.08760994 | 600.093s | 89.318s | 15,630,862 | +| Mean | 1.08782605 | 600.078s | 89.256s | 15,631,730 | + +## Gate Status + +This package is organized in the same shape as a Parameter Golf records-folder submission, but the phase 1 logs do not pass the record gate: + +- artifact size is under 16,000,000 bytes +- all three required seeds completed and emitted `RUN_COMPLETE_DO_NOT_KILL submission_ready=1` +- final sliding eval is under 600 seconds +- training time is over 600 seconds in all three logs by 24-117 ms +- mean BPB does not beat the current leaderboard SOTA + +## How to Run + +```bash +pip install --break-system-packages \ + flash_attn_3 \ + --find-links https://windreamer.github.io/flash-attention3-wheels/cu128_torch280 + +pip install --break-system-packages \ + 'fused-softcap-ce @ git+https://github.com/anthony-maio/fused-softcap-ce.git@25e7ad6292cd1e837eef592f50e4d9f5990b6c84' \ + brotli zstandard sentencepiece numpy tqdm + +DATA_DIR=./data \ +SEED=42 \ +torchrun --standalone --nproc_per_node=8 train_gpt.py +``` + +Use the Mikeapedia SP8192 data layout: + +```text +data/tokenizers/fineweb_8192_bpe.model +data/datasets/fineweb10B_sp8192/fineweb_train_*.bin +data/datasets/fineweb10B_sp8192/fineweb_val_*.bin +``` + +## Files + +- `train_gpt.py` +- `submission.json` +- `PR_DESCRIPTION.md` +- `train_seed42.log` +- `train_seed1337.log` +- `train_seed2024.log` +- `requirements.txt` +- `GATE_REPORT.md` diff --git a/records/track_10min_16mb/parcae-px43-embed7-clip1300/requirements.txt b/records/track_10min_16mb/parcae-px43-embed7-clip1300/requirements.txt new file mode 100644 index 0000000000..3cbdb33a87 --- /dev/null +++ b/records/track_10min_16mb/parcae-px43-embed7-clip1300/requirements.txt @@ -0,0 +1,6 @@ +brotli +zstandard +sentencepiece +numpy +tqdm +fused-softcap-ce @ git+https://github.com/anthony-maio/fused-softcap-ce.git@25e7ad6292cd1e837eef592f50e4d9f5990b6c84 diff --git a/records/track_10min_16mb/parcae-px43-embed7-clip1300/submission.json b/records/track_10min_16mb/parcae-px43-embed7-clip1300/submission.json new file mode 100644 index 0000000000..c875b41364 --- /dev/null +++ b/records/track_10min_16mb/parcae-px43-embed7-clip1300/submission.json @@ -0,0 +1,53 @@ +{ + "name": "Billy Endson", + "github": "User123331", + "val_bpb": 1.08782605, + "artifact_bytes": 15633824, + "training_time_seconds": 600.117, + "gpu": "8xH100 SXM", + "status": "draft_non_record_not_submitted", + "seeds": [42, 1337, 2024], + "runs": [ + { + "seed": 42, + "val_bpb": 1.08802944, + "prequant_bpb": 1.09071600, + "artifact_bytes": 15633824, + "training_time_seconds": 600.024, + "eval_time_seconds": 89.275, + "train_log": "train_seed42.log" + }, + { + "seed": 1337, + "val_bpb": 1.08783878, + "prequant_bpb": 1.09073547, + "artifact_bytes": 15630505, + "training_time_seconds": 600.117, + "eval_time_seconds": 89.174, + "train_log": "train_seed1337.log" + }, + { + "seed": 2024, + "val_bpb": 1.08760994, + "prequant_bpb": 1.09049411, + "artifact_bytes": 15630862, + "training_time_seconds": 600.093, + "eval_time_seconds": 89.318, + "train_log": "train_seed2024.log" + } + ], + "techniques": [ + "SP8192 tokenizer", + "Parallel residual routing", + "Depth recurrence over layers 3-5", + "QK-Gain 5.25", + "MuonEq-style row-normalized optimizer", + "EMA post-training weights", + "Hessian-aware mixed GPTQ", + "6-bit matrix quantization", + "7-bit embedding quantization", + "Brotli compression", + "Sliding-window evaluation" + ], + "notes": "Draft package generated from phase1 logs. This is not ready as a record PR because the parsed train logs exceed 600 seconds by 24-117 ms and the mean BPB does not beat the current leaderboard SOTA." +} diff --git a/records/track_10min_16mb/parcae-px43-embed7-clip1300/train_gpt.py b/records/track_10min_16mb/parcae-px43-embed7-clip1300/train_gpt.py new file mode 100644 index 0000000000..3e27bf641b --- /dev/null +++ b/records/track_10min_16mb/parcae-px43-embed7-clip1300/train_gpt.py @@ -0,0 +1,2 @@ +import lzma as L,base64 as B +exec(L.decompress(B.b85decode(';Qm)V3ta#ahy!|mS*&_qGSJZ_!icLS9#}~(c|-MFLM!d|c*$TJhq5pqrkZ>fc~1t4l>rmZckX-Rd%WS!COFfe&Q)}KwQ^n;h)dD(84o{wm3YxP%X5$NmAl)Zqnj?iBW#g&-k@9qOWq+xvW}T)Sz;_GMaBd){`ib{N1KN|?)FOWS5%@RsZ~=Lo{QUX)Qec|;cg41*J@@8W(!Po7I;qwM+VOdBFE>v;}WAzYyS5kf&7Ks{&^mAZ&W*!25CcOc6-Vh+jKyN52K39NUL&Mo$a+kUUjLY$k9cIe;@vwJkll1E>GaXOO@{05q9SRJzE)}aEt|U1MMvHuO%L3r|(>l3JYEHz#LyIv^cw8arD(7lttj^RIxGI?Gn0F0eQWZU-5m=aZuLQzscdh>(kH;d6cMcvY+!^#)ok2~tkW?Y@>)5G7Jz8Y-C$2cXNiqqpLtO->RPyHZ5MN{e<^_kQA`SeSDTT&GhZ$J4d@Z`h~Z|&g-_P0v3$*4HD4sQm?4&z(+mlj%bCty3kz?th)%Gy_Mk0!IEcT>9tGt5GtpdETQQOkP_&C(jg6%9+Fca^!5X@b?uk#3q!|9&eowqw%q4Rdu5VT?1+o+hP*JZgAUUut4nA=5Ai`gEkbvZ@OBmVK&ooKCb8D|a69ux2H_S+(4QRigRC&6`eGk_R=h4^g>TZe=N!x5fFDR*iWNVPq<@ItC1oAK|eR(^d@K%5%7cD&iOxVcoMgtt0IAd$QwKf>O9+JiznANrkEG~O;3yJaffZc0xfnbqhUVhDSlaHH>rJoA4;Exf(TR)8u@qy!TTCP;pN_u#7S+QVYsDCu-TI63U~1ebRpWIJ*|2lUJwDqXq5?%*DqsK^^>2N(m>@eo_8+Wp;%53lRJd}UIiJpds;QZ3fdwnR-CYL9^Fk?rkkjZE=wVA+#sLDjK+3JS*vyP<<@ct*mULH*Tx-qlWEXblwO&X)@$n{y<&Co80)iQ~TCNU?x{)~i-I=-nX?LprXsj$i`*ph4YJ0M$iVSxVJQj&c(-V9rTF1bnz=GaGfAuzmZE*B&|qQTjis*9B(^o_oIU^ZKr5Attwhh{z}DwwChmTs7N!yJkbu)2-B_-2@?7)h@<%%a;`#g>L&|DO%2uD5((70d`x)~R}qb>=UrefT9UI$4+6;t|Ki#OK-pF-C%Zz6P)mG8mBJrCqcT^4{)dith>s#j?~@jg=5mPFW%?Gp6-@mtW|6Ba8eMU}6~;t}2UTtBkc4H#u`ZRD24kN0*&udSWow1chzl0v-8yms)c((Bp;^q!^3*KKAyPEyBScpUI$c^WP^&-0|j_0cy5vMBPGyo}JpZLk^c9JdXwqOHM_H!+S_sS)%#^dp9ML4{|c@x%T+7x>cLR($)2zBX6CGK~G*Sb0Y|2(M?0#S|`(+AosuBO(lVB*9;df?z~@yeIAaV2$AeCceMfG_qYyKwpmi)x}3({6r@D=LLNY0-*oG0u%*<4Fptcv?5N+bQgVYoVIHyJSw$tWwSswgY>?V&L~vVD;%#s0B`_S2_g0f>z_w^-TxN6k>%SS~=ehQen#V4{vsP16MyjXjWOzY|*>u&5q{noUlYdbX(`(K1230`)|SL&`&z2W}ny0QU;Lz|MJcN&;$U)slHNTlaH8<5p|29;VQSLSQUCQBaI-yhRA1t5(J&(SeW%F5qy^(>0(@3i_E_3P+ukcvvM7zzjl@9C#s&9;)$mB?6;PIMK%fiVIEH($kgtAXPbk`ZCT{A-2^ZW5ZW?-K*x$W3xy>=LEHMT-&YSf)=rv=2Q^5&2^sQSV>z(fB-mJ%kaj85&AQ|E;VYQ-3wbkJ{mN~!yp9U@IzulkZ#hh4}v#DVx>qNMsgAG=&hCL$KUxFO@V`=;m2M5FjW1Z9uNKCL}Veey1KG*s)2mu|Wr2BSNKV2k9|=8)fo3?HJKHR1kL!@yaklhy~rAF4=h$gF8iG8eRI-Zo!WH>z*r*Z9YQDX_6u=N{lUdyfW@}Q{F3qLz#)f5DtAFJRq3*nOTBA7bb@GYb^I+&((8)-h(CZ=zxtm3=uaUm*qpAq|sS;YGN|?OSJ^#F^StZ14mU*C_qCGKklYYrGP3O<72W>Zpal-W4W;y1$eqLdNqx^U58F0(Pxq*Y9k*ksIXo^*>JkC5;TCTrL46XxA;tB?AmG|R{o=`S|D-{kJOgYe7E9M>^i~KEL4?4jTrbL5NZjW?1ViiKL8nZ2AWtx}hliS-ZiSBqRm>m{usps92va@uj*l!6G_y&j?@dq(3ul`g*C8x~c~lwj$KiPlb;=nxQ?=@LmbF4Wqa!x**-|OIBu@YXTOufO~VE9au2$;^PaQroP%K|F&!QF(HAcyjeKr9~o!`4>y6m&@=z5*CDj>-r`_3Z@AMnKe#TQ>-lo=J?mGMVWS(uf2vBM^(jGn+Psigj37JSvT!~u!yp1;ikS<_QEwq$Ozt`Q=?67&kgX-aXVCH$LP_E_+7fvX0Pv(Fo%GMyG}H#si>v*0nIxQti!Rre9p#LT)7KH+$Ni2Mc>(X!0(ZH>kZBw_+Am6mqq@74q#kB&wMr(TLls(tQ?|D{Uoz65J_gwSazpd+O-dkfm${R%r_a@{O1dP`WYuwW#@1wqe%JeE`GW!F^V<8}OuF@#|R|al$H~Gv+T$qVKlHaE*9BINXOFAKsItklUHcCHiv92wBzgkDo7i0xO5q$H6R7hQSf5@HGC^PFxZ-xrKn9)tKsxmsLKUbWc;2iHuXaT;toSIRO5t%a>_gyGbT|RYGNCzw@!34g-m6vufdtXS^c|$vixaF8nr3X5w3_6;nGa7aVqVg1{6-mSS;WP)7U&AM-3~B;VfrKI9Qo|{0JSRF$hn{gY*A$!WR+lio`)HXO#${ND1NQOa9=eWH>CS2i>%r2<40cii+ppaTSL^z!{A@E@(yAd5hb=0aB0;WXv8lLIaAl3x$~Ah1`~#A9TC@FYrguF(Ygum@V_AluOx**zAYa$1TFgK_oPYXrN$>1PVzHORWli$BalK`#*N5XM3P>c$G<2mEoLR0tAH>sShQYI5iOMudJ&>~^J9lL{qrgLku01ovBL;B|u+S=LbmP@qd|@;HLNcMe1~Xr64wF{lcq7ffVQ!KoK~lofPUmN<`qfbczw9TD{50=(K|q#Zh&D2Djy>XIO@e%GKZNSmci7oC8T9lLn?fIaMVH$0wz1QQloWUuCO^OVgN;Vbs$E%+4EgS3E~6Vl(FzdteUww{N+W_==rKC;jWVRkp%Ep=$dq9iDX@kYj}&#uBX8Vc&b+$v_D)O6DD3`V96{=S(YzW!=}1@`%CpCgWsEXEdvr*F{i7HS2+z%+olbfU;tErFU5ieqf?o$$3gIFJH1p{c-WI`DYDUKwZsy0B;u;0!bR3S01(A^9Z@4E20C{2xG(>ztSB0F;MLJ+ifG{KeWv2%BU!g+$p)^{VGH~PkW_h5QB?AK2&>Ri`EiRbjVGjlqT30p+C_6PM%_W`Ajg6A8LxhU&50fLQb=Yi!l?(cB$f|yz%f>!u*%*JuIC}jo+yUD5P)|7V_KE?A{gFbQf|Bb62&@|p$`61=FtF!bI;S8`gZLR%FJg`HtA^#78#@i(u6DpCda~XW{BBpOz{u}0J4HK3)eoq0=hm5y1@UqMtll4_ZOqvTDXhsnutVbCpSpaFVD$QVA%O7^4nQV3uV{wtf3~Q^r^mt1;X~p8)9FGoClKSOdo<=ZBCt0DD%-W1zq=nNb?L?h95fk;TZArj!TZc`0K|4JUwQWr}FzJLn0bi4){$D|zB5!tTus^R9DY3k?`rfa*dw+-Qo(wV-kOdpce>KfST>n0T4>3IR_e252|ETE1xHKQx=e8lZauQg*ww`?N#b8QsJ*`dy-KM$FXW?81H>CdrGT9Zg^<2+2=a9byqJE`nqCsH{Ky-LAV{YBkG{BGdHhkzq{pRzm9CTrrS<{lz-}Ghz@|Z2~hFXM-(#7=(A8pd)hB4*5ILFVbf0U(7uFr*cMg@46ssTkfyVVuTNpT(Lrv4RLE#7$^2FO&k8|qC9@;vq7Z25f-Kh;TZ!Sv?5|Wi(|z#P#VczN0*M-COs_@nJq^*MEJFxp|WnweS?Oj)wT2PY2E$UR3|ApXjg*)O_w7GPADg*+Bj4y^pL(|s05@xTQq={#M1wZgDJp$BfM8Dq7eNl)d~IhN^}wO31l&$*U^KexVts_`*WxNaZ98tpgv2RC5WW`1LfRswEp`R8N({c;#^00#iL%TdUQ?fV~Z=(anz(S_vvUAA1CLQIR!Eb<;H1Ma*a{vZmWjs1n*PyRbaT5^*BMEOz;J79S=DWM!*OrPRr`fDL_1@zI&jxHt@d2I2~JOrG9_zNxG4hnm?rzA_(02UsF1dYXB0M=V2zz!bGCO!87~DB|ZU~Er(33;*aZHQSrs$XeUJylj5n*zPObQpaPeHGw3!L|Y}^o_(ymQ^jT;IzX7_h7simA{>^LiwKe2+K;C@ZPq~lJsGUdok|MF7gGjR#eO#9&*kWZ~|m#5fv;z&-KI!xR%++q@~QyleUe_{SJ5SIE=fw=SZhzbvTumqvDND0)JC8YvinX$m*ko`UWqXr4IRzgd(k;5JR?r8f1_oHNosXl;oUzV(+}RWi5ie0fxV9H@i;o;vP<-Yf^(24+zrsIDfmQH&C1+LPT^=T!2pJNMrsrpT3cc8$)%OGR;1%JeW`@R%e#z(i$kyGMz#m0W<*_6L8(0vbUDrswp>n9kOEa?D)JI6??N0ec*rv(2xd3bBh@#7N%lxP3LtvZ1&ih`}pydSjSV9*kQCE!QT+=aIFt8A%vk$Xr|+X+1@GB}NH~0fBHbh?4)C;^t5qP)Ff!G=X`p)kO-`oKNe53LF`4BuB4XQ}djqp$^dG&&jBYGEX}3j&Z@dbYW4QoKCNvt4rT-7`kkdodAxjw)YmD%bHMayh`Us(Emt&~Evc$QMAwLaq!Jx?*Y$RS18Y*c-FurtCtn3)FQs$7#3Z7R90U5GvX*LbJ;w!M-UVb{0n$o!@Rny##nkQDmp<&Mm3qk=T)2b=aJBzbHn+8zW4+DPW|w(^^Cj8eVz#^iW<0+Z&oe(XF~*54HF9S38ND+(doW22H4)*({B3^~vSix73FjlSMa~O8h@57$>TEX%B2i9D7{C4_l1Azz@@D5Znx&U$POo2G7|aXRh5jit`QB@yv+5(ZBOCtK#0MShzf@PpAhB5JnTJB-Zi3E0}C2VoYC`o@&~GREbd&xWd=+YbyMHzOonl$SFn@o^(OZuEKy*C)|u3;d4UzkLr)<{m)AOqZXs#^NwCg!jMu?TOZipPqT*j|gfG6oU31-jY8Ai&j}~P%tJ?ne1(9WY#LlLps<1qQK(sp$99*-QPT~%KKbJxdCDWK3(e!SzP@hVn3Z;3T&Jn7a0bOKqN}?ha47KHrWoj!7}jv;0{5|X&d^dl21D--0TO-%S!@T@ZF@2fBr`|a~HdFE`1MV681L)ocUnz9UkrB?H{9&ksBh{w4^Po!{bACDWnbL&maX%I0C-uEJj02*SWkpH#(VO=75%@uAl|ZNCZxuIxG0+!ao0V^rCaW2v9Bkt|6-Q-L##yuMv0su+rtPsDd%feXhvro#K$%S1|Bc$v5(dspR_xO*7c3DqlqAm@}sjBke%|}HV3|)HU@THqo5Kw{$^fyDv7c-DICk@JX@p>*n60q14hdBjD(DaXS5nqsdxIA#X+y}ydYM%v%X1Qm+UwdlbkL}xSFg%l#jo|wvn}N)eFz*f(R&9n`P5YxEnJ>^l93g-TVljN_~%Q%eF8LQ+qVd_mhWf?pME4kBL9+an51{NC1z6JwLvPq)ZlogM{nZ$QyLckl4@isrUd-T4{#OZyGsn)`2#WfPU~?1NeItC2S@BAXr686Nh;vyDnZ@rk3J1rfPP(LUTOTao^^YOD7sSUrfiLd%NlWy`3LMJe&Chl}J|Q9O)Q0M8-(g*;xwBWo^|oM6Z(LSkqcMd(iME2_Lb2CAY45>6KmreSy2?b0<7JW8ENBd_s>YoXjJ}P}S&`>RkHJbJD~YfGo}b(0}#z7?Jwj6td;RpihKVOpD_6tS>AmcWKk2n9FcZO<*g?N^VLKbK2}7?}yI`~-r!~zq_L^^r$)R52EtyNXZF^!pYWD{Ig!_=L0Fy*4jU;-%ghPN%a1L({-R=QL`3QB6h;E=2Pc^!hC^6O*+LYP>p1GM|E2lV5$TioP8;3jL?>*2vyTB*J4OFM5G&pi$^>dOjyN6fvN2p-klR~}R}sl}(&UGKz(PA{e-xDMkvKF|pg31OrqP~qZ2-awRq%>?+Wc%f3X`|vYMu#ZG5OW{5{P0MT64$@0@7}Q(IZerfT@u#KK^`#Xo*@kn1E6i=kgp6)jsF?@M{q9UeJRwgJIv$p^Lsr8Tv}U|E3InVZTPk;XMBX_i=;ZvPVk}crlRbPfqX=3y)SqhA(qFa4V}^135j*Ajm%18(c4%4>Hch`N@t|ftcQB4ac>zNSk4rT88HBeR30HTF>FF{ypoBn|?FtI>HFx)qgzT@Y<3PM4Q?UnRee&$&`38}^djyTiZ-A)2Mwzr~i-;U-yF-cAA0)R?*H9o0ehBJ>ZaDWC%Bfh}oIBeSJ1LE6Aa5cDGnx1RqcrmlZ~YfU{797C7DtT78j$CEq4Gk>^b8*4T5ND|z&U;AX59s_p(ghd4+hi`zzO-HepFZyl|H3Qdec7;sUGuN)73&O8?l~EwY+dc8_?a=f0yLHy^i5ZtCERf36*xlj@1RVP>+5I^;V!=8DC^UOGK=Pp=c8oQF$1*AvU3GjPlPB0HoR2ddR=)HFqZ^5NIR5AtM;@=Ty`UL1g){P(E)a0$3xsfA+YHBAVof7ar5m#-rAA-!w<*R5xL#50~}P2H8-gFd>lPpH9kU1tM=&Ols_`rFZ>TQK~Gf`6kFVO!~IcJKoR$424c9BC+nGCmM+`XkIzHoTu>*caBVYb*2nvEanIh8Ijl~=(*KWzn!cg6glN{)ZZfE`PUmR6vL(E_;yY_1p}C|*MQPU;GbV`?LXn#Zka59o+Px(k(E=*EUMgcwN7b}ke~Df_S9p56qs!#eFR1ct|@i4h11G0ug>tSle3W!WD$>@cpR^sm^7}Mv6_stnq=u01=ctjd1$619gQ4cq6o6k=ONIjlf)SG{Q#By3On0j6_HMZQ2JA+Af_)FYmyV6Y68^XImPFnq?+?kdUBQ01mHY^+)#OPSv`BIB9nA(=jShSxV{=?$M?r^ci%oYa-b8Es=eqzKZD9k}^l6CFoSZiJ&z;%DK*D>V;auASCcM;*XvGOStOyd^$EOs=HcK^Dw0c_oT-pjf?#DJy%%+cv1KM<4mp+kNV$uU5TcmuPtW*qWPHrCiLKjBPIy>TBgn^V`sS~HeglQLB}Oo^lgjEa^XzSY;O%5c><-81eXo_?krjB2Q_nMqMkKFzn{v5z9UJsj37dmbZhw1%uV)TsvTt9vaT!1bFXPZ+w}?+NOtj!G3ay{{xb($nF;k`%kqx185fuE13JNn-O~*MB0Q>$d2>mGxqS04@TMGlN(~J*>iZ)q_|{5bsvfdfEPwfzBzb;xfBrB>tE<#uqrv8c-&1`nzO5;AHpS$YP34N$dD5Vyc_M9xEyqe!YoZBYI50krdIPg&pWbj-Xb78qM?!mzwv^7J>q&za!I*yIsU0ZI9?57xvU)dV1reXlX0v}~B)xAn|sp^e$+Iu`YIarnqjS#?s}c{(akj3oei#N#>7!`hwyxw-^j;A`hyFbVBCfVs>j9RS`?V#qYqz9MR6bogqIfy10>QzmQA)N^7==^lgw6=Qe7E^G5Z<{zL)PsgzqMLea_7_0OArj;7}SLPt@58E*gqLqL|1G@#MJ{1aGN8&%KN2(_dR{8SwTmN{oqkiUiwlM=nx(6K;mLco?1w`F>oR-_57RnX;IiI$p<>5_bSeu>FaO0(NwP(1-z+gZF`$;ePrk?tvCdB-LZ>#z15N*NMKZT`Vdp2)aaYosQJ9a%Jk4&p0Xjm!0B@>R+C6Bx1=zGBni2t4$?`=Py{<80iaNKyfhsln$Tx*R4+X&`Vg><9rh3h>lQ}vB#eS6|JsIkw|pa+PPpE>{FDxZMODEH|y4b_2rz2yPi2B8s>$C5jid&3Agk$MH6vCTbhP&coOA#-1KqhBEl;}uj#L^fqV>uCIMeVklwSp%x{jcJcQ`$QLt{+O||;*zV3yo6f%IuVf_6*@6NUK6x^oFnQ8U$RD7rrm~4|lan7q!iy9+M{4}-H#Ny(IhWx^w>GF8j$^?ajnMaKOU_o-VRz7A3OM?sqMftY`C$nlEu6%NTn1v^J)^H8d(G~GF0pE-yqwX_*%@tVwtf2Df6V+#vI{wB!;8lK(jvd(X{dd`b`fq5A5>ByMXj|%#Yci_vs|8RS$`CZ{3kPiMEF8u#hvli_2?;FU;1Q_yO^R^u_7OSC5C#;RoTdZMHCn@1oh(ZoVo7?BHfzg4Ft&45YXEkK9S<6I=%i4n3l@cxXQHqEnrL%B{mhTK~>m+)TouDh1t1GfwI}wygV$mXW_U2i&RVar$9?VHk4Z79sG(t~fU;5K^${F(m%CVLpnfts0QG!JXJ8lJgsy6CPoQgO*e17f8k^zX0dZb+VNn_u&5zocnGgA-4;I&+y_RV}pARZ`_(MgQoOAR*5qAQ+`?4B5)*?v@%`EatTb~at$_R_mJ>`7~5!$yN6&#%L{__gf`xADVK$~z(|xwGVNS8O_&n=Rv-JM{Hw*HbnpPKwd|4qu#m>O`t&G44oxeXmtuOpXGgptaTiz$XUe<#4k632eSaI()wV5AUYV1!R(CpHy!{5fi)AZ>Q>%UzgO6@W17Z^m_2?FU6&fgF1%h76)9xasi0-8UhW7{UOZ;MvOgaDnKH~;+-${fWc0KgS!3P1~8RjGf3WON&PuSAN~OhK~5Miac?v|*a5YB1LejhqW1v-Y27~Ibx*w%)fPpC*b*VE4iH5*YmKGwR5(jmk2kX&R04>CjwI#!OXbSWnpp0QhhnZGRhyn&6v}td49+C$dbf<&%m-B@HL5{Wy^)7Oi|WGy^|XP=s`38u9C0?j4rvFGfQS_2fd2Q;q-V7S(U-C3VD2%rCb^#iU{R{MKN8*}*7!FS!PaU(5zH957!g__{$o2pv^c>y``6ViwluSo@Obik+e5Ox<@~?b7~chuj1s=($bk>RV*UHnpjw?XDQ1Df>$}E(jc%ts%1NYmp4L=F`|$sennfLqTMb(`q*V9Y~)4Wy==2njF9(GhbuF@ZDm){D@l*SOK|8CYsylcG$8}%l3YcINosS1e9DDGG~X(r1>ZFUaySXq{kEW!TG-mR?ZQ5ByQxVIUeS*yFak%1)z}JXLOSKaLT8G9T^odX&$$gO-&@SPv~YXPLx{_xxf!_GK^UpvdyFw4=jux{)%+FTX?89CrsRNUWJufeG3Gzn<`@g_Q;@3HoHs(`KM-zi4TT2*WUnfgf47Hw?BvRhPR^2zfAT>l3(hwE3pcj>LiI{Fh~UdDl|Q6FWdWjdD4Zpc<#K-nG;>tM0;j)y7qm6O4W1ZEZECvpQ7yGIP})S=SqM{oV#Z7-+!PKQhM5JdW|S~=ZDBzru*+=ipf}1+Nvkddvgv&T@YwE7_$jk47z9m33K&#iCXqeBR`=Q@mdihO)``O(117&Ro&fjdmeMnFuyO(|Pw1PD#7mJ0ir9Nvq07os3Oge?sUw=$}Y}iP_^uPj`h_26#Y#*^A@GjC5-=o^I>{e4cqUwSA5$Qik{sJc`3e5xMN7CNO*`eqwC}J`P@;~KUv>DPho5rVCQ_e)+VH@pKy|B<60uep@t6pOXJmEN6gXZCy-epyYWu;@FF66jw+SMmx=df>6#OO;gA1~F0pna`(3baOLw?%-QGa(xPFnvY>QaaJVrE$Qe->78#DC9HRom(drS12F*{;t^pg>qd%xVbsf|ukvHVf>Ai!x%R{CyUME-UZAj+I(nmU}Qxyx3;R{9RJ;g*bH+DNzBZ{ecfkcT`!nEuHu5+t+9Svg`+p8D4tldM}8d69Xh!asH!c-NTgb8QYUvJ9xBFWKeL3uImPw>g;uh8Y{C9@*$N+=B5yV>!7yQngU;zz`KB9v!X!hU*cy&w8-Xb;AE2Q`ObhSSn?BT9F7@%p*N1-cRQr|CL1D|ZsJVaoY>r7c3=-$oV;!ttM9S{Ag7;db?gpose^D&AZ4++a*JDrq6U-1>3){avHkNtt1zl3jyqIf583&XKsx*KtTED#(N^;~oVsjd#H0KGI!Rwn?Ryitfzv|+k8Eud@K@t1(&Mil_{IXnruCUCy)M$D{a#fgh4g%LA78k^Hn&l$?ZanVI&?;7JZoG0N!T^M2b-i0A0vlzK~6tdA{ulYFGhyE|>Daul=(B&4p?*+xV7hGFuU7a!3ZaCc5VCj?$7vQJSz3&p`QPD1mb&wl&mdOgV7f&EQkp1zi;#$V{#Nw@=SF3P^;tzbKE;CS;C>F``jW>KxDP7PVftO5`+J@acF;O+}Gcq~3@U;#^I$3nmu3l4Po!%>2MzuRFInF~m%=8+M)Q~e)u6r_Lt`+A~X6{yFUpZt#rspS3g6=}VtNjm>OJvV>8jf0hdABT;k=q4I?pSeEz{Zs9=8=&{8@Dv&vAjJbq>Rx>J;17ED{9ERxbqerj{6*rBwtM8@HVp2VJ1P~5Jf5!Hu;t_M(TShl|75zW|RU}BONsox~F2K74~VBb-|L*w3o9)O(Gl3u=si_-<{c8Zx(6^KqJN67JbLg5cgA0ZUJ+)#Am?bc+|GEQZdDZ5*NKLBz=Xt1OkdLtjP@}a|(boKU+bZR=P?taEw1GuelTMu2dHfXq{7R(=D59vV+2NuKCwi_GOgYrt6ODED#fwAygfGENtY5U?5V+V{J1X)>j)!l;10Ohm#jx8o2_h$G6?S4)do3ykLNh|pejuVMYtps-#*O_*7TAXV)P9p+5&U5>;!d+tSM5q_XOH)xRR`nxiZRy3<*hCD*m`p!k^m1ON0rKin{aolOY*~Nqgc;@!|6e?LUbdWjyo(65c3v+Qpt?35uzC)FwJ?Tv(F!E_{_TorRs5Y%`wIkkZR3cE0m4<4-52a*BL>kBk}V9V!%KHGo0}&Fz$y^e}$~nr|8Ryjs;E_YLl3(RPyf(OF|bFcn21yt^r*7#z~GvD5$'),format=L.FORMAT_RAW,filters=[{'id':L.FILTER_LZMA2,'preset':9|L.PRESET_EXTREME}])) diff --git a/records/track_10min_16mb/parcae-px43-embed7-clip1300/train_seed1337.log b/records/track_10min_16mb/parcae-px43-embed7-clip1300/train_seed1337.log new file mode 100644 index 0000000000..52c13e4ede --- /dev/null +++ b/records/track_10min_16mb/parcae-px43-embed7-clip1300/train_seed1337.log @@ -0,0 +1,166 @@ +Hyperparameters: + adam_eps: 1e-08 + adam_wd: 0.02 + beta1: 0.9 + beta2: 0.95 + compressor: brotli + data_dir: /home/ubuntu/parameter-golf/data + datasets_dir: /home/ubuntu/parameter-golf/data/datasets/fineweb10B_sp8192 + distributed: True + ema_decay: 0.9965 + embed_bits: 7 + embed_clip_sigmas: 15.0 + embed_lr: 0.6 + embed_wd: 0.085 + embedding_dim: 512 + enable_looping_at: 0.35 + eval_abort_seconds: 540.0 + eval_budget_seconds: 600.0 + eval_seq_len: 2048 + eval_stride: 64 + export_wait_timeout_seconds: 1200.0 + gptq_calibration_batches: 96 + gptq_reserve_seconds: 18.0 + grad_accum_steps: 1 + grad_clip_norm: 0.3 + head_lr: 0.008 + is_main_process: True + iterations: 20000 + ln_scale: True + local_rank: 0 + logfile: logs/px01_b00_control_s42_px43_embed7_clip1300_seed1337.txt + logit_softcap: 30.0 + loop_end: 5 + loop_start: 3 + matrix_bits: 6 + matrix_clip_sigmas: 13.0 + matrix_lr: 0.022 + max_wallclock_seconds: 618.0 + min_lr: 0.0 + mlp_mult: 4.0 + model_dim: 512 + model_path: final_model.pt + model_size_limit_bytes: 16000000 + muon_backend_steps: 5 + muon_beta2: 0.95 + muon_momentum: 0.99 + muon_momentum_warmup_start: 0.92 + muon_momentum_warmup_steps: 1500 + muon_row_normalize: True + muon_wd: 0.095 + num_heads: 8 + num_kv_heads: 4 + num_layers: 11 + num_loops: 2 + parallel_residual_start: 7 + qk_gain_init: 5.25 + quantized_model_path: final_model.int6.ptz + rank: 0 + rope_base: 10000.0 + rope_dims: 16 + rope_train_seq_len: 2048 + run_id: px01_b00_control_s42_px43_embed7_clip1300_seed1337 + scalar_lr: 0.02 + seed: 1337 + selective_prune_enabled: True + skip_gates_enabled: True + sliding_window_enabled: True +tokenizer_sha256: a24fd9326f81c9456e24484aae2a05b209898738a0082f37b085ef2fe873cec7 tie_embeddings: True + +tokenizer_sha256: a24fd9326f81c9456e24484aae2a05b209898738a0082f37b085ef2fe873cec7 tied_embed_init_std: 0.005tokenizer_sha256: a24fd9326f81c9456e24484aae2a05b209898738a0082f37b085ef2fe873cec7tokenizer_sha256: a24fd9326f81c9456e24484aae2a05b209898738a0082f37b085ef2fe873cec7tokenizer_sha256: a24fd9326f81c9456e24484aae2a05b209898738a0082f37b085ef2fe873cec7tokenizer_sha256: a24fd9326f81c9456e24484aae2a05b209898738a0082f37b085ef2fe873cec7tokenizer_sha256: a24fd9326f81c9456e24484aae2a05b209898738a0082f37b085ef2fe873cec7 + + + + + + + tied_embed_lr: 0.03 + tokenizer_path: /home/ubuntu/parameter-golf/data/tokenizers/fineweb_8192_bpe.model + train_batch_tokens: 786432 + train_files: /home/ubuntu/parameter-golf/data/datasets/fineweb10B_sp8192/fineweb_train_*.bin + train_log_every: 500 + train_seq_len: 2048 + val_batch_tokens: 524288 + val_files: /home/ubuntu/parameter-golf/data/datasets/fineweb10B_sp8192/fineweb_val_*.bin + val_loss_every: 4000 + vocab_size: 8192 + warmdown_frac: 0.72 + warmup_steps: 20 + world_size: 8 + xsa_last_n: 11 +tokenizer_sha256: a24fd9326f81c9456e24484aae2a05b209898738a0082f37b085ef2fe873cec7 +train_shards: 122 +val_tokens: 38834176 +model_params:35944536 +gptq:reserving 18s, effective=600000ms +warmup_step: 1/20 +warmup_step: 2/20 +warmup_step: 3/20 +warmup_step: 4/20 +warmup_step: 5/20 +warmup_step: 6/20 +warmup_step: 10/20 +warmup_step: 20/20 +loop_warmup:enabled encoder:[0, 1, 2, 3, 4, 5, 3, 4] decoder:[5, 3, 4, 5, 6, 7, 8, 9, 10] +loop_warmup_step: 1/20 +loop_warmup_step: 2/20 +loop_warmup_step: 3/20 +loop_warmup_step: 4/20 +loop_warmup_step: 5/20 +loop_warmup_step: 6/20 +loop_warmup_step: 10/20 +loop_warmup_step: 20/20 +0/20000 val_loss: 9.0142 val_bpb: 3.3428 +1/20000 train_loss: 9.0139 train_time: 0.0m tok/s: 8284529 +2/20000 train_loss: 12.4685 train_time: 0.0m tok/s: 8151088 +3/20000 train_loss: 11.6102 train_time: 0.0m tok/s: 8067718 +4/20000 train_loss: 10.0085 train_time: 0.0m tok/s: 8019840 +5/20000 train_loss: 8.6123 train_time: 0.0m tok/s: 7989428 +500/20000 train_loss: 3.4505 train_time: 0.8m tok/s: 7776852 +1000/20000 train_loss: 3.3783 train_time: 1.7m tok/s: 7781695 +1500/20000 train_loss: 3.3020 train_time: 2.5m tok/s: 7780681 +2000/20000 train_loss: 3.2602 train_time: 3.4m tok/s: 7779475 +layer_loop:enabled step:2078 frac:0.350 encoder:[0, 1, 2, 3, 4, 5, 3, 4] decoder:[5, 3, 4, 5, 6, 7, 8, 9, 10] +2500/20000 train_loss: 3.1542 train_time: 4.5m tok/s: 7203162 +3000/20000 train_loss: 3.1787 train_time: 5.8m tok/s: 6791195 +3500/20000 train_loss: 3.1338 train_time: 7.0m tok/s: 6525184 +4000/20000 train_loss: 2.9852 train_time: 8.3m tok/s: 6339205 +4000/20000 val_loss: 3.0234 val_bpb: 1.1212 +4500/20000 train_loss: 3.0391 train_time: 9.5m tok/s: 6202956 +4699/20000 val_loss: 2.9439 val_bpb: 1.0917 +stopping_early: wallclock_cap train_time: 600117ms step: 4699/20000 +peak memory allocated: 39034 MiB reserved: 39060 MiB +ema:applying EMA weights +pre-quantization post-ema val_loss:2.94129076 val_bpb:1.09073547 eval_time:8134ms +Serialized model: 135431033 bytes +Code size: 19689 bytes +GPTQ:collecting Hessians from calibration data... +GPTQ:collected 67 Hessians in 22.0s +export:cpu_threads=32 (was 16) +GPTQ:quantizing weights... +Quantized weights: + gptq (int6): blocks.attn.c_k.weight, blocks.attn.c_q.weight, blocks.attn.c_v.weight, blocks.attn.proj.weight, blocks.mlp.fc.weight, blocks.mlp.proj.weight + gptq (int7): tok_emb.weight + passthrough (float16): blocks.attn.q_gain, blocks.attn_scale, blocks.mlp_scale, blocks.resid_mix, skip_gates, skip_weights +GPTQ:quantized in 6.8s +export:packing compressed blob with brotli... +export_wait rank=6 elapsed=30s remaining=1170s +export_wait rank=7 elapsed=30s remaining=1170s +export_wait rank=4 elapsed=30s remaining=1170s +export_wait rank=1 elapsed=30s remaining=1170s +export_wait rank=5 elapsed=30s remaining=1170s +export_wait rank=3 elapsed=30s remaining=1170s +export_wait rank=2 elapsed=30s remaining=1170s +export_wait rank=6 elapsed=60s remaining=1140s +export_wait rank=7 elapsed=60s remaining=1140s +export_wait rank=4 elapsed=60s remaining=1140s +export_wait rank=1 elapsed=60s remaining=1140s +export_wait rank=5 elapsed=60s remaining=1140s +export_wait rank=3 elapsed=60s remaining=1140s +export_wait rank=2 elapsed=60s remaining=1140s +export:packed bytes=15610816 seconds=77.3 +size_check total_submission=15630505 limit=16000000 +Serialized model quantized+brotli: 15610816 bytes +Total submission size quantized+brotli: 15630505 bytes +quantized_sliding_window val_loss:2.93347952 val_bpb:1.08783878 eval_time:89174ms +RUN_COMPLETE_DO_NOT_KILL submission_ready=1 diff --git a/records/track_10min_16mb/parcae-px43-embed7-clip1300/train_seed2024.log b/records/track_10min_16mb/parcae-px43-embed7-clip1300/train_seed2024.log new file mode 100644 index 0000000000..eaef06a774 --- /dev/null +++ b/records/track_10min_16mb/parcae-px43-embed7-clip1300/train_seed2024.log @@ -0,0 +1,166 @@ +Hyperparameters: + adam_eps: 1e-08 + adam_wd: 0.02 + beta1: 0.9 + beta2: 0.95 + compressor: brotli + data_dir: /home/ubuntu/parameter-golf/data + datasets_dir: /home/ubuntu/parameter-golf/data/datasets/fineweb10B_sp8192 + distributed: True + ema_decay: 0.9965 + embed_bits: 7 + embed_clip_sigmas: 15.0 + embed_lr: 0.6 + embed_wd: 0.085 + embedding_dim: 512 + enable_looping_at: 0.35 + eval_abort_seconds: 540.0 + eval_budget_seconds: 600.0 + eval_seq_len: 2048 + eval_stride: 64 + export_wait_timeout_seconds: 1200.0 + gptq_calibration_batches: 96 + gptq_reserve_seconds: 18.0 + grad_accum_steps: 1 + grad_clip_norm: 0.3 + head_lr: 0.008 + is_main_process: True + iterations: 20000 + ln_scale: True + local_rank: 0 + logfile: logs/px01_b00_control_s42_px43_embed7_clip1300_seed2024.txt + logit_softcap: 30.0 + loop_end: 5 + loop_start: 3 + matrix_bits: 6 + matrix_clip_sigmas: 13.0 + matrix_lr: 0.022 + max_wallclock_seconds: 618.0 + min_lr: 0.0 + mlp_mult: 4.0 + model_dim: 512 + model_path: final_model.pt + model_size_limit_bytes: 16000000 + muon_backend_steps: 5 + muon_beta2: 0.95 + muon_momentum: 0.99 + muon_momentum_warmup_start: 0.92 + muon_momentum_warmup_steps: 1500 + muon_row_normalize: True + muon_wd: 0.095 + num_heads: 8 + num_kv_heads: 4 +tokenizer_sha256: a24fd9326f81c9456e24484aae2a05b209898738a0082f37b085ef2fe873cec7 num_layers: 11 + + num_loops: 2 + parallel_residual_start: 7 + qk_gain_init: 5.25 + quantized_model_path: final_model.int6.ptz +tokenizer_sha256: a24fd9326f81c9456e24484aae2a05b209898738a0082f37b085ef2fe873cec7 + rank: 0 + rope_base: 10000.0 + rope_dims: 16 +tokenizer_sha256: a24fd9326f81c9456e24484aae2a05b209898738a0082f37b085ef2fe873cec7 rope_train_seq_len: 2048 + + run_id: px01_b00_control_s42_px43_embed7_clip1300_seed2024tokenizer_sha256: a24fd9326f81c9456e24484aae2a05b209898738a0082f37b085ef2fe873cec7 + + scalar_lr: 0.02 +tokenizer_sha256: a24fd9326f81c9456e24484aae2a05b209898738a0082f37b085ef2fe873cec7 + seed: 2024 + selective_prune_enabled: True + skip_gates_enabled: True + sliding_window_enabled: True + tie_embeddings: True + tied_embed_init_std: 0.005 + tied_embed_lr: 0.03 + tokenizer_path: /home/ubuntu/parameter-golf/data/tokenizers/fineweb_8192_bpe.model +tokenizer_sha256: a24fd9326f81c9456e24484aae2a05b209898738a0082f37b085ef2fe873cec7 train_batch_tokens: 786432 + + train_files: /home/ubuntu/parameter-golf/data/datasets/fineweb10B_sp8192/fineweb_train_*.bin + train_log_every: 500 + train_seq_len: 2048 + val_batch_tokens: 524288 + val_files: /home/ubuntu/parameter-golf/data/datasets/fineweb10B_sp8192/fineweb_val_*.bin + val_loss_every: 4000 + vocab_size: 8192 + warmdown_frac: 0.72 + warmup_steps: 20 + world_size: 8 + xsa_last_n: 11 +tokenizer_sha256: a24fd9326f81c9456e24484aae2a05b209898738a0082f37b085ef2fe873cec7 +tokenizer_sha256: a24fd9326f81c9456e24484aae2a05b209898738a0082f37b085ef2fe873cec7 +train_shards: 122 +val_tokens: 38834176 +model_params:35944536 +gptq:reserving 18s, effective=600000ms +warmup_step: 1/20 +warmup_step: 2/20 +warmup_step: 3/20 +warmup_step: 4/20 +warmup_step: 5/20 +warmup_step: 6/20 +warmup_step: 10/20 +warmup_step: 20/20 +loop_warmup:enabled encoder:[0, 1, 2, 3, 4, 5, 3, 4] decoder:[5, 3, 4, 5, 6, 7, 8, 9, 10] +loop_warmup_step: 1/20 +loop_warmup_step: 2/20 +loop_warmup_step: 3/20 +loop_warmup_step: 4/20 +loop_warmup_step: 5/20 +loop_warmup_step: 6/20 +loop_warmup_step: 10/20 +loop_warmup_step: 20/20 +0/20000 val_loss: 9.0150 val_bpb: 3.3431 +1/20000 train_loss: 9.0146 train_time: 0.0m tok/s: 8099741 +2/20000 train_loss: 12.5214 train_time: 0.0m tok/s: 8051975 +3/20000 train_loss: 11.6525 train_time: 0.0m tok/s: 8003701 +4/20000 train_loss: 10.0600 train_time: 0.0m tok/s: 7981745 +5/20000 train_loss: 8.6342 train_time: 0.0m tok/s: 7964572 +500/20000 train_loss: 3.4559 train_time: 0.8m tok/s: 7786780 +1000/20000 train_loss: 3.3819 train_time: 1.7m tok/s: 7783305 +1500/20000 train_loss: 3.3013 train_time: 2.5m tok/s: 7782871 +2000/20000 train_loss: 3.2623 train_time: 3.4m tok/s: 7781021 +layer_loop:enabled step:2078 frac:0.350 encoder:[0, 1, 2, 3, 4, 5, 3, 4] decoder:[5, 3, 4, 5, 6, 7, 8, 9, 10] +2500/20000 train_loss: 3.1528 train_time: 4.5m tok/s: 7205703 +3000/20000 train_loss: 3.1753 train_time: 5.8m tok/s: 6794973 +3500/20000 train_loss: 3.1313 train_time: 7.0m tok/s: 6529743 +4000/20000 train_loss: 2.9827 train_time: 8.3m tok/s: 6344180 +4000/20000 val_loss: 3.0228 val_bpb: 1.1209 +4500/20000 train_loss: 3.0358 train_time: 9.5m tok/s: 6207506 +4702/20000 val_loss: 2.9433 val_bpb: 1.0915 +stopping_early: wallclock_cap train_time: 600093ms step: 4702/20000 +peak memory allocated: 39034 MiB reserved: 39060 MiB +ema:applying EMA weights +pre-quantization post-ema val_loss:2.94063992 val_bpb:1.09049411 eval_time:8132ms +Serialized model: 135431033 bytes +Code size: 19689 bytes +GPTQ:collecting Hessians from calibration data... +GPTQ:collected 67 Hessians in 22.0s +export:cpu_threads=32 (was 16) +GPTQ:quantizing weights... +Quantized weights: + gptq (int6): blocks.attn.c_k.weight, blocks.attn.c_q.weight, blocks.attn.c_v.weight, blocks.attn.proj.weight, blocks.mlp.fc.weight, blocks.mlp.proj.weight + gptq (int7): tok_emb.weight + passthrough (float16): blocks.attn.q_gain, blocks.attn_scale, blocks.mlp_scale, blocks.resid_mix, skip_gates, skip_weights +GPTQ:quantized in 6.1s +export:packing compressed blob with brotli... +export_wait rank=6 elapsed=30s remaining=1170s +export_wait rank=2 elapsed=30s remaining=1170s +export_wait rank=1 elapsed=30s remaining=1170s +export_wait rank=3 elapsed=30s remaining=1170s +export_wait rank=7 elapsed=30s remaining=1170s +export_wait rank=4 elapsed=30s remaining=1170s +export_wait rank=5 elapsed=30s remaining=1170s +export_wait rank=6 elapsed=60s remaining=1140s +export_wait rank=2 elapsed=60s remaining=1140s +export_wait rank=1 elapsed=60s remaining=1140s +export_wait rank=3 elapsed=60s remaining=1140s +export_wait rank=7 elapsed=60s remaining=1140s +export_wait rank=4 elapsed=60s remaining=1140s +export_wait rank=5 elapsed=60s remaining=1140s +export:packed bytes=15611173 seconds=77.2 +size_check total_submission=15630862 limit=16000000 +Serialized model quantized+brotli: 15611173 bytes +Total submission size quantized+brotli: 15630862 bytes +quantized_sliding_window val_loss:2.93286243 val_bpb:1.08760994 eval_time:89318ms +RUN_COMPLETE_DO_NOT_KILL submission_ready=1 diff --git a/records/track_10min_16mb/parcae-px43-embed7-clip1300/train_seed42.log b/records/track_10min_16mb/parcae-px43-embed7-clip1300/train_seed42.log new file mode 100644 index 0000000000..67f0ba5b03 --- /dev/null +++ b/records/track_10min_16mb/parcae-px43-embed7-clip1300/train_seed42.log @@ -0,0 +1,166 @@ +Hyperparameters: + adam_eps: 1e-08 + adam_wd: 0.02 + beta1: 0.9 + beta2: 0.95 + compressor: brotli + data_dir: /home/ubuntu/parameter-golf/data + datasets_dir: /home/ubuntu/parameter-golf/data/datasets/fineweb10B_sp8192 + distributed: True + ema_decay: 0.9965 + embed_bits: 7 + embed_clip_sigmas: 15.0 + embed_lr: 0.6 + embed_wd: 0.085 + embedding_dim: 512 +tokenizer_sha256: a24fd9326f81c9456e24484aae2a05b209898738a0082f37b085ef2fe873cec7tokenizer_sha256: a24fd9326f81c9456e24484aae2a05b209898738a0082f37b085ef2fe873cec7tokenizer_sha256: a24fd9326f81c9456e24484aae2a05b209898738a0082f37b085ef2fe873cec7 +tokenizer_sha256: a24fd9326f81c9456e24484aae2a05b209898738a0082f37b085ef2fe873cec7 enable_looping_at: 0.35 +tokenizer_sha256: a24fd9326f81c9456e24484aae2a05b209898738a0082f37b085ef2fe873cec7 + + + + eval_abort_seconds: 540.0 + eval_budget_seconds: 600.0 + eval_seq_len: 2048 + eval_stride: 64 + export_wait_timeout_seconds: 1200.0 + gptq_calibration_batches: 96 + gptq_reserve_seconds: 18.0 + grad_accum_steps: 1 + grad_clip_norm: 0.3 + head_lr: 0.008 + is_main_process: True + iterations: 20000 + ln_scale: True + local_rank: 0 + logfile: logs/px01_b00_control_s42_px43_embed7_clip1300_seed42.txt + logit_softcap: 30.0 + loop_end: 5 + loop_start: 3 + matrix_bits: 6 + matrix_clip_sigmas: 13.0 + matrix_lr: 0.022 + max_wallclock_seconds: 618.0 + min_lr: 0.0 + mlp_mult: 4.0 + model_dim: 512 + model_path: final_model.pt + model_size_limit_bytes: 16000000 + muon_backend_steps: 5 + muon_beta2: 0.95 + muon_momentum: 0.99 + muon_momentum_warmup_start: 0.92 + muon_momentum_warmup_steps: 1500 + muon_row_normalize: True + muon_wd: 0.095 + num_heads: 8 + num_kv_heads: 4 + num_layers: 11 + num_loops: 2 + parallel_residual_start: 7 + qk_gain_init: 5.25 + quantized_model_path: final_model.int6.ptz + rank: 0 + rope_base: 10000.0 + rope_dims: 16 + rope_train_seq_len: 2048 + run_id: px01_b00_control_s42_px43_embed7_clip1300_seed42 + scalar_lr: 0.02 + seed: 42 + selective_prune_enabled: True +tokenizer_sha256: a24fd9326f81c9456e24484aae2a05b209898738a0082f37b085ef2fe873cec7 skip_gates_enabled: True + + sliding_window_enabled: True + tie_embeddings: True + tied_embed_init_std: 0.005 + tied_embed_lr: 0.03 + tokenizer_path: /home/ubuntu/parameter-golf/data/tokenizers/fineweb_8192_bpe.model + train_batch_tokens: 786432 + train_files: /home/ubuntu/parameter-golf/data/datasets/fineweb10B_sp8192/fineweb_train_*.bin + train_log_every: 500 + train_seq_len: 2048 + val_batch_tokens: 524288 + val_files: /home/ubuntu/parameter-golf/data/datasets/fineweb10B_sp8192/fineweb_val_*.bin + val_loss_every: 4000 + vocab_size: 8192 + warmdown_frac: 0.72 + warmup_steps: 20 + world_size: 8 + xsa_last_n: 11 +tokenizer_sha256: a24fd9326f81c9456e24484aae2a05b209898738a0082f37b085ef2fe873cec7 +tokenizer_sha256: a24fd9326f81c9456e24484aae2a05b209898738a0082f37b085ef2fe873cec7 +train_shards: 122 +val_tokens: 38834176 +model_params:35944536 +gptq:reserving 18s, effective=600000ms +warmup_step: 1/20 +warmup_step: 2/20 +warmup_step: 3/20 +warmup_step: 4/20 +warmup_step: 5/20 +warmup_step: 6/20 +warmup_step: 10/20 +warmup_step: 20/20 +loop_warmup:enabled encoder:[0, 1, 2, 3, 4, 5, 3, 4] decoder:[5, 3, 4, 5, 6, 7, 8, 9, 10] +loop_warmup_step: 1/20 +loop_warmup_step: 2/20 +loop_warmup_step: 3/20 +loop_warmup_step: 4/20 +loop_warmup_step: 5/20 +loop_warmup_step: 6/20 +loop_warmup_step: 10/20 +loop_warmup_step: 20/20 +0/20000 val_loss: 9.0154 val_bpb: 3.3432 +1/20000 train_loss: 9.0157 train_time: 0.0m tok/s: 8283486 +2/20000 train_loss: 12.4731 train_time: 0.0m tok/s: 8133683 +3/20000 train_loss: 11.6268 train_time: 0.0m tok/s: 8039476 +4/20000 train_loss: 10.0496 train_time: 0.0m tok/s: 8002687 +5/20000 train_loss: 8.6437 train_time: 0.0m tok/s: 7974359 +500/20000 train_loss: 3.4550 train_time: 0.8m tok/s: 7797722 +1000/20000 train_loss: 3.3856 train_time: 1.7m tok/s: 7798017 +1500/20000 train_loss: 3.3026 train_time: 2.5m tok/s: 7796702 +2000/20000 train_loss: 3.2629 train_time: 3.4m tok/s: 7793195 +layer_loop:enabled step:2081 frac:0.350 encoder:[0, 1, 2, 3, 4, 5, 3, 4] decoder:[5, 3, 4, 5, 6, 7, 8, 9, 10] +2500/20000 train_loss: 3.1546 train_time: 4.5m tok/s: 7217293 +3000/20000 train_loss: 3.1773 train_time: 5.8m tok/s: 6801780 +3500/20000 train_loss: 3.1366 train_time: 7.0m tok/s: 6533469 +4000/20000 train_loss: 2.9834 train_time: 8.3m tok/s: 6346623 +4000/20000 val_loss: 3.0236 val_bpb: 1.1213 +4500/20000 train_loss: 3.0376 train_time: 9.5m tok/s: 6208737 +4702/20000 val_loss: 2.9438 val_bpb: 1.0917 +stopping_early: wallclock_cap train_time: 600024ms step: 4702/20000 +peak memory allocated: 39034 MiB reserved: 39060 MiB +ema:applying EMA weights +pre-quantization post-ema val_loss:2.94123826 val_bpb:1.09071600 eval_time:8120ms +Serialized model: 135431033 bytes +Code size: 19689 bytes +GPTQ:collecting Hessians from calibration data... +GPTQ:collected 67 Hessians in 22.0s +export:cpu_threads=32 (was 16) +GPTQ:quantizing weights... +Quantized weights: + gptq (int6): blocks.attn.c_k.weight, blocks.attn.c_q.weight, blocks.attn.c_v.weight, blocks.attn.proj.weight, blocks.mlp.fc.weight, blocks.mlp.proj.weight + gptq (int7): tok_emb.weight + passthrough (float16): blocks.attn.q_gain, blocks.attn_scale, blocks.mlp_scale, blocks.resid_mix, skip_gates, skip_weights +GPTQ:quantized in 6.5s +export:packing compressed blob with brotli... +export_wait rank=4 elapsed=30s remaining=1170s +export_wait rank=5 elapsed=30s remaining=1170sexport_wait rank=1 elapsed=30s remaining=1170sexport_wait rank=6 elapsed=30s remaining=1170sexport_wait rank=3 elapsed=30s remaining=1170s + + + +export_wait rank=7 elapsed=30s remaining=1170s +export_wait rank=2 elapsed=30s remaining=1170s +export_wait rank=4 elapsed=60s remaining=1140s +export_wait rank=6 elapsed=60s remaining=1140s +export_wait rank=3 elapsed=60s remaining=1140s +export_wait rank=5 elapsed=60s remaining=1140s +export_wait rank=1 elapsed=60s remaining=1140s +export_wait rank=7 elapsed=60s remaining=1140s +export_wait rank=2 elapsed=60s remaining=1140s +export:packed bytes=15614135 seconds=77.6 +size_check total_submission=15633824 limit=16000000 +Serialized model quantized+brotli: 15614135 bytes +Total submission size quantized+brotli: 15633824 bytes +quantized_sliding_window val_loss:2.93399367 val_bpb:1.08802944 eval_time:89275ms +RUN_COMPLETE_DO_NOT_KILL submission_ready=1