diff --git a/SUMMARY.md b/SUMMARY.md new file mode 100644 index 0000000000..e4ae57eddc --- /dev/null +++ b/SUMMARY.md @@ -0,0 +1,170 @@ +# Depth Recurrence in Parameter Golf — Research Summary + +Ivan Verbovoy (@iverbovoy) · 20.03.2026 → 20.04.2026 + +## TL;DR + +Single-person submission exploring **depth recurrence** (3 shared transformer blocks × 4 repeats = 12 effective layers) as an alternative to the flat 10-11 layer architectures used by the leaderboard. Best result: **val_bpb 1.1324 (3-seed mean)** on the 10-min track (PR [#1453](https://github.com/openai/parameter-golf/pull/1453)). Additional **4-hour non-record 1.0889** (PR [#895](https://github.com/openai/parameter-golf/pull/895)). OpenAI-acknowledged the approach as novel and published a dedicated non-record PR [#363](https://github.com/openai/parameter-golf/pull/363) inspired by similar exploration. + +## Architecture + +``` +tok_emb (+ optional BigramHash) + value_embeds × 2 + │ + for repeat in {0..3}: + for block in {A, B, C}: # 3 shared blocks + x += loop_embed[layer_idx] # per effective layer + x += Σ value_scales[l,e] * ve_e # per effective layer + x += cross_repeat_scale * block_out_prev_repeat # stateful recurrence + x = block(x, x0, use_xsa=(layer_idx ≥ xsa_start)) + final_norm + tied LM head + softcap +``` + +Key weight-sharing components: +- **loop_embed** `(effective_depth, model_dim)` — positional signal per effective layer +- **cross_repeat_scales** `(num_blocks, num_repeats-1, dim)` — stateful residual from prev repeat +- **resid_mix** — learned per-dim mix between current and block-0 residual +- **XSA** — last 4 effective layers subtract self-value projection +- **Hedge Mixer** — eval-time online mixture of Neural + Unigram + Bigram + Trigram(hash 65K) + Entropy experts + +## Progression + +| Date | PR | Track | Key idea | val_bpb | +|:----:|:--:|:-----:|:---------|--------:| +| 20.03 | [#148](https://github.com/openai/parameter-golf/pull/148) | 10min | Depth Recurrence + Cross-Repeat Skip | 1.2196 | +| 25.03 | [#784](https://github.com/openai/parameter-golf/pull/784) | 10min | + XSA(4) + LeakyReLU²(0.5) | 1.2065 | +| 26.03 | [#835](https://github.com/openai/parameter-golf/pull/835) | 10min | + Progressive Depth (2→3→4 repeats) | 1.1980 | +| 26.03 | [#856](https://github.com/openai/parameter-golf/pull/856) | 10min | + Hedge Mixer | 1.1454 | +| 26.03 | **[#895](https://github.com/openai/parameter-golf/pull/895)** | 4h | 4-hour Progressive Depth | **1.0889** | +| 05.04 | [#1384](https://github.com/openai/parameter-golf/pull/1384) | 10min | + tuned schedule + WD + SWA (3-seed) | 1.1441 | +| 07.04 | **[#1453](https://github.com/openai/parameter-golf/pull/1453)** | 10min | + **Int7 attn + Int5 MLP mixed quant** (3-seed) | **1.1324** | + +## Experiments catalog + +### What worked (baseline 1.1324) + +| Technique | Effect | Notes | +|-----------|-------:|:------| +| Depth Recurrence 3×4 | — | Core architecture, enables 23.7M params in 16MB | +| Cross-Repeat Skip | −0.03 | Prev-repeat residual makes recurrence stateful | +| Value embeds (2 tables) | −0.07 | Critical. Adds per-layer token lookup | +| XSA last 4 | −0.01 | Self-value bias removal at top layers | +| Progressive Depth (0.30:2, 0.50:3, 1.0:4) | −0.005 | Ramp repeats during training | +| SWA (start 0.6, every 30) | −0.01 | ~44 checkpoints averaged | +| Hedge Mixer (5 experts) | −0.05 | Eval-time mixture, but stochastic (std 0.013) | +| **Int7 attn + Int5 MLP mixed quant** | −0.012 | Frees 2MB for d=880 mlp×3 vs d=832 mlp×2 | +| Muon optimizer + WD=0.04 | — | Standard for challenge | + +### What did NOT improve mean 1.1324 + +Tested on 1–3 seeds and verified neither sliding nor hedge-mean improves: + +| Technique | Result | Why | +|-----------|:------:|:----| +| BigramHash 2048×112 | −0.005 ❌ | Too few buckets, hash collisions dominate | +| BigramHash 3072×112 | +0.005 ❌ | Single-seed −0.003 but 3-seed mean worse: stabilizes hedge but cuts peaks (seed 7 went 1.1193→1.1444) | +| BigramHash 4096×112 | +0.004 ❌ | Past sweet spot, sparse buckets degrade | +| Noisy QAT (default) | +0.011 ❌ | Noise on int5 MLP too large (~amax/15), SWA collects pre-QAT checkpoints | +| LoRA rank-2 per-repeat (attn.proj, mlp.proj) | +0.013 ❌ | Per-repeat signal already saturated by loop_embed + cross_repeat_scales | +| XSA-all (12 layers) | worse | Optimum is last 4, early XSA hurts | +| Inter-repeat RMSNorm | worse | Breaks scaling balance | +| EMA (τ=0.997) | +22ms/step | CPU overhead > benefit at our scale | +| Partial RoPE + VRL + LN Scale (combined) | worse | Too many interacting changes | +| MuonEq-R optimizer | diverged | Incompatible with our Muon setup | +| Auxiliary losses (edge-of-chaos regularization) | neutral | χ stabilized but bpb unchanged at 5 repeats | +| 3×6 d=960 | worse | Fewer steps dominates | +| 6×2 d=640/736 | worse | Too narrow | +| 4L × 3rep | worse | Fewer unique blocks in limited compute | +| TTT (LoRA-based) | −0.002 | Positive but 410s eval; dropped for budget | +| SD-clip k=3.5, k=10 | worse | Percentile-search already near optimum for int8 | + +### GPTQ with Hessian error compensation (3-seed validated) + +Implemented column-wise GPTQ with training-data calibration (no access to val). Collects `X^T X` per `nn.Linear` over 5 training batches, then column-by-column quantization with Cholesky(H_inv) error compensation. ~100 lines added to 1496-line submission. + +| Seed | roundtrip Δ | sliding Δ | hedge Δ | +|------|------------:|----------:|--------:| +| 1337 | −0.0034 | −0.0033 | +0.008 | +| 42 | −0.0007 | −0.0008 | −0.0006 | +| 7 | −0.0013 | −0.0013 | +0.023 | +| **3-seed mean** | **−0.0018** | **−0.0018** | **+0.010** | + +**Deterministic improvement** on sliding/roundtrip (both −0.002). **Hedge mean worse by +0.010** — submission #1453's seed 7 hedge was unusually low (1.1193) and we couldn't reproduce that luck in our session. + +Implication: GPTQ makes the model genuinely better (sliding/roundtrip = deterministic metric of model quality), but `val_bpb` is scored on hedge which has ±0.013 seed variance + ±0.008 session variance. The model-level gain gets dominated by hedge stochasticity. + +Not submitting GPTQ as replacement — #1453 remains the best hedge-mean result. GPTQ-enhanced code kept as reference. + +## Key insights + +### 1. Depth recurrence is viable but not SOTA for this challenge + +Our 1.1324 (3-seed) vs SOTA 1.1147 (abaybektursun's flat 11×512 + AR Self-Gen GPTQ + BigramHash 3072×112). Gap ~0.018. Evangelinehelsinki's separate exploration found flat 11L beats 3×3 recurrence by ~0.025 at same trick stack. **Recurrence trades unique parameters for effective depth**, which helps fit 23.7M params in 16MB but underperforms flat architecture per-layer. + +### 2. Hedge Mixer dominates and destabilizes + +Hedge gives ~−0.05 bpb lift over sliding but has huge variance: +- **±0.013 bpb between seeds** (same config) +- **±0.008 bpb between sessions** at identical model weights (sanity-run confirmed roundtrip/sliding match to 0.0002, hedge diverged 0.008) + +Most architectural gains get absorbed by hedge noise. Deterministic metrics (sliding, roundtrip) are the reliable signal. + +### 3. Weight-sharing saturates quickly + +On 3×4 recurrence: +- loop_embed + cross_repeat_scales + value_scales already provide per-repeat variance +- LoRA per-repeat on top **hurt** (+0.006 sliding) — the model was already using available capacity +- Inter-repeat RMSNorm also hurt + +Additional per-repeat degrees of freedom have diminishing/negative returns. + +### 4. Progressive Depth schedule matters + +Shifting schedule from (0.40:2, 0.65:3, 1.0:4) to **(0.30:2, 0.50:3, 1.0:4)** gave −0.004 bpb — 55% more full-depth training steps. Combined with longer warmdown (3000 vs 2000) and denser SWA (every 30 vs 50) at higher start frac (0.6 vs 0.4) for ~44 averaged checkpoints. + +### 5. Mixed quantization > uniform + +Separating attn (int7, 63 levels) from MLP (int5, 16 levels): +- Attention quality drop dominates total loss at low precision → keep attn higher +- MLP tolerates aggressive quantization → allows 2MB saving +- 2MB saved → model width up from d=832 mlp×2 → d=880 mlp×3 + +Gain: −0.012 bpb. + +### 6. Calibration data makes GPTQ work + +Original percentile-search GPTQ ("GPTQ-lite" in our code) only optimizes per-row clip point via MSE. Full GPTQ with column-wise Hessian error compensation gave deterministic −0.002..−0.003 on sliding. Training-data calibration worked; AR self-gen calibration would likely stabilize further. + +## Files + +- Main submission: `records/track_non_record_16mb/2026-04-08_DepthRecurrence_Int7MixedQuant_HedgeMixer/` (PR #1453 backing) +- 4-hour submission: PR #895 +- Experimental code variants in repo root: `train_gpt_refactored.py`, `train_gpt_exp1.py`, etc. + +## Reproduction + +Config used for PR #1453 (submitted): +``` +MODEL_DIM=880 NUM_HEADS=8 NUM_KV_HEADS=4 MLP_MULT=3 +NUM_LAYERS=3 NUM_REPEATS=4 +QUANT_LEVELS=63 MLP_QUANT_LEVELS=16 +PROG_DEPTH="0.30:2,0.50:3,1.0:4" +WARMDOWN_ITERS=3000 +SWA_START_FRAC=0.6 SWA_EVERY=30 +MATRIX_LR=0.018 MUON_WD=0.04 +XSA_LAST_N=4 QK_GAIN_INIT=1.5 +USE_HEDGE=1 HEDGE_ETA=0.1 +MAX_WALLCLOCK_SECONDS=600 +``` + +3 seeds tested (1337, 42, 7) on 8× H100 SXM 80GB, PyTorch 2.5.1. + +## Resource footprint + +- RunPod compute grant: ~$950 of $1000 used +- ~25 full training runs + calibration experiments +- 1 person, 32 days + +## Acknowledgments + +Thanks to OpenAI for running this challenge and sponsoring the compute grant. Thanks to **abaybektursun**, **thwu1**, **Raahil Shah**, **Evangelinehelsinki** for publishing detailed submissions that informed several of my experiments (particularly GPTQ calibration, BigramHash sizing, and the noisy-QAT analysis for recurrent architectures). diff --git a/records/track_non_record_16mb/2026-04-08_DepthRecurrence_Int7MixedQuant_HedgeMixer/README.md b/records/track_non_record_16mb/2026-04-08_DepthRecurrence_Int7MixedQuant_HedgeMixer/README.md new file mode 100644 index 0000000000..79cd726f80 --- /dev/null +++ b/records/track_non_record_16mb/2026-04-08_DepthRecurrence_Int7MixedQuant_HedgeMixer/README.md @@ -0,0 +1,135 @@ +# Non-record: Depth Recurrence + Int7 Mixed Quant + Parallel Hedge Mixer + +**val_bpb: 1.1324** (3-seed mean, std 0.0131) | **~15.40 MB** | 8×H100 SXM, 600s + +Improves on [PR #1384](https://github.com/openai/parameter-golf/pull/1384) (1.1441 bpb) by **−0.012 bpb** through mixed int7/int5 quantization enabling a wider MLP 3× model, and parallelized hedge mixer eval. + +## Results (8×H100 80GB SXM, PyTorch 2.5.1) + +| Seed | Steps | ms/step | Roundtrip | Sliding | **Hedge** | Artifact | Eval time | +|------|-------|---------|-----------|---------|-----------|----------|-----------| +| 1337 | 4,247 | 141.3ms | 1.2168 | 1.1832 | **1.1324** | 15.40 MB | 167s | +| 42 | 4,389 | 136.7ms | 1.2172 | 1.1840 | **1.1454** | 15.28 MB | 164s | +| 7 | 4,391 | 136.7ms | 1.2163 | 1.1828 | **1.1193** | 15.29 MB | 163s | +| **Mean** | **4,342** | **138.2ms** | **1.2168** | **1.1834** | **1.1324** | | **~164s** | + +Additional seeds for variance analysis: seed 2024 → 1.1431, seed 99 → 1.1405. 5-seed mean: **1.1361** (std 0.0095). + +## Changes vs PR #1384 (1.1441 bpb) + +| Change | Effect | Impact | +|--------|--------|--------| +| MLP 2× → 3× (d=832→880) | +38% parameters, wider model | −0.013 sliding bpb | +| Int8 → **Int7 attn** + Int5 MLP | Fits larger model in 16MB budget | enables above | +| Earlier progressive depth (30/50 vs 40/65) | +55% full-depth training steps | −0.004 bpb | +| More SWA (every 30, start 0.6) | 43 checkpoints vs 13 | smoother average | +| Parallel hedge eval (8 GPU) | 580s → 164s eval time | fits 10 min budget | + +## Key Finding: Int7 Attention is the Sweet Spot + +Standard approaches use uniform quantization (all int8 or all int6). Experiments show that **attention and MLP weights have very different sensitivity to quantization**: + +- **Attention weights** directly affect the neural expert in hedge mixer. Int6 (31 levels) causes hedge boost to drop from −0.052 to −0.039 — a significant quality loss. +- **MLP weights** tolerate aggressive quantization. Int5 (16 levels) compresses well with minimal quality impact. +- **Int7 (63 levels)** for attention recovers hedge boost to −0.051, nearly matching int8's −0.052. + +The 2MB saved by using int5 MLP instead of int8 is reinvested into a wider model (d=880 with MLP 3× vs d=832 with MLP 2×). + +| Quant config | Model | Sliding | Hedge | Hedge boost | Size | Fits? | +|-------------|-------|---------|-------|-------------|------|-------| +| Int8 attn + Int5 MLP | d=896 | 1.1760 | 1.1349 | −0.041 | 17.4 MB | ✗ | +| **Int7 attn + Int5 MLP** | **d=880** | **1.1832** | **1.1324** | **−0.051** | **15.4 MB** | **✓** | +| Int6 attn + Int5 MLP | d=896 | 1.1870 | 1.1480 | −0.039 | 15.4 MB | ✓ | + +## Architecture: Depth Recurrence + +Instead of 9–11 unique transformer blocks, **3 shared blocks are repeated 4 times** (12 effective layers). This trades unique parameters for effective depth, fitting 23.7M parameters into ~15.4 MB. + +| Parameter | Value | +|-----------|-------| +| Layers × Repeats | 3 × 4 (12 effective) | +| Model dim | 880 | +| Heads / KV heads | 8 / 4 (head_dim=110) | +| MLP multiplier | 3× (hidden=2640) | +| Vocab size | 1024 (SP BPE) | +| Parameters | 23.7M | +| Logit softcap | 30.0 | + +### Recurrence components + +- **Cross-Repeat Skip**: Each block receives a weighted residual from its own output in the previous repeat — turns stateless recurrence into stateful +- **Loop Embedding**: Learned per-layer vector added before each block — depth-wise positional encoding for shared weights +- **Value Embeddings**: 2 extra embedding tables mixed into the residual stream at each effective layer with learned scales +- **XSA (Exclusive Self-Attention)**: On last 4 effective layers — prevents attention collapse in deep recurrent models +- **LeakyReLU(0.5)²**: Better gradient flow than ReLU² for deep/recurrent models + +## Progressive Depth Training + +Training uses increasing recurrence depth, recompiling at phase boundaries: + +| Phase | Wallclock | Repeats | Effective layers | Step speed | +|-------|-----------|---------|-----------------|------------| +| 0–30% | 0–180s | 2 | 6 | ~90ms | +| 30–50% | 180–300s | 3 | 9 | ~105ms | +| 50–100% | 300–600s | 4 | 12 | ~130ms | + +Schedule tuned for the MLP 3× config: earlier transitions (30/50% vs 40/65% in PR #1384) give +55% more steps at full depth. Warmdown 3000 iterations, SWA every 30 steps from LR scale < 0.6 (~43 checkpoints). + +## Eval: Parallel Hedge Mixer + +5-expert online ensemble with **8-GPU parallelized forward pass**: + +| Expert | Description | +|--------|-------------| +| Neural | Model's own logits (log-softmax) | +| Unigram | Global token frequency with Laplace smoothing | +| Bigram | Conditional P(token \| prev_token) | +| Trigram | Hashed trigram context (65K buckets) | +| Entropy | Model's entropy as calibration signal | + +**Parallelization**: Each batch of windows is split across 8 GPUs for the forward pass, logits gathered via `all_gather` to rank 0 for sequential mixer scoring. This reduces hedge eval from 580s (single GPU) to **164s**, fitting within the 10-minute eval budget. + +Hedge provides **−0.051 bpb improvement** over sliding window (1.1834 → 1.1324 mean). + +## Training Details + +| Parameter | Value | +|-----------|-------| +| Optimizer | Muon (matrices) + Adam (scalars, embeddings) | +| Matrix / Scalar LR | 0.018 / 0.018 | +| Tied embed LR | 0.021 | +| Muon WD | 0.04 | +| Muon momentum | 0.95 (warmup 0.85→0.95 over 500 steps) | +| Grad clip | 0.3 | +| Batch tokens | 524,288 | +| Quantization | Int7 attn (63 levels) + Int5 MLP (16 levels) + zstd-22 | + +## Evolution + +| PR | Date | Score | What changed | +|----|------|-------|-------------| +| [#148](https://github.com/openai/parameter-golf/pull/148) | Mar 20 | 1.2196 (sliding) | Depth recurrence (3×4), cross-repeat skip, value embeddings | +| [#784](https://github.com/openai/parameter-golf/pull/784) | Mar 25 | 1.2065 (sliding) | + XSA(4), LeakyReLU², GPTQ-lite, zstd-22 | +| [#835](https://github.com/openai/parameter-golf/pull/835) | Mar 26 | 1.1980 (sliding) | + Progressive depth training (+30% steps) | +| [#1384](https://github.com/openai/parameter-golf/pull/1384) | Apr 5 | 1.1441 (hedge) | + Hedge Mixer (5-expert eval-time ensemble) | +| **This PR** | Apr 8 | **1.1324** (hedge) | + Int7 mixed quant, MLP 3×, d=880, parallel hedge | + +## Lineage + +- Depth recurrence architecture — original to this submission line +- XSA from [PR #198](https://github.com/openai/parameter-golf/pull/198) (unnir) +- LeakyReLU² from [PR #493](https://github.com/openai/parameter-golf/pull/493) (parinzee) +- Mixed int5/int6 quantization concept from [PR #549](https://github.com/openai/parameter-golf/pull/549) (thwu1), extended here to int7 +- SWA, Muon WD from modded-nanogpt community + +## Reproducing + +```bash +SEED=1337 QUANT_LEVELS=63 MLP_QUANT_LEVELS=16 \ +MODEL_DIM=880 NUM_HEADS=8 NUM_KV_HEADS=4 MLP_MULT=3 \ +NUM_LAYERS=3 NUM_REPEATS=4 XSA_LAST_N=4 NUM_VALUE_EMBEDS=2 \ +PROG_DEPTH="0.30:2,0.50:3,1.0:4" \ +WARMDOWN_ITERS=3000 SWA_START_FRAC=0.6 SWA_EVERY=30 \ +VOCAB_SIZE=1024 TRAIN_SEQ_LEN=1024 TRAIN_BATCH_TOKENS=524288 \ +torchrun --nproc_per_node=8 train_gpt.py +``` diff --git a/records/track_non_record_16mb/2026-04-08_DepthRecurrence_Int7MixedQuant_HedgeMixer/submission.json b/records/track_non_record_16mb/2026-04-08_DepthRecurrence_Int7MixedQuant_HedgeMixer/submission.json new file mode 100644 index 0000000000..42c36b509d --- /dev/null +++ b/records/track_non_record_16mb/2026-04-08_DepthRecurrence_Int7MixedQuant_HedgeMixer/submission.json @@ -0,0 +1,16 @@ +{ + "author": "Ivan Verbovoy", + "github_id": "iverbovoy", + "name": "Depth Recurrence + Int7 Mixed Quantization + Parallel Hedge Mixer", + "blurb": "3 shared blocks x 4 repeats (12 effective layers) with MLP 3x (d=880), progressive depth (2->3->4 repeats), int7 attention (63 levels) + int5 MLP (16 levels) mixed quantization, 8-GPU parallel Hedge Mixer eval. Key finding: int7 is the sweet spot for attention quantization — recovers 98% of int8 hedge quality while saving 2MB for a wider model. 5 seeds tested, 3-seed mean reported.", + "date": "2026-04-08T00:00:00Z", + "val_loss": 1.91197327, + "val_bpb": 1.13237601, + "roundtrip_val_bpb": 1.21676461, + "sliding_val_bpb": 1.18335612, + "seeds": [1337, 42, 7], + "mean_steps": 4342, + "wallclock_seconds": 600, + "eval_seconds": 164, + "bytes_model_int8_zstd22": 15403955 +} diff --git a/records/track_non_record_16mb/2026-04-08_DepthRecurrence_Int7MixedQuant_HedgeMixer/train_gpt.py b/records/track_non_record_16mb/2026-04-08_DepthRecurrence_Int7MixedQuant_HedgeMixer/train_gpt.py new file mode 100644 index 0000000000..e9947b2728 --- /dev/null +++ b/records/track_non_record_16mb/2026-04-08_DepthRecurrence_Int7MixedQuant_HedgeMixer/train_gpt.py @@ -0,0 +1,1465 @@ +"""Progressive Depth + Hedge Mixer submission. 1500 line limit.""" +from __future__ import annotations +import copy, glob, io, math, os, random, subprocess, sys, time, uuid +import zstandard as zstd +from pathlib import Path +import numpy as np +import sentencepiece as spm +import torch, torch.distributed as dist, torch.nn.functional as F +from torch import Tensor, nn +from torch.nn.parallel import DistributedDataParallel as DDP + +class HedgeMixer: + """Online mixture of 5 experts via Hedge algorithm for eval-time improvement. + Experts: Neural, Unigram, Bigram, Trigram (hashed), Entropy.""" + def __init__(self, vocab_size: int = 1024, device: str = "cuda", eta: float = 0.1): + self.V = vocab_size + self.device = device + self.eta = eta + self.log_weights = torch.zeros(5, device=device) + self.log_weights[0] = 2.0 # bias toward neural + self.uni_counts = torch.zeros(vocab_size, device=device) + self.bi_counts = torch.zeros(vocab_size, vocab_size, device=device) + self.total_tokens = 0 + self.TRI_HASH = 65536 + self.tri_counts = torch.zeros(self.TRI_HASH, vocab_size, device=device) + self.tri_row_totals = torch.zeros(self.TRI_HASH, device=device) + + def update(self, tokens: Tensor) -> None: + t = tokens.to(self.device).long() + n = t.numel() + if n == 0: + return + self.total_tokens += n + ones = torch.ones(n, device=self.device) + self.uni_counts.scatter_add_(0, t, ones) + if n >= 2: + bi_idx = t[:-1] * self.V + t[1:] + self.bi_counts.reshape(-1).scatter_add_(0, bi_idx, torch.ones(n - 1, device=self.device)) + if n >= 3: + tri_ctx = ((t[:-2] * 36313) ^ (t[1:-1] * 27191)) % self.TRI_HASH + tri_idx = tri_ctx * self.V + t[2:] + ones_tri = torch.ones(n - 2, device=self.device) + self.tri_counts.reshape(-1).scatter_add_(0, tri_idx, ones_tri) + self.tri_row_totals.scatter_add_(0, tri_ctx, ones_tri) + + def mix_and_score(self, neural_logits: Tensor, x_batch: Tensor, y_batch: Tensor, wlens: list[int]) -> Tensor: + bsz, slen, V = neural_logits.shape + uniform_nll = math.log(self.V) + has_data = self.total_tokens > 0 + neural_lp = F.log_softmax(neural_logits, dim=-1) + neural_nll = -neural_lp.gather(2, y_batch.unsqueeze(2)).squeeze(2) + if not has_data or self.total_tokens < 10000: + return neural_nll + uni_probs = (self.uni_counts + 0.1) / (self.total_tokens + 0.1 * self.V) + uni_nll = -uni_probs.log()[y_batch] + bi_total = self.bi_counts.sum(dim=1, keepdim=True) + bi_probs = (self.bi_counts + 0.1) / (bi_total + 0.1 * self.V) + bi_nll = -bi_probs.log()[x_batch.reshape(-1), y_batch.reshape(-1)].reshape(bsz, slen) + if slen >= 2: + prev2 = torch.zeros_like(x_batch) + prev2[:, 1:] = x_batch[:, :-1] + ctx_hash = ((prev2 * 36313) ^ (x_batch * 27191)) % self.TRI_HASH + tri_count = self.tri_counts[ctx_hash.reshape(-1).long(), y_batch.reshape(-1).long()] + tri_total = self.tri_row_totals[ctx_hash.reshape(-1).long()].clamp(min=1) + tri_nll = -(((tri_count + 0.01) / (tri_total + 0.01 * self.V)).log()).reshape(bsz, slen) + else: + tri_nll = torch.full((bsz, slen), uniform_nll, device=self.device) + entropy_nll = -(neural_lp.exp() * neural_lp).sum(-1) + expert_nll = torch.stack([neural_nll, uni_nll, bi_nll, tri_nll, entropy_nll], dim=-1) + log_w = self.log_weights - self.log_weights.logsumexp(0) + mixed_nll = -(-expert_nll + log_w.unsqueeze(0).unsqueeze(0)).logsumexp(dim=-1) + # Update weights + wlens_t = torch.tensor(wlens, device=self.device, dtype=torch.long) + mask = torch.arange(slen, device=self.device).unsqueeze(0) < wlens_t.unsqueeze(1) + masked_nll = expert_nll * mask.unsqueeze(-1).float() + expert_mean_loss = masked_nll.sum(dim=(0, 1)) / mask.sum().clamp(min=1) + self.log_weights -= self.eta * expert_mean_loss + return mixed_nll + +# HYPERPARAMETERS + +class Hyperparameters: + data_path = os.environ.get("DATA_PATH", "./data/datasets/fineweb10B_sp1024") + train_files = os.path.join(data_path, "fineweb_train_*.bin") + val_files = os.path.join(data_path, "fineweb_val_*.bin") + tokenizer_path = os.environ.get("TOKENIZER_PATH", "./data/tokenizers/fineweb_1024_bpe.model") + run_id = os.environ.get("RUN_ID", str(uuid.uuid4())) + seed = int(os.environ.get("SEED", 1337)) + + val_batch_size = int(os.environ.get("VAL_BATCH_SIZE", 524_288)) + val_loss_every = int(os.environ.get("VAL_LOSS_EVERY", 1000)) + train_log_every = int(os.environ.get("TRAIN_LOG_EVERY", 200)) + + iterations = int(os.environ.get("ITERATIONS", 20000)) + warmdown_iters = int(os.environ.get("WARMDOWN_ITERS", 2000)) + warmup_steps = int(os.environ.get("WARMUP_STEPS", 20)) + train_batch_tokens = int(os.environ.get("TRAIN_BATCH_TOKENS", 524_288)) + train_seq_len = int(os.environ.get("TRAIN_SEQ_LEN", 1024)) + max_wallclock_seconds = float(os.environ.get("MAX_WALLCLOCK_SECONDS", 600.0)) + qk_gain_init = float(os.environ.get("QK_GAIN_INIT", 1.5)) + grad_clip_norm = float(os.environ.get("GRAD_CLIP_NORM", 0.3)) + + prog_depth_schedule = os.environ.get("PROG_DEPTH", "0.4:2,0.65:3,1.0:4") + + xsa_last_n = int(os.environ.get("XSA_LAST_N", 4)) + + swa_enabled = bool(int(os.environ.get("SWA_ENABLED", "1"))) + swa_start_frac = float(os.environ.get("SWA_START_FRAC", 0.4)) + swa_every = int(os.environ.get("SWA_EVERY", 50)) + + eval_seq_len = int(os.environ.get("EVAL_SEQ_LEN", 1024)) + eval_stride = int(os.environ.get("EVAL_STRIDE", 256)) + eval_batch_seqs = int(os.environ.get("EVAL_BATCH_SEQS", 32)) + + use_hedge = bool(int(os.environ.get("USE_HEDGE", "1"))) + hedge_eta = float(os.environ.get("HEDGE_ETA", 0.1)) + + vocab_size = int(os.environ.get("VOCAB_SIZE", 1024)) + num_layers = int(os.environ.get("NUM_LAYERS", 3)) + num_repeats = int(os.environ.get("NUM_REPEATS", 4)) + num_kv_heads = int(os.environ.get("NUM_KV_HEADS", 4)) + model_dim = int(os.environ.get("MODEL_DIM", 832)) + num_heads = int(os.environ.get("NUM_HEADS", 8)) + mlp_mult = int(os.environ.get("MLP_MULT", 2)) + num_value_embeds = int(os.environ.get("NUM_VALUE_EMBEDS", 2)) + tie_embeddings = bool(int(os.environ.get("TIE_EMBEDDINGS", "1"))) + rope_base = float(os.environ.get("ROPE_BASE", 10000.0)) + logit_softcap = float(os.environ.get("LOGIT_SOFTCAP", 30.0)) + bigram_vocab_size = int(os.environ.get("BIGRAM_VOCAB_SIZE", 0)) + bigram_dim = int(os.environ.get("BIGRAM_DIM", 112)) + + embed_lr = float(os.environ.get("EMBED_LR", 0.6)) + head_lr = float(os.environ.get("HEAD_LR", 0.008)) + tied_embed_lr = float(os.environ.get("TIED_EMBED_LR", 0.021)) + tied_embed_init_std = float(os.environ.get("TIED_EMBED_INIT_STD", 0.005)) + matrix_lr = float(os.environ.get("MATRIX_LR", 0.018)) + scalar_lr = float(os.environ.get("SCALAR_LR", 0.018)) + muon_momentum = float(os.environ.get("MUON_MOMENTUM", 0.95)) + muon_backend_steps = int(os.environ.get("MUON_BACKEND_STEPS", 5)) + muon_momentum_warmup_start = float(os.environ.get("MUON_MOMENTUM_WARMUP_START", 0.85)) + muon_momentum_warmup_steps = int(os.environ.get("MUON_MOMENTUM_WARMUP_STEPS", 500)) + muon_wd = float(os.environ.get("MUON_WD", 0.04)) + beta1 = float(os.environ.get("BETA1", 0.9)) + beta2 = float(os.environ.get("BETA2", 0.95)) + adam_eps = float(os.environ.get("ADAM_EPS", 1e-8)) + +# MUON OPTIMIZER + +def zeropower_via_newtonschulz5(G: Tensor, steps: int = 10, eps: float = 1e-7) -> Tensor: + a, b, c = (3.4445, -4.7750, 2.0315) + X = G.bfloat16() + X /= X.norm() + eps + transposed = G.size(0) > G.size(1) + if transposed: + X = X.T + for _ in range(steps): + A = X @ X.T + B = b * A + c * A @ A + X = a * X + B @ X + return X.T if transposed else X + +class Muon(torch.optim.Optimizer): + def __init__(self, params, lr: float, momentum: float, backend_steps: int, nesterov: bool = True, weight_decay: float = 0.0): + super().__init__( + params, + dict(lr=lr, momentum=momentum, backend_steps=backend_steps, nesterov=nesterov, weight_decay=weight_decay), + ) + + @torch.no_grad() + def step(self, closure=None): + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + + distributed = dist.is_available() and dist.is_initialized() + world_size = dist.get_world_size() if distributed else 1 + rank = dist.get_rank() if distributed else 0 + + for group in self.param_groups: + params = group["params"] + if not params: + continue + lr = group["lr"] + momentum = group["momentum"] + backend_steps = group["backend_steps"] + nesterov = group["nesterov"] + + total_params = sum(int(p.numel()) for p in params) + updates_flat = torch.zeros(total_params, device=params[0].device, dtype=torch.bfloat16) + + curr = 0 + for i, p in enumerate(params): + if i % world_size == rank and p.grad is not None: + g = p.grad + state = self.state[p] + if "momentum_buffer" not in state: + state["momentum_buffer"] = torch.zeros_like(g) + buf = state["momentum_buffer"] + buf.mul_(momentum).add_(g) + if nesterov: + g = g.add(buf, alpha=momentum) + g = zeropower_via_newtonschulz5(g, steps=backend_steps) + # Scale correction from Muon reference implementations. + g *= max(1, g.size(0) / g.size(1)) ** 0.5 + updates_flat[curr : curr + p.numel()] = g.reshape(-1) + curr += p.numel() + + if distributed: + dist.all_reduce(updates_flat, op=dist.ReduceOp.SUM) + + wd = group.get("weight_decay", 0.0) + curr = 0 + for p in params: + g = updates_flat[curr : curr + p.numel()].view_as(p).to(dtype=p.dtype) + if wd > 0: + p.data.mul_(1.0 - lr * wd) + p.add_(g, alpha=-lr) + curr += p.numel() + + return loss + +# TOKENIZER-AGNOSTIC EVALUATION + +def build_sentencepiece_luts( + sp: spm.SentencePieceProcessor, vocab_size: int, device: torch.device +) -> tuple[Tensor, Tensor, Tensor]: + sp_vocab_size = int(sp.vocab_size()) + table_size = max(sp_vocab_size, vocab_size) + base_bytes_np = np.zeros((table_size,), dtype=np.int16) + has_leading_space_np = np.zeros((table_size,), dtype=np.bool_) + is_boundary_token_np = np.ones((table_size,), dtype=np.bool_) + for token_id in range(sp_vocab_size): + if sp.is_control(token_id) or sp.is_unknown(token_id) or sp.is_unused(token_id): + continue + is_boundary_token_np[token_id] = False + if sp.is_byte(token_id): + base_bytes_np[token_id] = 1 + continue + piece = sp.id_to_piece(token_id) + if piece.startswith("▁"): + has_leading_space_np[token_id] = True + piece = piece[1:] + base_bytes_np[token_id] = len(piece.encode("utf-8")) + return ( + torch.tensor(base_bytes_np, dtype=torch.int16, device=device), + torch.tensor(has_leading_space_np, dtype=torch.bool, device=device), + torch.tensor(is_boundary_token_np, dtype=torch.bool, device=device), + ) + +def load_validation_tokens(pattern: str, seq_len: int) -> Tensor: + files = [Path(p) for p in sorted(glob.glob(pattern))] + if not files: + raise FileNotFoundError(f"No files found for pattern: {pattern}") + # The export pipeline writes the fixed first-50k-doc validation set to fineweb_val_*. + tokens = torch.cat([load_data_shard(file) for file in files]).contiguous() + usable = ((tokens.numel() - 1) // seq_len) * seq_len + if usable <= 0: + raise ValueError(f"Validation split is too short for TRAIN_SEQ_LEN={seq_len}") + return tokens[: usable + 1] + +def eval_val( + args: Hyperparameters, + model: nn.Module, + rank: int, + world_size: int, + device: torch.device, + grad_accum_steps: int, + val_tokens: Tensor, + base_bytes_lut: Tensor, + has_leading_space_lut: Tensor, + is_boundary_token_lut: Tensor, +) -> tuple[float, float]: + # Validation computes two metrics: + # - val_loss: token cross-entropy (natural log) + # - val_bpb: tokenizer-agnostic compression metric used by the challenge + local_batch_tokens = args.val_batch_size // (world_size * grad_accum_steps) + if local_batch_tokens < args.train_seq_len: + raise ValueError( + "VAL_BATCH_SIZE must provide at least one sequence per rank; " + f"got VAL_BATCH_SIZE={args.val_batch_size}, WORLD_SIZE={world_size}, " + f"GRAD_ACCUM_STEPS={grad_accum_steps}, TRAIN_SEQ_LEN={args.train_seq_len}" + ) + local_batch_seqs = local_batch_tokens // args.train_seq_len + total_seqs = (val_tokens.numel() - 1) // args.train_seq_len + seq_start = (total_seqs * rank) // world_size + seq_end = (total_seqs * (rank + 1)) // world_size + val_loss_sum = torch.zeros((), device=device, dtype=torch.float64) + val_token_count = torch.zeros((), device=device, dtype=torch.float64) + val_byte_count = torch.zeros((), device=device, dtype=torch.float64) + + model.eval() + with torch.inference_mode(): + for batch_seq_start in range(seq_start, seq_end, local_batch_seqs): + batch_seq_end = min(batch_seq_start + local_batch_seqs, seq_end) + raw_start = batch_seq_start * args.train_seq_len + raw_end = batch_seq_end * args.train_seq_len + 1 + local = val_tokens[raw_start:raw_end].to(device=device, dtype=torch.int64, non_blocking=True) + x = local[:-1].reshape(-1, args.train_seq_len) + y = local[1:].reshape(-1, args.train_seq_len) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + batch_loss = model(x, y).detach() + batch_token_count = float(y.numel()) + val_loss_sum += batch_loss.to(torch.float64) * batch_token_count + val_token_count += batch_token_count + prev_ids = x.reshape(-1) + tgt_ids = y.reshape(-1) + token_bytes = base_bytes_lut[tgt_ids].to(dtype=torch.int16) + token_bytes += (has_leading_space_lut[tgt_ids] & ~is_boundary_token_lut[prev_ids]).to(dtype=torch.int16) + val_byte_count += token_bytes.to(torch.float64).sum() + + if dist.is_available() and dist.is_initialized(): + dist.all_reduce(val_loss_sum, op=dist.ReduceOp.SUM) + dist.all_reduce(val_token_count, op=dist.ReduceOp.SUM) + dist.all_reduce(val_byte_count, op=dist.ReduceOp.SUM) + + val_loss = val_loss_sum / val_token_count + bits_per_token = val_loss.item() / math.log(2.0) + tokens_per_byte = val_token_count.item() / val_byte_count.item() + model.train() + return float(val_loss.item()), float(bits_per_token * tokens_per_byte) + +def eval_val_sliding( + args: Hyperparameters, + base_model: nn.Module, + rank: int, + world_size: int, + device: torch.device, + val_tokens: Tensor, + base_bytes_lut: Tensor, + has_leading_space_lut: Tensor, + is_boundary_token_lut: Tensor, + use_hedge: bool = False, + hedge_eta: float = 0.1, +) -> tuple[float, float]: + """Sliding window eval with batching. Windows of train_seq_len advance by eval_stride. + Only the last stride tokens per window are scored (first window scores all). + Optional Hedge Mixer: online n-gram ensemble over scored tokens.""" + seq_len = args.eval_seq_len + stride = args.eval_stride + batch_seqs = args.eval_batch_seqs + total_tokens = val_tokens.numel() - 1 + + window_starts = [ws for ws in range(0, total_tokens, stride) + if min(ws + seq_len, total_tokens) - ws >= 1] + + # Without Hedge: distribute windows across ranks + # With Hedge: parallel forward across ranks, mixer on rank 0 + use_dist = dist.is_available() and dist.is_initialized() and world_size > 1 + + if not use_hedge: + total_windows = len(window_starts) + my_s = (total_windows * rank) // world_size + my_e = (total_windows * (rank + 1)) // world_size + my_windows = window_starts[my_s:my_e] + + mixer = HedgeMixer(vocab_size=args.vocab_size, device=device, eta=hedge_eta) if use_hedge else None + + val_loss_sum = torch.zeros((), device=device, dtype=torch.float64) + val_token_count = torch.zeros((), device=device, dtype=torch.float64) + val_byte_count = torch.zeros((), device=device, dtype=torch.float64) + + base_model.eval() + with torch.inference_mode(): + if use_hedge: + # Parallel forward across all GPUs, mixer on rank 0 + hedge_batch = batch_seqs + for bi in range(0, len(window_starts), hedge_batch): + batch_ws = window_starts[bi : bi + hedge_batch] + bsz = len(batch_ws) + + # All ranks build full batch (cheap indexing) + x_batch = torch.zeros(bsz, seq_len, dtype=torch.int64, device=device) + y_batch = torch.zeros(bsz, seq_len, dtype=torch.int64, device=device) + wlens: list[int] = [] + for i, ws in enumerate(batch_ws): + end = min(ws + seq_len, total_tokens) + wlen = end - ws + wlens.append(wlen) + chunk = val_tokens[ws : end + 1].to(dtype=torch.int64, device=device) + x_batch[i, :wlen] = chunk[:-1] + y_batch[i, :wlen] = chunk[1:] + + # Pad to multiple of world_size + per_rank = (bsz + world_size - 1) // world_size + padded = per_rank * world_size + if padded > bsz: + x_batch = F.pad(x_batch, (0, 0, 0, padded - bsz)) + y_batch = F.pad(y_batch, (0, 0, 0, padded - bsz)) + + # Each rank forwards its slice + my_x = x_batch[rank * per_rank : (rank + 1) * per_rank] + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + my_logits = base_model.forward_logits(my_x) + + # Gather logits to all ranks + if use_dist: + gathered = [torch.zeros_like(my_logits) for _ in range(world_size)] + dist.all_gather(gathered, my_logits.contiguous()) + logits = torch.cat(gathered, dim=0)[:bsz] + else: + logits = my_logits[:bsz] + + # Rank 0: mixer scoring + n-gram update + if rank == 0: + nll = mixer.mix_and_score(logits.float(), x_batch[:bsz], y_batch[:bsz], wlens) + for i, ws in enumerate(batch_ws): + wlen = wlens[i] + s = 0 if ws == 0 else max(wlen - stride, 0) + scored_nll = nll[i, s:wlen].to(torch.float64) + val_loss_sum += scored_nll.sum() + val_token_count += float(wlen - s) + prev_ids = x_batch[i, s:wlen] + tgt_ids = y_batch[i, s:wlen] + token_bytes = base_bytes_lut[tgt_ids].to(dtype=torch.int16) + token_bytes += (has_leading_space_lut[tgt_ids] & ~is_boundary_token_lut[prev_ids]).to(dtype=torch.int16) + val_byte_count += token_bytes.to(torch.float64).sum() + for i, ws in enumerate(batch_ws): + wlen = wlens[i] + s = 0 if ws == 0 else max(wlen - stride, 0) + mixer.update(y_batch[i, s:wlen]) + + if use_dist: + dist.barrier() + else: + # Non-hedge: each rank processes its windows independently + for bi in range(0, len(my_windows), batch_seqs): + batch_ws = my_windows[bi : bi + batch_seqs] + bsz = len(batch_ws) + x_batch = torch.zeros(bsz, seq_len, dtype=torch.int64, device=device) + y_batch = torch.zeros(bsz, seq_len, dtype=torch.int64, device=device) + wlens: list[int] = [] + + for i, ws in enumerate(batch_ws): + end = min(ws + seq_len, total_tokens) + wlen = end - ws + wlens.append(wlen) + chunk = val_tokens[ws : end + 1].to(dtype=torch.int64, device=device) + x_batch[i, :wlen] = chunk[:-1] + y_batch[i, :wlen] = chunk[1:] + + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + logits = base_model.forward_logits(x_batch) + + nll = F.cross_entropy( + logits.reshape(-1, logits.size(-1)).float(), + y_batch.reshape(-1), + reduction="none", + ).reshape(bsz, seq_len) + + for i, ws in enumerate(batch_ws): + wlen = wlens[i] + s = 0 if ws == 0 else max(wlen - stride, 0) + scored_nll = nll[i, s:wlen].to(torch.float64) + val_loss_sum += scored_nll.sum() + val_token_count += float(wlen - s) + prev_ids = x_batch[i, s:wlen] + tgt_ids = y_batch[i, s:wlen] + token_bytes = base_bytes_lut[tgt_ids].to(dtype=torch.int16) + token_bytes += (has_leading_space_lut[tgt_ids] & ~is_boundary_token_lut[prev_ids]).to(dtype=torch.int16) + val_byte_count += token_bytes.to(torch.float64).sum() + + if not use_hedge and use_dist: + dist.all_reduce(val_loss_sum, op=dist.ReduceOp.SUM) + dist.all_reduce(val_token_count, op=dist.ReduceOp.SUM) + dist.all_reduce(val_byte_count, op=dist.ReduceOp.SUM) + + val_loss = val_loss_sum / val_token_count + bits_per_token = val_loss.item() / math.log(2.0) + tokens_per_byte = val_token_count.item() / val_byte_count.item() + base_model.train() + return float(val_loss.item()), float(bits_per_token * tokens_per_byte) + +# QUANTIZATION + +CONTROL_TENSOR_NAME_PATTERNS = tuple( + pattern + for pattern in os.environ.get( + "CONTROL_TENSOR_NAME_PATTERNS", + "attn_scale,attn_scales,mlp_scale,mlp_scales,resid_mix,resid_mixes,q_gain,skip_weight,skip_weights", + ).split(",") + if pattern +) +INT8_KEEP_FLOAT_FP32_NAME_PATTERNS = tuple( + pattern + for pattern in os.environ.get( + "INT8_KEEP_FLOAT_FP32_NAME_PATTERNS", + ",".join(CONTROL_TENSOR_NAME_PATTERNS), + ).split(",") + if pattern +) +INT8_KEEP_FLOAT_MAX_NUMEL = 65_536 +INT8_KEEP_FLOAT_STORE_DTYPE = torch.float16 +INT8_PER_ROW_SCALE_DTYPE = torch.float16 +INT8_CLIP_PERCENTILE = 99.99984 +INT8_CLIP_Q = INT8_CLIP_PERCENTILE / 100.0 +# Quantization levels: 127 = int8, 31 = int6, 16 = int5. Per-tensor override via MLP_QUANT_LEVELS. +QUANT_LEVELS = int(os.environ.get("QUANT_LEVELS", 127)) +MLP_QUANT_LEVELS = int(os.environ.get("MLP_QUANT_LEVELS", 0)) # 0 = same as QUANT_LEVELS +MLP_TENSOR_PATTERNS = ("mlp.fc.", "mlp.proj.", "fc.weight", "mlp.proj.weight") + + +def tensor_nbytes(t: Tensor) -> int: + return int(t.numel()) * int(t.element_size()) + +def keep_float_tensor(name: str, t: Tensor, passthrough_orig_dtypes: dict[str, str]) -> Tensor: + if any(pattern in name for pattern in INT8_KEEP_FLOAT_FP32_NAME_PATTERNS): + return t.float().contiguous() + if t.dtype in {torch.float32, torch.bfloat16}: + passthrough_orig_dtypes[name] = str(t.dtype).removeprefix("torch.") + return t.to(dtype=INT8_KEEP_FLOAT_STORE_DTYPE).contiguous() + return t + +GPTQ_LITE_PERCENTILES = [0.9999, 0.99995, 0.99999, 0.999995, 0.999999] + +def quantize_float_tensor(t: Tensor, ql: int = 0) -> tuple[Tensor, Tensor]: + if ql <= 0: + ql = QUANT_LEVELS + t32 = t.float() + if t32.ndim == 2: + abs_t = t32.abs() + best_q, best_scale, best_mse = None, None, None + for pct in GPTQ_LITE_PERCENTILES: + clip_abs = torch.quantile(abs_t, pct, dim=1) if t32.numel() else torch.empty((t32.shape[0],), dtype=torch.float32) + s = (clip_abs / ql).clamp_min(1e-12) + q = torch.clamp(torch.round(torch.clamp(t32, -clip_abs[:, None], clip_abs[:, None]) / s[:, None]), -ql, ql) + mse = (t32 - q * s[:, None]).square().sum(dim=1) + if best_mse is None: + best_mse, best_q, best_scale = mse, q, s + else: + better = mse < best_mse + best_mse = torch.where(better, mse, best_mse) + best_q = torch.where(better[:, None], q, best_q) + best_scale = torch.where(better, s, best_scale) + return best_q.to(torch.int8).contiguous(), best_scale.to(dtype=INT8_PER_ROW_SCALE_DTYPE).contiguous() + clip_abs = float(torch.quantile(t32.abs().flatten(), INT8_CLIP_Q).item()) if t32.numel() else 0.0 + scale = torch.tensor(clip_abs / ql if clip_abs > 0 else 1.0, dtype=torch.float32) + q = torch.clamp(torch.round(torch.clamp(t32, -clip_abs, clip_abs) / scale), -ql, ql).to(torch.int8).contiguous() + return q, scale + + +def quantize_state_dict_int8(state_dict: dict[str, Tensor]): + # Single supported clean-script export format: + # - per-row int8 for 2D float tensors + # - per-tensor int8 for other float tensors + # - exact passthrough for non-floats + # - passthrough for small float tensors, stored as fp16 to save bytes + quantized: dict[str, Tensor] = {} + scales: dict[str, Tensor] = {} + dtypes: dict[str, str] = {} + passthrough: dict[str, Tensor] = {} + passthrough_orig_dtypes: dict[str, str] = {} + qmeta: dict[str, dict[str, object]] = {} + stats = dict.fromkeys( + ("param_count", "num_tensors", "num_float_tensors", "num_nonfloat_tensors", "baseline_tensor_bytes", "int8_payload_bytes"), + 0, + ) + + for name, tensor in state_dict.items(): + t = tensor.detach().to("cpu").contiguous() + stats["param_count"] += int(t.numel()) + stats["num_tensors"] += 1 + stats["baseline_tensor_bytes"] += tensor_nbytes(t) + + if not t.is_floating_point(): + stats["num_nonfloat_tensors"] += 1 + passthrough[name] = t + stats["int8_payload_bytes"] += tensor_nbytes(t) + continue + + # Small float tensors are cheap enough to keep directly. We still downcast + # fp32/bf16 passthrough tensors to fp16 so metadata does not dominate size. + if t.numel() <= INT8_KEEP_FLOAT_MAX_NUMEL: + kept = keep_float_tensor(name, t, passthrough_orig_dtypes) + passthrough[name] = kept + stats["int8_payload_bytes"] += tensor_nbytes(kept) + continue + + stats["num_float_tensors"] += 1 + mlp_ql = MLP_QUANT_LEVELS if MLP_QUANT_LEVELS > 0 else QUANT_LEVELS + ql = mlp_ql if any(p in name for p in MLP_TENSOR_PATTERNS) else QUANT_LEVELS + q, s = quantize_float_tensor(t, ql=ql) + meta: dict[str, object] = {} + if s.ndim > 0: + meta["scheme"] = "per_row" + meta["axis"] = 0 + if ql != QUANT_LEVELS: + meta["ql"] = ql + if meta: + qmeta[name] = meta + quantized[name] = q + scales[name] = s + dtypes[name] = str(t.dtype).removeprefix("torch.") + stats["int8_payload_bytes"] += tensor_nbytes(q) + tensor_nbytes(s) + + obj: dict[str, object] = { + "__quant_format__": "int8_clean_per_row_v1", + "quantized": quantized, + "scales": scales, + "dtypes": dtypes, + "passthrough": passthrough, + } + if qmeta: + obj["qmeta"] = qmeta + if passthrough_orig_dtypes: + obj["passthrough_orig_dtypes"] = passthrough_orig_dtypes + return obj, stats + +def dequantize_state_dict_int8(obj: dict[str, object]) -> dict[str, Tensor]: + out: dict[str, Tensor] = {} + qmeta = obj.get("qmeta", {}) + passthrough_orig_dtypes = obj.get("passthrough_orig_dtypes", {}) + for name, q in obj["quantized"].items(): + dtype = getattr(torch, obj["dtypes"][name]) + s = obj["scales"][name] + if qmeta.get(name, {}).get("scheme") == "per_row" or s.ndim > 0: + s = s.to(dtype=torch.float32) + # Broadcast the saved row scale back across trailing dimensions. + out[name] = (q.float() * s.view(q.shape[0], *([1] * (q.ndim - 1)))).to(dtype=dtype).contiguous() + else: + scale = float(s.item()) + out[name] = (q.float() * scale).to(dtype=dtype).contiguous() + for name, t in obj["passthrough"].items(): + # Restore small tensors, undoing the temporary fp16 storage cast if needed. + out_t = t.detach().to("cpu").contiguous() + orig_dtype = passthrough_orig_dtypes.get(name) + if isinstance(orig_dtype, str): + out_t = out_t.to(dtype=getattr(torch, orig_dtype)).contiguous() + out[name] = out_t + return out + +# DATA LOADING + +def load_data_shard(file: Path) -> Tensor: + header_bytes = 256 * np.dtype(" None: + self.file_idx = (self.file_idx + 1) % len(self.files) + self.tokens = load_data_shard(self.files[self.file_idx]) + self.pos = 0 + + def take(self, n: int) -> Tensor: + chunks: list[Tensor] = [] + remaining = n + while remaining > 0: + avail = self.tokens.numel() - self.pos + if avail <= 0: + self._advance_file() + continue + k = min(remaining, avail) + chunks.append(self.tokens[self.pos : self.pos + k]) + self.pos += k + remaining -= k + return chunks[0] if len(chunks) == 1 else torch.cat(chunks) + +class DistributedTokenLoader: + # Each call consumes a contiguous chunk from the shared token stream, then slices out + # one disjoint span per rank. The extra "+1" token lets us build (x, y) by shifting. + def __init__(self, pattern: str, rank: int, world_size: int, device: torch.device): + self.rank = rank + self.world_size = world_size + self.device = device + self.stream = TokenStream(pattern) + + def next_batch(self, global_tokens: int, seq_len: int, grad_accum_steps: int) -> tuple[Tensor, Tensor]: + local_tokens = global_tokens // (self.world_size * grad_accum_steps) + per_rank_span = local_tokens + 1 + chunk = self.stream.take(per_rank_span * self.world_size) + start = self.rank * per_rank_span + local = chunk[start : start + per_rank_span].to(dtype=torch.int64) + x = local[:-1].reshape(-1, seq_len) + y = local[1:].reshape(-1, seq_len) + return x.to(self.device, non_blocking=True), y.to(self.device, non_blocking=True) + +# TRANSFORMER MODULES + +class RMSNorm(nn.Module): + def __init__(self, eps: float | None = None): + super().__init__() + self.eps = eps + + def forward(self, x: Tensor) -> Tensor: + return F.rms_norm(x, (x.size(-1),), eps=self.eps) + +class CastedLinear(nn.Linear): + def forward(self, x: Tensor) -> Tensor: + bias = self.bias.to(x.dtype) if self.bias is not None else None + return F.linear(x, self.weight.to(x.dtype), bias) + +def restore_low_dim_params_to_fp32(module: nn.Module) -> None: + # Keep small/control parameters in fp32 even when the model body runs in bf16. + with torch.no_grad(): + for name, param in module.named_parameters(): + if (param.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS)) and param.dtype != torch.float32: + param.data = param.data.float() + +class Rotary(nn.Module): + def __init__(self, dim: int, base: float = 10000.0): + super().__init__() + inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.float32) / dim)) + self.register_buffer("inv_freq", inv_freq, persistent=False) + self._seq_len_cached = 0 + self._cos_cached: Tensor | None = None + self._sin_cached: Tensor | None = None + + def forward(self, seq_len: int, device: torch.device, dtype: torch.dtype) -> tuple[Tensor, Tensor]: + if ( + self._cos_cached is None + or self._sin_cached is None + or self._seq_len_cached != seq_len + or self._cos_cached.device != device + ): + t = torch.arange(seq_len, device=device, dtype=self.inv_freq.dtype) + freqs = torch.outer(t, self.inv_freq.to(device)) + self._cos_cached = freqs.cos()[None, None, :, :] + self._sin_cached = freqs.sin()[None, None, :, :] + self._seq_len_cached = seq_len + return self._cos_cached.to(dtype=dtype), self._sin_cached.to(dtype=dtype) + +def apply_rotary_emb(x: Tensor, cos: Tensor, sin: Tensor) -> Tensor: + half = x.size(-1) // 2 + x1, x2 = x[..., :half], x[..., half:] + return torch.cat((x1 * cos + x2 * sin, x1 * (-sin) + x2 * cos), dim=-1) + +class CausalSelfAttention(nn.Module): + def __init__( + self, + dim: int, + num_heads: int, + num_kv_heads: int, + rope_base: float, + qk_gain_init: float, + ): + super().__init__() + if dim % num_heads != 0: + raise ValueError("model_dim must be divisible by num_heads") + if num_heads % num_kv_heads != 0: + raise ValueError("num_heads must be divisible by num_kv_heads") + self.num_heads = num_heads + self.num_kv_heads = num_kv_heads + self.head_dim = dim // num_heads + if self.head_dim % 2 != 0: + raise ValueError("head_dim must be even for RoPE") + kv_dim = self.num_kv_heads * self.head_dim + self.c_q = CastedLinear(dim, dim, bias=False) + self.c_k = CastedLinear(dim, kv_dim, bias=False) + self.c_v = CastedLinear(dim, kv_dim, bias=False) + self.proj = CastedLinear(dim, dim, bias=False) + self.proj._zero_init = True + self.q_gain = nn.Parameter(torch.full((num_heads,), qk_gain_init, dtype=torch.float32)) + self.rotary = Rotary(self.head_dim, base=rope_base) + + def _xsa(self, y: Tensor, v: Tensor) -> Tensor: + """Subtract self-value projection from attention output (GQA-aware).""" + B, T, H, D = y.shape + Hkv = v.size(2) + group = H // Hkv + y_g = y.reshape(B, T, Hkv, group, D) + vn = F.normalize(v, dim=-1).unsqueeze(3) # [B, T, Hkv, 1, D] + proj = (y_g * vn).sum(dim=-1, keepdim=True) * vn + return (y_g - proj).reshape(B, T, H, D) + + def forward(self, x: Tensor, use_xsa: bool = False) -> Tensor: + bsz, seqlen, dim = x.shape + q = self.c_q(x).reshape(bsz, seqlen, self.num_heads, self.head_dim).transpose(1, 2) + k = self.c_k(x).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim).transpose(1, 2) + v = self.c_v(x).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim).transpose(1, 2) + q = F.rms_norm(q, (q.size(-1),)) + k = F.rms_norm(k, (k.size(-1),)) + cos, sin = self.rotary(seqlen, x.device, q.dtype) + q = apply_rotary_emb(q, cos, sin) + k = apply_rotary_emb(k, cos, sin) + q = q * self.q_gain.to(dtype=q.dtype)[None, :, None, None] + y = F.scaled_dot_product_attention( + q, + k, + v, + attn_mask=None, + is_causal=True, + enable_gqa=(self.num_kv_heads != self.num_heads), + ) + # XSA: remove self-value bias from attention output + if use_xsa: + y = y.transpose(1, 2).contiguous() # [B, T, H, D] + v_for_xsa = v.transpose(1, 2) # [B, T, Hkv, D] + y = self._xsa(y, v_for_xsa) + y = y.reshape(bsz, seqlen, dim) + else: + y = y.transpose(1, 2).contiguous().reshape(bsz, seqlen, dim) + return self.proj(y) + +class MLP(nn.Module): + def __init__(self, dim: int, mlp_mult: int): + super().__init__() + hidden = mlp_mult * dim + self.fc = CastedLinear(dim, hidden, bias=False) + self.proj = CastedLinear(hidden, dim, bias=False) + self.proj._zero_init = True + + def forward(self, x: Tensor) -> Tensor: + x = F.leaky_relu(self.fc(x), negative_slope=0.5) + return self.proj(x.square()) + +class Block(nn.Module): + def __init__( + self, + dim: int, + num_heads: int, + num_kv_heads: int, + mlp_mult: int, + rope_base: float, + qk_gain_init: float, + ): + super().__init__() + self.attn_norm = RMSNorm() + self.mlp_norm = RMSNorm() + self.attn = CausalSelfAttention(dim, num_heads, num_kv_heads, rope_base, qk_gain_init) + self.mlp = MLP(dim, mlp_mult) + self.attn_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + self.mlp_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + self.resid_mix = nn.Parameter(torch.stack((torch.ones(dim), torch.zeros(dim))).float()) + + def forward(self, x: Tensor, x0: Tensor, use_xsa: bool = False) -> Tensor: + mix = self.resid_mix.to(dtype=x.dtype) + x = mix[0][None, None, :] * x + mix[1][None, None, :] * x0 + attn_out = self.attn(self.attn_norm(x), use_xsa=use_xsa) + x = x + self.attn_scale.to(dtype=x.dtype)[None, None, :] * attn_out + x = x + self.mlp_scale.to(dtype=x.dtype)[None, None, :] * self.mlp(self.mlp_norm(x)) + return x + +class BigramHashEmbedding(nn.Module): + """Hashed bigram embedding — adds token-pair context to input embeddings.""" + def __init__(self, bigram_vocab_size: int, bigram_dim: int, model_dim: int): + super().__init__() + self.bigram_vocab_size = bigram_vocab_size + self.embed = nn.Embedding(bigram_vocab_size, bigram_dim) + nn.init.zeros_(self.embed.weight) + self.proj = CastedLinear(bigram_dim, model_dim, bias=False) if bigram_dim != model_dim else None + if self.proj is not None: + nn.init.zeros_(self.proj.weight) + self.scale = nn.Parameter(torch.tensor(0.05, dtype=torch.float32)) + + def forward(self, token_ids: Tensor) -> Tensor: + t = token_ids.to(torch.int32) + mod = self.bigram_vocab_size - 1 + h = torch.empty_like(t) + h[..., 0] = mod + h[..., 1:] = (36313 * t[..., 1:] ^ 27191 * t[..., :-1]) % mod + out = self.embed(h.long()) + if self.proj is not None: + out = self.proj(out) + return out * self.scale.to(dtype=out.dtype) + + +class GPT(nn.Module): + def __init__( + self, + vocab_size: int, + num_layers: int, + num_repeats: int, + model_dim: int, + num_heads: int, + num_kv_heads: int, + mlp_mult: int, + num_value_embeds: int, + tie_embeddings: bool, + tied_embed_init_std: float, + logit_softcap: float, + rope_base: float, + qk_gain_init: float, + xsa_last_n: int = 0, + bigram_vocab_size: int = 0, + bigram_dim: int = 112, + ): + super().__init__() + if logit_softcap <= 0.0: + raise ValueError(f"logit_softcap must be positive, got {logit_softcap}") + self.tie_embeddings = tie_embeddings + self.tied_embed_init_std = tied_embed_init_std + self.logit_softcap = logit_softcap + self.num_repeats = num_repeats + self.xsa_last_n = xsa_last_n + effective_depth = num_layers * num_repeats + self.tok_emb = nn.Embedding(vocab_size, model_dim) + self.bigram = BigramHashEmbedding(bigram_vocab_size, bigram_dim, model_dim) if bigram_vocab_size > 0 else None + # Value embeddings: extra embedding tables mixed into each effective layer + self.num_value_embeds = num_value_embeds + if num_value_embeds > 0: + self.value_embeds = nn.ModuleList([nn.Embedding(vocab_size, model_dim) for _ in range(num_value_embeds)]) + self.value_scales = nn.Parameter(torch.zeros(effective_depth, num_value_embeds, model_dim, dtype=torch.float32)) + self.blocks = nn.ModuleList( + [ + Block( + model_dim, + num_heads, + num_kv_heads, + mlp_mult, + rope_base, + qk_gain_init, + ) + for i in range(num_layers) + ] + ) + # Loop embedding: tells the model which effective layer it's at + self.loop_embed = nn.Parameter(torch.zeros(effective_depth, model_dim, dtype=torch.float32)) + # Cross-repeat skip: each block receives its own output from previous repeat + self.cross_repeat_scales = nn.Parameter(torch.zeros(num_layers, num_repeats - 1, model_dim, dtype=torch.float32)) + self.final_norm = RMSNorm() + self.lm_head = None if tie_embeddings else CastedLinear(model_dim, vocab_size, bias=False) + if self.lm_head is not None: + self.lm_head._zero_init = True + self._init_weights() + + def _init_weights(self) -> None: + if self.tie_embeddings: + nn.init.normal_(self.tok_emb.weight, mean=0.0, std=self.tied_embed_init_std) + for module in self.modules(): + if isinstance(module, nn.Linear) and getattr(module, "_zero_init", False): + nn.init.zeros_(module.weight) + + def forward_hidden(self, input_ids: Tensor) -> Tensor: + x = self.tok_emb(input_ids) + if self.bigram is not None: + x = x + self.bigram(input_ids) + x = F.rms_norm(x, (x.size(-1),)) + x0 = x + ve_list: list[Tensor] = [] + if self.num_value_embeds > 0: + for ve in self.value_embeds: + ve_list.append(ve(input_ids)) # (bsz, seq, dim) + + cur_repeats = self.cur_repeats if hasattr(self, "cur_repeats") else self.num_repeats + cur_depth = len(self.blocks) * cur_repeats + xsa_start = max(0, cur_depth - self.xsa_last_n) if self.xsa_last_n > 0 else cur_depth + + num_blocks = len(self.blocks) + prev_block_outputs: list[Tensor | None] = [None] * num_blocks + layer_idx = 0 + for repeat in range(cur_repeats): + for block_idx, block in enumerate(self.blocks): + x = x + self.loop_embed[layer_idx].to(dtype=x.dtype) + if self.num_value_embeds > 0 and layer_idx < self.value_scales.size(0): + for ve_idx, ve_out in enumerate(ve_list): + vs = self.value_scales[layer_idx, ve_idx].to(dtype=x.dtype) + x = x + vs[None, None, :] * ve_out + # Cross-repeat skip: mix in this block's output from previous repeat + if repeat > 0 and prev_block_outputs[block_idx] is not None: + rep_idx = min(repeat - 1, self.cross_repeat_scales.size(1) - 1) + scale = self.cross_repeat_scales[block_idx, rep_idx].to(dtype=x.dtype) + x = x + scale[None, None, :] * prev_block_outputs[block_idx] + x = block(x, x0, use_xsa=(layer_idx >= xsa_start)) + prev_block_outputs[block_idx] = x.detach() if not self.training else x + layer_idx += 1 + + return x + + def forward_logits(self, input_ids: Tensor) -> Tensor: + x = self.forward_hidden(input_ids) + x = self.final_norm(x) + proj_w = self.tok_emb.weight if self.tie_embeddings else self.lm_head.weight + logits_proj = F.linear(x, proj_w) + return self.logit_softcap * torch.tanh(logits_proj / self.logit_softcap) + + def forward(self, input_ids: Tensor, target_ids: Tensor) -> Tensor: + logits = self.forward_logits(input_ids) + logits = logits.reshape(-1, logits.size(-1)) + targets = target_ids.reshape(-1) + return F.cross_entropy(logits.float(), targets, reduction="mean") + +# TRAINING + +def main() -> None: + global zeropower_via_newtonschulz5 + + code = Path(__file__).read_text(encoding="utf-8") + args = Hyperparameters() + zeropower_via_newtonschulz5 = torch.compile(zeropower_via_newtonschulz5) + + # DISTRIBUTED + CUDA SETUP + + distributed = "RANK" in os.environ and "WORLD_SIZE" in os.environ + rank = int(os.environ.get("RANK", "0")) + world_size = int(os.environ.get("WORLD_SIZE", "1")) + local_rank = int(os.environ.get("LOCAL_RANK", "0")) + if world_size <= 0: + raise ValueError(f"WORLD_SIZE must be positive, got {world_size}") + grad_accum_steps = max(1, 8 // world_size) + grad_scale = 1.0 / grad_accum_steps + if not torch.cuda.is_available(): + raise RuntimeError("CUDA is required") + device = torch.device("cuda", local_rank) + torch.cuda.set_device(device) + if distributed: + dist.init_process_group(backend="nccl", device_id=device) + dist.barrier() + master_process = rank == 0 + + # Fast math knobs + torch.backends.cuda.matmul.allow_tf32 = True + torch.backends.cudnn.allow_tf32 = True + from torch.backends.cuda import enable_cudnn_sdp, enable_flash_sdp, enable_math_sdp, enable_mem_efficient_sdp + + enable_cudnn_sdp(False) + enable_flash_sdp(True) + enable_mem_efficient_sdp(False) + enable_math_sdp(False) + + logfile = None + if master_process: + os.makedirs("logs", exist_ok=True) + logfile = f"logs/{args.run_id}.txt" + print(logfile) + + def log0(msg: str, console: bool = True) -> None: + if not master_process: + return + if console: + print(msg) + if logfile is not None: + with open(logfile, "a", encoding="utf-8") as f: + print(msg, file=f) + + log0(code, console=False) + log0("=" * 100, console=False) + log0(f"Running Python {sys.version}", console=False) + log0(f"Running PyTorch {torch.__version__}", console=False) + log0( + subprocess.run(["nvidia-smi"], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True, check=False).stdout, + console=False, + ) + log0("=" * 100, console=False) + + # TOKENIZER + VALIDATION METRIC SETUP + + random.seed(args.seed) + np.random.seed(args.seed) + torch.manual_seed(args.seed) + torch.cuda.manual_seed_all(args.seed) + + if not args.tokenizer_path.endswith(".model"): + raise ValueError(f"Script only setup for SentencePiece .model file: {args.tokenizer_path}") + sp = spm.SentencePieceProcessor(model_file=args.tokenizer_path) + if int(sp.vocab_size()) != args.vocab_size: + raise ValueError( + f"VOCAB_SIZE={args.vocab_size} does not match tokenizer vocab_size={int(sp.vocab_size())}" + ) + dataset_dir = Path(args.data_path).resolve() + actual_train_files = len(list(dataset_dir.glob("fineweb_train_*.bin"))) + val_tokens = load_validation_tokens(args.val_files, args.train_seq_len) + base_bytes_lut, has_leading_space_lut, is_boundary_token_lut = build_sentencepiece_luts( + sp, args.vocab_size, device + ) + log0(f"val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path={args.tokenizer_path}") + log0(f"train_loader:dataset:{dataset_dir.name} train_shards:{actual_train_files}") + log0(f"val_loader:shards pattern={args.val_files} tokens:{val_tokens.numel() - 1}") + + # MODEL + OPTIMIZER SETUP + + base_model = GPT( + vocab_size=args.vocab_size, + num_layers=args.num_layers, + num_repeats=args.num_repeats, + model_dim=args.model_dim, + num_heads=args.num_heads, + num_kv_heads=args.num_kv_heads, + mlp_mult=args.mlp_mult, + num_value_embeds=args.num_value_embeds, + tie_embeddings=args.tie_embeddings, + tied_embed_init_std=args.tied_embed_init_std, + logit_softcap=args.logit_softcap, + rope_base=args.rope_base, + qk_gain_init=args.qk_gain_init, + xsa_last_n=args.xsa_last_n, + bigram_vocab_size=args.bigram_vocab_size, + bigram_dim=args.bigram_dim, + ).to(device).bfloat16() + for module in base_model.modules(): + if isinstance(module, CastedLinear): + module.float() + restore_low_dim_params_to_fp32(base_model) + def recompile(): + cm = torch.compile(base_model, dynamic=False, fullgraph=True) + return DDP(cm, device_ids=[local_rank], broadcast_buffers=False) if distributed else cm + compiled_model = recompile() + model: nn.Module = compiled_model + + block_named_params = list(base_model.blocks.named_parameters()) + matrix_params = [p for name, p in block_named_params if p.ndim == 2 and not any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS)] + scalar_params = [p for name, p in block_named_params if p.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS)] + scalar_params.append(base_model.loop_embed) + scalar_params.append(base_model.cross_repeat_scales) + if base_model.num_value_embeds > 0: + scalar_params.append(base_model.value_scales) + token_lr = args.tied_embed_lr if args.tie_embeddings else args.embed_lr + embed_params = [base_model.tok_emb.weight] + if base_model.num_value_embeds > 0: + embed_params.extend(ve.weight for ve in base_model.value_embeds) + adam_kw = dict(betas=(args.beta1, args.beta2), eps=args.adam_eps, fused=True) + optimizer_tok = torch.optim.Adam([{"params": embed_params, "lr": token_lr, "base_lr": token_lr}], **adam_kw) + optimizer_muon = Muon(matrix_params, lr=args.matrix_lr, momentum=args.muon_momentum, backend_steps=args.muon_backend_steps, weight_decay=args.muon_wd) + for group in optimizer_muon.param_groups: + group["base_lr"] = args.matrix_lr + optimizer_scalar = torch.optim.Adam([{"params": scalar_params, "lr": args.scalar_lr, "base_lr": args.scalar_lr}], **adam_kw) + optimizers: list[torch.optim.Optimizer] = [optimizer_tok, optimizer_muon, optimizer_scalar] + if base_model.lm_head is not None: + optimizer_head = torch.optim.Adam([{"params": [base_model.lm_head.weight], "lr": args.head_lr, "base_lr": args.head_lr}], **adam_kw) + optimizers.insert(1, optimizer_head) + + n_params = sum(p.numel() for p in base_model.parameters()) + gpu_name = torch.cuda.get_device_name(0) if torch.cuda.is_available() else "cpu" + log0(f"model_params:{n_params}") + log0(f"world_size:{world_size} gpu:{gpu_name} grad_accum_steps:{grad_accum_steps}") + log0("sdp_backends:cudnn=False flash=True mem_efficient=False math=False") + log0(f"attention_mode:gqa num_heads:{args.num_heads} num_kv_heads:{args.num_kv_heads}") + log0(f"tie_embeddings:{args.tie_embeddings} embed_lr:{token_lr} head_lr:{args.head_lr if base_model.lm_head is not None else 0.0} matrix_lr:{args.matrix_lr} scalar_lr:{args.scalar_lr}") + log0(f"train_batch_tokens:{args.train_batch_tokens} train_seq_len:{args.train_seq_len} iterations:{args.iterations} warmup_steps:{args.warmup_steps} max_wallclock_seconds:{args.max_wallclock_seconds:.3f}") + log0(f"seed:{args.seed}") + # DATA LOADER & MODEL WARMUP + + train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) + + def zero_grad_all() -> None: + for opt in optimizers: + opt.zero_grad(set_to_none=True) + + max_wallclock_ms = 1000.0 * args.max_wallclock_seconds if args.max_wallclock_seconds > 0 else None + + def lr_mul(step: int, elapsed_ms: float) -> float: + if args.warmdown_iters <= 0: + return 1.0 + if max_wallclock_ms is None: + warmdown_start = max(args.iterations - args.warmdown_iters, 0) + return max((args.iterations - step) / max(args.warmdown_iters, 1), 0.0) if warmdown_start <= step < args.iterations else 1.0 + step_ms = elapsed_ms / max(step, 1) + warmdown_ms = args.warmdown_iters * step_ms + remaining_ms = max(max_wallclock_ms - elapsed_ms, 0.0) + return remaining_ms / max(warmdown_ms, 1e-9) if remaining_ms <= warmdown_ms else 1.0 + + # Warmup primes the compiled forward/backward/optimizer paths, then we restore the + # initial weights/optimizer state so measured training starts from the true init. + if args.warmup_steps > 0: + initial_model_state = {name: tensor.detach().cpu().clone() for name, tensor in base_model.state_dict().items()} + initial_optimizer_states = [copy.deepcopy(opt.state_dict()) for opt in optimizers] + model.train() + for warmup_step in range(args.warmup_steps): + zero_grad_all() + for micro_step in range(grad_accum_steps): + if distributed: + model.require_backward_grad_sync = micro_step == grad_accum_steps - 1 + x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + warmup_loss = model(x, y) + (warmup_loss * grad_scale).backward() + for opt in optimizers: + opt.step() + zero_grad_all() + if args.warmup_steps <= 20 or (warmup_step + 1) % 10 == 0 or warmup_step + 1 == args.warmup_steps: + log0(f"warmup_step:{warmup_step + 1}/{args.warmup_steps}") + base_model.load_state_dict(initial_model_state, strict=True) + for opt, state in zip(optimizers, initial_optimizer_states, strict=True): + opt.load_state_dict(state) + zero_grad_all() + if distributed: + model.require_backward_grad_sync = True + train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) + + # MAIN TRAINING LOOP + + training_time_ms = 0.0 + stop_after_step: int | None = None + swa_state: dict[str, Tensor] | None = None + swa_count = 0 + + # Progressive depth schedule: parse "frac:repeats,..." and sort + prog_phases: list[tuple[float, int]] = [] + for entry in args.prog_depth_schedule.split(","): + frac_s, rep_s = entry.strip().split(":") + prog_phases.append((float(frac_s), int(rep_s))) + prog_phases.sort() + current_phase_repeats = prog_phases[0][1] if prog_phases else args.num_repeats + base_model.cur_repeats = current_phase_repeats + # Recompile with initial phase depth + compiled_model = recompile() + model = compiled_model + log0(f"prog_depth: schedule={prog_phases} starting_repeats={current_phase_repeats}") + + torch.cuda.synchronize() + t0 = time.perf_counter() + + step = 0 + while True: + last_step = step == args.iterations or (stop_after_step is not None and step >= stop_after_step) + + should_validate = last_step or (args.val_loss_every > 0 and step % args.val_loss_every == 0) + if should_validate: + torch.cuda.synchronize() + training_time_ms += 1000.0 * (time.perf_counter() - t0) + val_loss, val_bpb = eval_val( + args, + model, + rank, + world_size, + device, + grad_accum_steps, + val_tokens, + base_bytes_lut, + has_leading_space_lut, + is_boundary_token_lut, + ) + log0( + f"step:{step}/{args.iterations} val_loss:{val_loss:.4f} val_bpb:{val_bpb:.4f} " + f"train_time:{training_time_ms:.0f}ms step_avg:{training_time_ms / max(step, 1):.2f}ms" + ) + torch.cuda.synchronize() + t0 = time.perf_counter() + + if last_step: + if stop_after_step is not None and step < args.iterations: + log0( + f"stopping_early: wallclock_cap train_time:{training_time_ms:.0f}ms " + f"step:{step}/{args.iterations}" + ) + break + + elapsed_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) + # Progressive depth: check if we need to switch phase + # Use synchronized elapsed time (max across ranks) to avoid race conditions + if max_wallclock_ms is not None and prog_phases: + if distributed: + elapsed_tensor = torch.tensor(elapsed_ms, device=device) + dist.all_reduce(elapsed_tensor, op=dist.ReduceOp.MAX) + frac = elapsed_tensor.item() / max_wallclock_ms + else: + frac = elapsed_ms / max_wallclock_ms + new_repeats = prog_phases[-1][1] # default to last + for phase_frac, phase_rep in prog_phases: + if frac < phase_frac: + new_repeats = phase_rep + break + if new_repeats != current_phase_repeats: + current_phase_repeats = new_repeats + base_model.cur_repeats = new_repeats + torch._dynamo.reset() + compiled_model = recompile() + model = compiled_model + log0(f"prog_depth: switched to {new_repeats} repeats at step:{step} frac:{frac:.2f}") + scale = lr_mul(step, elapsed_ms) + zero_grad_all() + train_loss = torch.zeros((), device=device) + for micro_step in range(grad_accum_steps): + if distributed: + model.require_backward_grad_sync = micro_step == grad_accum_steps - 1 + x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + loss = model(x, y) + train_loss += loss.detach() + (loss * grad_scale).backward() + train_loss /= grad_accum_steps + + frac = min(step / args.muon_momentum_warmup_steps, 1.0) if args.muon_momentum_warmup_steps > 0 else 1.0 + muon_momentum = (1 - frac) * args.muon_momentum_warmup_start + frac * args.muon_momentum + for group in optimizer_muon.param_groups: + group["momentum"] = muon_momentum + + for opt in optimizers: + for group in opt.param_groups: + group["lr"] = group["base_lr"] * scale + + if args.grad_clip_norm > 0: + torch.nn.utils.clip_grad_norm_(base_model.parameters(), args.grad_clip_norm) + for opt in optimizers: + opt.step() + zero_grad_all() + + step += 1 + approx_training_time_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) + + # SWA: collect checkpoints during warmdown (only at full depth to avoid mixing phases) + at_full_depth = current_phase_repeats == args.num_repeats + if args.swa_enabled and at_full_depth and scale < args.swa_start_frac and step % args.swa_every == 0: + if swa_state is None: + swa_state = {name: t.detach().cpu().float() for name, t in base_model.state_dict().items()} + swa_count = 1 + log0(f"swa:start step:{step}") + else: + for name, t in base_model.state_dict().items(): + swa_state[name] += t.detach().cpu().float() + swa_count += 1 + + should_log_train = ( + args.train_log_every > 0 + and (step <= 10 or step % args.train_log_every == 0 or stop_after_step is not None) + ) + if should_log_train: + log0( + f"step:{step}/{args.iterations} train_loss:{train_loss.item():.4f} " + f"train_time:{approx_training_time_ms:.0f}ms step_avg:{approx_training_time_ms / step:.2f}ms" + ) + + # Needed to sync whether we've reached the wallclock cap. + reached_cap = max_wallclock_ms is not None and approx_training_time_ms >= max_wallclock_ms + if distributed and max_wallclock_ms is not None: + reached_cap_tensor = torch.tensor(int(reached_cap), device=device) + dist.all_reduce(reached_cap_tensor, op=dist.ReduceOp.MAX) + reached_cap = bool(reached_cap_tensor.item()) + if stop_after_step is None and reached_cap: + stop_after_step = step + + log0( + f"peak memory allocated: {torch.cuda.max_memory_allocated() // 1024 // 1024} MiB " + f"reserved: {torch.cuda.max_memory_reserved() // 1024 // 1024} MiB" + ) + + # Restore full depth for eval/export + base_model.cur_repeats = args.num_repeats + torch._dynamo.reset() + compiled_model = recompile() + model = compiled_model + + # Apply SWA if collected + if args.swa_enabled and swa_state is not None: + # Include final weights (may not have landed on swa_every boundary) + for name, t in base_model.state_dict().items(): + swa_state[name] += t.detach().cpu().float() + swa_count += 1 + log0(f"swa: averaging {swa_count} checkpoints") + current_state = base_model.state_dict() + avg_state = { + name: (tensor / swa_count).to(dtype=current_state[name].dtype) + for name, tensor in swa_state.items() + } + base_model.load_state_dict(avg_state, strict=True) + + # SERIALIZATION + ROUNDTRIP VALIDATION + # Save the raw state (useful for debugging/loading in PyTorch directly), then always produce + # the compressed quantized+zstd artifact and validate the round-tripped weights. + + if master_process: + torch.save(base_model.state_dict(), "final_model.pt") + model_bytes = os.path.getsize("final_model.pt") + code_bytes = len(code.encode("utf-8")) + log0(f"Serialized model: {model_bytes} bytes") + log0(f"Code size: {code_bytes} bytes") + log0(f"Total submission size: {model_bytes + code_bytes} bytes") + + + quant_obj, quant_stats = quantize_state_dict_int8(base_model.state_dict()) + quant_buf = io.BytesIO() + torch.save(quant_obj, quant_buf) + quant_raw = quant_buf.getvalue() + zstd_level = int(os.environ.get("ZSTD_LEVEL", 22)) + cctx = zstd.ZstdCompressor(level=zstd_level) + quant_blob = cctx.compress(quant_raw) + quant_raw_bytes = len(quant_raw) + if master_process: + with open("final_model.int8.ptz", "wb") as f: + f.write(quant_blob) + quant_file_bytes = os.path.getsize("final_model.int8.ptz") + code_bytes = len(code.encode("utf-8")) + ratio = quant_stats["baseline_tensor_bytes"] / max(quant_stats["int8_payload_bytes"], 1) + log0( + f"Serialized model int8+zstd{zstd_level}: {quant_file_bytes} bytes " + f"(payload:{quant_stats['int8_payload_bytes']} raw_torch:{quant_raw_bytes} payload_ratio:{ratio:.2f}x)" + ) + log0(f"Total submission size int8+zstd{zstd_level}: {quant_file_bytes + code_bytes} bytes") + + if distributed: + dist.barrier() + with open("final_model.int8.ptz", "rb") as f: + quant_blob_disk = f.read() + dctx = zstd.ZstdDecompressor() + quant_state = torch.load(io.BytesIO(dctx.decompress(quant_blob_disk)), map_location="cpu") + base_model.load_state_dict(dequantize_state_dict_int8(quant_state), strict=True) + torch.cuda.synchronize() + t_qeval = time.perf_counter() + q_val_loss, q_val_bpb = eval_val( + args, + model, + rank, + world_size, + device, + grad_accum_steps, + val_tokens, + base_bytes_lut, + has_leading_space_lut, + is_boundary_token_lut, + ) + torch.cuda.synchronize() + log0( + f"final_roundtrip val_loss:{q_val_loss:.4f} val_bpb:{q_val_bpb:.4f} " + f"eval_time:{1000.0 * (time.perf_counter() - t_qeval):.0f}ms" + ) + log0(f"final_roundtrip_exact val_loss:{q_val_loss:.8f} val_bpb:{q_val_bpb:.8f}") + + # Sliding window eval + if args.eval_stride > 0: + base_model.load_state_dict(dequantize_state_dict_int8(quant_state), strict=True) + torch.cuda.synchronize() + t_sw = time.perf_counter() + sw_val_loss, sw_val_bpb = eval_val_sliding( + args, + base_model, + rank, + world_size, + device, + val_tokens, + base_bytes_lut, + has_leading_space_lut, + is_boundary_token_lut, + ) + torch.cuda.synchronize() + log0( + f"final_sliding_window val_loss:{sw_val_loss:.4f} val_bpb:{sw_val_bpb:.4f} " + f"window:{args.eval_seq_len} stride:{args.eval_stride} " + f"eval_time:{1000.0 * (time.perf_counter() - t_sw):.0f}ms" + ) + log0(f"final_sliding_window_exact val_loss:{sw_val_loss:.8f} val_bpb:{sw_val_bpb:.8f}") + + # Hedge Mixer eval (parallel forward across all GPUs, mixer on rank 0) + if args.use_hedge and args.eval_stride > 0: + base_model.load_state_dict(dequantize_state_dict_int8(quant_state), strict=True) + torch.cuda.synchronize() + t_hm = time.perf_counter() + hm_val_loss, hm_val_bpb = eval_val_sliding( + args, base_model, rank, world_size, device, + val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, + use_hedge=True, hedge_eta=args.hedge_eta, + ) + torch.cuda.synchronize() + log0( + f"final_hedge_mixer val_loss:{hm_val_loss:.4f} val_bpb:{hm_val_bpb:.4f} " + f"eval_time:{1000.0 * (time.perf_counter() - t_hm):.0f}ms" + ) + log0(f"final_hedge_mixer_exact val_loss:{hm_val_loss:.8f} val_bpb:{hm_val_bpb:.8f}") + + # Destroy DDP after all evals + if distributed: + dist.destroy_process_group() + +if __name__ == "__main__": + main() diff --git a/records/track_non_record_16mb/2026-04-08_DepthRecurrence_Int7MixedQuant_HedgeMixer/train_seed1337.log b/records/track_non_record_16mb/2026-04-08_DepthRecurrence_Int7MixedQuant_HedgeMixer/train_seed1337.log new file mode 100644 index 0000000000..6766a6ffe9 --- /dev/null +++ b/records/track_non_record_16mb/2026-04-08_DepthRecurrence_Int7MixedQuant_HedgeMixer/train_seed1337.log @@ -0,0 +1,295 @@ +W0407 17:45:44.409000 26546 torch/distributed/run.py:793] +W0407 17:45:44.409000 26546 torch/distributed/run.py:793] ***************************************** +W0407 17:45:44.409000 26546 torch/distributed/run.py:793] Setting OMP_NUM_THREADS environment variable for each process to be 1 in default, to avoid your system being overloaded, please further tune the variable for optimal performance in your application as needed. +W0407 17:45:44.409000 26546 torch/distributed/run.py:793] ***************************************** +logs/6a6652b3-fd76-4cf4-b620-e2c2fe011f94.txt +val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path=./data/tokenizers/fineweb_1024_bpe.model +train_loader:dataset:fineweb10B_sp1024 train_shards:80 +val_loader:shards pattern=./data/datasets/fineweb10B_sp1024/fineweb_val_*.bin tokens:62021632 +model_params:23662344 +world_size:8 grad_accum_steps:1 +sdp_backends:cudnn=False flash=True mem_efficient=False math=False +attention_mode:gqa num_heads:8 num_kv_heads:4 +tie_embeddings:True embed_lr:0.021 head_lr:0.0 matrix_lr:0.018 scalar_lr:0.018 +train_batch_tokens:524288 train_seq_len:1024 iterations:20000 warmup_steps:20 max_wallclock_seconds:600.000 +seed:1337 +warmup_step:1/20 +warmup_step:2/20 +warmup_step:3/20 +warmup_step:4/20 +warmup_step:5/20 +warmup_step:6/20 +warmup_step:7/20 +warmup_step:8/20 +warmup_step:9/20 +warmup_step:10/20 +warmup_step:11/20 +warmup_step:12/20 +warmup_step:13/20 +warmup_step:14/20 +warmup_step:15/20 +warmup_step:16/20 +warmup_step:17/20 +warmup_step:18/20 +warmup_step:19/20 +warmup_step:20/20 +prog_depth: schedule=[(0.3, 2), (0.5, 3), (1.0, 4)] starting_repeats=2 +step:0/20000 val_loss:6.9331 val_bpb:4.1062 train_time:0ms step_avg:0.02ms +step:1/20000 train_loss:6.9943 train_time:13115ms step_avg:13115.41ms +step:2/20000 train_loss:10.1329 train_time:13135ms step_avg:6567.35ms +step:3/20000 train_loss:9.9791 train_time:13221ms step_avg:4407.06ms +step:4/20000 train_loss:9.6349 train_time:13309ms step_avg:3327.31ms +step:5/20000 train_loss:9.0052 train_time:13399ms step_avg:2679.85ms +step:6/20000 train_loss:8.4087 train_time:13489ms step_avg:2248.15ms +step:7/20000 train_loss:7.4150 train_time:13580ms step_avg:1939.98ms +step:8/20000 train_loss:6.7232 train_time:13669ms step_avg:1708.68ms +step:9/20000 train_loss:6.1753 train_time:13760ms step_avg:1528.92ms +step:10/20000 train_loss:5.8159 train_time:13849ms step_avg:1384.95ms +step:50/20000 val_loss:4.1892 val_bpb:2.4811 train_time:17468ms step_avg:349.36ms +step:100/20000 val_loss:3.3132 val_bpb:1.9623 train_time:21923ms step_avg:219.23ms +step:150/20000 val_loss:2.8768 val_bpb:1.7038 train_time:26392ms step_avg:175.95ms +step:200/20000 train_loss:2.7423 train_time:30877ms step_avg:154.38ms +step:200/20000 val_loss:2.7293 val_bpb:1.6164 train_time:30928ms step_avg:154.64ms +step:250/20000 val_loss:2.6200 val_bpb:1.5517 train_time:35410ms step_avg:141.64ms +step:300/20000 val_loss:2.5608 val_bpb:1.5167 train_time:39901ms step_avg:133.00ms +step:350/20000 val_loss:2.5257 val_bpb:1.4959 train_time:44396ms step_avg:126.85ms +step:400/20000 train_loss:2.2840 train_time:48910ms step_avg:122.27ms +step:400/20000 val_loss:2.4928 val_bpb:1.4764 train_time:48967ms step_avg:122.42ms +step:450/20000 val_loss:2.4617 val_bpb:1.4580 train_time:53466ms step_avg:118.81ms +step:500/20000 val_loss:2.4430 val_bpb:1.4469 train_time:57968ms step_avg:115.94ms +step:550/20000 val_loss:2.4277 val_bpb:1.4378 train_time:62468ms step_avg:113.58ms +step:600/20000 train_loss:2.5024 train_time:66982ms step_avg:111.64ms +step:600/20000 val_loss:2.4049 val_bpb:1.4243 train_time:67039ms step_avg:111.73ms +step:650/20000 val_loss:2.3913 val_bpb:1.4163 train_time:71540ms step_avg:110.06ms +step:700/20000 val_loss:2.3708 val_bpb:1.4041 train_time:76048ms step_avg:108.64ms +step:750/20000 val_loss:2.3562 val_bpb:1.3955 train_time:80589ms step_avg:107.45ms +step:800/20000 train_loss:2.2624 train_time:85101ms step_avg:106.38ms +step:800/20000 val_loss:2.3481 val_bpb:1.3907 train_time:85159ms step_avg:106.45ms +step:850/20000 val_loss:2.3338 val_bpb:1.3822 train_time:89666ms step_avg:105.49ms +step:900/20000 val_loss:2.3242 val_bpb:1.3765 train_time:94167ms step_avg:104.63ms +step:950/20000 val_loss:2.3160 val_bpb:1.3717 train_time:98666ms step_avg:103.86ms +step:1000/20000 train_loss:2.3457 train_time:103178ms step_avg:103.18ms +step:1000/20000 val_loss:2.3067 val_bpb:1.3661 train_time:103236ms step_avg:103.24ms +step:1050/20000 val_loss:2.2990 val_bpb:1.3616 train_time:107739ms step_avg:102.61ms +step:1100/20000 val_loss:2.2933 val_bpb:1.3582 train_time:112242ms step_avg:102.04ms +step:1150/20000 val_loss:2.2929 val_bpb:1.3580 train_time:116803ms step_avg:101.57ms +step:1200/20000 train_loss:2.3646 train_time:121250ms step_avg:101.04ms +step:1200/20000 val_loss:2.2786 val_bpb:1.3495 train_time:121309ms step_avg:101.09ms +step:1250/20000 val_loss:2.2763 val_bpb:1.3481 train_time:125807ms step_avg:100.65ms +step:1300/20000 val_loss:2.2680 val_bpb:1.3432 train_time:130299ms step_avg:100.23ms +step:1350/20000 val_loss:2.2661 val_bpb:1.3421 train_time:134865ms step_avg:99.90ms +step:1400/20000 train_loss:2.4132 train_time:139304ms step_avg:99.50ms +step:1400/20000 val_loss:2.2606 val_bpb:1.3388 train_time:139362ms step_avg:99.54ms +step:1450/20000 val_loss:2.2569 val_bpb:1.3367 train_time:143859ms step_avg:99.21ms +step:1500/20000 val_loss:2.2532 val_bpb:1.3345 train_time:148351ms step_avg:98.90ms +step:1550/20000 val_loss:2.2564 val_bpb:1.3364 train_time:152914ms step_avg:98.65ms +step:1600/20000 train_loss:2.0841 train_time:157349ms step_avg:98.34ms +step:1600/20000 val_loss:2.2499 val_bpb:1.3325 train_time:157407ms step_avg:98.38ms +step:1650/20000 val_loss:2.2437 val_bpb:1.3288 train_time:161903ms step_avg:98.12ms +step:1700/20000 val_loss:2.2387 val_bpb:1.3259 train_time:166395ms step_avg:97.88ms +step:1750/20000 val_loss:2.2371 val_bpb:1.3249 train_time:170950ms step_avg:97.69ms +step:1800/20000 train_loss:2.1831 train_time:175380ms step_avg:97.43ms +step:1800/20000 val_loss:2.2362 val_bpb:1.3244 train_time:175438ms step_avg:97.47ms +step:1850/20000 val_loss:2.2298 val_bpb:1.3206 train_time:179926ms step_avg:97.26ms +prog_depth: switched to 3 repeats at step:1851 frac:0.30 +step:1900/20000 val_loss:2.2491 val_bpb:1.3321 train_time:204494ms step_avg:107.63ms +step:1950/20000 val_loss:2.2267 val_bpb:1.3188 train_time:211170ms step_avg:108.29ms +step:2000/20000 train_loss:2.2371 train_time:217713ms step_avg:108.86ms +step:2000/20000 val_loss:2.2152 val_bpb:1.3120 train_time:217806ms step_avg:108.90ms +step:2050/20000 val_loss:2.2118 val_bpb:1.3100 train_time:224442ms step_avg:109.48ms +step:2100/20000 val_loss:2.2142 val_bpb:1.3114 train_time:231154ms step_avg:110.07ms +step:2150/20000 val_loss:2.2049 val_bpb:1.3059 train_time:237816ms step_avg:110.61ms +step:2200/20000 train_loss:2.0426 train_time:244382ms step_avg:111.08ms +step:2200/20000 val_loss:2.2016 val_bpb:1.3039 train_time:244473ms step_avg:111.12ms +step:2250/20000 val_loss:2.2006 val_bpb:1.3033 train_time:251143ms step_avg:111.62ms +step:2300/20000 val_loss:2.1949 val_bpb:1.3000 train_time:257858ms step_avg:112.11ms +step:2350/20000 val_loss:2.1934 val_bpb:1.2990 train_time:264514ms step_avg:112.56ms +step:2400/20000 train_loss:2.1688 train_time:271085ms step_avg:112.95ms +step:2400/20000 val_loss:2.1873 val_bpb:1.2954 train_time:271175ms step_avg:112.99ms +step:2450/20000 val_loss:2.1832 val_bpb:1.2930 train_time:277832ms step_avg:113.40ms +step:2500/20000 val_loss:2.1776 val_bpb:1.2897 train_time:284551ms step_avg:113.82ms +step:2550/20000 val_loss:2.1742 val_bpb:1.2877 train_time:291202ms step_avg:114.20ms +step:2600/20000 train_loss:2.3733 train_time:297764ms step_avg:114.52ms +step:2600/20000 val_loss:2.1771 val_bpb:1.2894 train_time:297854ms step_avg:114.56ms +prog_depth: switched to 4 repeats at step:2617 frac:0.50 +step:2650/20000 val_loss:2.1750 val_bpb:1.2882 train_time:317247ms step_avg:119.72ms +step:2700/20000 val_loss:2.1576 val_bpb:1.2778 train_time:326061ms step_avg:120.76ms +step:2750/20000 val_loss:2.1505 val_bpb:1.2737 train_time:334837ms step_avg:121.76ms +step:2800/20000 train_loss:2.1765 train_time:343536ms step_avg:122.69ms +step:2800/20000 val_loss:2.1426 val_bpb:1.2690 train_time:343626ms step_avg:122.72ms +step:2850/20000 val_loss:2.1363 val_bpb:1.2653 train_time:352437ms step_avg:123.66ms +step:2900/20000 val_loss:2.1320 val_bpb:1.2627 train_time:361316ms step_avg:124.59ms +step:2950/20000 val_loss:2.1271 val_bpb:1.2598 train_time:370141ms step_avg:125.47ms +swa:start step:3000 +step:3000/20000 train_loss:2.1503 train_time:378869ms step_avg:126.29ms +step:3000/20000 val_loss:2.1216 val_bpb:1.2565 train_time:379013ms step_avg:126.34ms +step:3050/20000 val_loss:2.1204 val_bpb:1.2558 train_time:387855ms step_avg:127.17ms +step:3100/20000 val_loss:2.1130 val_bpb:1.2515 train_time:396766ms step_avg:127.99ms +step:3150/20000 val_loss:2.1111 val_bpb:1.2503 train_time:405626ms step_avg:128.77ms +step:3200/20000 train_loss:2.1037 train_time:414367ms step_avg:129.49ms +step:3200/20000 val_loss:2.1049 val_bpb:1.2466 train_time:414456ms step_avg:129.52ms +step:3250/20000 val_loss:2.1014 val_bpb:1.2446 train_time:423388ms step_avg:130.27ms +step:3300/20000 val_loss:2.0977 val_bpb:1.2424 train_time:432248ms step_avg:130.98ms +step:3350/20000 val_loss:2.0949 val_bpb:1.2407 train_time:441082ms step_avg:131.67ms +step:3400/20000 train_loss:2.0645 train_time:449853ms step_avg:132.31ms +step:3400/20000 val_loss:2.0901 val_bpb:1.2379 train_time:449940ms step_avg:132.34ms +step:3450/20000 val_loss:2.0866 val_bpb:1.2358 train_time:458859ms step_avg:133.00ms +step:3500/20000 val_loss:2.0810 val_bpb:1.2325 train_time:467688ms step_avg:133.63ms +step:3550/20000 val_loss:2.0772 val_bpb:1.2302 train_time:476565ms step_avg:134.24ms +step:3600/20000 train_loss:2.0100 train_time:485313ms step_avg:134.81ms +step:3600/20000 val_loss:2.0728 val_bpb:1.2276 train_time:485427ms step_avg:134.84ms +step:3650/20000 val_loss:2.0693 val_bpb:1.2255 train_time:494314ms step_avg:135.43ms +step:3700/20000 val_loss:2.0659 val_bpb:1.2236 train_time:503167ms step_avg:135.99ms +step:3750/20000 val_loss:2.0624 val_bpb:1.2215 train_time:512023ms step_avg:136.54ms +step:3800/20000 train_loss:2.1044 train_time:520751ms step_avg:137.04ms +step:3800/20000 val_loss:2.0582 val_bpb:1.2190 train_time:520839ms step_avg:137.06ms +step:3850/20000 val_loss:2.0549 val_bpb:1.2170 train_time:529747ms step_avg:137.60ms +step:3900/20000 val_loss:2.0515 val_bpb:1.2150 train_time:538610ms step_avg:138.11ms +step:3950/20000 val_loss:2.0474 val_bpb:1.2126 train_time:547443ms step_avg:138.59ms +step:4000/20000 train_loss:2.0431 train_time:556202ms step_avg:139.05ms +step:4000/20000 val_loss:2.0443 val_bpb:1.2107 train_time:556290ms step_avg:139.07ms +step:4050/20000 val_loss:2.0412 val_bpb:1.2089 train_time:565198ms step_avg:139.56ms +step:4100/20000 val_loss:2.0383 val_bpb:1.2072 train_time:574028ms step_avg:140.01ms +step:4150/20000 val_loss:2.0359 val_bpb:1.2058 train_time:582879ms step_avg:140.45ms +step:4200/20000 train_loss:2.0466 train_time:591664ms step_avg:140.87ms +step:4200/20000 val_loss:2.0339 val_bpb:1.2046 train_time:591778ms step_avg:140.90ms +step:4247/20000 val_loss:2.0330 val_bpb:1.2041 train_time:600067ms step_avg:141.29ms +stopping_early: wallclock_cap train_time:600067ms step:4247/20000 +peak memory allocated: 30666 MiB reserved: 31708 MiB +swa: averaging 43 checkpoints +Serialized model: 89176650 bytes +Code size: 76171 bytes +Total submission size: 89252821 bytes +Serialized model int8+zstd22: 15403955 bytes (payload:23776800 raw_torch:23793515 payload_ratio:3.75x) +Total submission size int8+zstd22: 15480126 bytes +/workspace/parameter-golf/train_gpt_refactored.py:1584: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature. + quant_state = torch.load(io.BytesIO(dctx.decompress(quant_blob_disk)), map_location="cpu") +/workspace/parameter-golf/train_gpt_refactored.py:1584: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature. + quant_state = torch.load(io.BytesIO(dctx.decompress(quant_blob_disk)), map_location="cpu") +/workspace/parameter-golf/train_gpt_refactored.py:1584: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature. + quant_state = torch.load(io.BytesIO(dctx.decompress(quant_blob_disk)), map_location="cpu") +/workspace/parameter-golf/train_gpt_refactored.py:1584: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature. + quant_state = torch.load(io.BytesIO(dctx.decompress(quant_blob_disk)), map_location="cpu") +/workspace/parameter-golf/train_gpt_refactored.py:1584: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature. + quant_state = torch.load(io.BytesIO(dctx.decompress(quant_blob_disk)), map_location="cpu") +/workspace/parameter-golf/train_gpt_refactored.py:1584: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature. + quant_state = torch.load(io.BytesIO(dctx.decompress(quant_blob_disk)), map_location="cpu") +/workspace/parameter-golf/train_gpt_refactored.py:1584: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature. + quant_state = torch.load(io.BytesIO(dctx.decompress(quant_blob_disk)), map_location="cpu") +/workspace/parameter-golf/train_gpt_refactored.py:1584: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature. + quant_state = torch.load(io.BytesIO(dctx.decompress(quant_blob_disk)), map_location="cpu") +final_roundtrip val_loss:2.0545 val_bpb:1.2168 eval_time:14726ms +final_roundtrip_exact val_loss:2.05447891 val_bpb:1.21677813 +final_sliding_window val_loss:1.9978 val_bpb:1.1832 window:1024 stride:256 eval_time:78746ms +final_sliding_window_exact val_loss:1.99781023 val_bpb:1.18321525 +[rank7]: Traceback (most recent call last): +[rank7]: File "/workspace/parameter-golf/train_gpt_refactored.py", line 1670, in +[rank7]: main() +[rank7]: File "/workspace/parameter-golf/train_gpt_refactored.py", line 1636, in main +[rank7]: hm_val_loss, hm_val_bpb = eval_val_sliding( +[rank7]: ^^^^^^^^^^^^^^^^^ +[rank7]: File "/workspace/parameter-golf/train_gpt_refactored.py", line 476, in eval_val_sliding +[rank7]: tokens_per_byte = val_token_count.item() / val_byte_count.item() +[rank7]: ~~~~~~~~~~~~~~~~~~~~~~~^~~~~~~~~~~~~~~~~~~~~~~ +[rank7]: ZeroDivisionError: float division by zero +[rank1]: Traceback (most recent call last): +[rank1]: File "/workspace/parameter-golf/train_gpt_refactored.py", line 1670, in +[rank1]: main() +[rank1]: File "/workspace/parameter-golf/train_gpt_refactored.py", line 1636, in main +[rank1]: hm_val_loss, hm_val_bpb = eval_val_sliding( +[rank1]: ^^^^^^^^^^^^^^^^^ +[rank1]: File "/workspace/parameter-golf/train_gpt_refactored.py", line 476, in eval_val_sliding +[rank1]: tokens_per_byte = val_token_count.item() / val_byte_count.item() +[rank1]: ~~~~~~~~~~~~~~~~~~~~~~~^~~~~~~~~~~~~~~~~~~~~~~ +[rank1]: ZeroDivisionError: float division by zero +[rank2]: Traceback (most recent call last): +[rank2]: File "/workspace/parameter-golf/train_gpt_refactored.py", line 1670, in +[rank2]: main() +[rank2]: File "/workspace/parameter-golf/train_gpt_refactored.py", line 1636, in main +[rank2]: hm_val_loss, hm_val_bpb = eval_val_sliding( +[rank2]: ^^^^^^^^^^^^^^^^^ +[rank2]: File "/workspace/parameter-golf/train_gpt_refactored.py", line 476, in eval_val_sliding +[rank2]: tokens_per_byte = val_token_count.item() / val_byte_count.item() +[rank2]: ~~~~~~~~~~~~~~~~~~~~~~~^~~~~~~~~~~~~~~~~~~~~~~ +[rank2]: ZeroDivisionError: float division by zero +[rank3]: Traceback (most recent call last): +[rank3]: File "/workspace/parameter-golf/train_gpt_refactored.py", line 1670, in +[rank3]: main() +[rank3]: File "/workspace/parameter-golf/train_gpt_refactored.py", line 1636, in main +[rank3]: hm_val_loss, hm_val_bpb = eval_val_sliding( +[rank3]: ^^^^^^^^^^^^^^^^^ +[rank3]: File "/workspace/parameter-golf/train_gpt_refactored.py", line 476, in eval_val_sliding +[rank3]: tokens_per_byte = val_token_count.item() / val_byte_count.item() +[rank3]: ~~~~~~~~~~~~~~~~~~~~~~~^~~~~~~~~~~~~~~~~~~~~~~ +[rank3]: ZeroDivisionError: float division by zero +[rank4]: Traceback (most recent call last): +[rank4]: File "/workspace/parameter-golf/train_gpt_refactored.py", line 1670, in +[rank4]: main() +[rank4]: File "/workspace/parameter-golf/train_gpt_refactored.py", line 1636, in main +[rank4]: hm_val_loss, hm_val_bpb = eval_val_sliding( +[rank4]: ^^^^^^^^^^^^^^^^^ +[rank4]: File "/workspace/parameter-golf/train_gpt_refactored.py", line 476, in eval_val_sliding +[rank4]: tokens_per_byte = val_token_count.item() / val_byte_count.item() +[rank4]: ~~~~~~~~~~~~~~~~~~~~~~~^~~~~~~~~~~~~~~~~~~~~~~ +[rank4]: ZeroDivisionError: float division by zero +[rank5]: Traceback (most recent call last): +[rank5]: File "/workspace/parameter-golf/train_gpt_refactored.py", line 1670, in +[rank5]: main() +[rank5]: File "/workspace/parameter-golf/train_gpt_refactored.py", line 1636, in main +[rank5]: hm_val_loss, hm_val_bpb = eval_val_sliding( +[rank5]: ^^^^^^^^^^^^^^^^^ +[rank5]: File "/workspace/parameter-golf/train_gpt_refactored.py", line 476, in eval_val_sliding +[rank5]: tokens_per_byte = val_token_count.item() / val_byte_count.item() +[rank5]: ~~~~~~~~~~~~~~~~~~~~~~~^~~~~~~~~~~~~~~~~~~~~~~ +[rank5]: ZeroDivisionError: float division by zero +[rank6]: Traceback (most recent call last): +[rank6]: File "/workspace/parameter-golf/train_gpt_refactored.py", line 1670, in +[rank6]: main() +[rank6]: File "/workspace/parameter-golf/train_gpt_refactored.py", line 1636, in main +[rank6]: hm_val_loss, hm_val_bpb = eval_val_sliding( +[rank6]: ^^^^^^^^^^^^^^^^^ +[rank6]: File "/workspace/parameter-golf/train_gpt_refactored.py", line 476, in eval_val_sliding +[rank6]: tokens_per_byte = val_token_count.item() / val_byte_count.item() +[rank6]: ~~~~~~~~~~~~~~~~~~~~~~~^~~~~~~~~~~~~~~~~~~~~~~ +[rank6]: ZeroDivisionError: float division by zero +final_hedge_mixer val_loss:1.9120 val_bpb:1.1324 eval_time:166534ms +final_hedge_mixer_exact val_loss:1.91197327 val_bpb:1.13237779 +W0407 18:09:06.564000 26546 torch/distributed/elastic/multiprocessing/api.py:897] Sending process 26612 closing signal SIGTERM +W0407 18:09:06.569000 26546 torch/distributed/elastic/multiprocessing/api.py:897] Sending process 26614 closing signal SIGTERM +W0407 18:09:06.573000 26546 torch/distributed/elastic/multiprocessing/api.py:897] Sending process 26615 closing signal SIGTERM +W0407 18:09:06.576000 26546 torch/distributed/elastic/multiprocessing/api.py:897] Sending process 26616 closing signal SIGTERM +W0407 18:09:06.579000 26546 torch/distributed/elastic/multiprocessing/api.py:897] Sending process 26617 closing signal SIGTERM +W0407 18:09:06.586000 26546 torch/distributed/elastic/multiprocessing/api.py:897] Sending process 26618 closing signal SIGTERM +E0407 18:09:08.071000 26546 torch/distributed/elastic/multiprocessing/api.py:869] failed (exitcode: 1) local_rank: 2 (pid: 26613) of binary: /usr/bin/python +Traceback (most recent call last): + File "/usr/local/bin/torchrun", line 8, in + sys.exit(main()) + ^^^^^^ + File "/usr/local/lib/python3.11/dist-packages/torch/distributed/elastic/multiprocessing/errors/__init__.py", line 355, in wrapper + return f(*args, **kwargs) + ^^^^^^^^^^^^^^^^^^ + File "/usr/local/lib/python3.11/dist-packages/torch/distributed/run.py", line 919, in main + run(args) + File "/usr/local/lib/python3.11/dist-packages/torch/distributed/run.py", line 910, in run + elastic_launch( + File "/usr/local/lib/python3.11/dist-packages/torch/distributed/launcher/api.py", line 138, in __call__ + return launch_agent(self._config, self._entrypoint, list(args)) + ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + File "/usr/local/lib/python3.11/dist-packages/torch/distributed/launcher/api.py", line 269, in launch_agent + raise ChildFailedError( +torch.distributed.elastic.multiprocessing.errors.ChildFailedError: +============================================================ +train_gpt_refactored.py FAILED +------------------------------------------------------------ +Failures: + +------------------------------------------------------------ +Root Cause (first observed failure): +[0]: + time : 2026-04-07_18:09:06 + host : 2e165e16ceb6 + rank : 2 (local_rank: 2) + exitcode : 1 (pid: 26613) + error_file: + traceback : To enable traceback see: https://pytorch.org/docs/stable/elastic/errors.html +============================================================ diff --git a/records/track_non_record_16mb/2026-04-08_DepthRecurrence_Int7MixedQuant_HedgeMixer/train_seed42.log b/records/track_non_record_16mb/2026-04-08_DepthRecurrence_Int7MixedQuant_HedgeMixer/train_seed42.log new file mode 100644 index 0000000000..a85cf83e91 --- /dev/null +++ b/records/track_non_record_16mb/2026-04-08_DepthRecurrence_Int7MixedQuant_HedgeMixer/train_seed42.log @@ -0,0 +1,298 @@ +W0407 18:10:10.426000 41945 torch/distributed/run.py:793] +W0407 18:10:10.426000 41945 torch/distributed/run.py:793] ***************************************** +W0407 18:10:10.426000 41945 torch/distributed/run.py:793] Setting OMP_NUM_THREADS environment variable for each process to be 1 in default, to avoid your system being overloaded, please further tune the variable for optimal performance in your application as needed. +W0407 18:10:10.426000 41945 torch/distributed/run.py:793] ***************************************** +logs/bcc0f2a5-feb5-4e3f-81b3-77d2f0053cb4.txt +val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path=./data/tokenizers/fineweb_1024_bpe.model +train_loader:dataset:fineweb10B_sp1024 train_shards:80 +val_loader:shards pattern=./data/datasets/fineweb10B_sp1024/fineweb_val_*.bin tokens:62021632 +model_params:23662344 +world_size:8 grad_accum_steps:1 +sdp_backends:cudnn=False flash=True mem_efficient=False math=False +attention_mode:gqa num_heads:8 num_kv_heads:4 +tie_embeddings:True embed_lr:0.021 head_lr:0.0 matrix_lr:0.018 scalar_lr:0.018 +train_batch_tokens:524288 train_seq_len:1024 iterations:20000 warmup_steps:20 max_wallclock_seconds:600.000 +seed:42 +warmup_step:1/20 +warmup_step:2/20 +warmup_step:3/20 +warmup_step:4/20 +warmup_step:5/20 +warmup_step:6/20 +warmup_step:7/20 +warmup_step:8/20 +warmup_step:9/20 +warmup_step:10/20 +warmup_step:11/20 +warmup_step:12/20 +warmup_step:13/20 +warmup_step:14/20 +warmup_step:15/20 +warmup_step:16/20 +warmup_step:17/20 +warmup_step:18/20 +warmup_step:19/20 +warmup_step:20/20 +prog_depth: schedule=[(0.3, 2), (0.5, 3), (1.0, 4)] starting_repeats=2 +step:0/20000 val_loss:6.9317 val_bpb:4.1053 train_time:0ms step_avg:0.01ms +step:1/20000 train_loss:6.9905 train_time:5808ms step_avg:5807.92ms +step:2/20000 train_loss:10.0369 train_time:5828ms step_avg:2913.79ms +step:3/20000 train_loss:9.7214 train_time:5914ms step_avg:1971.24ms +step:4/20000 train_loss:8.9514 train_time:6002ms step_avg:1500.47ms +step:5/20000 train_loss:7.7804 train_time:6091ms step_avg:1218.18ms +step:6/20000 train_loss:6.9234 train_time:6182ms step_avg:1030.32ms +step:7/20000 train_loss:6.0257 train_time:6272ms step_avg:896.07ms +step:8/20000 train_loss:5.7572 train_time:6375ms step_avg:796.85ms +step:9/20000 train_loss:5.5914 train_time:6452ms step_avg:716.92ms +step:10/20000 train_loss:5.4264 train_time:6541ms step_avg:654.14ms +step:50/20000 val_loss:4.0061 val_bpb:2.3727 train_time:10179ms step_avg:203.58ms +step:100/20000 val_loss:3.1466 val_bpb:1.8636 train_time:14652ms step_avg:146.52ms +step:150/20000 val_loss:2.8040 val_bpb:1.6607 train_time:19127ms step_avg:127.51ms +step:200/20000 train_loss:2.7127 train_time:23627ms step_avg:118.13ms +step:200/20000 val_loss:2.6859 val_bpb:1.5908 train_time:23685ms step_avg:118.43ms +step:250/20000 val_loss:2.6028 val_bpb:1.5415 train_time:28182ms step_avg:112.73ms +step:300/20000 val_loss:2.5501 val_bpb:1.5103 train_time:32678ms step_avg:108.93ms +step:350/20000 val_loss:2.5103 val_bpb:1.4868 train_time:37179ms step_avg:106.23ms +step:400/20000 train_loss:2.2813 train_time:41703ms step_avg:104.26ms +step:400/20000 val_loss:2.4880 val_bpb:1.4735 train_time:41762ms step_avg:104.40ms +step:450/20000 val_loss:2.4562 val_bpb:1.4547 train_time:46274ms step_avg:102.83ms +step:500/20000 val_loss:2.4407 val_bpb:1.4455 train_time:50788ms step_avg:101.58ms +step:550/20000 val_loss:2.4227 val_bpb:1.4349 train_time:55298ms step_avg:100.54ms +step:600/20000 train_loss:2.5027 train_time:59826ms step_avg:99.71ms +step:600/20000 val_loss:2.4037 val_bpb:1.4236 train_time:59883ms step_avg:99.81ms +step:650/20000 val_loss:2.3894 val_bpb:1.4151 train_time:64394ms step_avg:99.07ms +step:700/20000 val_loss:2.3729 val_bpb:1.4053 train_time:68903ms step_avg:98.43ms +step:750/20000 val_loss:2.3590 val_bpb:1.3972 train_time:73407ms step_avg:97.88ms +step:800/20000 train_loss:2.2626 train_time:77932ms step_avg:97.42ms +step:800/20000 val_loss:2.3515 val_bpb:1.3927 train_time:77990ms step_avg:97.49ms +step:850/20000 val_loss:2.3396 val_bpb:1.3857 train_time:82491ms step_avg:97.05ms +step:900/20000 val_loss:2.3324 val_bpb:1.3814 train_time:87000ms step_avg:96.67ms +step:950/20000 val_loss:2.3224 val_bpb:1.3755 train_time:91507ms step_avg:96.32ms +step:1000/20000 train_loss:2.3521 train_time:96029ms step_avg:96.03ms +step:1000/20000 val_loss:2.3126 val_bpb:1.3697 train_time:96087ms step_avg:96.09ms +step:1050/20000 val_loss:2.3054 val_bpb:1.3654 train_time:100597ms step_avg:95.81ms +step:1100/20000 val_loss:2.3003 val_bpb:1.3624 train_time:105104ms step_avg:95.55ms +step:1150/20000 val_loss:2.3010 val_bpb:1.3628 train_time:109671ms step_avg:95.37ms +step:1200/20000 train_loss:2.3687 train_time:114117ms step_avg:95.10ms +step:1200/20000 val_loss:2.2880 val_bpb:1.3551 train_time:114175ms step_avg:95.15ms +step:1250/20000 val_loss:2.2835 val_bpb:1.3524 train_time:118676ms step_avg:94.94ms +step:1300/20000 val_loss:2.2768 val_bpb:1.3484 train_time:123180ms step_avg:94.75ms +step:1350/20000 val_loss:2.2736 val_bpb:1.3466 train_time:127762ms step_avg:94.64ms +step:1400/20000 train_loss:2.4197 train_time:132212ms step_avg:94.44ms +step:1400/20000 val_loss:2.2679 val_bpb:1.3432 train_time:132269ms step_avg:94.48ms +step:1450/20000 val_loss:2.2654 val_bpb:1.3417 train_time:136765ms step_avg:94.32ms +step:1500/20000 val_loss:2.2603 val_bpb:1.3387 train_time:141265ms step_avg:94.18ms +step:1550/20000 val_loss:2.2624 val_bpb:1.3399 train_time:145842ms step_avg:94.09ms +step:1600/20000 train_loss:2.0892 train_time:150283ms step_avg:93.93ms +step:1600/20000 val_loss:2.2566 val_bpb:1.3365 train_time:150341ms step_avg:93.96ms +step:1650/20000 val_loss:2.2515 val_bpb:1.3335 train_time:154847ms step_avg:93.85ms +step:1700/20000 val_loss:2.2473 val_bpb:1.3310 train_time:159346ms step_avg:93.73ms +step:1750/20000 val_loss:2.2456 val_bpb:1.3300 train_time:163924ms step_avg:93.67ms +step:1800/20000 train_loss:2.1935 train_time:168364ms step_avg:93.54ms +step:1800/20000 val_loss:2.2439 val_bpb:1.3290 train_time:168422ms step_avg:93.57ms +step:1850/20000 val_loss:2.2371 val_bpb:1.3249 train_time:172920ms step_avg:93.47ms +step:1900/20000 val_loss:2.2405 val_bpb:1.3269 train_time:177418ms step_avg:93.38ms +prog_depth: switched to 3 repeats at step:1928 frac:0.30 +step:1950/20000 val_loss:2.2777 val_bpb:1.3490 train_time:192116ms step_avg:98.52ms +step:2000/20000 train_loss:2.2469 train_time:198652ms step_avg:99.33ms +step:2000/20000 val_loss:2.2375 val_bpb:1.3252 train_time:198742ms step_avg:99.37ms +step:2050/20000 val_loss:2.2255 val_bpb:1.3181 train_time:205377ms step_avg:100.18ms +step:2100/20000 val_loss:2.2253 val_bpb:1.3179 train_time:212097ms step_avg:101.00ms +step:2150/20000 val_loss:2.2141 val_bpb:1.3113 train_time:218742ms step_avg:101.74ms +step:2200/20000 train_loss:2.0529 train_time:225293ms step_avg:102.41ms +step:2200/20000 val_loss:2.2100 val_bpb:1.3089 train_time:225383ms step_avg:102.45ms +step:2250/20000 val_loss:2.2093 val_bpb:1.3085 train_time:232034ms step_avg:103.13ms +step:2300/20000 val_loss:2.2023 val_bpb:1.3044 train_time:238745ms step_avg:103.80ms +step:2350/20000 val_loss:2.2026 val_bpb:1.3045 train_time:245395ms step_avg:104.42ms +step:2400/20000 train_loss:2.1791 train_time:251950ms step_avg:104.98ms +step:2400/20000 val_loss:2.1957 val_bpb:1.3004 train_time:252039ms step_avg:105.02ms +step:2450/20000 val_loss:2.1960 val_bpb:1.3006 train_time:258686ms step_avg:105.59ms +step:2500/20000 val_loss:2.1899 val_bpb:1.2970 train_time:265415ms step_avg:106.17ms +step:2550/20000 val_loss:2.1885 val_bpb:1.2961 train_time:272065ms step_avg:106.69ms +step:2600/20000 train_loss:2.3936 train_time:278616ms step_avg:107.16ms +step:2600/20000 val_loss:2.1956 val_bpb:1.3003 train_time:278706ms step_avg:107.19ms +step:2650/20000 val_loss:2.1860 val_bpb:1.2947 train_time:285360ms step_avg:107.68ms +step:2700/20000 val_loss:2.1820 val_bpb:1.2923 train_time:292073ms step_avg:108.18ms +step:2750/20000 val_loss:2.1788 val_bpb:1.2904 train_time:298718ms step_avg:108.62ms +prog_depth: switched to 4 repeats at step:2760 frac:0.50 +step:2800/20000 train_loss:2.2110 train_time:318478ms step_avg:113.74ms +step:2800/20000 val_loss:2.1757 val_bpb:1.2886 train_time:318566ms step_avg:113.77ms +step:2850/20000 val_loss:2.1603 val_bpb:1.2794 train_time:327323ms step_avg:114.85ms +step:2900/20000 val_loss:2.1521 val_bpb:1.2746 train_time:336174ms step_avg:115.92ms +step:2950/20000 val_loss:2.1455 val_bpb:1.2707 train_time:344963ms step_avg:116.94ms +step:3000/20000 train_loss:2.1686 train_time:353670ms step_avg:117.89ms +step:3000/20000 val_loss:2.1385 val_bpb:1.2665 train_time:353758ms step_avg:117.92ms +step:3050/20000 val_loss:2.1375 val_bpb:1.2660 train_time:362562ms step_avg:118.87ms +step:3100/20000 val_loss:2.1295 val_bpb:1.2612 train_time:371433ms step_avg:119.82ms +step:3150/20000 val_loss:2.1271 val_bpb:1.2598 train_time:380236ms step_avg:120.71ms +swa:start step:3180 +step:3200/20000 train_loss:2.1184 train_time:389003ms step_avg:121.56ms +step:3200/20000 val_loss:2.1206 val_bpb:1.2560 train_time:389091ms step_avg:121.59ms +step:3250/20000 val_loss:2.1170 val_bpb:1.2538 train_time:398022ms step_avg:122.47ms +step:3300/20000 val_loss:2.1129 val_bpb:1.2514 train_time:406892ms step_avg:123.30ms +step:3350/20000 val_loss:2.1098 val_bpb:1.2495 train_time:415717ms step_avg:124.09ms +step:3400/20000 train_loss:2.0800 train_time:424478ms step_avg:124.85ms +step:3400/20000 val_loss:2.1043 val_bpb:1.2463 train_time:424566ms step_avg:124.87ms +step:3450/20000 val_loss:2.1009 val_bpb:1.2443 train_time:433496ms step_avg:125.65ms +step:3500/20000 val_loss:2.0947 val_bpb:1.2406 train_time:442340ms step_avg:126.38ms +step:3550/20000 val_loss:2.0906 val_bpb:1.2382 train_time:451208ms step_avg:127.10ms +step:3600/20000 train_loss:2.0179 train_time:459967ms step_avg:127.77ms +step:3600/20000 val_loss:2.0857 val_bpb:1.2353 train_time:460082ms step_avg:127.80ms +step:3650/20000 val_loss:2.0820 val_bpb:1.2331 train_time:468992ms step_avg:128.49ms +step:3700/20000 val_loss:2.0790 val_bpb:1.2313 train_time:477870ms step_avg:129.15ms +step:3750/20000 val_loss:2.0751 val_bpb:1.2290 train_time:486749ms step_avg:129.80ms +step:3800/20000 train_loss:2.1125 train_time:495491ms step_avg:130.39ms +step:3800/20000 val_loss:2.0705 val_bpb:1.2263 train_time:495579ms step_avg:130.42ms +step:3850/20000 val_loss:2.0667 val_bpb:1.2240 train_time:504495ms step_avg:131.04ms +step:3900/20000 val_loss:2.0635 val_bpb:1.2221 train_time:513355ms step_avg:131.63ms +step:3950/20000 val_loss:2.0584 val_bpb:1.2191 train_time:522184ms step_avg:132.20ms +step:4000/20000 train_loss:2.0573 train_time:530943ms step_avg:132.74ms +step:4000/20000 val_loss:2.0549 val_bpb:1.2170 train_time:531031ms step_avg:132.76ms +step:4050/20000 val_loss:2.0515 val_bpb:1.2150 train_time:539973ms step_avg:133.33ms +step:4100/20000 val_loss:2.0477 val_bpb:1.2128 train_time:548795ms step_avg:133.85ms +step:4150/20000 val_loss:2.0447 val_bpb:1.2110 train_time:557674ms step_avg:134.38ms +step:4200/20000 train_loss:2.0542 train_time:566481ms step_avg:134.88ms +step:4200/20000 val_loss:2.0416 val_bpb:1.2092 train_time:566595ms step_avg:134.90ms +step:4250/20000 val_loss:2.0390 val_bpb:1.2076 train_time:575436ms step_avg:135.40ms +step:4300/20000 val_loss:2.0365 val_bpb:1.2062 train_time:584295ms step_avg:135.88ms +step:4350/20000 val_loss:2.0347 val_bpb:1.2051 train_time:593178ms step_avg:136.36ms +step:4389/20000 val_loss:2.0339 val_bpb:1.2046 train_time:600131ms step_avg:136.74ms +stopping_early: wallclock_cap train_time:600131ms step:4389/20000 +peak memory allocated: 30342 MiB reserved: 31078 MiB +swa: averaging 42 checkpoints +Serialized model: 89176650 bytes +Code size: 76171 bytes +Total submission size: 89252821 bytes +Serialized model int8+zstd22: 15282718 bytes (payload:23776800 raw_torch:23793515 payload_ratio:3.75x) +Total submission size int8+zstd22: 15358889 bytes +/workspace/parameter-golf/train_gpt_refactored.py:1584: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature. + quant_state = torch.load(io.BytesIO(dctx.decompress(quant_blob_disk)), map_location="cpu") +/workspace/parameter-golf/train_gpt_refactored.py:1584: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature. + quant_state = torch.load(io.BytesIO(dctx.decompress(quant_blob_disk)), map_location="cpu") +/workspace/parameter-golf/train_gpt_refactored.py:1584: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature. + quant_state = torch.load(io.BytesIO(dctx.decompress(quant_blob_disk)), map_location="cpu") +/workspace/parameter-golf/train_gpt_refactored.py:1584: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature. + quant_state = torch.load(io.BytesIO(dctx.decompress(quant_blob_disk)), map_location="cpu") +/workspace/parameter-golf/train_gpt_refactored.py:1584: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature. + quant_state = torch.load(io.BytesIO(dctx.decompress(quant_blob_disk)), map_location="cpu") +/workspace/parameter-golf/train_gpt_refactored.py:1584: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature. + quant_state = torch.load(io.BytesIO(dctx.decompress(quant_blob_disk)), map_location="cpu") +/workspace/parameter-golf/train_gpt_refactored.py:1584: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature. + quant_state = torch.load(io.BytesIO(dctx.decompress(quant_blob_disk)), map_location="cpu") +/workspace/parameter-golf/train_gpt_refactored.py:1584: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature. + quant_state = torch.load(io.BytesIO(dctx.decompress(quant_blob_disk)), map_location="cpu") +final_roundtrip val_loss:2.0553 val_bpb:1.2172 eval_time:15396ms +final_roundtrip_exact val_loss:2.05526682 val_bpb:1.21724478 +final_sliding_window val_loss:1.9991 val_bpb:1.1840 window:1024 stride:256 eval_time:78746ms +final_sliding_window_exact val_loss:1.99914054 val_bpb:1.18400313 +[rank4]: Traceback (most recent call last): +[rank4]: File "/workspace/parameter-golf/train_gpt_refactored.py", line 1670, in +[rank4]: main() +[rank4]: File "/workspace/parameter-golf/train_gpt_refactored.py", line 1636, in main +[rank4]: hm_val_loss, hm_val_bpb = eval_val_sliding( +[rank4]: ^^^^^^^^^^^^^^^^^ +[rank4]: File "/workspace/parameter-golf/train_gpt_refactored.py", line 476, in eval_val_sliding +[rank4]: tokens_per_byte = val_token_count.item() / val_byte_count.item() +[rank4]: ~~~~~~~~~~~~~~~~~~~~~~~^~~~~~~~~~~~~~~~~~~~~~~ +[rank4]: ZeroDivisionError: float division by zero +[rank3]: Traceback (most recent call last): +[rank3]: File "/workspace/parameter-golf/train_gpt_refactored.py", line 1670, in +[rank3]: main() +[rank3]: File "/workspace/parameter-golf/train_gpt_refactored.py", line 1636, in main +[rank3]: hm_val_loss, hm_val_bpb = eval_val_sliding( +[rank3]: ^^^^^^^^^^^^^^^^^ +[rank3]: File "/workspace/parameter-golf/train_gpt_refactored.py", line 476, in eval_val_sliding +[rank3]: tokens_per_byte = val_token_count.item() / val_byte_count.item() +[rank3]: ~~~~~~~~~~~~~~~~~~~~~~~^~~~~~~~~~~~~~~~~~~~~~~ +[rank3]: ZeroDivisionError: float division by zero +[rank7]: Traceback (most recent call last): +[rank7]: File "/workspace/parameter-golf/train_gpt_refactored.py", line 1670, in +[rank7]: main() +[rank7]: File "/workspace/parameter-golf/train_gpt_refactored.py", line 1636, in main +[rank7]: hm_val_loss, hm_val_bpb = eval_val_sliding( +[rank7]: ^^^^^^^^^^^^^^^^^ +[rank7]: File "/workspace/parameter-golf/train_gpt_refactored.py", line 476, in eval_val_sliding +[rank7]: tokens_per_byte = val_token_count.item() / val_byte_count.item() +[rank7]: ~~~~~~~~~~~~~~~~~~~~~~~^~~~~~~~~~~~~~~~~~~~~~~ +[rank7]: ZeroDivisionError: float division by zero +[rank6]: Traceback (most recent call last): +[rank6]: File "/workspace/parameter-golf/train_gpt_refactored.py", line 1670, in +[rank6]: main() +[rank6]: File "/workspace/parameter-golf/train_gpt_refactored.py", line 1636, in main +[rank6]: hm_val_loss, hm_val_bpb = eval_val_sliding( +[rank6]: ^^^^^^^^^^^^^^^^^ +[rank6]: File "/workspace/parameter-golf/train_gpt_refactored.py", line 476, in eval_val_sliding +[rank6]: tokens_per_byte = val_token_count.item() / val_byte_count.item() +[rank6]: ~~~~~~~~~~~~~~~~~~~~~~~^~~~~~~~~~~~~~~~~~~~~~~ +[rank6]: ZeroDivisionError: float division by zero +[rank1]: Traceback (most recent call last): +[rank1]: File "/workspace/parameter-golf/train_gpt_refactored.py", line 1670, in +[rank1]: main() +[rank1]: File "/workspace/parameter-golf/train_gpt_refactored.py", line 1636, in main +[rank1]: hm_val_loss, hm_val_bpb = eval_val_sliding( +[rank1]: ^^^^^^^^^^^^^^^^^ +[rank1]: File "/workspace/parameter-golf/train_gpt_refactored.py", line 476, in eval_val_sliding +[rank1]: tokens_per_byte = val_token_count.item() / val_byte_count.item() +[rank1]: ~~~~~~~~~~~~~~~~~~~~~~~^~~~~~~~~~~~~~~~~~~~~~~ +[rank1]: ZeroDivisionError: float division by zero +[rank2]: Traceback (most recent call last): +[rank2]: File "/workspace/parameter-golf/train_gpt_refactored.py", line 1670, in +[rank2]: main() +[rank2]: File "/workspace/parameter-golf/train_gpt_refactored.py", line 1636, in main +[rank2]: hm_val_loss, hm_val_bpb = eval_val_sliding( +[rank2]: ^^^^^^^^^^^^^^^^^ +[rank2]: File "/workspace/parameter-golf/train_gpt_refactored.py", line 476, in eval_val_sliding +[rank2]: tokens_per_byte = val_token_count.item() / val_byte_count.item() +[rank2]: ~~~~~~~~~~~~~~~~~~~~~~~^~~~~~~~~~~~~~~~~~~~~~~ +[rank2]: ZeroDivisionError: float division by zero +[rank5]: Traceback (most recent call last): +[rank5]: File "/workspace/parameter-golf/train_gpt_refactored.py", line 1670, in +[rank5]: main() +[rank5]: File "/workspace/parameter-golf/train_gpt_refactored.py", line 1636, in main +[rank5]: hm_val_loss, hm_val_bpb = eval_val_sliding( +[rank5]: ^^^^^^^^^^^^^^^^^ +[rank5]: File "/workspace/parameter-golf/train_gpt_refactored.py", line 476, in eval_val_sliding +[rank5]: tokens_per_byte = val_token_count.item() / val_byte_count.item() +[rank5]: ~~~~~~~~~~~~~~~~~~~~~~~^~~~~~~~~~~~~~~~~~~~~~~ +[rank5]: ZeroDivisionError: float division by zero +final_hedge_mixer val_loss:1.9340 val_bpb:1.1454 eval_time:163554ms +final_hedge_mixer_exact val_loss:1.93403856 val_bpb:1.14544609 +W0407 18:32:03.362000 41945 torch/distributed/elastic/multiprocessing/api.py:897] Sending process 42011 closing signal SIGTERM +W0407 18:32:03.363000 41945 torch/distributed/elastic/multiprocessing/api.py:897] Sending process 42012 closing signal SIGTERM +W0407 18:32:03.364000 41945 torch/distributed/elastic/multiprocessing/api.py:897] Sending process 42013 closing signal SIGTERM +W0407 18:32:03.365000 41945 torch/distributed/elastic/multiprocessing/api.py:897] Sending process 42015 closing signal SIGTERM +W0407 18:32:03.366000 41945 torch/distributed/elastic/multiprocessing/api.py:897] Sending process 42016 closing signal SIGTERM +W0407 18:32:03.367000 41945 torch/distributed/elastic/multiprocessing/api.py:897] Sending process 42017 closing signal SIGTERM +E0407 18:32:04.323000 41945 torch/distributed/elastic/multiprocessing/api.py:869] failed (exitcode: 1) local_rank: 4 (pid: 42014) of binary: /usr/bin/python +Traceback (most recent call last): + File "/usr/local/bin/torchrun", line 8, in + sys.exit(main()) + ^^^^^^ + File "/usr/local/lib/python3.11/dist-packages/torch/distributed/elastic/multiprocessing/errors/__init__.py", line 355, in wrapper + return f(*args, **kwargs) + ^^^^^^^^^^^^^^^^^^ + File "/usr/local/lib/python3.11/dist-packages/torch/distributed/run.py", line 919, in main + run(args) + File "/usr/local/lib/python3.11/dist-packages/torch/distributed/run.py", line 910, in run + elastic_launch( + File "/usr/local/lib/python3.11/dist-packages/torch/distributed/launcher/api.py", line 138, in __call__ + return launch_agent(self._config, self._entrypoint, list(args)) + ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + File "/usr/local/lib/python3.11/dist-packages/torch/distributed/launcher/api.py", line 269, in launch_agent + raise ChildFailedError( +torch.distributed.elastic.multiprocessing.errors.ChildFailedError: +============================================================ +train_gpt_refactored.py FAILED +------------------------------------------------------------ +Failures: + +------------------------------------------------------------ +Root Cause (first observed failure): +[0]: + time : 2026-04-07_18:32:03 + host : 2e165e16ceb6 + rank : 4 (local_rank: 4) + exitcode : 1 (pid: 42014) + error_file: + traceback : To enable traceback see: https://pytorch.org/docs/stable/elastic/errors.html +============================================================ diff --git a/records/track_non_record_16mb/2026-04-08_DepthRecurrence_Int7MixedQuant_HedgeMixer/train_seed7.log b/records/track_non_record_16mb/2026-04-08_DepthRecurrence_Int7MixedQuant_HedgeMixer/train_seed7.log new file mode 100644 index 0000000000..edc41f1427 --- /dev/null +++ b/records/track_non_record_16mb/2026-04-08_DepthRecurrence_Int7MixedQuant_HedgeMixer/train_seed7.log @@ -0,0 +1,299 @@ +W0407 18:33:12.749000 45793 torch/distributed/run.py:793] +W0407 18:33:12.749000 45793 torch/distributed/run.py:793] ***************************************** +W0407 18:33:12.749000 45793 torch/distributed/run.py:793] Setting OMP_NUM_THREADS environment variable for each process to be 1 in default, to avoid your system being overloaded, please further tune the variable for optimal performance in your application as needed. +W0407 18:33:12.749000 45793 torch/distributed/run.py:793] ***************************************** +Traceback (most recent call last): + File "/usr/local/bin/torchrun", line 8, in + sys.exit(main()) + ^^^^^^ + File "/usr/local/lib/python3.11/dist-packages/torch/distributed/elastic/multiprocessing/errors/__init__.py", line 355, in wrapper + return f(*args, **kwargs) + ^^^^^^^^^^^^^^^^^^ + File "/usr/local/lib/python3.11/dist-packages/torch/distributed/run.py", line 919, in main + run(args) + File "/usr/local/lib/python3.11/dist-packages/torch/distributed/run.py", line 910, in run + elastic_launch( + File "/usr/local/lib/python3.11/dist-packages/torch/distributed/launcher/api.py", line 138, in __call__ + returnwarmup_step:1/20 +warmup_step:2/20 +warmup_step:3/20 +warmup_step:4/20 +warmup_step:5/20 +warmup_step:6/20 +warmup_step:7/20 +warmup_step:8/20 +warmup_step:9/20 +warmup_step:10/20 +warmup_step:11/20 +warmup_step:12/20 +warmup_step:13/20 +warmup_step:14/20 +warmup_step:15/20 +warmup_step:16/20 +warmup_step:17/20 +warmup_step:18/20 +warmup_step:19/20 +warmup_step:20/20 +prog_depth: schedule=[(0.3, 2), (0.5, 3), (1.0, 4)] starting_repeats=2 +step:0/20000 val_loss:6.9322 val_bpb:4.1056 train_time:0ms step_avg:0.01ms +step:1/20000 train_loss:6.9932 train_time:5848ms step_avg:5847.68ms +step:2/20000 train_loss:9.9975 train_time:5867ms step_avg:2933.44ms +step:3/20000 train_loss:9.6852 train_time:5954ms step_avg:1984.61ms +step:4/20000 train_loss:8.9228 train_time:6042ms step_avg:1510.43ms +step:5/20000 train_loss:7.7753 train_time:6139ms step_avg:1227.82ms +step:6/20000 train_loss:6.9328 train_time:6224ms step_avg:1037.37ms +step:7/20000 train_loss:6.0226 train_time:6315ms step_avg:902.09ms +step:8/20000 train_loss:5.7609 train_time:6406ms step_avg:800.69ms +step:9/20000 train_loss:5.6007 train_time:6496ms step_avg:721.77ms +step:10/20000 train_loss:5.4417 train_time:6587ms step_avg:658.73ms +step:50/20000 val_loss:4.0227 val_bpb:2.3825 train_time:10231ms step_avg:204.62ms +step:100/20000 val_loss:3.1260 val_bpb:1.8514 train_time:14709ms step_avg:147.09ms +step:150/20000 val_loss:2.8145 val_bpb:1.6669 train_time:19189ms step_avg:127.93ms +step:200/20000 train_loss:2.7073 train_time:23676ms step_avg:118.38ms +step:200/20000 val_loss:2.6976 val_bpb:1.5977 train_time:23734ms step_avg:118.67ms +step:250/20000 val_loss:2.6070 val_bpb:1.5440 train_time:28231ms step_avg:112.92ms +step:300/20000 val_loss:2.5557 val_bpb:1.5136 train_time:32731ms step_avg:109.10ms +step:350/20000 val_loss:2.5119 val_bpb:1.4877 train_time:37231ms step_avg:106.37ms +step:400/20000 train_loss:2.2702 train_time:41757ms step_avg:104.39ms +step:400/20000 val_loss:2.4847 val_bpb:1.4716 train_time:41815ms step_avg:104.54ms +step:450/20000 val_loss:2.4538 val_bpb:1.4533 train_time:46318ms step_avg:102.93ms +step:500/20000 val_loss:2.4394 val_bpb:1.4448 train_time:50824ms step_avg:101.65ms +step:550/20000 val_loss:2.4238 val_bpb:1.4355 train_time:55332ms step_avg:100.60ms +step:600/20000 train_loss:2.5007 train_time:59848ms step_avg:99.75ms +step:600/20000 val_loss:2.4044 val_bpb:1.4240 train_time:59904ms step_avg:99.84ms +step:650/20000 val_loss:2.3866 val_bpb:1.4135 train_time:64419ms step_avg:99.11ms +step:700/20000 val_loss:2.3685 val_bpb:1.4028 train_time:68929ms step_avg:98.47ms +step:750/20000 val_loss:2.3552 val_bpb:1.3949 train_time:73437ms step_avg:97.92ms +step:800/20000 train_loss:2.2588 train_time:77960ms step_avg:97.45ms +step:800/20000 val_loss:2.3469 val_bpb:1.3900 train_time:78019ms step_avg:97.52ms +step:850/20000 val_loss:2.3349 val_bpb:1.3829 train_time:82521ms step_avg:97.08ms +step:900/20000 val_loss:2.3276 val_bpb:1.3785 train_time:87028ms step_avg:96.70ms +step:950/20000 val_loss:2.3171 val_bpb:1.3723 train_time:91533ms step_avg:96.35ms +step:1000/20000 train_loss:2.3494 train_time:96048ms step_avg:96.05ms +step:1000/20000 val_loss:2.3076 val_bpb:1.3667 train_time:96105ms step_avg:96.11ms +step:1050/20000 val_loss:2.3004 val_bpb:1.3624 train_time:100608ms step_avg:95.82ms +step:1100/20000 val_loss:2.2960 val_bpb:1.3598 train_time:105109ms step_avg:95.55ms +step:1150/20000 val_loss:2.2949 val_bpb:1.3591 train_time:109671ms step_avg:95.37ms +step:1200/20000 train_loss:2.3626 train_time:114119ms step_avg:95.10ms +step:1200/20000 val_loss:2.2810 val_bpb:1.3509 train_time:114177ms step_avg:95.15ms +step:1250/20000 val_loss:2.2791 val_bpb:1.3498 train_time:118683ms step_avg:94.95ms +step:1300/20000 val_loss:2.2693 val_bpb:1.3440 train_time:123182ms step_avg:94.76ms +step:1350/20000 val_loss:2.2676 val_bpb:1.3430 train_time:127741ms step_avg:94.62ms +step:1400/20000 train_loss:2.4113 train_time:132188ms step_avg:94.42ms +step:1400/20000 val_loss:2.2617 val_bpb:1.3395 train_time:132246ms step_avg:94.46ms +step:1450/20000 val_loss:2.2604 val_bpb:1.3387 train_time:136747ms step_avg:94.31ms +step:1500/20000 val_loss:2.2549 val_bpb:1.3355 train_time:141247ms step_avg:94.16ms +step:1550/20000 val_loss:2.2582 val_bpb:1.3374 train_time:145816ms step_avg:94.07ms +step:1600/20000 train_loss:2.0807 train_time:150260ms step_avg:93.91ms +step:1600/20000 val_loss:2.2508 val_bpb:1.3331 train_time:150319ms step_avg:93.95ms +step:1650/20000 val_loss:2.2464 val_bpb:1.3305 train_time:154815ms step_avg:93.83ms +step:1700/20000 val_loss:2.2417 val_bpb:1.3276 train_time:159309ms step_avg:93.71ms +step:1750/20000 val_loss:2.2397 val_bpb:1.3265 train_time:163869ms step_avg:93.64ms +step:1800/20000 train_loss:2.1850 train_time:168309ms step_avg:93.51ms +step:1800/20000 val_loss:2.2387 val_bpb:1.3259 train_time:168368ms step_avg:93.54ms +step:1850/20000 val_loss:2.2328 val_bpb:1.3224 train_time:172867ms step_avg:93.44ms +step:1900/20000 val_loss:2.2355 val_bpb:1.3240 train_time:177367ms step_avg:93.35ms +prog_depth: switched to 3 repeats at step:1929 frac:0.30 +step:1950/20000 val_loss:2.2771 val_bpb:1.3486 train_time:191990ms step_avg:98.46ms +step:2000/20000 train_loss:2.2512 train_time:198522ms step_avg:99.26ms +step:2000/20000 val_loss:2.2340 val_bpb:1.3231 train_time:198612ms step_avg:99.31ms +step:2050/20000 val_loss:2.2220 val_bpb:1.3160 train_time:205246ms step_avg:100.12ms +step:2100/20000 val_loss:2.2171 val_bpb:1.3131 train_time:211958ms step_avg:100.93ms +step:2150/20000 val_loss:2.2095 val_bpb:1.3086 train_time:218600ms step_avg:101.67ms +step:2200/20000 train_loss:2.0481 train_time:225153ms step_avg:102.34ms +step:2200/20000 val_loss:2.2044 val_bpb:1.3056 train_time:225241ms step_avg:102.38ms +step:2250/20000 val_loss:2.2037 val_bpb:1.3051 train_time:231889ms step_avg:103.06ms +step:2300/20000 val_loss:2.1969 val_bpb:1.3012 train_time:238596ms step_avg:103.74ms +step:2350/20000 val_loss:2.1980 val_bpb:1.3018 train_time:245247ms step_avg:104.36ms +step:2400/20000 train_loss:2.1752 train_time:251809ms step_avg:104.92ms +step:2400/20000 val_loss:2.1926 val_bpb:1.2986 train_time:251900ms step_avg:104.96ms +step:2450/20000 val_loss:2.1913 val_bpb:1.2978 train_time:258549ms step_avg:105.53ms +step:2500/20000 val_loss:2.1854 val_bpb:1.2943 train_time:265268ms step_avg:106.11ms +step:2550/20000 val_loss:2.1855 val_bpb:1.2944 train_time:271914ms step_avg:106.63ms +step:2600/20000 train_loss:2.3890 train_time:278470ms step_avg:107.10ms +step:2600/20000 val_loss:2.1905 val_bpb:1.2974 train_time:278561ms step_avg:107.14ms +step:2650/20000 val_loss:2.1815 val_bpb:1.2920 train_time:285206ms step_avg:107.63ms +step:2700/20000 val_loss:2.1774 val_bpb:1.2896 train_time:291918ms step_avg:108.12ms +step:2750/20000 val_loss:2.1757 val_bpb:1.2886 train_time:298565ms step_avg:108.57ms +prog_depth: switched to 4 repeats at step:2761 frac:0.50 +step:2800/20000 train_loss:2.2098 train_time:318158ms step_avg:113.63ms +step:2800/20000 val_loss:2.1724 val_bpb:1.2866 train_time:318246ms step_avg:113.66ms +step:2850/20000 val_loss:2.1574 val_bpb:1.2777 train_time:327002ms step_avg:114.74ms +step:2900/20000 val_loss:2.1488 val_bpb:1.2726 train_time:335847ms step_avg:115.81ms +step:2950/20000 val_loss:2.1426 val_bpb:1.2689 train_time:344637ms step_avg:116.83ms +step:3000/20000 train_loss:2.1666 train_time:353337ms step_avg:117.78ms +step:3000/20000 val_loss:2.1369 val_bpb:1.2656 train_time:353425ms step_avg:117.81ms +step:3050/20000 val_loss:2.1353 val_bpb:1.2647 train_time:362230ms step_avg:118.76ms +step:3100/20000 val_loss:2.1275 val_bpb:1.2600 train_time:371097ms step_avg:119.71ms +step:3150/20000 val_loss:2.1248 val_bpb:1.2584 train_time:379901ms step_avg:120.60ms +swa:start step:3180 +step:3200/20000 train_loss:2.1195 train_time:388667ms step_avg:121.46ms +step:3200/20000 val_loss:2.1180 val_bpb:1.2544 train_time:388755ms step_avg:121.49ms +step:3250/20000 val_loss:2.1147 val_bpb:1.2525 train_time:397698ms step_avg:122.37ms +step:3300/20000 val_loss:2.1107 val_bpb:1.2501 train_time:406572ms step_avg:123.20ms +step:3350/20000 val_loss:2.1075 val_bpb:1.2482 train_time:415399ms step_avg:124.00ms +step:3400/20000 train_loss:2.0720 train_time:424194ms step_avg:124.76ms +step:3400/20000 val_loss:2.1023 val_bpb:1.2451 train_time:424281ms step_avg:124.79ms +step:3450/20000 val_loss:2.0992 val_bpb:1.2432 train_time:433216ms step_avg:125.57ms +step:3500/20000 val_loss:2.0926 val_bpb:1.2394 train_time:442059ms step_avg:126.30ms +step:3550/20000 val_loss:2.0886 val_bpb:1.2370 train_time:450919ms step_avg:127.02ms +step:3600/20000 train_loss:2.0163 train_time:459663ms step_avg:127.68ms +step:3600/20000 val_loss:2.0839 val_bpb:1.2342 train_time:459777ms step_avg:127.72ms +step:3650/20000 val_loss:2.0803 val_bpb:1.2320 train_time:468670ms step_avg:128.40ms +step:3700/20000 val_loss:2.0769 val_bpb:1.2301 train_time:477528ms step_avg:129.06ms +step:3750/20000 val_loss:2.0732 val_bpb:1.2279 train_time:486381ms step_avg:129.70ms +step:3800/20000 train_loss:2.1148 train_time:495143ms step_avg:130.30ms +step:3800/20000 val_loss:2.0685 val_bpb:1.2251 train_time:495231ms step_avg:130.32ms +step:3850/20000 val_loss:2.0649 val_bpb:1.2230 train_time:504163ms step_avg:130.95ms +step:3900/20000 val_loss:2.0616 val_bpb:1.2210 train_time:513027ms step_avg:131.55ms +step:3950/20000 val_loss:2.0567 val_bpb:1.2181 train_time:521849ms step_avg:132.11ms +step:4000/20000 train_loss:2.0534 train_time:530607ms step_avg:132.65ms +step:4000/20000 val_loss:2.0533 val_bpb:1.2161 train_time:530695ms step_avg:132.67ms +step:4050/20000 val_loss:2.0497 val_bpb:1.2140 train_time:539631ms step_avg:133.24ms +step:4100/20000 val_loss:2.0462 val_bpb:1.2119 train_time:548469ms step_avg:133.77ms +step:4150/20000 val_loss:2.0431 val_bpb:1.2101 train_time:557332ms step_avg:134.30ms +step:4200/20000 train_loss:2.0542 train_time:566146ms step_avg:134.80ms +step:4200/20000 val_loss:2.0400 val_bpb:1.2082 train_time:566263ms step_avg:134.82ms +step:4250/20000 val_loss:2.0374 val_bpb:1.2066 train_time:575089ms step_avg:135.32ms +step:4300/20000 val_loss:2.0349 val_bpb:1.2052 train_time:583962ms step_avg:135.81ms +step:4350/20000 val_loss:2.0330 val_bpb:1.2041 train_time:592832ms step_avg:136.28ms +step:4391/20000 val_loss:2.0322 val_bpb:1.2036 train_time:600153ms step_avg:136.68ms +stopping_early: wallclock_cap train_time:600153ms step:4391/20000 +peak memory allocated: 30342 MiB reserved: 31078 MiB +swa: averaging 42 checkpoints +Serialized model: 89176650 bytes +Code size: 76171 bytes +Total submission size: 89252821 bytes +Serialized model int8+zstd22: 15285123 bytes (payload:23776800 raw_torch:23793515 payload_ratio:3.75x) +Total submission size int8+zstd22: 15361294 bytes +/workspace/parameter-golf/train_gpt_refactored.py:1584: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature. + quant_state = torch.load(io.BytesIO(dctx.decompress(quant_blob_disk)), map_location="cpu") +/workspace/parameter-golf/train_gpt_refactored.py:1584: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature. + quant_state = torch.load(io.BytesIO(dctx.decompress(quant_blob_disk)), map_location="cpu") +/workspace/parameter-golf/train_gpt_refactored.py:1584: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature. + quant_state = torch.load(io.BytesIO(dctx.decompress(quant_blob_disk)), map_location="cpu") +/workspace/parameter-golf/train_gpt_refactored.py:1584: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature. + quant_state = torch.load(io.BytesIO(dctx.decompress(quant_blob_disk)), map_location="cpu") +/workspace/parameter-golf/train_gpt_refactored.py:1584: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature. + quant_state = torch.load(io.BytesIO(dctx.decompress(quant_blob_disk)), map_location="cpu") +/workspace/parameter-golf/train_gpt_refactored.py:1584: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature. + quant_state = torch.load(io.BytesIO(dctx.decompress(quant_blob_disk)), map_location="cpu") +/workspace/parameter-golf/train_gpt_refactored.py:1584: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature. + quant_state = torch.load(io.BytesIO(dctx.decompress(quant_blob_disk)), map_location="cpu") +/workspace/parameter-golf/train_gpt_refactored.py:1584: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature. + quant_state = torch.load(io.BytesIO(dctx.decompress(quant_blob_disk)), map_location="cpu") +final_roundtrip val_loss:2.0536 val_bpb:1.2163 eval_time:15264ms +final_roundtrip_exact val_loss:2.05362252 val_bpb:1.21627093 +final_sliding_window val_loss:1.9972 val_bpb:1.1828 window:1024 stride:256 eval_time:78463ms +final_sliding_window_exact val_loss:1.99719348 val_bpb:1.18284997 +[rank5]: Traceback (most recent call last): +[rank5]: File "/workspace/parameter-golf/train_gpt_refactored.py", line 1670, in +[rank5]: main() +[rank5]: File "/workspace/parameter-golf/train_gpt_refactored.py", line 1636, in main +[rank5]: hm_val_loss, hm_val_bpb = eval_val_sliding( +[rank5]: ^^^^^^^^^^^^^^^^^ +[rank5]: File "/workspace/parameter-golf/train_gpt_refactored.py", line 476, in eval_val_sliding +[rank5]: tokens_per_byte = val_token_count.item() / val_byte_count.item() +[rank5]: ~~~~~~~~~~~~~~~~~~~~~~~^~~~~~~~~~~~~~~~~~~~~~~ +[rank5]: ZeroDivisionError: float division by zero +final_hedge_mixer val_loss:1.8899 val_bpb:1.1193 eval_time:162912ms +final_hedge_mixer_exact val_loss:1.88989896 val_bpb:1.11930414 +[rank1]: Traceback (most recent call last): +[rank1]: File "/workspace/parameter-golf/train_gpt_refactored.py", line 1670, in +[rank1]: main() +[rank1]: File "/workspace/parameter-golf/train_gpt_refactored.py", line 1636, in main +[rank1]: hm_val_loss, hm_val_bpb = eval_val_sliding( +[rank1]: ^^^^^^^^^^^^^^^^^ +[rank1]: File "/workspace/parameter-golf/train_gpt_refactored.py", line 476, in eval_val_sliding +[rank1]: tokens_per_byte = val_token_count.item() / val_byte_count.item() +[rank1]: ~~~~~~~~~~~~~~~~~~~~~~~^~~~~~~~~~~~~~~~~~~~~~~ +[rank1]: ZeroDivisionError: float division by zero +[rank3]: Traceback (most recent call last): +[rank3]: File "/workspace/parameter-golf/train_gpt_refactored.py", line 1670, in +[rank3]: main() +[rank3]: File "/workspace/parameter-golf/train_gpt_refactored.py", line 1636, in main +[rank3]: hm_val_loss, hm_val_bpb = eval_val_sliding( +[rank3]: ^^^^^^^^^^^^^^^^^ +[rank3]: File "/workspace/parameter-golf/train_gpt_refactored.py", line 476, in eval_val_sliding +[rank3]: tokens_per_byte = val_token_count.item() / val_byte_count.item() +[rank3]: ~~~~~~~~~~~~~~~~~~~~~~~^~~~~~~~~~~~~~~~~~~~~~~ +[rank3]: ZeroDivisionError: float division by zero +[rank6]: Traceback (most recent call last): +[rank6]: File "/workspace/parameter-golf/train_gpt_refactored.py", line 1670, in +[rank6]: main() +[rank6]: File "/workspace/parameter-golf/train_gpt_refactored.py", line 1636, in main +[rank6]: hm_val_loss, hm_val_bpb = eval_val_sliding( +[rank6]: ^^^^^^^^^^^^^^^^^ +[rank6]: File "/workspace/parameter-golf/train_gpt_refactored.py", line 476, in eval_val_sliding +[rank6]: tokens_per_byte = val_token_count.item() / val_byte_count.item() +[rank6]: ~~~~~~~~~~~~~~~~~~~~~~~^~~~~~~~~~~~~~~~~~~~~~~ +[rank6]: ZeroDivisionError: float division by zero +[rank7]: Traceback (most recent call last): +[rank7]: File "/workspace/parameter-golf/train_gpt_refactored.py", line 1670, in +[rank7]: main() +[rank7]: File "/workspace/parameter-golf/train_gpt_refactored.py", line 1636, in main +[rank7]: hm_val_loss, hm_val_bpb = eval_val_sliding( +[rank7]: ^^^^^^^^^^^^^^^^^ +[rank7]: File "/workspace/parameter-golf/train_gpt_refactored.py", line 476, in eval_val_sliding +[rank7]: tokens_per_byte = val_token_count.item() / val_byte_count.item() +[rank7]: ~~~~~~~~~~~~~~~~~~~~~~~^~~~~~~~~~~~~~~~~~~~~~~ +[rank7]: ZeroDivisionError: float division by zero +[rank4]: Traceback (most recent call last): +[rank4]: File "/workspace/parameter-golf/train_gpt_refactored.py", line 1670, in +[rank4]: main() +[rank4]: File "/workspace/parameter-golf/train_gpt_refactored.py", line 1636, in main +[rank4]: hm_val_loss, hm_val_bpb = eval_val_sliding( +[rank4]: ^^^^^^^^^^^^^^^^^ +[rank4]: File "/workspace/parameter-golf/train_gpt_refactored.py", line 476, in eval_val_sliding +[rank4]: tokens_per_byte = val_token_count.item() / val_byte_count.item() +[rank4]: ~~~~~~~~~~~~~~~~~~~~~~~^~~~~~~~~~~~~~~~~~~~~~~ +[rank4]: ZeroDivisionError: float division by zero +[rank2]: Traceback (most recent call last): +[rank2]: File "/workspace/parameter-golf/train_gpt_refactored.py", line 1670, in +[rank2]: main() +[rank2]: File "/workspace/parameter-golf/train_gpt_refactored.py", line 1636, in main +[rank2]: hm_val_loss, hm_val_bpb = eval_val_sliding( +[rank2]: ^^^^^^^^^^^^^^^^^ +[rank2]: File "/workspace/parameter-golf/train_gpt_refactored.py", line 476, in eval_val_sliding +[rank2]: tokens_per_byte = val_token_count.item() / val_byte_count.item() +[rank2]: ~~~~~~~~~~~~~~~~~~~~~~~^~~~~~~~~~~~~~~~~~~~~~~ +[rank2]: ZeroDivisionError: float division by zero +W0407 18:54:26.949000 44428 torch/distributed/elastic/multiprocessing/api.py:897] Sending process 44494 closing signal SIGTERM +W0407 18:54:26.951000 44428 torch/distributed/elastic/multiprocessing/api.py:897] Sending process 44496 closing signal SIGTERM +W0407 18:54:26.952000 44428 torch/distributed/elastic/multiprocessing/api.py:897] Sending process 44497 closing signal SIGTERM +W0407 18:54:26.952000 44428 torch/distributed/elastic/multiprocessing/api.py:897] Sending process 44498 closing signal SIGTERM +W0407 18:54:26.953000 44428 torch/distributed/elastic/multiprocessing/api.py:897] Sending process 44499 closing signal SIGTERM +W0407 18:54:26.954000 44428 torch/distributed/elastic/multiprocessing/api.py:897] Sending process 44500 closing signal SIGTERM +E0407 18:54:27.741000 44428 torch/distributed/elastic/multiprocessing/api.py:869] failed (exitcode: 1) local_rank: 2 (pid: 44495) of binary: /usr/bin/python +Traceback (most recent call last): + File "/usr/local/bin/torchrun", line 8, in + sys.exit(main()) + ^^^^^^ + File "/usr/local/lib/python3.11/dist-packages/torch/distributed/elastic/multiprocessing/errors/__init__.py", line 355, in wrapper + return f(*args, **kwargs) + ^^^^^^^^^^^^^^^^^^ + File "/usr/local/lib/python3.11/dist-packages/torch/distributed/run.py", line 919, in main + run(args) + File "/usr/local/lib/python3.11/dist-packages/torch/distributed/run.py", line 910, in run + elastic_launch( + File "/usr/local/lib/python3.11/dist-packages/torch/distributed/launcher/api.py", line 138, in __call__ + return launch_agent(self._config, self._entrypoint, list(args)) + ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + File "/usr/local/lib/python3.11/dist-packages/torch/distributed/launcher/api.py", line 269, in launch_agent + raise ChildFailedError( +torch.distributed.elastic.multiprocessing.errors.ChildFailedError: +============================================================ +train_gpt_refactored.py FAILED +------------------------------------------------------------ +Failures: + +------------------------------------------------------------ +Root Cause (first observed failure): +[0]: + time : 2026-04-07_18:54:26 + host : 2e165e16ceb6 + rank : 2 (local_rank: 2) + exitcode : 1 (pid: 44495) + error_file: + traceback : To enable traceback see: https://pytorch.org/docs/stable/elastic/errors.html +============================================================